Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
ad84b996
Unverified
Commit
ad84b996
authored
Feb 25, 2024
by
Yeuoly
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: model invoke api
parent
282922f3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
279 additions
and
5 deletions
+279
-5
model_runtime.py
api/controllers/inner_api/model_runtime.py
+46
-4
model_runner.py
api/core/app_runner/model_runner.py
+132
-0
completion_service.py
api/services/completion_service.py
+101
-1
No files found.
api/controllers/inner_api/model_runtime.py
View file @
ad84b996
from
flask_restful
import
Resource
import
json
from
flask_restful
import
Resource
,
reqparse
from
flask
import
Response
from
flask.helpers
import
stream_with_context
from
controllers.console.setup
import
setup_required
from
controllers.inner_api
import
api
from
controllers.inner_api.wraps
import
inner_api_only
from
services.completion_service
import
CompletionService
from
typing
import
Generator
,
Union
class
EnterpriseModelInvokeLLMApi
(
Resource
):
"""Model invoke API for enterprise edition"""
...
...
@@ -11,6 +18,41 @@ class EnterpriseModelInvokeLLMApi(Resource):
@
setup_required
@
inner_api_only
def
post
(
self
):
pass
api
.
add_resource
(
EnterpriseModelInvokeLLMApi
,
'/model/invoke/llm'
)
\ No newline at end of file
request_parser
=
reqparse
.
RequestParser
()
request_parser
.
add_argument
(
'tenant_id'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
request_parser
.
add_argument
(
'provider'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
request_parser
.
add_argument
(
'model'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
request_parser
.
add_argument
(
'completion_params'
,
type
=
dict
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
request_parser
.
add_argument
(
'prompt_messages'
,
type
=
list
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
request_parser
.
add_argument
(
'tools'
,
type
=
list
,
required
=
False
,
nullable
=
True
,
location
=
'json'
)
request_parser
.
add_argument
(
'stop'
,
type
=
list
,
required
=
False
,
nullable
=
True
,
location
=
'json'
)
request_parser
.
add_argument
(
'stream'
,
type
=
bool
,
required
=
False
,
nullable
=
True
,
location
=
'json'
)
request_parser
.
add_argument
(
'user'
,
type
=
str
,
required
=
False
,
nullable
=
True
,
location
=
'json'
)
args
=
request_parser
.
parse_args
()
response
=
CompletionService
.
invoke_model
(
tenant_id
=
args
[
'tenant_id'
],
provider
=
args
[
'provider'
],
model
=
args
[
'model'
],
completion_params
=
args
[
'completion_params'
],
prompt_messages
=
args
[
'prompt_messages'
],
tools
=
args
[
'tools'
],
stop
=
args
[
'stop'
],
stream
=
args
[
'stream'
],
user
=
args
[
'user'
],
)
return
compact_response
(
response
)
def
compact_response
(
response
:
Union
[
dict
,
Generator
])
->
Response
:
if
isinstance
(
response
,
dict
):
return
Response
(
response
=
json
.
dumps
(
response
),
status
=
200
,
mimetype
=
'application/json'
)
else
:
def
generate
()
->
Generator
:
yield
from
response
return
Response
(
stream_with_context
(
generate
()),
status
=
200
,
mimetype
=
'text/event-stream'
)
api
.
add_resource
(
EnterpriseModelInvokeLLMApi
,
'/model/invoke/llm'
)
api/core/app_runner/model_runner.py
0 → 100644
View file @
ad84b996
from
core.provider_manager
import
ProviderManager
,
ProviderModelBundle
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMUsage
from
core.model_runtime.entities.message_entities
import
PromptMessage
,
PromptMessageTool
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
events.inner_event
import
model_was_invoked
from
typing
import
Generator
,
Union
,
cast
,
Optional
class
ModelRunner
:
"""
Model runner
"""
@
staticmethod
def
run
(
provider_model_bundle
:
ProviderModelBundle
,
model
:
str
,
prompt_messages
:
list
[
PromptMessage
],
model_parameters
:
Optional
[
dict
]
=
None
,
tools
:
Optional
[
list
[
PromptMessageTool
]]
=
None
,
stop
:
Optional
[
list
[
str
]]
=
None
,
stream
:
bool
=
True
,
user
:
Optional
[
str
]
=
None
,
)
->
Union
[
Generator
,
dict
]:
"""
Run model
"""
llm_model
=
cast
(
LargeLanguageModel
,
provider_model_bundle
.
model_type_instance
)
credentials
=
provider_model_bundle
.
configuration
.
get_current_credentials
(
model_type
=
ModelType
.
LLM
,
model
=
model
,
)
if
not
credentials
:
raise
ValueError
(
'No credentials found for model'
)
response
=
llm_model
.
invoke
(
model
=
model
,
credentials
=
credentials
,
prompt_messages
=
prompt_messages
,
model_parameters
=
model_parameters
,
tools
=
tools
,
stop
=
stop
,
stream
=
stream
,
user
=
user
,
)
if
stream
:
return
ModelRunner
.
handle_streaming_response
(
tenant_id
=
provider_model_bundle
.
configuration
.
tenant_id
,
provider
=
provider_model_bundle
.
configuration
.
provider
,
model
=
model
,
model_type
=
ModelType
.
LLM
.
value
,
response
=
response
,
)
return
ModelRunner
.
handle_blocking_response
(
tenant_id
=
provider_model_bundle
.
configuration
.
tenant_id
,
provider
=
provider_model_bundle
.
configuration
.
provider
,
model
=
model
,
model_type
=
ModelType
.
LLM
.
value
,
response
=
response
,
)
def
handle_streaming_response
(
tenant_id
:
str
,
provider
:
str
,
model
:
str
,
model_type
:
str
,
response
:
Generator
[
LLMResultChunk
,
None
,
None
],
)
->
Generator
[
dict
]:
"""
Handle streaming response
"""
usage
=
LLMUsage
.
empty_usage
()
for
chunk
in
response
:
if
chunk
.
delta
.
usage
:
usage
.
completion_price
+=
chunk
.
delta
.
usage
.
completion_price
usage
.
prompt_price
+=
chunk
.
delta
.
usage
.
prompt_price
usage
.
prompt_price_unit
=
chunk
.
delta
.
usage
.
prompt_price_unit
usage
.
prompt_unit_price
=
chunk
.
delta
.
usage
.
prompt_unit_price
usage
.
completion_price_unit
=
chunk
.
delta
.
usage
.
completion_price_unit
usage
.
completion_unit_price
=
chunk
.
delta
.
usage
.
completion_unit_price
usage
.
prompt_tokens
+=
chunk
.
delta
.
usage
.
prompt_tokens
usage
.
completion_tokens
+=
chunk
.
delta
.
usage
.
completion_tokens
usage
.
currency
=
chunk
.
delta
.
usage
.
currency
yield
jsonable_encoder
(
chunk
)
model_was_invoked
(
None
,
tenant_id
=
tenant_id
,
model_config
=
{
'provider'
:
provider
,
'model_type'
:
model_type
,
'model'
:
model
,
},
message_tokens
=
usage
.
prompt_tokens
,
answer_tokens
=
usage
.
completion_tokens
,
)
def
handle_blocking_response
(
tenant_id
:
str
,
provider
:
str
,
model
:
str
,
model_type
:
str
,
response
:
LLMResult
,
)
->
dict
:
"""
Handle blocking response
"""
usage
=
response
.
usage
or
LLMUsage
.
empty_usage
()
model_was_invoked
(
None
,
tenant_id
=
tenant_id
,
model_config
=
{
'provider'
:
provider
,
'model_type'
:
model_type
,
'model'
:
model
,
},
message_tokens
=
usage
.
prompt_tokens
,
answer_tokens
=
usage
.
completion_tokens
,
)
return
jsonable_encoder
(
response
)
\ No newline at end of file
api/services/completion_service.py
View file @
ad84b996
...
...
@@ -5,8 +5,22 @@ from typing import Any, Union
from
sqlalchemy
import
and_
from
core.application_manager
import
ApplicationManager
from
core.provider_manager
import
ProviderManager
from
core.errors.error
import
ModelCurrentlyNotSupportError
,
ProviderTokenNotInitError
,
QuotaExceededError
from
core.entities.application_entities
import
InvokeFrom
from
core.entities.model_entities
import
ModelStatus
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.file.message_file_parser
import
MessageFileParser
from
core.model_runtime.entities.message_entities
import
(
PromptMessage
,
UserPromptMessage
,
SystemPromptMessage
,
AssistantPromptMessage
,
ToolPromptMessage
,
PromptMessageRole
,
PromptMessageTool
)
from
core.app_runner.model_runner
import
ModelRunner
from
extensions.ext_database
import
db
from
models.model
import
Account
,
App
,
AppModelConfig
,
Conversation
,
EndUser
,
Message
from
services.app_model_config_service
import
AppModelConfigService
...
...
@@ -15,7 +29,6 @@ from services.errors.app_model_config import AppModelConfigBrokenError
from
services.errors.conversation
import
ConversationCompletedError
,
ConversationNotExistsError
from
services.errors.message
import
MessageNotExistsError
class
CompletionService
:
@
classmethod
...
...
@@ -256,3 +269,90 @@ class CompletionService:
filtered_inputs
[
variable
]
=
value
.
replace
(
'
\x00
'
,
''
)
if
value
else
None
return
filtered_inputs
@
staticmethod
def
invoke_model
(
tenant_id
:
str
,
provider
:
str
,
model
:
str
,
completion_params
:
dict
,
prompt_messages
:
list
[
dict
],
tools
:
list
[
dict
],
stop
:
list
[
str
],
stream
:
bool
,
user
:
str
)
->
Union
[
Generator
,
dict
]:
"""
invoke model
:param tenant_id: the tenant id
:param provider: the provider
:param model: the model
:param mode: the mode
:param completion_params: the completion params
:param prompt_messages: the prompt messages
:param stream: the stream
:return: the model result
"""
converted_prompt_messages
:
list
[
PromptMessage
]
=
[]
for
message
in
prompt_messages
:
role
=
message
.
get
(
'role'
)
if
not
role
:
raise
ValueError
(
'role is required'
)
if
role
==
PromptMessageRole
.
USER
.
value
:
converted_prompt_messages
.
append
(
UserPromptMessage
(
content
=
message
[
'content'
]))
elif
role
==
PromptMessageRole
.
ASSISTANT
.
value
:
converted_prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
message
[
'content'
],
tool_calls
=
message
.
get
(
'tool_calls'
,
[])
))
elif
role
==
PromptMessageRole
.
SYSTEM
.
value
:
converted_prompt_messages
.
append
(
SystemPromptMessage
(
content
=
message
[
'content'
]))
elif
role
==
PromptMessageRole
.
TOOL
.
value
:
converted_prompt_messages
.
append
(
ToolPromptMessage
(
content
=
message
[
'content'
],
tool_call_id
=
message
[
'tool_call_id'
]
))
else
:
raise
ValueError
(
f
'Unknown role: {role}'
)
# check if the model is available
bundle
=
ProviderManager
()
.
get_provider_model_bundle
(
tenant_id
=
tenant_id
,
provider
=
provider
,
model_type
=
ModelType
.
LLM
,
)
provider_model
=
bundle
.
configuration
.
get_provider_model
(
model_type
=
ModelType
.
LLM
,
model
=
model
,
)
if
not
provider_model
:
raise
ModelCurrentlyNotSupportError
(
f
"Could not find model {model} in provider {provider}."
)
if
provider_model
.
status
==
ModelStatus
.
NO_CONFIGURE
:
raise
ProviderTokenNotInitError
(
f
"Model {model} credentials is not initialized."
)
elif
provider_model
.
status
==
ModelStatus
.
NO_PERMISSION
:
raise
ModelCurrentlyNotSupportError
(
f
"Dify Hosted OpenAI {model} currently not support."
)
if
provider_model
.
status
==
ModelStatus
.
QUOTA_EXCEEDED
:
raise
QuotaExceededError
(
f
"Model provider {provider} quota exceeded."
)
converted_tools
=
[]
for
tool
in
tools
:
converted_tools
.
append
(
PromptMessageTool
(
name
=
tool
[
'name'
],
description
=
tool
[
'description'
],
parameters
=
tool
[
'parameters'
]
))
# invoke model
return
ModelRunner
.
run
(
provider_model_bundle
=
bundle
,
model
=
model
,
prompt_messages
=
converted_prompt_messages
,
model_parameters
=
completion_params
,
tools
=
converted_tools
,
stop
=
stop
,
stream
=
stream
,
user
=
user
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment