Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
6d5b3863
Unverified
Commit
6d5b3863
authored
Jan 30, 2024
by
Yeuoly
Committed by
GitHub
Jan 30, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Feat/blocking function call (#2247)
parent
1ea18a29
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
33 changed files
with
430 additions
and
95 deletions
+430
-95
assistant_app_runner.py
api/core/app_runner/assistant_app_runner.py
+11
-3
assistant_base_runner.py
api/core/features/assistant_base_runner.py
+14
-1
assistant_cot_runner.py
api/core/features/assistant_cot_runner.py
+4
-3
assistant_fc_runner.py
api/core/features/assistant_fc_runner.py
+105
-23
model_entities.py
api/core/model_runtime/entities/model_entities.py
+1
-0
_constant.py
...e/model_runtime/model_providers/azure_openai/_constant.py
+5
-0
llm.py
...ore/model_runtime/model_providers/azure_openai/llm/llm.py
+24
-4
llm.py
api/core/model_runtime/model_providers/chatglm/llm/llm.py
+5
-1
abab5.5-chat.yaml
...del_runtime/model_providers/minimax/llm/abab5.5-chat.yaml
+2
-0
abab6-chat.yaml
...model_runtime/model_providers/minimax/llm/abab6-chat.yaml
+2
-0
chat_completion.py
...el_runtime/model_providers/minimax/llm/chat_completion.py
+1
-2
chat_completion_pro.py
...untime/model_providers/minimax/llm/chat_completion_pro.py
+37
-9
llm.py
api/core/model_runtime/model_providers/minimax/llm/llm.py
+42
-1
types.py
api/core/model_runtime/model_providers/minimax/llm/types.py
+10
-0
gpt-3.5-turbo-0613.yaml
...untime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml
+1
-0
gpt-3.5-turbo-1106.yaml
...untime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml
+1
-0
gpt-3.5-turbo-16k-0613.yaml
...me/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml
+1
-0
gpt-3.5-turbo-16k.yaml
...runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml
+1
-0
gpt-3.5-turbo.yaml
...del_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml
+1
-0
gpt-4-0125-preview.yaml
...untime/model_providers/openai/llm/gpt-4-0125-preview.yaml
+1
-0
gpt-4-1106-preview.yaml
...untime/model_providers/openai/llm/gpt-4-1106-preview.yaml
+1
-0
gpt-4-32k.yaml
...e/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml
+1
-0
gpt-4-turbo-preview.yaml
...ntime/model_providers/openai/llm/gpt-4-turbo-preview.yaml
+1
-0
gpt-4.yaml
api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml
+1
-0
llm.py
api/core/model_runtime/model_providers/openai/llm/llm.py
+1
-1
llm.py
api/core/model_runtime/model_providers/xinference/llm/llm.py
+27
-4
text_embedding.py
...del_providers/xinference/text_embedding/text_embedding.py
+19
-4
xinference_helper.py
...l_runtime/model_providers/xinference/xinference_helper.py
+19
-6
glm_3_turbo.yaml
...odel_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml
+4
-0
glm_4.yaml
...core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml
+4
-0
llm.py
api/core/model_runtime/model_providers/zhipuai/llm/llm.py
+21
-0
requirements.txt
api/requirements.txt
+1
-1
xinference.py
...ests/integration_tests/model_runtime/__mock/xinference.py
+61
-32
No files found.
api/core/app_runner/assistant_app_runner.py
View file @
6d5b3863
...
...
@@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from
core.memory.token_buffer_memory
import
TokenBufferMemory
from
core.model_manager
import
ModelInstance
from
core.model_runtime.entities.llm_entities
import
LLMUsage
from
core.model_runtime.entities.model_entities
import
ModelFeature
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.moderation.base
import
ModerationException
from
core.tools.entities.tool_entities
import
ToolRuntimeVariablePool
...
...
@@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner):
memory
=
memory
,
)
# change function call strategy based on LLM model
llm_model
=
cast
(
LargeLanguageModel
,
model_instance
.
model_type_instance
)
model_schema
=
llm_model
.
get_model_schema
(
model_instance
.
model
,
model_instance
.
credentials
)
if
set
([
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
TOOL_CALL
])
.
intersection
(
model_schema
.
features
):
agent_entity
.
strategy
=
AgentEntity
.
Strategy
.
FUNCTION_CALLING
# start agent runner
if
agent_entity
.
strategy
==
AgentEntity
.
Strategy
.
CHAIN_OF_THOUGHT
:
assistant_cot_runner
=
AssistantCotApplicationRunner
(
...
...
@@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner):
prompt_messages
=
prompt_message
,
variables_pool
=
tool_variables
,
db_variables
=
tool_conversation_variables
,
model_instance
=
model_instance
)
invoke_result
=
assistant_cot_runner
.
run
(
model_instance
=
model_instance
,
conversation
=
conversation
,
message
=
message
,
query
=
query
,
...
...
@@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner):
memory
=
memory
,
prompt_messages
=
prompt_message
,
variables_pool
=
tool_variables
,
db_variables
=
tool_conversation_variables
db_variables
=
tool_conversation_variables
,
model_instance
=
model_instance
)
invoke_result
=
assistant_fc_runner
.
run
(
model_instance
=
model_instance
,
conversation
=
conversation
,
message
=
message
,
query
=
query
,
...
...
api/core/features/assistant_base_runner.py
View file @
6d5b3863
import
logging
import
json
from
typing
import
Optional
,
List
,
Tuple
,
Union
from
typing
import
Optional
,
List
,
Tuple
,
Union
,
cast
from
datetime
import
datetime
from
mimetypes
import
guess_extension
...
...
@@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
AgentEntity
,
AppOrchestrationConfigEntity
,
ApplicationGenerateEntity
,
InvokeFrom
from
core.model_runtime.entities.message_entities
import
PromptMessage
,
PromptMessageTool
from
core.model_runtime.entities.llm_entities
import
LLMUsage
from
core.model_runtime.entities.model_entities
import
ModelFeature
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_manager
import
ModelInstance
from
core.file.message_file_parser
import
FileTransferMethod
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
prompt_messages
:
Optional
[
List
[
PromptMessage
]]
=
None
,
variables_pool
:
Optional
[
ToolRuntimeVariablePool
]
=
None
,
db_variables
:
Optional
[
ToolConversationVariables
]
=
None
,
model_instance
:
ModelInstance
=
None
)
->
None
:
"""
Agent runner
...
...
@@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self
.
history_prompt_messages
=
prompt_messages
self
.
variables_pool
=
variables_pool
self
.
db_variables_pool
=
db_variables
self
.
model_instance
=
model_instance
# init callback
self
.
agent_callback
=
DifyAgentCallbackHandler
()
...
...
@@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner):
MessageAgentThought
.
message_id
==
self
.
message
.
id
,
)
.
count
()
# check if model supports stream tool call
llm_model
=
cast
(
LargeLanguageModel
,
model_instance
.
model_type_instance
)
model_schema
=
llm_model
.
get_model_schema
(
model_instance
.
model
,
model_instance
.
credentials
)
if
model_schema
and
ModelFeature
.
STREAM_TOOL_CALL
in
(
model_schema
.
features
or
[]):
self
.
stream_tool_call
=
True
else
:
self
.
stream_tool_call
=
False
def
_repacket_app_orchestration_config
(
self
,
app_orchestration_config
:
AppOrchestrationConfigEntity
)
->
AppOrchestrationConfigEntity
:
"""
Repacket app orchestration config
...
...
api/core/features/assistant_cot_runner.py
View file @
6d5b3863
...
...
@@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from
models.model
import
Conversation
,
Message
class
AssistantCotApplicationRunner
(
BaseAssistantApplicationRunner
):
def
run
(
self
,
model_instance
:
ModelInstance
,
conversation
:
Conversation
,
def
run
(
self
,
conversation
:
Conversation
,
message
:
Message
,
query
:
str
,
)
->
Union
[
Generator
,
LLMResult
]:
...
...
@@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
llm_usage
.
prompt_price
+=
usage
.
prompt_price
llm_usage
.
completion_price
+=
usage
.
completion_price
model_instance
=
self
.
model_instance
while
function_call_state
and
iteration_step
<=
max_iteration_steps
:
# continue to run until there is not any tool call
function_call_state
=
False
...
...
@@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# remove Action: xxx from agent thought
agent_thought
=
re
.
sub
(
r'Action:.*'
,
''
,
agent_thought
,
flags
=
re
.
IGNORECASE
)
if
action_name
and
action_input
:
if
action_name
and
action_input
is
not
None
:
return
AgentScratchpadUnit
(
agent_response
=
content
,
thought
=
agent_thought
,
...
...
api/core/features/assistant_fc_runner.py
View file @
6d5b3863
...
...
@@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
from
core.model_runtime.entities.message_entities
import
PromptMessage
,
UserPromptMessage
,
\
SystemPromptMessage
,
AssistantPromptMessage
,
ToolPromptMessage
,
PromptMessageTool
from
core.model_runtime.entities.llm_entities
import
LLMResultChunk
,
LLMResult
,
LLMUsage
from
core.model_runtime.entities.llm_entities
import
LLMResultChunk
,
LLMResult
,
LLMUsage
,
LLMResultChunkDelta
from
core.model_manager
import
ModelInstance
from
core.application_queue_manager
import
PublishFrom
...
...
@@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
logger
=
logging
.
getLogger
(
__name__
)
class
AssistantFunctionCallApplicationRunner
(
BaseAssistantApplicationRunner
):
def
run
(
self
,
model_instance
:
ModelInstance
,
conversation
:
Conversation
,
def
run
(
self
,
conversation
:
Conversation
,
message
:
Message
,
query
:
str
,
)
->
Generator
[
LLMResultChunk
,
None
,
None
]:
...
...
@@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
llm_usage
.
prompt_price
+=
usage
.
prompt_price
llm_usage
.
completion_price
+=
usage
.
completion_price
model_instance
=
self
.
model_instance
while
function_call_state
and
iteration_step
<=
max_iteration_steps
:
function_call_state
=
False
...
...
@@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
# recale llm max tokens
self
.
recale_llm_max_tokens
(
self
.
model_config
,
prompt_messages
)
# invoke model
chunks
:
Generator
[
LLMResultChunk
,
None
,
None
]
=
model_instance
.
invoke_llm
(
chunks
:
Union
[
Generator
[
LLMResultChunk
,
None
,
None
],
LLMResult
]
=
model_instance
.
invoke_llm
(
prompt_messages
=
prompt_messages
,
model_parameters
=
app_orchestration_config
.
model_config
.
parameters
,
tools
=
prompt_messages_tools
,
stop
=
app_orchestration_config
.
model_config
.
stop
,
stream
=
True
,
stream
=
self
.
stream_tool_call
,
user
=
self
.
user_id
,
callbacks
=
[],
)
...
...
@@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
current_llm_usage
=
None
for
chunk
in
chunks
:
if
self
.
stream_tool_call
:
for
chunk
in
chunks
:
# check if there is any tool call
if
self
.
check_tool_calls
(
chunk
):
function_call_state
=
True
tool_calls
.
extend
(
self
.
extract_tool_calls
(
chunk
))
tool_call_names
=
';'
.
join
([
tool_call
[
1
]
for
tool_call
in
tool_calls
])
try
:
tool_call_inputs
=
json
.
dumps
({
tool_call
[
1
]:
tool_call
[
2
]
for
tool_call
in
tool_calls
},
ensure_ascii
=
False
)
except
json
.
JSONDecodeError
as
e
:
# ensure ascii to avoid encoding error
tool_call_inputs
=
json
.
dumps
({
tool_call
[
1
]:
tool_call
[
2
]
for
tool_call
in
tool_calls
})
if
chunk
.
delta
.
message
and
chunk
.
delta
.
message
.
content
:
if
isinstance
(
chunk
.
delta
.
message
.
content
,
list
):
for
content
in
chunk
.
delta
.
message
.
content
:
response
+=
content
.
data
else
:
response
+=
chunk
.
delta
.
message
.
content
if
chunk
.
delta
.
usage
:
increase_usage
(
llm_usage
,
chunk
.
delta
.
usage
)
current_llm_usage
=
chunk
.
delta
.
usage
yield
chunk
else
:
result
:
LLMResult
=
chunks
# check if there is any tool call
if
self
.
check_
tool_calls
(
chunk
):
if
self
.
check_
blocking_tool_calls
(
result
):
function_call_state
=
True
tool_calls
.
extend
(
self
.
extract_
tool_calls
(
chunk
))
tool_calls
.
extend
(
self
.
extract_
blocking_tool_calls
(
result
))
tool_call_names
=
';'
.
join
([
tool_call
[
1
]
for
tool_call
in
tool_calls
])
try
:
tool_call_inputs
=
json
.
dumps
({
...
...
@@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
tool_call
[
1
]:
tool_call
[
2
]
for
tool_call
in
tool_calls
})
if
chunk
.
delta
.
message
and
chunk
.
delta
.
message
.
content
:
if
isinstance
(
chunk
.
delta
.
message
.
content
,
list
):
for
content
in
chunk
.
delta
.
message
.
content
:
if
result
.
usage
:
increase_usage
(
llm_usage
,
result
.
usage
)
current_llm_usage
=
result
.
usage
if
result
.
message
and
result
.
message
.
content
:
if
isinstance
(
result
.
message
.
content
,
list
):
for
content
in
result
.
message
.
content
:
response
+=
content
.
data
else
:
response
+=
chunk
.
delta
.
message
.
content
if
chunk
.
delta
.
usage
:
increase_usage
(
llm_usage
,
chunk
.
delta
.
usage
)
current_llm_usage
=
chunk
.
delta
.
usage
response
+=
result
.
message
.
content
if
not
result
.
message
.
content
:
result
.
message
.
content
=
''
yield
LLMResultChunk
(
model
=
model_instance
.
model
,
prompt_messages
=
result
.
prompt_messages
,
system_fingerprint
=
result
.
system_fingerprint
,
delta
=
LLMResultChunkDelta
(
index
=
0
,
message
=
result
.
message
,
usage
=
result
.
usage
,
)
)
yield
chunk
if
tool_calls
:
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
''
,
name
=
''
,
tool_calls
=
[
AssistantPromptMessage
.
ToolCall
(
id
=
tool_call
[
0
],
type
=
'function'
,
function
=
AssistantPromptMessage
.
ToolCall
.
ToolCallFunction
(
name
=
tool_call
[
1
],
arguments
=
json
.
dumps
(
tool_call
[
2
],
ensure_ascii
=
False
)
)
)
for
tool_call
in
tool_calls
]
))
# save thought
self
.
save_agent_thought
(
...
...
@@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
final_answer
+=
response
+
'
\n
'
# update prompt messages
if
response
.
strip
():
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
response
,
))
# call tools
tool_responses
=
[]
for
tool_call_id
,
tool_call_name
,
tool_call_args
in
tool_calls
:
...
...
@@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
)
self
.
queue_manager
.
publish_agent_thought
(
agent_thought
,
PublishFrom
.
APPLICATION_MANAGER
)
# update prompt messages
if
response
.
strip
():
prompt_messages
.
append
(
AssistantPromptMessage
(
content
=
response
,
))
# update prompt tool
for
prompt_tool
in
prompt_messages_tools
:
self
.
update_prompt_message_tool
(
tool_instances
[
prompt_tool
.
name
],
prompt_tool
)
...
...
@@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
if
llm_result_chunk
.
delta
.
message
.
tool_calls
:
return
True
return
False
def
check_blocking_tool_calls
(
self
,
llm_result
:
LLMResult
)
->
bool
:
"""
Check if there is any blocking tool call in llm result
"""
if
llm_result
.
message
.
tool_calls
:
return
True
return
False
def
extract_tool_calls
(
self
,
llm_result_chunk
:
LLMResultChunk
)
->
Union
[
None
,
List
[
Tuple
[
str
,
str
,
Dict
[
str
,
Any
]]]]:
"""
...
...
@@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
))
return
tool_calls
def
extract_blocking_tool_calls
(
self
,
llm_result
:
LLMResult
)
->
Union
[
None
,
List
[
Tuple
[
str
,
str
,
Dict
[
str
,
Any
]]]]:
"""
Extract blocking tool calls from llm result
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls
=
[]
for
prompt_message
in
llm_result
.
message
.
tool_calls
:
tool_calls
.
append
((
prompt_message
.
id
,
prompt_message
.
function
.
name
,
json
.
loads
(
prompt_message
.
function
.
arguments
),
))
return
tool_calls
def
organize_prompt_messages
(
self
,
prompt_template
:
str
,
query
:
str
=
None
,
...
...
api/core/model_runtime/entities/model_entities.py
View file @
6d5b3863
...
...
@@ -78,6 +78,7 @@ class ModelFeature(Enum):
MULTI_TOOL_CALL
=
"multi-tool-call"
AGENT_THOUGHT
=
"agent-thought"
VISION
=
"vision"
STREAM_TOOL_CALL
=
"stream-tool-call"
class
DefaultParameterName
(
Enum
):
...
...
api/core/model_runtime/model_providers/azure_openai/_constant.py
View file @
6d5b3863
...
...
@@ -36,6 +36,7 @@ LLM_BASE_MODELS = [
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
...
...
@@ -80,6 +81,7 @@ LLM_BASE_MODELS = [
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
...
...
@@ -124,6 +126,7 @@ LLM_BASE_MODELS = [
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
...
...
@@ -198,6 +201,7 @@ LLM_BASE_MODELS = [
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
...
...
@@ -272,6 +276,7 @@ LLM_BASE_MODELS = [
features
=
[
ModelFeature
.
AGENT_THOUGHT
,
ModelFeature
.
MULTI_TOOL_CALL
,
ModelFeature
.
STREAM_TOOL_CALL
,
],
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
...
...
api/core/model_runtime/model_providers/azure_openai/llm/llm.py
View file @
6d5b3863
...
...
@@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
tools
:
Optional
[
list
[
PromptMessageTool
]]
=
None
)
->
Generator
:
index
=
0
full_assistant_content
=
''
delta_assistant_message_function_call_storage
:
ChoiceDeltaFunctionCall
=
None
real_model
=
model
system_fingerprint
=
None
completion
=
''
...
...
@@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
delta
=
chunk
.
choices
[
0
]
if
delta
.
finish_reason
is
None
and
(
delta
.
delta
.
content
is
None
or
delta
.
delta
.
content
==
''
):
if
delta
.
finish_reason
is
None
and
(
delta
.
delta
.
content
is
None
or
delta
.
delta
.
content
==
''
)
and
\
delta
.
delta
.
function_call
is
None
:
continue
# assistant_message_tool_calls = delta.delta.tool_calls
assistant_message_function_call
=
delta
.
delta
.
function_call
# extract tool calls from response
if
delta_assistant_message_function_call_storage
is
not
None
:
# handle process of stream function call
if
assistant_message_function_call
:
# message has not ended ever
delta_assistant_message_function_call_storage
.
arguments
+=
assistant_message_function_call
.
arguments
continue
else
:
# message has ended
assistant_message_function_call
=
delta_assistant_message_function_call_storage
delta_assistant_message_function_call_storage
=
None
else
:
if
assistant_message_function_call
:
# start of stream function call
delta_assistant_message_function_call_storage
=
assistant_message_function_call
if
delta_assistant_message_function_call_storage
.
arguments
is
None
:
delta_assistant_message_function_call_storage
.
arguments
=
''
continue
# extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call
=
self
.
_extract_response_function_call
(
assistant_message_function_call
)
...
...
@@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
if
message
.
name
is
not
None
:
if
message
.
name
:
message_dict
[
"name"
]
=
message
.
name
return
message_dict
...
...
@@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
num_tokens
=
0
for
tool
in
tools
:
num_tokens
+=
len
(
encoding
.
encode
(
'type'
))
num_tokens
+=
len
(
encoding
.
encode
(
tool
.
get
(
"type"
)))
num_tokens
+=
len
(
encoding
.
encode
(
'function'
))
# calculate num tokens for function object
...
...
api/core/model_runtime/model_providers/chatglm/llm/llm.py
View file @
6d5b3863
...
...
@@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageFunction
,
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
)
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
,
ToolPromptMessage
)
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
...
...
@@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
elif
isinstance
(
message
,
SystemPromptMessage
):
message
=
cast
(
SystemPromptMessage
,
message
)
message_dict
=
{
"role"
:
"system"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
ToolPromptMessage
):
# check if last message is user message
message
=
cast
(
ToolPromptMessage
,
message
)
message_dict
=
{
"role"
:
"function"
,
"content"
:
message
.
content
}
else
:
raise
ValueError
(
f
"Unknown message type {type(message)}"
)
...
...
api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml
View file @
6d5b3863
...
...
@@ -4,6 +4,8 @@ label:
model_type
:
llm
features
:
-
agent-thought
-
tool-call
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
16384
...
...
api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml
View file @
6d5b3863
...
...
@@ -4,6 +4,8 @@ label:
model_type
:
llm
features
:
-
agent-thought
-
tool-call
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
32768
...
...
api/core/model_runtime/model_providers/minimax/llm/chat_completion.py
View file @
6d5b3863
...
...
@@ -16,7 +16,7 @@ class MinimaxChatCompletion(object):
"""
def
generate
(
self
,
model
:
str
,
api_key
:
str
,
group_id
:
str
,
prompt_messages
:
List
[
MinimaxMessage
],
model_parameters
:
dict
,
tools
:
Dict
[
str
,
Any
],
stop
:
List
[
str
]
|
None
,
stream
:
bool
,
user
:
str
)
\
tools
:
List
[
Dict
[
str
,
Any
]
],
stop
:
List
[
str
]
|
None
,
stream
:
bool
,
user
:
str
)
\
->
Union
[
MinimaxMessage
,
Generator
[
MinimaxMessage
,
None
,
None
]]:
"""
generate chat completion
...
...
@@ -162,7 +162,6 @@ class MinimaxChatCompletion(object):
continue
for
choice
in
choices
:
print
(
choice
)
message
=
choice
[
'delta'
]
yield
MinimaxMessage
(
content
=
message
,
...
...
api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
View file @
6d5b3863
...
...
@@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object):
"""
def
generate
(
self
,
model
:
str
,
api_key
:
str
,
group_id
:
str
,
prompt_messages
:
List
[
MinimaxMessage
],
model_parameters
:
dict
,
tools
:
Dict
[
str
,
Any
],
stop
:
List
[
str
]
|
None
,
stream
:
bool
,
user
:
str
)
\
tools
:
List
[
Dict
[
str
,
Any
]
],
stop
:
List
[
str
]
|
None
,
stream
:
bool
,
user
:
str
)
\
->
Union
[
MinimaxMessage
,
Generator
[
MinimaxMessage
,
None
,
None
]]:
"""
generate chat completion
...
...
@@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object):
**
extra_kwargs
}
if
tools
:
body
[
'functions'
]
=
tools
body
[
'function_call'
]
=
{
'type'
:
'auto'
}
try
:
response
=
post
(
url
=
url
,
data
=
dumps
(
body
),
headers
=
headers
,
stream
=
stream
,
timeout
=
(
10
,
300
))
...
...
@@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object):
"""
handle stream chat generate response
"""
function_call_storage
=
None
for
line
in
response
.
iter_lines
():
if
not
line
:
continue
...
...
@@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object):
msg
=
data
[
'base_resp'
][
'status_msg'
]
self
.
_handle_error
(
code
,
msg
)
if
data
[
'reply'
]:
if
data
[
'reply'
]
or
'usage'
in
data
and
data
[
'usage'
]
:
total_tokens
=
data
[
'usage'
][
'total_tokens'
]
message
=
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
,
...
...
@@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object):
'total_tokens'
:
total_tokens
}
message
.
stop_reason
=
data
[
'choices'
][
0
][
'finish_reason'
]
if
function_call_storage
:
function_call_message
=
MinimaxMessage
(
content
=
''
,
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
)
function_call_message
.
function_call
=
function_call_storage
yield
function_call_message
yield
message
return
...
...
@@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object):
continue
for
choice
in
choices
:
message
=
choice
[
'messages'
][
0
][
'text'
]
if
not
message
:
continue
message
=
choice
[
'messages'
][
0
]
if
'function_call'
in
message
:
if
not
function_call_storage
:
function_call_storage
=
message
[
'function_call'
]
if
'arguments'
not
in
function_call_storage
or
not
function_call_storage
[
'arguments'
]:
function_call_storage
[
'arguments'
]
=
''
continue
else
:
function_call_storage
[
'arguments'
]
+=
message
[
'function_call'
][
'arguments'
]
continue
else
:
if
function_call_storage
:
message
[
'function_call'
]
=
function_call_storage
function_call_storage
=
None
yield
MinimaxMessage
(
content
=
message
,
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
)
\ No newline at end of file
minimax_message
=
MinimaxMessage
(
content
=
''
,
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
)
if
'function_call'
in
message
:
minimax_message
.
function_call
=
message
[
'function_call'
]
if
'text'
in
message
:
minimax_message
.
content
=
message
[
'text'
]
yield
minimax_message
\ No newline at end of file
api/core/model_runtime/model_providers/minimax/llm/llm.py
View file @
6d5b3863
...
...
@@ -2,7 +2,7 @@ from typing import Generator, List
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
)
SystemPromptMessage
,
UserPromptMessage
,
ToolPromptMessage
)
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
...
...
@@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
"""
client
:
MinimaxChatCompletionPro
=
self
.
model_apis
[
model
]()
if
tools
:
tools
=
[{
"name"
:
tool
.
name
,
"description"
:
tool
.
description
,
"parameters"
:
tool
.
parameters
}
for
tool
in
tools
]
response
=
client
.
generate
(
model
=
model
,
api_key
=
credentials
[
'minimax_api_key'
],
...
...
@@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
elif
isinstance
(
prompt_message
,
UserPromptMessage
):
return
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
USER
.
value
,
content
=
prompt_message
.
content
)
elif
isinstance
(
prompt_message
,
AssistantPromptMessage
):
if
prompt_message
.
tool_calls
:
message
=
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
,
content
=
''
)
message
.
function_call
=
{
'name'
:
prompt_message
.
tool_calls
[
0
]
.
function
.
name
,
'arguments'
:
prompt_message
.
tool_calls
[
0
]
.
function
.
arguments
}
return
message
return
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
ASSISTANT
.
value
,
content
=
prompt_message
.
content
)
elif
isinstance
(
prompt_message
,
ToolPromptMessage
):
return
MinimaxMessage
(
role
=
MinimaxMessage
.
Role
.
FUNCTION
.
value
,
content
=
prompt_message
.
content
)
else
:
raise
NotImplementedError
(
f
'Prompt message type {type(prompt_message)} is not supported'
)
...
...
@@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
finish_reason
=
message
.
stop_reason
if
message
.
stop_reason
else
None
,
),
)
elif
message
.
function_call
:
if
'name'
not
in
message
.
function_call
or
'arguments'
not
in
message
.
function_call
:
continue
yield
LLMResultChunk
(
model
=
model
,
prompt_messages
=
prompt_messages
,
delta
=
LLMResultChunkDelta
(
index
=
0
,
message
=
AssistantPromptMessage
(
content
=
''
,
tool_calls
=
[
AssistantPromptMessage
.
ToolCall
(
id
=
''
,
type
=
'function'
,
function
=
AssistantPromptMessage
.
ToolCall
.
ToolCallFunction
(
name
=
message
.
function_call
[
'name'
],
arguments
=
message
.
function_call
[
'arguments'
]
)
)]
),
),
)
else
:
yield
LLMResultChunk
(
model
=
model
,
...
...
api/core/model_runtime/model_providers/minimax/llm/types.py
View file @
6d5b3863
...
...
@@ -7,13 +7,23 @@ class MinimaxMessage:
USER
=
'USER'
ASSISTANT
=
'BOT'
SYSTEM
=
'SYSTEM'
FUNCTION
=
'FUNCTION'
role
:
str
=
Role
.
USER
.
value
content
:
str
usage
:
Dict
[
str
,
int
]
=
None
stop_reason
:
str
=
''
function_call
:
Dict
[
str
,
Any
]
=
None
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
if
self
.
function_call
and
self
.
role
==
MinimaxMessage
.
Role
.
ASSISTANT
.
value
:
return
{
'sender_type'
:
'BOT'
,
'sender_name'
:
'专家'
,
'text'
:
''
,
'function_call'
:
self
.
function_call
}
return
{
'sender_type'
:
self
.
role
,
'sender_name'
:
'我'
if
self
.
role
==
'USER'
else
'专家'
,
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
4096
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
16385
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
16385
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
16385
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
4096
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
128000
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
128000
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
32768
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
128000
...
...
api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml
View file @
6d5b3863
...
...
@@ -6,6 +6,7 @@ model_type: llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
context_size
:
8192
...
...
api/core/model_runtime/model_providers/openai/llm/llm.py
View file @
6d5b3863
...
...
@@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
if
message
.
name
is
not
None
:
if
message
.
name
:
message_dict
[
"name"
]
=
message
.
name
return
message_dict
...
...
api/core/model_runtime/model_providers/xinference/llm/llm.py
View file @
6d5b3863
...
...
@@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast
from
core.model_runtime.entities.common_entities
import
I18nObject
from
core.model_runtime.entities.llm_entities
import
LLMMode
,
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
)
SystemPromptMessage
,
UserPromptMessage
,
ToolPromptMessage
)
from
core.model_runtime.entities.model_entities
import
(
AIModelEntity
,
FetchFrom
,
ModelPropertyKey
,
ModelType
,
ParameterRule
,
ParameterType
)
ParameterRule
,
ParameterType
,
ModelFeature
)
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.xinference.
llm.
xinference_helper
import
(
XinferenceHelper
,
from
core.model_runtime.model_providers.xinference.xinference_helper
import
(
XinferenceHelper
,
XinferenceModelExtraParameter
)
from
core.model_runtime.utils
import
helper
from
openai
import
(
APIConnectionError
,
APITimeoutError
,
AuthenticationError
,
ConflictError
,
InternalServerError
,
...
...
@@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
"""
if
'temperature'
in
model_parameters
:
if
model_parameters
[
'temperature'
]
<
0.01
:
model_parameters
[
'temperature'
]
=
0.01
elif
model_parameters
[
'temperature'
]
>
1.0
:
model_parameters
[
'temperature'
]
=
0.99
return
self
.
_generate
(
model
=
model
,
credentials
=
credentials
,
prompt_messages
=
prompt_messages
,
model_parameters
=
model_parameters
,
tools
=
tools
,
stop
=
stop
,
stream
=
stream
,
user
=
user
,
...
...
@@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
credentials
[
'completion_type'
]
=
'completion'
else
:
raise
ValueError
(
f
'xinference model ability {extra_param.model_ability} is not supported'
)
if
extra_param
.
support_function_call
:
credentials
[
'support_function_call'
]
=
True
except
RuntimeError
as
e
:
raise
CredentialsValidateFailedError
(
f
'Xinference credentials validate failed: {e}'
)
...
...
@@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
elif
isinstance
(
message
,
SystemPromptMessage
):
message
=
cast
(
SystemPromptMessage
,
message
)
message_dict
=
{
"role"
:
"system"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
ToolPromptMessage
):
message
=
cast
(
ToolPromptMessage
,
message
)
message_dict
=
{
"tool_call_id"
:
message
.
tool_call_id
,
"role"
:
"tool"
,
"content"
:
message
.
content
}
else
:
raise
ValueError
(
f
"Unknown message type {type(message)}"
)
...
...
@@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
label
=
I18nObject
(
zh_Hans
=
'温度'
,
en_US
=
'Temperature'
)
)
,
),
ParameterRule
(
name
=
'top_p'
,
...
...
@@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_type
=
LLMMode
.
COMPLETION
.
value
else
:
raise
ValueError
(
f
'xinference model ability {extra_args.model_ability} is not supported'
)
support_function_call
=
credentials
.
get
(
'support_function_call'
,
False
)
entity
=
AIModelEntity
(
model
=
model
,
...
...
@@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
),
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_type
=
ModelType
.
LLM
,
features
=
[
ModelFeature
.
TOOL_CALL
]
if
support_function_call
else
[],
model_properties
=
{
ModelPropertyKey
.
MODE
:
completion_type
,
},
...
...
@@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
"""
if
'server_url'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'server_url is required in credentials'
)
if
credentials
[
'server_url'
]
.
endswith
(
'/'
):
credentials
[
'server_url'
]
=
credentials
[
'server_url'
][:
-
1
]
client
=
OpenAI
(
base_url
=
f
'{credentials["server_url"]}/v1'
,
api_key
=
'abc'
,
...
...
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
View file @
6d5b3863
...
...
@@ -2,7 +2,7 @@ import time
from
typing
import
Optional
from
core.model_runtime.entities.common_entities
import
I18nObject
from
core.model_runtime.entities.model_entities
import
AIModelEntity
,
FetchFrom
,
ModelType
,
PriceType
from
core.model_runtime.entities.model_entities
import
AIModelEntity
,
FetchFrom
,
ModelType
,
PriceType
,
ModelPropertyKey
from
core.model_runtime.entities.text_embedding_entities
import
EmbeddingUsage
,
TextEmbeddingResult
from
core.model_runtime.errors.invoke
import
(
InvokeAuthorizationError
,
InvokeBadRequestError
,
InvokeConnectionError
,
InvokeError
,
InvokeRateLimitError
,
InvokeServerUnavailableError
)
...
...
@@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from
core.model_runtime.model_providers.__base.text_embedding_model
import
TextEmbeddingModel
from
xinference_client.client.restful.restful_client
import
Client
,
RESTfulEmbeddingModelHandle
,
RESTfulModelHandle
from
core.model_runtime.model_providers.xinference.xinference_helper
import
XinferenceHelper
class
XinferenceTextEmbeddingModel
(
TextEmbeddingModel
):
"""
...
...
@@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
server_url
=
credentials
[
'server_url'
]
model_uid
=
credentials
[
'model_uid'
]
if
server_url
.
endswith
(
'/'
):
server_url
=
server_url
[:
-
1
]
client
=
Client
(
base_url
=
server_url
)
try
:
...
...
@@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
try
:
server_url
=
credentials
[
'server_url'
]
model_uid
=
credentials
[
'model_uid'
]
extra_args
=
XinferenceHelper
.
get_xinference_extra_parameter
(
server_url
=
server_url
,
model_uid
=
model_uid
)
if
extra_args
.
max_tokens
:
credentials
[
'max_tokens'
]
=
extra_args
.
max_tokens
self
.
_invoke
(
model
=
model
,
credentials
=
credentials
,
texts
=
[
'ping'
])
except
InvokeAuthorizationError
:
except
(
InvokeAuthorizationError
,
RuntimeError
)
:
raise
CredentialsValidateFailedError
(
'Invalid api key'
)
@
property
...
...
@@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
used to define customizable model schema
"""
entity
=
AIModelEntity
(
model
=
model
,
label
=
I18nObject
(
...
...
@@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
),
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_type
=
ModelType
.
TEXT_EMBEDDING
,
model_properties
=
{},
model_properties
=
{
ModelPropertyKey
.
MAX_CHUNKS
:
1
,
ModelPropertyKey
.
CONTEXT_SIZE
:
'max_tokens'
in
credentials
and
credentials
[
'max_tokens'
]
or
512
,
},
parameter_rules
=
[]
)
...
...
api/core/model_runtime/model_providers/xinference/
llm/
xinference_helper.py
→
api/core/model_runtime/model_providers/xinference/xinference_helper.py
View file @
6d5b3863
from
threading
import
Lock
from
time
import
time
from
typing
import
List
from
os
import
path
from
requests
import
get
from
requests.adapters
import
HTTPAdapter
...
...
@@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object):
model_format
:
str
model_handle_type
:
str
model_ability
:
List
[
str
]
max_tokens
:
int
=
512
support_function_call
:
bool
=
False
def
__init__
(
self
,
model_format
:
str
,
model_handle_type
:
str
,
model_ability
:
List
[
str
])
->
None
:
def
__init__
(
self
,
model_format
:
str
,
model_handle_type
:
str
,
model_ability
:
List
[
str
],
support_function_call
:
bool
,
max_tokens
:
int
)
->
None
:
self
.
model_format
=
model_format
self
.
model_handle_type
=
model_handle_type
self
.
model_ability
=
model_ability
self
.
support_function_call
=
support_function_call
self
.
max_tokens
=
max_tokens
cache
=
{}
cache_lock
=
Lock
()
...
...
@@ -49,7 +55,7 @@ class XinferenceHelper:
get xinference model extra parameter like model_format and model_handle_type
"""
url
=
f
'{server_url}/v1/models/{model_uid}'
url
=
path
.
join
(
server_url
,
'v1/models'
,
model_uid
)
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session
=
Session
()
...
...
@@ -66,10 +72,12 @@ class XinferenceHelper:
response_json
=
response
.
json
()
model_format
=
response_json
[
'model_format'
]
model_ability
=
response_json
[
'model_ability'
]
model_format
=
response_json
.
get
(
'model_format'
,
'ggmlv3'
)
model_ability
=
response_json
.
get
(
'model_ability'
,
[])
if
model_format
==
'ggmlv3'
and
'chatglm'
in
response_json
[
'model_name'
]:
if
response_json
.
get
(
'model_type'
)
==
'embedding'
:
model_handle_type
=
'embedding'
elif
model_format
==
'ggmlv3'
and
'chatglm'
in
response_json
[
'model_name'
]:
model_handle_type
=
'chatglm'
elif
'generate'
in
model_ability
:
model_handle_type
=
'generate'
...
...
@@ -78,8 +86,13 @@ class XinferenceHelper:
else
:
raise
NotImplementedError
(
f
'xinference model handle type {model_handle_type} is not supported'
)
support_function_call
=
'tools'
in
model_ability
max_tokens
=
response_json
.
get
(
'max_tokens'
,
512
)
return
XinferenceModelExtraParameter
(
model_format
=
model_format
,
model_handle_type
=
model_handle_type
,
model_ability
=
model_ability
model_ability
=
model_ability
,
support_function_call
=
support_function_call
,
max_tokens
=
max_tokens
)
\ No newline at end of file
api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml
View file @
6d5b3863
...
...
@@ -2,6 +2,10 @@ model: glm-3-turbo
label
:
en_US
:
glm-3-turbo
model_type
:
llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
parameter_rules
:
...
...
api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml
View file @
6d5b3863
...
...
@@ -2,6 +2,10 @@ model: glm-4
label
:
en_US
:
glm-4
model_type
:
llm
features
:
-
multi-tool-call
-
agent-thought
-
stream-tool-call
model_properties
:
mode
:
chat
parameter_rules
:
...
...
api/core/model_runtime/model_providers/zhipuai/llm/llm.py
View file @
6d5b3863
...
...
@@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
'content'
:
prompt_message
.
content
,
'tool_call_id'
:
prompt_message
.
tool_call_id
})
elif
isinstance
(
prompt_message
,
AssistantPromptMessage
):
if
prompt_message
.
tool_calls
:
params
[
'messages'
]
.
append
({
'role'
:
'assistant'
,
'content'
:
prompt_message
.
content
,
'tool_calls'
:
[
{
'id'
:
tool_call
.
id
,
'type'
:
tool_call
.
type
,
'function'
:
{
'name'
:
tool_call
.
function
.
name
,
'arguments'
:
tool_call
.
function
.
arguments
}
}
for
tool_call
in
prompt_message
.
tool_calls
]
})
else
:
params
[
'messages'
]
.
append
({
'role'
:
'assistant'
,
'content'
:
prompt_message
.
content
})
else
:
params
[
'messages'
]
.
append
({
'role'
:
prompt_message
.
role
.
value
,
...
...
api/requirements.txt
View file @
6d5b3863
...
...
@@ -47,7 +47,7 @@ dashscope[tokenizer]~=1.14.0
huggingface_hub~=0.16.4
transformers~=4.31.0
pandas==1.5.3
xinference-client~=0.
6.4
xinference-client~=0.
8.1
safetensors==0.3.2
zhipuai==1.0.7
werkzeug~=3.0.1
...
...
api/tests/integration_tests/model_runtime/__mock/xinference.py
View file @
6d5b3863
...
...
@@ -19,58 +19,86 @@ class MockXinferenceClass(object):
raise
RuntimeError
(
'404 Not Found'
)
if
'generate'
==
model_uid
:
return
RESTfulGenerateModelHandle
(
model_uid
,
base_url
=
self
.
base_url
)
return
RESTfulGenerateModelHandle
(
model_uid
,
base_url
=
self
.
base_url
,
auth_headers
=
{}
)
if
'chat'
==
model_uid
:
return
RESTfulChatModelHandle
(
model_uid
,
base_url
=
self
.
base_url
)
return
RESTfulChatModelHandle
(
model_uid
,
base_url
=
self
.
base_url
,
auth_headers
=
{}
)
if
'embedding'
==
model_uid
:
return
RESTfulEmbeddingModelHandle
(
model_uid
,
base_url
=
self
.
base_url
)
return
RESTfulEmbeddingModelHandle
(
model_uid
,
base_url
=
self
.
base_url
,
auth_headers
=
{}
)
if
'rerank'
==
model_uid
:
return
RESTfulRerankModelHandle
(
model_uid
,
base_url
=
self
.
base_url
)
return
RESTfulRerankModelHandle
(
model_uid
,
base_url
=
self
.
base_url
,
auth_headers
=
{}
)
raise
RuntimeError
(
'404 Not Found'
)
def
get
(
self
:
Session
,
url
:
str
,
**
kwargs
):
if
'/v1/models/'
in
url
:
response
=
Response
()
response
=
Response
()
if
'v1/models/'
in
url
:
# get model uid
model_uid
=
url
.
split
(
'/'
)[
-
1
]
if
not
re
.
match
(
r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
,
model_uid
)
and
\
model_uid
not
in
[
'generate'
,
'chat'
,
'embedding'
,
'rerank'
]:
response
.
status_code
=
404
r
aise
ConnectionError
(
'404 Not Found'
)
r
eturn
response
# check if url is valid
if
not
re
.
match
(
r'^(https?):\/\/[^\s\/$.?#].[^\s]*$'
,
url
):
response
.
status_code
=
404
raise
ConnectionError
(
'404 Not Found'
)
return
response
if
model_uid
in
[
'generate'
,
'chat'
]:
response
.
status_code
=
200
response
.
_content
=
b
'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return
response
elif
model_uid
==
'embedding'
:
response
.
status_code
=
200
response
.
_content
=
b
'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return
response
elif
'v1/cluster/auth'
in
url
:
response
.
status_code
=
200
response
.
_content
=
b
'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
"auth": true
}'''
return
response
def
_check_cluster_authenticated
(
self
):
self
.
_cluster_authed
=
True
def
rerank
(
self
:
RESTfulRerankModelHandle
,
documents
:
List
[
str
],
query
:
str
,
top_n
:
int
)
->
dict
:
# check if self._model_uid is a valid uuid
if
not
re
.
match
(
r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
,
self
.
_model_uid
)
and
\
...
...
@@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
def
setup_xinference_mock
(
request
,
monkeypatch
:
MonkeyPatch
):
if
MOCK
:
monkeypatch
.
setattr
(
Client
,
'get_model'
,
MockXinferenceClass
.
get_chat_model
)
monkeypatch
.
setattr
(
Client
,
'_check_cluster_authenticated'
,
MockXinferenceClass
.
_check_cluster_authenticated
)
monkeypatch
.
setattr
(
Session
,
'get'
,
MockXinferenceClass
.
get
)
monkeypatch
.
setattr
(
RESTfulEmbeddingModelHandle
,
'create_embedding'
,
MockXinferenceClass
.
create_embedding
)
monkeypatch
.
setattr
(
RESTfulRerankModelHandle
,
'rerank'
,
MockXinferenceClass
.
rerank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment