Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
6d5b3863
Unverified
Commit
6d5b3863
authored
Jan 30, 2024
by
Yeuoly
Committed by
GitHub
Jan 30, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Feat/blocking function call (#2247)
parent
1ea18a29
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
33 changed files
with
430 additions
and
95 deletions
+430
-95
assistant_app_runner.py
api/core/app_runner/assistant_app_runner.py
+11
-3
assistant_base_runner.py
api/core/features/assistant_base_runner.py
+14
-1
assistant_cot_runner.py
api/core/features/assistant_cot_runner.py
+4
-3
assistant_fc_runner.py
api/core/features/assistant_fc_runner.py
+105
-23
model_entities.py
api/core/model_runtime/entities/model_entities.py
+1
-0
_constant.py
...e/model_runtime/model_providers/azure_openai/_constant.py
+5
-0
llm.py
...ore/model_runtime/model_providers/azure_openai/llm/llm.py
+24
-4
llm.py
api/core/model_runtime/model_providers/chatglm/llm/llm.py
+5
-1
abab5.5-chat.yaml
...del_runtime/model_providers/minimax/llm/abab5.5-chat.yaml
+2
-0
abab6-chat.yaml
...model_runtime/model_providers/minimax/llm/abab6-chat.yaml
+2
-0
chat_completion.py
...el_runtime/model_providers/minimax/llm/chat_completion.py
+1
-2
chat_completion_pro.py
...untime/model_providers/minimax/llm/chat_completion_pro.py
+37
-9
llm.py
api/core/model_runtime/model_providers/minimax/llm/llm.py
+42
-1
types.py
api/core/model_runtime/model_providers/minimax/llm/types.py
+10
-0
gpt-3.5-turbo-0613.yaml
...untime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml
+1
-0
gpt-3.5-turbo-1106.yaml
...untime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml
+1
-0
gpt-3.5-turbo-16k-0613.yaml
...me/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml
+1
-0
gpt-3.5-turbo-16k.yaml
...runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml
+1
-0
gpt-3.5-turbo.yaml
...del_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml
+1
-0
gpt-4-0125-preview.yaml
...untime/model_providers/openai/llm/gpt-4-0125-preview.yaml
+1
-0
gpt-4-1106-preview.yaml
...untime/model_providers/openai/llm/gpt-4-1106-preview.yaml
+1
-0
gpt-4-32k.yaml
...e/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml
+1
-0
gpt-4-turbo-preview.yaml
...ntime/model_providers/openai/llm/gpt-4-turbo-preview.yaml
+1
-0
gpt-4.yaml
api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml
+1
-0
llm.py
api/core/model_runtime/model_providers/openai/llm/llm.py
+1
-1
llm.py
api/core/model_runtime/model_providers/xinference/llm/llm.py
+27
-4
text_embedding.py
...del_providers/xinference/text_embedding/text_embedding.py
+19
-4
xinference_helper.py
...l_runtime/model_providers/xinference/xinference_helper.py
+19
-6
glm_3_turbo.yaml
...odel_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml
+4
-0
glm_4.yaml
...core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml
+4
-0
llm.py
api/core/model_runtime/model_providers/zhipuai/llm/llm.py
+21
-0
requirements.txt
api/requirements.txt
+1
-1
xinference.py
...ests/integration_tests/model_runtime/__mock/xinference.py
+61
-32
No files found.
api/core/app_runner/assistant_app_runner.py
View file @
6d5b3863
...
@@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom
...
@@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.model_manager
import
ModelInstance
from
core.model_manager
import
ModelInstance
from
core.model_runtime.entities.llm_entities
import
LLMUsage
from
core.model_runtime.entities.llm_entities
import
LLMUsage
from
core.model_runtime.entities.model_entities
import
ModelFeature
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.moderation.base
import
ModerationException
from
core.moderation.base
import
ModerationException
from
core.tools.entities.tool_entities
import
ToolRuntimeVariablePool
from
core.tools.entities.tool_entities
import
ToolRuntimeVariablePool
...
@@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner):
...
@@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner):
memory
=
memory
,
memory
=
memory
,
)
)
# change function call strategy based on LLM model
llm_model
=
cast
(
LargeLanguageModel
,
model_instance
.
model_type_instance
)
model_schema
=
llm_model
.
get_model_schema
(
model_instance
.
model
,
model_instance
.
credentials
)
if
set
([
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
TOOL_CALL
])
.
intersection
(
model_schema
.
features
):
agent_entity
.
strategy
=
AgentEntity
.
Strategy
.
FUNCTION_CALLING
# start agent runner
# start agent runner
if
agent_entity
.
strategy
==
AgentEntity
.
Strategy
.
CHAIN_OF_THOUGHT
:
if
agent_entity
.
strategy
==
AgentEntity
.
Strategy
.
CHAIN_OF_THOUGHT
:
assistant_cot_runner
=
AssistantCotApplicationRunner
(
assistant_cot_runner
=
AssistantCotApplicationRunner
(
...
@@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner):
...
@@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner):
prompt_messages
=
prompt_message
,
prompt_messages
=
prompt_message
,
variables_pool
=
tool_variables
,
variables_pool
=
tool_variables
,
db_variables
=
tool_conversation_variables
,
db_variables
=
tool_conversation_variables
,
model_instance
=
model_instance
)
)
invoke_result
=
assistant_cot_runner
.
run
(
invoke_result
=
assistant_cot_runner
.
run
(
model_instance
=
model_instance
,
conversation
=
conversation
,
conversation
=
conversation
,
message
=
message
,
message
=
message
,
query
=
query
,
query
=
query
,
...
@@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner):
...
@@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner):
memory
=
memory
,
memory
=
memory
,
prompt_messages
=
prompt_message
,
prompt_messages
=
prompt_message
,
variables_pool
=
tool_variables
,
variables_pool
=
tool_variables
,
db_variables
=
tool_conversation_variables
db_variables
=
tool_conversation_variables
,
model_instance
=
model_instance
)
)
invoke_result
=
assistant_fc_runner
.
run
(
invoke_result
=
assistant_fc_runner
.
run
(
model_instance
=
model_instance
,
conversation
=
conversation
,
conversation
=
conversation
,
message
=
message
,
message
=
message
,
query
=
query
,
query
=
query
,
...
...
api/core/features/assistant_base_runner.py
View file @
6d5b3863
import
logging
import
logging
import
json
import
json
from
typing
import
Optional
,
List
,
Tuple
,
Union
from
typing
import
Optional
,
List
,
Tuple
,
Union
,
cast
from
datetime
import
datetime
from
datetime
import
datetime
from
mimetypes
import
guess_extension
from
mimetypes
import
guess_extension
...
@@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
...
@@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
AgentEntity
,
AppOrchestrationConfigEntity
,
ApplicationGenerateEntity
,
InvokeFrom
AgentEntity
,
AppOrchestrationConfigEntity
,
ApplicationGenerateEntity
,
InvokeFrom
from
core.model_runtime.entities.message_entities
import
PromptMessage
,
PromptMessageTool
from
core.model_runtime.entities.message_entities
import
PromptMessage
,
PromptMessageTool
from
core.model_runtime.entities.llm_entities
import
LLMUsage
from
core.model_runtime.entities.llm_entities
import
LLMUsage
from
core.model_runtime.entities.model_entities
import
ModelFeature
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_manager
import
ModelInstance
from
core.file.message_file_parser
import
FileTransferMethod
from
core.file.message_file_parser
import
FileTransferMethod
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
...
@@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
prompt_messages
:
Optional
[
List
[
PromptMessage
]]
=
None
,
prompt_messages
:
Optional
[
List
[
PromptMessage
]]
=
None
,
variables_pool
:
Optional
[
ToolRuntimeVariablePool
]
=
None
,
variables_pool
:
Optional
[
ToolRuntimeVariablePool
]
=
None
,
db_variables
:
Optional
[
ToolConversationVariables
]
=
None
,
db_variables
:
Optional
[
ToolConversationVariables
]
=
None
,
model_instance
:
ModelInstance
=
None
)
->
None
:
)
->
None
:
"""
"""
Agent runner
Agent runner
...
@@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
...
@@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self
.
history_prompt_messages
=
prompt_messages
self
.
history_prompt_messages
=
prompt_messages
self
.
variables_pool
=
variables_pool
self
.
variables_pool
=
variables_pool
self
.
db_variables_pool
=
db_variables
self
.
db_variables_pool
=
db_variables
self
.
model_instance
=
model_instance
# init callback
# init callback
self
.
agent_callback
=
DifyAgentCallbackHandler
()
self
.
agent_callback
=
DifyAgentCallbackHandler
()
...
@@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner):
...
@@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner):
MessageAgentThought
.
message_id
==
self
.
message
.
id
,
MessageAgentThought
.
message_id
==
self
.
message
.
id
,
)
.
count
()
)
.
count
()
# check if model supports stream tool call
llm_model
=
cast
(
LargeLanguageModel
,
model_instance
.
model_type_instance
)
model_schema
=
llm_model
.
get_model_schema
(
model_instance
.
model
,
model_instance
.
credentials
)
if
model_schema
and
ModelFeature
.
STREAM_TOOL_CALL
in
(
model_schema
.
features
or
[]):
self
.
stream_tool_call
=
True
else
:
self
.
stream_tool_call
=
False
def
_repacket_app_orchestration_config
(
self
,
app_orchestration_config
:
AppOrchestrationConfigEntity
)
->
AppOrchestrationConfigEntity
:
def
_repacket_app_orchestration_config
(
self
,
app_orchestration_config
:
AppOrchestrationConfigEntity
)
->
AppOrchestrationConfigEntity
:
"""
"""
Repacket app orchestration config
Repacket app orchestration config
...
...
api/core/features/assistant_cot_runner.py
View file @
6d5b3863
...
@@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
...
@@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from
models.model
import
Conversation
,
Message
from
models.model
import
Conversation
,
Message
class
AssistantCotApplicationRunner
(
BaseAssistantApplicationRunner
):
class
AssistantCotApplicationRunner
(
BaseAssistantApplicationRunner
):
def
run
(
self
,
model_instance
:
ModelInstance
,
def
run
(
self
,
conversation
:
Conversation
,
conversation
:
Conversation
,
message
:
Message
,
message
:
Message
,
query
:
str
,
query
:
str
,
)
->
Union
[
Generator
,
LLMResult
]:
)
->
Union
[
Generator
,
LLMResult
]:
...
@@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
llm_usage
.
prompt_price
+=
usage
.
prompt_price
llm_usage
.
prompt_price
+=
usage
.
prompt_price
llm_usage
.
completion_price
+=
usage
.
completion_price
llm_usage
.
completion_price
+=
usage
.
completion_price
model_instance
=
self
.
model_instance
while
function_call_state
and
iteration_step
<=
max_iteration_steps
:
while
function_call_state
and
iteration_step
<=
max_iteration_steps
:
# continue to run until there is not any tool call
# continue to run until there is not any tool call
function_call_state
=
False
function_call_state
=
False
...
@@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# remove Action: xxx from agent thought
# remove Action: xxx from agent thought
agent_thought
=
re
.
sub
(
r'Action:.*'
,
''
,
agent_thought
,
flags
=
re
.
IGNORECASE
)
agent_thought
=
re
.
sub
(
r'Action:.*'
,
''
,
agent_thought
,
flags
=
re
.
IGNORECASE
)
if
action_name
and
action_input
:
if
action_name
and
action_input
is
not
None
:
return
AgentScratchpadUnit
(
return
AgentScratchpadUnit
(
agent_response
=
content
,
agent_response
=
content
,
thought
=
agent_thought
,
thought
=
agent_thought
,
...
...
api/core/features/assistant_fc_runner.py
View file @
6d5b3863
...
@@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
...
@@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
from
core.model_runtime.entities.message_entities
import
PromptMessage
,
UserPromptMessage
,
\
from
core.model_runtime.entities.message_entities
import
PromptMessage
,
UserPromptMessage
,
\
SystemPromptMessage
,
AssistantPromptMessage
,
ToolPromptMessage
,
PromptMessageTool
SystemPromptMessage
,
AssistantPromptMessage
,
ToolPromptMessage
,
PromptMessageTool
from
core.model_runtime.entities.llm_entities
import
LLMResultChunk
,
LLMResult
,
LLMUsage
from
core.model_runtime.entities.llm_entities
import
LLMResultChunk
,
LLMResult
,
LLMUsage
,
LLMResultChunkDelta
from
core.model_manager
import
ModelInstance
from
core.model_manager
import
ModelInstance
from
core.application_queue_manager
import
PublishFrom
from
core.application_queue_manager
import
PublishFrom
...
@@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
...
@@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
AssistantFunctionCallApplicationRunner
(
BaseAssistantApplicationRunner
):
class
AssistantFunctionCallApplicationRunner
(
BaseAssistantApplicationRunner
):
def
run
(
self
,
model_instance
:
ModelInstance
,
def
run
(
self
,
conversation
:
Conversation
,
conversation
:
Conversation
,
message
:
Message
,
message
:
Message
,
query
:
str
,
query
:
str
,
)
->
Generator
[
LLMResultChunk
,
None
,
None
]:
)
->
Generator
[
LLMResultChunk
,
None
,
None
]:
...
@@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
llm_usage
.
prompt_price
+=
usage
.
prompt_price
llm_usage
.
prompt_price
+=
usage
.
prompt_price
llm_usage
.
completion_price
+=
usage
.
completion_price
llm_usage
.
completion_price
+=
usage
.
completion_price
model_instance
=
self
.
model_instance
while
function_call_state
and
iteration_step
<=
max_iteration_steps
:
while
function_call_state
and
iteration_step
<=
max_iteration_steps
:
function_call_state
=
False
function_call_state
=
False
...
@@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
# recale llm max tokens
# recale llm max tokens
self
.
recale_llm_max_tokens
(
self
.
model_config
,
prompt_messages
)
self
.
recale_llm_max_tokens
(
self
.
model_config
,
prompt_messages
)
# invoke model
# invoke model
chunks
:
Generator
[
LLMResultChunk
,
None
,
None
]
=
model_instance
.
invoke_llm
(
chunks
:
Union
[
Generator
[
LLMResultChunk
,
None
,
None
],
LLMResult
]
=
model_instance
.
invoke_llm
(
prompt_messages
=
prompt_messages
,
prompt_messages
=
prompt_messages
,
model_parameters
=
app_orchestration_config
.
model_config
.
parameters
,
model_parameters
=
app_orchestration_config
.
model_config
.
parameters
,
tools
=
prompt_messages_tools
,
tools
=
prompt_messages_tools
,
stop
=
app_orchestration_config
.
model_config
.
stop
,
stop
=
app_orchestration_config
.
model_config
.
stop
,
stream
=
True
,
stream
=
self
.
stream_tool_call
,
user
=
self
.
user_id
,
user
=
self
.
user_id
,
callbacks
=
[],
callbacks
=
[],
)
)
...
@@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
current_llm_usage
=
None
current_llm_usage
=
None
for
chunk
in
chunks
:
if
self
.
stream_tool_call
:
for
chunk
in
chunks
:
# check if there is any tool call
if
self
.
check_tool_calls
(
chunk
):
function_call_state
=
True
tool_calls
.
extend
(
self
.
extract_tool_calls
(
chunk
))
tool_call_names
=
';'
.
join
([
tool_call
[
1
]
for
tool_call
in
tool_calls
])
try
:
tool_call_inputs
=
json
.
dumps
({
tool_call
[
1
]:
tool_call
[
2
]
for
tool_call
in
tool_calls
},
ensure_ascii
=
False
)
except
json
.
JSONDecodeError
as
e
:
# ensure ascii to avoid encoding error
tool_call_inputs
=
json
.
dumps
({
tool_call
[
1
]:
tool_call
[
2
]
for
tool_call
in
tool_calls
})
if
chunk
.
delta
.
message
and
chunk
.
delta
.
message
.
content
:
if
isinstance
(
chunk
.
delta
.
message
.
content
,
list
):
for
content
in
chunk
.
delta
.
message
.
content
:
response
+=
content
.
data
else
:
response
+=
chunk
.
delta
.
message
.
content
if
chunk
.
delta
.
usage
:
increase_usage
(
llm_usage
,
chunk
.
delta
.
usage
)
current_llm_usage
=
chunk
.
delta
.
usage
yield
chunk
else
:
result
:
LLMResult
=
chunks
# check if there is any tool call
# check if there is any tool call
if
self
.
check_
tool_calls
(
chunk
):
if
self
.
check_
blocking_tool_calls
(
result
):
function_call_state
=
True
function_call_state
=
True
tool_calls
.
extend
(
self
.
extract_
tool_calls
(
chunk
))
tool_calls
.
extend
(
self
.
extract_
blocking_tool_calls
(
result
))
tool_call_names
=
';'
.
join
([
tool_call
[
1
]
for
tool_call
in
tool_calls
])
tool_call_names
=
';'
.
join
([
tool_call
[
1
]
for
tool_call
in
tool_calls
])
try
:
try
:
tool_call_inputs
=
json
.
dumps
({
tool_call_inputs
=
json
.
dumps
({
...
@@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
tool_call
[
1
]:
tool_call
[
2
]
for
tool_call
in
tool_calls
tool_call
[
1
]:
tool_call
[
2
]
for
tool_call
in
tool_calls
})
})
if
chunk
.
delta
.
message
and
chunk
.
delta
.
message
.
content
:
if
result
.
usage
:
if
isinstance
(
chunk
.
delta
.
message
.
content
,
list
):
increase_usage
(
llm_usage
,
result
.
usage
)
for
content
in
chunk
.
delta
.
message
.
content
:
current_llm_usage
=
result
.
usage
if
result
.
message
and
result
.
message
.
content
:
if
isinstance
(
result
.
message
.
content
,
list
):
for
content
in
result
.
message
.
content
:
response
+=
content
.
data
response
+=
content
.
data
else
:
else
:
response
+=
chunk
.
delta
.
message
.
content
response
+=
result
.
message
.
content
if
chunk
.
delta
.
usage
:
if
not
result
.
message
.
content
:
increase_usage
(
llm_usage
,
chunk
.
delta
.
usage
)
result
.
message
.
content
=
''
current_llm_usage
=
chunk
.
delta
.
usage
yield
LLMResultChunk
(
model
=
model_instance
.
model
,
prompt_messages
=
result
.
prompt_messages
,
system_fingerprint
=
result
.
system_fingerprint
,
delta
=
LLMResultChunkDelta
(
index
=
0
,
message
=
result
.
message
,
usage
=
result
.
usage
,
)
)
yield
chunk
if
tool_calls
:
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
''
,
name
=
''
,
tool_calls
=
[
AssistantPromptMessage
.
ToolCall
(
id
=
tool_call
[
0
],
type
=
'function'
,
function
=
AssistantPromptMessage
.
ToolCall
.
ToolCallFunction
(
name
=
tool_call
[
1
],
arguments
=
json
.
dumps
(
tool_call
[
2
],
ensure_ascii
=
False
)
)
)
for
tool_call
in
tool_calls
]
))
# save thought
# save thought
self
.
save_agent_thought
(
self
.
save_agent_thought
(
...
@@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
final_answer
+=
response
+
'
\n
'
final_answer
+=
response
+
'
\n
'
# update prompt messages
if
response
.
strip
():
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
response
,
))
# call tools
# call tools
tool_responses
=
[]
tool_responses
=
[]
for
tool_call_id
,
tool_call_name
,
tool_call_args
in
tool_calls
:
for
tool_call_id
,
tool_call_name
,
tool_call_args
in
tool_calls
:
...
@@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
)
)
self
.
queue_manager
.
publish_agent_thought
(
agent_thought
,
PublishFrom
.
APPLICATION_MANAGER
)
self
.
queue_manager
.
publish_agent_thought
(
agent_thought
,
PublishFrom
.
APPLICATION_MANAGER
)
# update prompt messages
if
response
.
strip
():
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
response
,
))
# update prompt tool
# update prompt tool
for
prompt_tool
in
prompt_messages_tools
:
for
prompt_tool
in
prompt_messages_tools
:
self
.
update_prompt_message_tool
(
tool_instances
[
prompt_tool
.
name
],
prompt_tool
)
self
.
update_prompt_message_tool
(
tool_instances
[
prompt_tool
.
name
],
prompt_tool
)
...
@@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
if
llm_result_chunk
.
delta
.
message
.
tool_calls
:
if
llm_result_chunk
.
delta
.
message
.
tool_calls
:
return
True
return
True
return
False
return
False
def
check_blocking_tool_calls
(
self
,
llm_result
:
LLMResult
)
->
bool
:
"""
Check if there is any blocking tool call in llm result
"""
if
llm_result
.
message
.
tool_calls
:
return
True
return
False
def
extract_tool_calls
(
self
,
llm_result_chunk
:
LLMResultChunk
)
->
Union
[
None
,
List
[
Tuple
[
str
,
str
,
Dict
[
str
,
Any
]]]]:
def
extract_tool_calls
(
self
,
llm_result_chunk
:
LLMResultChunk
)
->
Union
[
None
,
List
[
Tuple
[
str
,
str
,
Dict
[
str
,
Any
]]]]:
"""
"""
...
@@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
))
))
return
tool_calls
return
tool_calls
def
extract_blocking_tool_calls
(
self
,
llm_result
:
LLMResult
)
->
Union
[
None
,
List
[
Tuple
[
str
,
str
,
Dict
[
str
,
Any
]]]]:
"""
Extract blocking tool calls from llm result
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls
=
[]
for
prompt_message
in
llm_result
.
message
.
tool_calls
:
tool_calls
.
append
((
prompt_message
.
id
,
prompt_message
.
function
.
name
,
json
.
loads
(
prompt_message
.
function
.
arguments
),
))
return
tool_calls
def
organize_prompt_messages
(
self
,
prompt_template
:
str
,
def
organize_prompt_messages
(
self
,
prompt_template
:
str
,
query
:
str
=
None
,
query
:
str
=
None
,
...
...
api/core/model_runtime/entities/model_entities.py
View file @
6d5b3863
...
@@ -78,6 +78,7 @@ class ModelFeature(Enum):
...
@@ -78,6 +78,7 @@ class ModelFeature(Enum):
MULTI_TOOL_CALL
=
"multi-tool-call"
MULTI_TOOL_CALL
=
"multi-tool-call"
AGENT_THOUGHT
=
"agent-thought"
AGENT_THOUGHT
=
"agent-thought"
VISION
=
"vision"
VISION
=
"vision"
STREAM_TOOL_CALL
=
"stream-tool-call"
class
DefaultParameterName
(
Enum
):
class
DefaultParameterName
(
Enum
):
...
...
api/core/model_runtime/model_providers/azure_openai/_constant.py
View file @
6d5b3863
...
@@ -36,6 +36,7 @@ LLM_BASE_MODELS = [
...
@@ -36,6 +36,7 @@ LLM_BASE_MODELS = [
features
=
[
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
model_properties
=
{
...
@@ -80,6 +81,7 @@ LLM_BASE_MODELS = [
...
@@ -80,6 +81,7 @@ LLM_BASE_MODELS = [
features
=
[
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
model_properties
=
{
...
@@ -124,6 +126,7 @@ LLM_BASE_MODELS = [
...
@@ -124,6 +126,7 @@ LLM_BASE_MODELS = [
features
=
[
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
model_properties
=
{
...
@@ -198,6 +201,7 @@ LLM_BASE_MODELS = [
...
@@ -198,6 +201,7 @@ LLM_BASE_MODELS = [
features
=
[
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
model_properties
=
{
...
@@ -272,6 +276,7 @@ LLM_BASE_MODELS = [
...
@@ -272,6 +276,7 @@ LLM_BASE_MODELS = [
features
=
[
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
model_properties
=
{
...
...
api/core/model_runtime/model_providers/azure_openai/llm/llm.py
View file @
6d5b3863
...
@@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
...
@@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
tools
:
Optional
[
list
[
PromptMessageTool
]]
=
None
)
->
Generator
:
tools
:
Optional
[
list
[
PromptMessageTool
]]
=
None
)
->
Generator
:
index
=
0
index
=
0
full_assistant_content
=
''
full_assistant_content
=
''
delta_assistant_message_function_call_storage
:
ChoiceDeltaFunctionCall
=
None
real_model
=
model
real_model
=
model
system_fingerprint
=
None
system_fingerprint
=
None
completion
=
''
completion
=
''
...
@@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
...
@@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
delta
=
chunk
.
choices
[
0
]
delta
=
chunk
.
choices
[
0
]
if
delta
.
finish_reason
is
None
and
(
delta
.
delta
.
content
is
None
or
delta
.
delta
.
content
==
''
):
if
delta
.
finish_reason
is
None
and
(
delta
.
delta
.
content
is
None
or
delta
.
delta
.
content
==
''
)
and
\
delta
.
delta
.
function_call
is
None
:
continue
continue
# assistant_message_tool_calls = delta.delta.tool_calls
# assistant_message_tool_calls = delta.delta.tool_calls
assistant_message_function_call
=
delta
.
delta
.
function_call
assistant_message_function_call
=
delta
.
delta
.
function_call
# extract tool calls from response
if
delta_assistant_message_function_call_storage
is
not
None
:
# handle process of stream function call
if
assistant_message_function_call
:
# message has not ended ever
delta_assistant_message_function_call_storage
.
arguments
+=
assistant_message_function_call
.
arguments
continue
else
:
# message has ended
assistant_message_function_call
=
delta_assistant_message_function_call_storage
delta_assistant_message_function_call_storage
=
None
else
:
if
assistant_message_function_call
:
# start of stream function call
delta_assistant_message_function_call_storage
=
assistant_message_function_call
if
delta_assistant_message_function_call_storage
.
arguments
is
None
:
delta_assistant_message_function_call_storage
.
arguments
=
''
continue
# extract tool calls from response
# extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call
=
self
.
_extract_response_function_call
(
assistant_message_function_call
)
function_call
=
self
.
_extract_response_function_call
(
assistant_message_function_call
)
...
@@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
...
@@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
else
:
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
raise
ValueError
(
f
"Got unknown type {message}"
)
if
message
.
name
is
not
None
:
if
message
.
name
:
message_dict
[
"name"
]
=
message
.
name
message_dict
[
"name"
]
=
message
.
name
return
message_dict
return
message_dict
...
@@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
...
@@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
num_tokens
=
0
num_tokens
=
0
for
tool
in
tools
:
for
tool
in
tools
:
num_tokens
+=
len
(
encoding
.
encode
(
'type'
))
num_tokens
+=
len
(
encoding
.
encode
(
'type'
))
num_tokens
+=
len
(
encoding
.
encode
(
tool
.
get
(
"type"
)))
num_tokens
+=
len
(
encoding
.
encode
(
'function'
))
num_tokens
+=
len
(
encoding
.
encode
(
'function'
))
# calculate num tokens for function object
# calculate num tokens for function object
...
...
api/core/model_runtime/model_providers/chatglm/llm/llm.py
View file @
6d5b3863
...
@@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast
...
@@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageFunction
,
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageFunction
,
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
)
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
,
ToolPromptMessage
)
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
...
@@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
...
@@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
elif
isinstance
(
message
,
SystemPromptMessage
):
elif
isinstance
(
message
,
SystemPromptMessage
):
message
=
cast
(
SystemPromptMessage
,
message
)
message
=
cast
(
SystemPromptMessage
,
message
)
message_dict
=
{
"role"
:
"system"
,
"content"
:
message
.
content
}
message_dict
=
{
"role"
:
"system"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
ToolPromptMessage
):
# check if last message is user message
message
=
cast
(
ToolPromptMessage
,
message
)
message_dict
=
{
"role"
:
"function"
,
"content"
:
message
.
content
}
else
:
else
:
raise
ValueError
(
f
"Unknown message type {type(message)}"
)
raise
ValueError
(
f
"Unknown message type {type(message)}"
)
...
...
api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml
View file @
6d5b3863
...
@@ -4,6 +4,8 @@ label:
...
@@ -4,6 +4,8 @@ label:
model_type
:
llm
model_type
:
llm
features
:
features
:
-
agent-thought
-
agent-thought
-
tool-call
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
16384
context_size
:
16384
...
...
api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml
View file @
6d5b3863
...
@@ -4,6 +4,8 @@ label:
...
@@ -4,6 +4,8 @@ label:
model_type
:
llm
model_type
:
llm
features
:
features
:
-
agent-thought
-
agent-thought
-
tool-call
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
32768
context_size
:
32768
...
...
api/core/model_runtime/model_providers/minimax/llm/chat_completion.py
View file @
6d5b3863
...
@@ -16,7 +16,7 @@ class MinimaxChatCompletion(object):
...
@@ -16,7 +16,7 @@ class MinimaxChatCompletion(object):
"""
"""
def
generate
(
self
,
model
:
str
,
api_key
:
str
,
group_id
:
str
,
def
generate
(
self
,
model
:
str
,
api_key
:
str
,
group_id
:
str
,
prompt_messages
:
List
[
MinimaxMessage
],
model_parameters
:
dict
,
prompt_messages
:
List
[
MinimaxMessage
],
model_parameters
:
dict
,
tools
:
Dict
[
str
,
Any
],
stop
:
List
[
str
]
|
None
,
stream
:
bool
,
user
:
str
)
\
tools
:
List
[
Dict
[
str
,
Any
]
],
stop
:
List
[
str
]
|
None
,
stream
:
bool
,
user
:
str
)
\
->
Union
[
MinimaxMessage
,
Generator
[
MinimaxMessage
,
None
,
None
]]:
->
Union
[
MinimaxMessage
,
Generator
[
MinimaxMessage
,
None
,
None
]]:
"""
"""
generate chat completion
generate chat completion
...
@@ -162,7 +162,6 @@ class MinimaxChatCompletion(object):
...
@@ -162,7 +162,6 @@ class MinimaxChatCompletion(object):
continue
continue
for
choice
in
choices
:
for
choice
in
choices
:
print
(
choice
)
message
=
choice
[
'delta'
]
message
=
choice
[
'delta'
]
yield
MinimaxMessage
(
yield
MinimaxMessage
(
content
=
message
,
content
=
message
,
...
...
api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
View file @
6d5b3863
...
@@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object):
...
@@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object):
"""
"""
def
generate
(
self
,
model
:
str
,
api_key
:
str
,
group_id
:
str
,
def
generate
(
self
,
model
:
str
,
api_key
:
str
,
group_id
:
str
,
prompt_messages
:
List
[
MinimaxMessage
],
model_parameters
:
dict
,
prompt_messages
:
List
[
MinimaxMessage
],
model_parameters
:
dict
,
tools
:
Dict
[
str
,
Any
],
stop
:
List
[
str
]
|
None
,
stream
:
bool
,
user
:
str
)
\
tools
:
List
[
Dict
[
str
,
Any
]
],
stop
:
List
[
str
]
|
None
,
stream
:
bool
,
user
:
str
)
\
->
Union
[
MinimaxMessage
,
Generator
[
MinimaxMessage
,
None
,
None
]]:
->
Union
[
MinimaxMessage
,
Generator
[
MinimaxMessage
,
None
,
None
]]:
"""
"""
generate chat completion
generate chat completion
...
@@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object):
...
@@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object):
**
extra_kwargs
**
extra_kwargs
}
}
if
tools
:
body
[
'functions'
]
=
tools
body
[
'function_call'
]
=
{
'type'
:
'auto'
}
try
:
try
:
response
=
post
(
response
=
post
(
url
=
url
,
data
=
dumps
(
body
),
headers
=
headers
,
stream
=
stream
,
timeout
=
(
10
,
300
))
url
=
url
,
data
=
dumps
(
body
),
headers
=
headers
,
stream
=
stream
,
timeout
=
(
10
,
300
))
...
@@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object):
...
@@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object):
"""
"""
handle stream chat generate response
handle stream chat generate response
"""
"""
function_call_storage
=
None
for
line
in
response
.
iter_lines
():
for
line
in
response
.
iter_lines
():
if
not
line
:
if
not
line
:
continue
continue
...
@@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object):
...
@@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object):
msg
=
data
[
'base_resp'
][
'status_msg'
]
msg
=
data
[
'base_resp'
][
'status_msg'
]
self
.
_handle_error
(
code
,
msg
)
self
.
_handle_error
(
code
,
msg
)
if
data
[
'reply'
]:
if
data
[
'reply'
]
or
'usage'
in
data
and
data
[
'usage'
]
:
total_tokens
=
data
[
'usage'
][
'total_tokens'
]
total_tokens
=
data
[
'usage'
][
'total_tokens'
]
message
=
MinimaxMessage
(
message
=
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
,
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
,
...
@@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object):
...
@@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object):
'total_tokens'
:
total_tokens
'total_tokens'
:
total_tokens
}
}
message
.
stop_reason
=
data
[
'choices'
][
0
][
'finish_reason'
]
message
.
stop_reason
=
data
[
'choices'
][
0
][
'finish_reason'
]
if
function_call_storage
:
function_call_message
=
MinimaxMessage
(
content
=
''
,
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
)
function_call_message
.
function_call
=
function_call_storage
yield
function_call_message
yield
message
yield
message
return
return
...
@@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object):
...
@@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object):
continue
continue
for
choice
in
choices
:
for
choice
in
choices
:
message
=
choice
[
'messages'
][
0
][
'text'
]
message
=
choice
[
'messages'
][
0
]
if
not
message
:
continue
if
'function_call'
in
message
:
if
not
function_call_storage
:
function_call_storage
=
message
[
'function_call'
]
if
'arguments'
not
in
function_call_storage
or
not
function_call_storage
[
'arguments'
]:
function_call_storage
[
'arguments'
]
=
''
continue
else
:
function_call_storage
[
'arguments'
]
+=
message
[
'function_call'
][
'arguments'
]
continue
else
:
if
function_call_storage
:
message
[
'function_call'
]
=
function_call_storage
function_call_storage
=
None
yield
MinimaxMessage
(
minimax_message
=
MinimaxMessage
(
content
=
''
,
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
)
content
=
message
,
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
if
'function_call'
in
message
:
)
minimax_message
.
function_call
=
message
[
'function_call'
]
\ No newline at end of file
if
'text'
in
message
:
minimax_message
.
content
=
message
[
'text'
]
yield
minimax_message
\ No newline at end of file
api/core/model_runtime/model_providers/minimax/llm/llm.py
View file @
6d5b3863
...
@@ -2,7 +2,7 @@ from typing import Generator, List
...
@@ -2,7 +2,7 @@ from typing import Generator, List
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageTool
,
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
)
SystemPromptMessage
,
UserPromptMessage
,
ToolPromptMessage
)
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
...
@@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
...
@@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
"""
"""
client
:
MinimaxChatCompletionPro
=
self
.
model_apis
[
model
]()
client
:
MinimaxChatCompletionPro
=
self
.
model_apis
[
model
]()
if
tools
:
tools
=
[{
"name"
:
tool
.
name
,
"description"
:
tool
.
description
,
"parameters"
:
tool
.
parameters
}
for
tool
in
tools
]
response
=
client
.
generate
(
response
=
client
.
generate
(
model
=
model
,
model
=
model
,
api_key
=
credentials
[
'minimax_api_key'
],
api_key
=
credentials
[
'minimax_api_key'
],
...
@@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
...
@@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
elif
isinstance
(
prompt_message
,
UserPromptMessage
):
elif
isinstance
(
prompt_message
,
UserPromptMessage
):
return
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
USER
.
value
,
content
=
prompt_message
.
content
)
return
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
USER
.
value
,
content
=
prompt_message
.
content
)
elif
isinstance
(
prompt_message
,
AssistantPromptMessage
):
elif
isinstance
(
prompt_message
,
AssistantPromptMessage
):
if
prompt_message
.
tool_calls
:
message
=
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
,
content
=
''
)
message
.
function_call
=
{
'name'
:
prompt_message
.
tool_calls
[
0
]
.
function
.
name
,
'arguments'
:
prompt_message
.
tool_calls
[
0
]
.
function
.
arguments
}
return
message
return
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
,
content
=
prompt_message
.
content
)
return
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
,
content
=
prompt_message
.
content
)
elif
isinstance
(
prompt_message
,
ToolPromptMessage
):
return
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
FUNCTION
.
value
,
content
=
prompt_message
.
content
)
else
:
else
:
raise
NotImplementedError
(
f
'Prompt message type {type(prompt_message)} is not supported'
)
raise
NotImplementedError
(
f
'Prompt message type {type(prompt_message)} is not supported'
)
...
@@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
...
@@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
finish_reason
=
message
.
stop_reason
if
message
.
stop_reason
else
None
,
finish_reason
=
message
.
stop_reason
if
message
.
stop_reason
else
None
,
),
),
)
)
elif
message
.
function_call
:
if
'name'
not
in
message
.
function_call
or
'arguments'
not
in
message
.
function_call
:
continue
yield
LLMResultChunk
(
model
=
model
,
prompt_messages
=
prompt_messages
,
delta
=
LLMResultChunkDelta
(
index
=
0
,
message
=
AssistantPromptMessage
(
content
=
''
,
tool_calls
=
[
AssistantPromptMessage
.
ToolCall
(
id
=
''
,
type
=
'function'
,
function
=
AssistantPromptMessage
.
ToolCall
.
ToolCallFunction
(
name
=
message
.
function_call
[
'name'
],
arguments
=
message
.
function_call
[
'arguments'
]
)
)]
),
),
)
else
:
else
:
yield
LLMResultChunk
(
yield
LLMResultChunk
(
model
=
model
,
model
=
model
,
...
...
api/core/model_runtime/model_providers/minimax/llm/types.py
View file @
6d5b3863
...
@@ -7,13 +7,23 @@ class MinimaxMessage:
...
@@ -7,13 +7,23 @@ class MinimaxMessage:
USER
=
'USER'
USER
=
'USER'
ASSISTANT
=
'BOT'
ASSISTANT
=
'BOT'
SYSTEM
=
'SYSTEM'
SYSTEM
=
'SYSTEM'
FUNCTION
=
'FUNCTION'
role
:
str
=
Role
.
USER
.
value
role
:
str
=
Role
.
USER
.
value
content
:
str
content
:
str
usage
:
Dict
[
str
,
int
]
=
None
usage
:
Dict
[
str
,
int
]
=
None
stop_reason
:
str
=
''
stop_reason
:
str
=
''
function_call
:
Dict
[
str
,
Any
]
=
None
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
if
self
.
function_call
and
self
.
role
==
MinimaxMessage
.
Role
.
ASSISTANT
.
value
:
return
{
'sender_type'
:
'BOT'
,
'sender_name'
:
'专家'
,
'text'
:
''
,
'function_call'
:
self
.
function_call
}
return
{
return
{
'sender_type'
:
self
.
role
,
'sender_type'
:
self
.
role
,
'sender_name'
:
'我'
if
self
.
role
==
'USER'
else
'专家'
,
'sender_name'
:
'我'
if
self
.
role
==
'USER'
else
'专家'
,
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
4096
context_size
:
4096
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
16385
context_size
:
16385
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
16385
context_size
:
16385
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
16385
context_size
:
16385
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
4096
context_size
:
4096
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
128000
context_size
:
128000
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
128000
context_size
:
128000
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
32768
context_size
:
32768
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
128000
context_size
:
128000
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml
View file @
6d5b3863
...
@@ -6,6 +6,7 @@ model_type: llm
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
features
:
-
multi-tool-call
-
multi-tool-call
-
agent-thought
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
context_size
:
8192
context_size
:
8192
...
...
api/core/model_runtime/model_providers/openai/llm/llm.py
View file @
6d5b3863
...
@@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
...
@@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
else
:
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
raise
ValueError
(
f
"Got unknown type {message}"
)
if
message
.
name
is
not
None
:
if
message
.
name
:
message_dict
[
"name"
]
=
message
.
name
message_dict
[
"name"
]
=
message
.
name
return
message_dict
return
message_dict
...
...
api/core/model_runtime/model_providers/xinference/llm/llm.py
View file @
6d5b3863
...
@@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast
...
@@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast
from
core.model_runtime.entities.common_entities
import
I18nObject
from
core.model_runtime.entities.common_entities
import
I18nObject
from
core.model_runtime.entities.llm_entities
import
LLMMode
,
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.llm_entities
import
LLMMode
,
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageTool
,
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
)
SystemPromptMessage
,
UserPromptMessage
,
ToolPromptMessage
)
from
core.model_runtime.entities.model_entities
import
(
AIModelEntity
,
FetchFrom
,
ModelPropertyKey
,
ModelType
,
from
core.model_runtime.entities.model_entities
import
(
AIModelEntity
,
FetchFrom
,
ModelPropertyKey
,
ModelType
,
ParameterRule
,
ParameterType
)
ParameterRule
,
ParameterType
,
ModelFeature
)
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.xinference.
llm.
xinference_helper
import
(
XinferenceHelper
,
from
core.model_runtime.model_providers.xinference.xinference_helper
import
(
XinferenceHelper
,
XinferenceModelExtraParameter
)
XinferenceModelExtraParameter
)
from
core.model_runtime.utils
import
helper
from
core.model_runtime.utils
import
helper
from
openai
import
(
APIConnectionError
,
APITimeoutError
,
AuthenticationError
,
ConflictError
,
InternalServerError
,
from
openai
import
(
APIConnectionError
,
APITimeoutError
,
AuthenticationError
,
ConflictError
,
InternalServerError
,
...
@@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
...
@@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
"""
"""
if
'temperature'
in
model_parameters
:
if
model_parameters
[
'temperature'
]
<
0.01
:
model_parameters
[
'temperature'
]
=
0.01
elif
model_parameters
[
'temperature'
]
>
1.0
:
model_parameters
[
'temperature'
]
=
0.99
return
self
.
_generate
(
return
self
.
_generate
(
model
=
model
,
credentials
=
credentials
,
prompt_messages
=
prompt_messages
,
model_parameters
=
model_parameters
,
model
=
model
,
credentials
=
credentials
,
prompt_messages
=
prompt_messages
,
model_parameters
=
model_parameters
,
tools
=
tools
,
stop
=
stop
,
stream
=
stream
,
user
=
user
,
tools
=
tools
,
stop
=
stop
,
stream
=
stream
,
user
=
user
,
...
@@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
...
@@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
credentials
[
'completion_type'
]
=
'completion'
credentials
[
'completion_type'
]
=
'completion'
else
:
else
:
raise
ValueError
(
f
'xinference model ability {extra_param.model_ability} is not supported'
)
raise
ValueError
(
f
'xinference model ability {extra_param.model_ability} is not supported'
)
if
extra_param
.
support_function_call
:
credentials
[
'support_function_call'
]
=
True
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
CredentialsValidateFailedError
(
f
'Xinference credentials validate failed: {e}'
)
raise
CredentialsValidateFailedError
(
f
'Xinference credentials validate failed: {e}'
)
...
@@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
...
@@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
elif
isinstance
(
message
,
SystemPromptMessage
):
elif
isinstance
(
message
,
SystemPromptMessage
):
message
=
cast
(
SystemPromptMessage
,
message
)
message
=
cast
(
SystemPromptMessage
,
message
)
message_dict
=
{
"role"
:
"system"
,
"content"
:
message
.
content
}
message_dict
=
{
"role"
:
"system"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
ToolPromptMessage
):
message
=
cast
(
ToolPromptMessage
,
message
)
message_dict
=
{
"tool_call_id"
:
message
.
tool_call_id
,
"role"
:
"tool"
,
"content"
:
message
.
content
}
else
:
else
:
raise
ValueError
(
f
"Unknown message type {type(message)}"
)
raise
ValueError
(
f
"Unknown message type {type(message)}"
)
...
@@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
...
@@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
label
=
I18nObject
(
label
=
I18nObject
(
zh_Hans
=
'温度'
,
zh_Hans
=
'温度'
,
en_US
=
'Temperature'
en_US
=
'Temperature'
)
)
,
),
),
ParameterRule
(
ParameterRule
(
name
=
'top_p'
,
name
=
'top_p'
,
...
@@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
...
@@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_type
=
LLMMode
.
COMPLETION
.
value
completion_type
=
LLMMode
.
COMPLETION
.
value
else
:
else
:
raise
ValueError
(
f
'xinference model ability {extra_args.model_ability} is not supported'
)
raise
ValueError
(
f
'xinference model ability {extra_args.model_ability} is not supported'
)
support_function_call
=
credentials
.
get
(
'support_function_call'
,
False
)
entity
=
AIModelEntity
(
entity
=
AIModelEntity
(
model
=
model
,
model
=
model
,
...
@@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
...
@@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
),
),
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_type
=
ModelType
.
LLM
,
model_type
=
ModelType
.
LLM
,
features
=
[
ModelFeature
.
TOOL_CALL
]
if
support_function_call
else
[],
model_properties
=
{
model_properties
=
{
ModelPropertyKey
.
MODE
:
completion_type
,
ModelPropertyKey
.
MODE
:
completion_type
,
},
},
...
@@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
...
@@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
"""
"""
if
'server_url'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'server_url is required in credentials'
)
if
credentials
[
'server_url'
]
.
endswith
(
'/'
):
credentials
[
'server_url'
]
=
credentials
[
'server_url'
][:
-
1
]
client
=
OpenAI
(
client
=
OpenAI
(
base_url
=
f
'{credentials["server_url"]}/v1'
,
base_url
=
f
'{credentials["server_url"]}/v1'
,
api_key
=
'abc'
,
api_key
=
'abc'
,
...
...
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
View file @
6d5b3863
...
@@ -2,7 +2,7 @@ import time
...
@@ -2,7 +2,7 @@ import time
from
typing
import
Optional
from
typing
import
Optional
from
core.model_runtime.entities.common_entities
import
I18nObject
from
core.model_runtime.entities.common_entities
import
I18nObject
from
core.model_runtime.entities.model_entities
import
AIModelEntity
,
FetchFrom
,
ModelType
,
PriceType
from
core.model_runtime.entities.model_entities
import
AIModelEntity
,
FetchFrom
,
ModelType
,
PriceType
,
ModelPropertyKey
from
core.model_runtime.entities.text_embedding_entities
import
EmbeddingUsage
,
TextEmbeddingResult
from
core.model_runtime.entities.text_embedding_entities
import
EmbeddingUsage
,
TextEmbeddingResult
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
...
@@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
...
@@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from
core.model_runtime.model_providers.__base.text_embedding_model
import
TextEmbeddingModel
from
core.model_runtime.model_providers.__base.text_embedding_model
import
TextEmbeddingModel
from
xinference_client.client.restful.restful_client
import
Client
,
RESTfulEmbeddingModelHandle
,
RESTfulModelHandle
from
xinference_client.client.restful.restful_client
import
Client
,
RESTfulEmbeddingModelHandle
,
RESTfulModelHandle
from
core.model_runtime.model_providers.xinference.xinference_helper
import
XinferenceHelper
class
XinferenceTextEmbeddingModel
(
TextEmbeddingModel
):
class
XinferenceTextEmbeddingModel
(
TextEmbeddingModel
):
"""
"""
...
@@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
...
@@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
"""
server_url
=
credentials
[
'server_url'
]
server_url
=
credentials
[
'server_url'
]
model_uid
=
credentials
[
'model_uid'
]
model_uid
=
credentials
[
'model_uid'
]
if
server_url
.
endswith
(
'/'
):
server_url
=
server_url
[:
-
1
]
client
=
Client
(
base_url
=
server_url
)
client
=
Client
(
base_url
=
server_url
)
try
:
try
:
...
@@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
...
@@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
:return:
:return:
"""
"""
try
:
try
:
server_url
=
credentials
[
'server_url'
]
model_uid
=
credentials
[
'model_uid'
]
extra_args
=
XinferenceHelper
.
get_xinference_extra_parameter
(
server_url
=
server_url
,
model_uid
=
model_uid
)
if
extra_args
.
max_tokens
:
credentials
[
'max_tokens'
]
=
extra_args
.
max_tokens
self
.
_invoke
(
model
=
model
,
credentials
=
credentials
,
texts
=
[
'ping'
])
self
.
_invoke
(
model
=
model
,
credentials
=
credentials
,
texts
=
[
'ping'
])
except
InvokeAuthorizationError
:
except
(
InvokeAuthorizationError
,
RuntimeError
)
:
raise
CredentialsValidateFailedError
(
'Invalid api key'
)
raise
CredentialsValidateFailedError
(
'Invalid api key'
)
@
property
@
property
...
@@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
...
@@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
"""
used to define customizable model schema
used to define customizable model schema
"""
"""
entity
=
AIModelEntity
(
entity
=
AIModelEntity
(
model
=
model
,
model
=
model
,
label
=
I18nObject
(
label
=
I18nObject
(
...
@@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
...
@@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
),
),
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_type
=
ModelType
.
TEXT_EMBEDDING
,
model_type
=
ModelType
.
TEXT_EMBEDDING
,
model_properties
=
{},
model_properties
=
{
ModelPropertyKey
.
MAX_CHUNKS
:
1
,
ModelPropertyKey
.
CONTEXT_SIZE
:
'max_tokens'
in
credentials
and
credentials
[
'max_tokens'
]
or
512
,
},
parameter_rules
=
[]
parameter_rules
=
[]
)
)
...
...
api/core/model_runtime/model_providers/xinference/
llm/
xinference_helper.py
→
api/core/model_runtime/model_providers/xinference/xinference_helper.py
View file @
6d5b3863
from
threading
import
Lock
from
threading
import
Lock
from
time
import
time
from
time
import
time
from
typing
import
List
from
typing
import
List
from
os
import
path
from
requests
import
get
from
requests
import
get
from
requests.adapters
import
HTTPAdapter
from
requests.adapters
import
HTTPAdapter
...
@@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object):
...
@@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object):
model_format
:
str
model_format
:
str
model_handle_type
:
str
model_handle_type
:
str
model_ability
:
List
[
str
]
model_ability
:
List
[
str
]
max_tokens
:
int
=
512
support_function_call
:
bool
=
False
def
__init__
(
self
,
model_format
:
str
,
model_handle_type
:
str
,
model_ability
:
List
[
str
])
->
None
:
def
__init__
(
self
,
model_format
:
str
,
model_handle_type
:
str
,
model_ability
:
List
[
str
],
support_function_call
:
bool
,
max_tokens
:
int
)
->
None
:
self
.
model_format
=
model_format
self
.
model_format
=
model_format
self
.
model_handle_type
=
model_handle_type
self
.
model_handle_type
=
model_handle_type
self
.
model_ability
=
model_ability
self
.
model_ability
=
model_ability
self
.
support_function_call
=
support_function_call
self
.
max_tokens
=
max_tokens
cache
=
{}
cache
=
{}
cache_lock
=
Lock
()
cache_lock
=
Lock
()
...
@@ -49,7 +55,7 @@ class XinferenceHelper:
...
@@ -49,7 +55,7 @@ class XinferenceHelper:
get xinference model extra parameter like model_format and model_handle_type
get xinference model extra parameter like model_format and model_handle_type
"""
"""
url
=
f
'{server_url}/v1/models/{model_uid}'
url
=
path
.
join
(
server_url
,
'v1/models'
,
model_uid
)
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session
=
Session
()
session
=
Session
()
...
@@ -66,10 +72,12 @@ class XinferenceHelper:
...
@@ -66,10 +72,12 @@ class XinferenceHelper:
response_json
=
response
.
json
()
response_json
=
response
.
json
()
model_format
=
response_json
[
'model_format'
]
model_format
=
response_json
.
get
(
'model_format'
,
'ggmlv3'
)
model_ability
=
response_json
[
'model_ability'
]
model_ability
=
response_json
.
get
(
'model_ability'
,
[])
if
model_format
==
'ggmlv3'
and
'chatglm'
in
response_json
[
'model_name'
]:
if
response_json
.
get
(
'model_type'
)
==
'embedding'
:
model_handle_type
=
'embedding'
elif
model_format
==
'ggmlv3'
and
'chatglm'
in
response_json
[
'model_name'
]:
model_handle_type
=
'chatglm'
model_handle_type
=
'chatglm'
elif
'generate'
in
model_ability
:
elif
'generate'
in
model_ability
:
model_handle_type
=
'generate'
model_handle_type
=
'generate'
...
@@ -78,8 +86,13 @@ class XinferenceHelper:
...
@@ -78,8 +86,13 @@ class XinferenceHelper:
else
:
else
:
raise
NotImplementedError
(
f
'xinference model handle type {model_handle_type} is not supported'
)
raise
NotImplementedError
(
f
'xinference model handle type {model_handle_type} is not supported'
)
support_function_call
=
'tools'
in
model_ability
max_tokens
=
response_json
.
get
(
'max_tokens'
,
512
)
return
XinferenceModelExtraParameter
(
return
XinferenceModelExtraParameter
(
model_format
=
model_format
,
model_format
=
model_format
,
model_handle_type
=
model_handle_type
,
model_handle_type
=
model_handle_type
,
model_ability
=
model_ability
model_ability
=
model_ability
,
support_function_call
=
support_function_call
,
max_tokens
=
max_tokens
)
)
\ No newline at end of file
api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml
View file @
6d5b3863
...
@@ -2,6 +2,10 @@ model: glm-3-turbo
...
@@ -2,6 +2,10 @@ model: glm-3-turbo
label
:
label
:
en_US
:
glm-3-turbo
en_US
:
glm-3-turbo
model_type
:
llm
model_type
:
llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
parameter_rules
:
parameter_rules
:
...
...
api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml
View file @
6d5b3863
...
@@ -2,6 +2,10 @@ model: glm-4
...
@@ -2,6 +2,10 @@ model: glm-4
label
:
label
:
en_US
:
glm-4
en_US
:
glm-4
model_type
:
llm
model_type
:
llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
model_properties
:
mode
:
chat
mode
:
chat
parameter_rules
:
parameter_rules
:
...
...
api/core/model_runtime/model_providers/zhipuai/llm/llm.py
View file @
6d5b3863
...
@@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
...
@@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
'content'
:
prompt_message
.
content
,
'content'
:
prompt_message
.
content
,
'tool_call_id'
:
prompt_message
.
tool_call_id
'tool_call_id'
:
prompt_message
.
tool_call_id
})
})
elif
isinstance
(
prompt_message
,
AssistantPromptMessage
):
if
prompt_message
.
tool_calls
:
params
[
'messages'
]
.
append
({
'role'
:
'assistant'
,
'content'
:
prompt_message
.
content
,
'tool_calls'
:
[
{
'id'
:
tool_call
.
id
,
'type'
:
tool_call
.
type
,
'function'
:
{
'name'
:
tool_call
.
function
.
name
,
'arguments'
:
tool_call
.
function
.
arguments
}
}
for
tool_call
in
prompt_message
.
tool_calls
]
})
else
:
params
[
'messages'
]
.
append
({
'role'
:
'assistant'
,
'content'
:
prompt_message
.
content
})
else
:
else
:
params
[
'messages'
]
.
append
({
params
[
'messages'
]
.
append
({
'role'
:
prompt_message
.
role
.
value
,
'role'
:
prompt_message
.
role
.
value
,
...
...
api/requirements.txt
View file @
6d5b3863
...
@@ -47,7 +47,7 @@ dashscope[tokenizer]~=1.14.0
...
@@ -47,7 +47,7 @@ dashscope[tokenizer]~=1.14.0
huggingface_hub~=0.16.4
huggingface_hub~=0.16.4
transformers~=4.31.0
transformers~=4.31.0
pandas==1.5.3
pandas==1.5.3
xinference-client~=0.
6.4
xinference-client~=0.
8.1
safetensors==0.3.2
safetensors==0.3.2
zhipuai==1.0.7
zhipuai==1.0.7
werkzeug~=3.0.1
werkzeug~=3.0.1
...
...
api/tests/integration_tests/model_runtime/__mock/xinference.py
View file @
6d5b3863
...
@@ -19,58 +19,86 @@ class MockXinferenceClass(object):
...
@@ -19,58 +19,86 @@ class MockXinferenceClass(object):
raise
RuntimeError
(
'404 Not Found'
)
raise
RuntimeError
(
'404 Not Found'
)
if
'generate'
==
model_uid
:
if
'generate'
==
model_uid
:
return
RESTfulGenerateModelHandle
(
model_uid
,
base_url
=
self
.
base_url
)
return
RESTfulGenerateModelHandle
(
model_uid
,
base_url
=
self
.
base_url
,
auth_headers
=
{}
)
if
'chat'
==
model_uid
:
if
'chat'
==
model_uid
:
return
RESTfulChatModelHandle
(
model_uid
,
base_url
=
self
.
base_url
)
return
RESTfulChatModelHandle
(
model_uid
,
base_url
=
self
.
base_url
,
auth_headers
=
{}
)
if
'embedding'
==
model_uid
:
if
'embedding'
==
model_uid
:
return
RESTfulEmbeddingModelHandle
(
model_uid
,
base_url
=
self
.
base_url
)
return
RESTfulEmbeddingModelHandle
(
model_uid
,
base_url
=
self
.
base_url
,
auth_headers
=
{}
)
if
'rerank'
==
model_uid
:
if
'rerank'
==
model_uid
:
return
RESTfulRerankModelHandle
(
model_uid
,
base_url
=
self
.
base_url
)
return
RESTfulRerankModelHandle
(
model_uid
,
base_url
=
self
.
base_url
,
auth_headers
=
{}
)
raise
RuntimeError
(
'404 Not Found'
)
raise
RuntimeError
(
'404 Not Found'
)
def
get
(
self
:
Session
,
url
:
str
,
**
kwargs
):
def
get
(
self
:
Session
,
url
:
str
,
**
kwargs
):
if
'/v1/models/'
in
url
:
response
=
Response
()
response
=
Response
()
if
'v1/models/'
in
url
:
# get model uid
# get model uid
model_uid
=
url
.
split
(
'/'
)[
-
1
]
model_uid
=
url
.
split
(
'/'
)[
-
1
]
if
not
re
.
match
(
r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
,
model_uid
)
and
\
if
not
re
.
match
(
r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
,
model_uid
)
and
\
model_uid
not
in
[
'generate'
,
'chat'
,
'embedding'
,
'rerank'
]:
model_uid
not
in
[
'generate'
,
'chat'
,
'embedding'
,
'rerank'
]:
response
.
status_code
=
404
response
.
status_code
=
404
r
aise
ConnectionError
(
'404 Not Found'
)
r
eturn
response
# check if url is valid
# check if url is valid
if
not
re
.
match
(
r'^(https?):\/\/[^\s\/$.?#].[^\s]*$'
,
url
):
if
not
re
.
match
(
r'^(https?):\/\/[^\s\/$.?#].[^\s]*$'
,
url
):
response
.
status_code
=
404
response
.
status_code
=
404
raise
ConnectionError
(
'404 Not Found'
)
return
response
if
model_uid
in
[
'generate'
,
'chat'
]:
response
.
status_code
=
200
response
.
_content
=
b
'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return
response
elif
model_uid
==
'embedding'
:
response
.
status_code
=
200
response
.
_content
=
b
'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return
response
elif
'v1/cluster/auth'
in
url
:
response
.
status_code
=
200
response
.
status_code
=
200
response
.
_content
=
b
'''{
response
.
_content
=
b
'''{
"model_type": "LLM",
"auth": true
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
}'''
return
response
return
response
def
_check_cluster_authenticated
(
self
):
self
.
_cluster_authed
=
True
def
rerank
(
self
:
RESTfulRerankModelHandle
,
documents
:
List
[
str
],
query
:
str
,
top_n
:
int
)
->
dict
:
def
rerank
(
self
:
RESTfulRerankModelHandle
,
documents
:
List
[
str
],
query
:
str
,
top_n
:
int
)
->
dict
:
# check if self._model_uid is a valid uuid
# check if self._model_uid is a valid uuid
if
not
re
.
match
(
r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
,
self
.
_model_uid
)
and
\
if
not
re
.
match
(
r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
,
self
.
_model_uid
)
and
\
...
@@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
...
@@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
def
setup_xinference_mock
(
request
,
monkeypatch
:
MonkeyPatch
):
def
setup_xinference_mock
(
request
,
monkeypatch
:
MonkeyPatch
):
if
MOCK
:
if
MOCK
:
monkeypatch
.
setattr
(
Client
,
'get_model'
,
MockXinferenceClass
.
get_chat_model
)
monkeypatch
.
setattr
(
Client
,
'get_model'
,
MockXinferenceClass
.
get_chat_model
)
monkeypatch
.
setattr
(
Client
,
'_check_cluster_authenticated'
,
MockXinferenceClass
.
_check_cluster_authenticated
)
monkeypatch
.
setattr
(
Session
,
'get'
,
MockXinferenceClass
.
get
)
monkeypatch
.
setattr
(
Session
,
'get'
,
MockXinferenceClass
.
get
)
monkeypatch
.
setattr
(
RESTfulEmbeddingModelHandle
,
'create_embedding'
,
MockXinferenceClass
.
create_embedding
)
monkeypatch
.
setattr
(
RESTfulEmbeddingModelHandle
,
'create_embedding'
,
MockXinferenceClass
.
create_embedding
)
monkeypatch
.
setattr
(
RESTfulRerankModelHandle
,
'rerank'
,
MockXinferenceClass
.
rerank
)
monkeypatch
.
setattr
(
RESTfulRerankModelHandle
,
'rerank'
,
MockXinferenceClass
.
rerank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment