Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
b921c556
Unverified
Commit
b921c556
authored
Jan 25, 2024
by
Yeuoly
Committed by
GitHub
Jan 25, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Feat/zhipuai function calling (#2199)
Co-authored-by:
Joel
<
iamjoel007@gmail.com
>
parent
bdc5e9ce
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
46 changed files
with
2115 additions
and
138 deletions
+2115
-138
_client.py
api/core/model_runtime/model_providers/zhipuai/_client.py
+0
-61
llm.py
api/core/model_runtime/model_providers/zhipuai/llm/llm.py
+152
-59
text_embedding.py
.../model_providers/zhipuai/text_embedding/text_embedding.py
+11
-12
__init__.py
...l_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py
+17
-0
__version__.py
...untime/model_providers/zhipuai/zhipuai_sdk/__version__.py
+2
-0
_client.py
...el_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py
+71
-0
__init__.py
...el_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py
+5
-0
__init__.py
...oviders/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py
+0
-0
async_completions.py
...hipuai/zhipuai_sdk/api_resource/chat/async_completions.py
+87
-0
chat.py
...l_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py
+16
-0
completions.py
...ders/zhipuai/zhipuai_sdk/api_resource/chat/completions.py
+71
-0
embeddings.py
..._providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py
+49
-0
files.py
...model_providers/zhipuai/zhipuai_sdk/api_resource/files.py
+78
-0
__init__.py
.../zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py
+0
-0
fine_tuning.py
...ipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py
+15
-0
jobs.py
...ders/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py
+115
-0
images.py
...odel_providers/zhipuai/zhipuai_sdk/api_resource/images.py
+55
-0
__init__.py
...time/model_providers/zhipuai/zhipuai_sdk/core/__init__.py
+0
-0
_base_api.py
...ime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py
+17
-0
_base_type.py
...me/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py
+115
-0
_errors.py
...ntime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py
+90
-0
_files.py
...untime/model_providers/zhipuai/zhipuai_sdk/core/_files.py
+46
-0
_http_client.py
.../model_providers/zhipuai/zhipuai_sdk/core/_http_client.py
+377
-0
_jwt_token.py
...me/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py
+30
-0
_request_opt.py
.../model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py
+54
-0
_response.py
...ime/model_providers/zhipuai/zhipuai_sdk/core/_response.py
+121
-0
_sse_client.py
...e/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py
+149
-0
_utils.py
...untime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py
+18
-0
__init__.py
...ime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py
+0
-0
__init__.py
...odel_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py
+0
-0
async_chat_completion.py
...s/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py
+23
-0
chat_completion.py
...oviders/zhipuai/zhipuai_sdk/types/chat/chat_completion.py
+45
-0
chat_completion_chunk.py
...s/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py
+55
-0
chat_completions_create_param.py
...i/zhipuai_sdk/types/chat/chat_completions_create_param.py
+8
-0
embeddings.py
...e/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py
+20
-0
file_object.py
.../model_providers/zhipuai/zhipuai_sdk/types/file_object.py
+24
-0
__init__.py
...oviders/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py
+5
-0
fine_tuning_job.py
.../zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py
+52
-0
fine_tuning_job_event.py
...ai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py
+36
-0
job_create_params.py
...hipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py
+15
-0
image.py
...untime/model_providers/zhipuai/zhipuai_sdk/types/image.py
+18
-0
test_llm.py
...tests/integration_tests/model_runtime/zhipuai/test_llm.py
+47
-1
test_text_embedding.py
...ration_tests/model_runtime/zhipuai/test_text_embedding.py
+1
-1
index.tsx
web/app/components/app/chat/answer/index.tsx
+1
-1
index.tsx
web/app/components/app/configuration/index.tsx
+2
-3
index.ts
web/config/index.ts
+2
-0
No files found.
api/core/model_runtime/model_providers/zhipuai/_client.py
deleted
100644 → 0
View file @
bdc5e9ce
"""Wrapper around ZhipuAI APIs."""
from
__future__
import
annotations
import
logging
import
posixpath
from
pydantic
import
BaseModel
,
Extra
from
zhipuai.model_api.api
import
InvokeType
from
zhipuai.utils
import
jwt_token
from
zhipuai.utils.http_client
import
post
,
stream
from
zhipuai.utils.sse_client
import
SSEClient
logger
=
logging
.
getLogger
(
__name__
)
class
ZhipuModelAPI
(
BaseModel
):
base_url
:
str
=
"https://open.bigmodel.cn/api/paas/v3/model-api"
api_key
:
str
api_timeout_seconds
=
60
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
def
invoke
(
self
,
**
kwargs
):
url
=
self
.
_build_api_url
(
kwargs
,
InvokeType
.
SYNC
)
response
=
post
(
url
,
self
.
_generate_token
(),
kwargs
,
self
.
api_timeout_seconds
)
if
not
response
[
'success'
]:
raise
ValueError
(
f
"Error Code: {response['code']}, Message: {response['msg']} "
)
return
response
def
sse_invoke
(
self
,
**
kwargs
):
url
=
self
.
_build_api_url
(
kwargs
,
InvokeType
.
SSE
)
data
=
stream
(
url
,
self
.
_generate_token
(),
kwargs
,
self
.
api_timeout_seconds
)
return
SSEClient
(
data
)
def
_build_api_url
(
self
,
kwargs
,
*
path
):
if
kwargs
:
if
"model"
not
in
kwargs
:
raise
Exception
(
"model param missed"
)
model
=
kwargs
.
pop
(
"model"
)
else
:
model
=
"-"
return
posixpath
.
join
(
self
.
base_url
,
model
,
*
path
)
def
_generate_token
(
self
):
if
not
self
.
api_key
:
raise
Exception
(
"api_key not provided, you could provide it."
)
try
:
return
jwt_token
.
generate_token
(
self
.
api_key
)
except
Exception
:
raise
ValueError
(
f
"Your api_key is invalid, please check it."
)
api/core/model_runtime/model_providers/zhipuai/llm/llm.py
View file @
b921c556
...
...
@@ -3,13 +3,15 @@ from typing import Any, Dict, Generator, List, Optional, Union
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
PromptMessage
,
PromptMessageRole
,
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
,
PromptMessageTool
,
SystemPromptMessage
,
UserPromptMessage
,
ToolPromptMessage
,
TextPromptMessageContent
,
ImagePromptMessageContent
,
PromptMessageContentType
)
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.utils
import
helper
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.zhipuai._client
import
ZhipuModelAPI
from
core.model_runtime.model_providers.zhipuai._common
import
_CommonZhipuaiAI
from
core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client
import
ZhipuAI
from
core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk
import
ChatCompletionChunk
from
core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion
import
Completion
class
ZhipuAILargeLanguageModel
(
_CommonZhipuaiAI
,
LargeLanguageModel
):
...
...
@@ -35,7 +37,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
credentials_kwargs
=
self
.
_to_credential_kwargs
(
credentials
)
# invoke model
return
self
.
_generate
(
model
,
credentials_kwargs
,
prompt_messages
,
model_parameters
,
stop
,
stream
,
user
)
return
self
.
_generate
(
model
,
credentials_kwargs
,
prompt_messages
,
model_parameters
,
tools
,
stop
,
stream
,
user
)
def
get_num_tokens
(
self
,
model
:
str
,
credentials
:
dict
,
prompt_messages
:
list
[
PromptMessage
],
tools
:
Optional
[
list
[
PromptMessageTool
]]
=
None
)
->
int
:
...
...
@@ -48,7 +50,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:param tools: tools for tool calling
:return:
"""
prompt
=
self
.
_convert_messages_to_prompt
(
prompt_messages
)
prompt
=
self
.
_convert_messages_to_prompt
(
prompt_messages
,
tools
)
return
self
.
_get_num_tokens_by_gpt2
(
prompt
)
...
...
@@ -72,6 +74,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
model_parameters
=
{
"temperature"
:
0.5
,
},
tools
=
[],
stream
=
False
)
except
Exception
as
ex
:
...
...
@@ -79,6 +82,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
def
_generate
(
self
,
model
:
str
,
credentials_kwargs
:
dict
,
prompt_messages
:
list
[
PromptMessage
],
model_parameters
:
dict
,
tools
:
Optional
[
list
[
PromptMessageTool
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stream
:
bool
=
True
,
user
:
Optional
[
str
]
=
None
)
->
Union
[
LLMResult
,
Generator
]:
"""
...
...
@@ -97,7 +101,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if
stop
:
extra_model_kwargs
[
'stop_sequences'
]
=
stop
client
=
Zhipu
ModelAP
I
(
client
=
Zhipu
A
I
(
api_key
=
credentials_kwargs
[
'api_key'
]
)
...
...
@@ -128,11 +132,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
# not support image message
continue
if
new_prompt_messages
and
new_prompt_messages
[
-
1
]
.
role
==
PromptMessageRole
.
USER
:
if
new_prompt_messages
and
new_prompt_messages
[
-
1
]
.
role
==
PromptMessageRole
.
USER
and
\
copy_prompt_message
.
role
==
PromptMessageRole
.
USER
:
new_prompt_messages
[
-
1
]
.
content
+=
"
\n\n
"
+
copy_prompt_message
.
content
else
:
if
copy_prompt_message
.
role
==
PromptMessageRole
.
USER
:
new_prompt_messages
.
append
(
copy_prompt_message
)
elif
copy_prompt_message
.
role
==
PromptMessageRole
.
TOOL
:
new_prompt_messages
.
append
(
copy_prompt_message
)
elif
copy_prompt_message
.
role
==
PromptMessageRole
.
SYSTEM
:
new_prompt_message
=
SystemPromptMessage
(
content
=
copy_prompt_message
.
content
)
new_prompt_messages
.
append
(
new_prompt_message
)
else
:
new_prompt_message
=
UserPromptMessage
(
content
=
copy_prompt_message
.
content
)
new_prompt_messages
.
append
(
new_prompt_message
)
...
...
@@ -145,7 +155,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if
model
==
'glm-4v'
:
params
=
{
'model'
:
model
,
'
prompt
'
:
[{
'
messages
'
:
[{
'role'
:
prompt_message
.
role
.
value
,
'content'
:
[
...
...
@@ -171,23 +181,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else
:
params
=
{
'model'
:
model
,
'prompt'
:
[{
'role'
:
prompt_message
.
role
.
value
,
'content'
:
prompt_message
.
content
,
}
for
prompt_message
in
new_prompt_messages
],
'messages'
:
[],
**
model_parameters
}
# glm model
if
not
model
.
startswith
(
'chatglm'
):
for
prompt_message
in
new_prompt_messages
:
if
prompt_message
.
role
==
PromptMessageRole
.
TOOL
:
params
[
'messages'
]
.
append
({
'role'
:
'tool'
,
'content'
:
prompt_message
.
content
,
'tool_call_id'
:
prompt_message
.
tool_call_id
})
else
:
params
[
'messages'
]
.
append
({
'role'
:
prompt_message
.
role
.
value
,
'content'
:
prompt_message
.
content
})
else
:
# chatglm model
for
prompt_message
in
new_prompt_messages
:
# merge system message to user message
if
prompt_message
.
role
==
PromptMessageRole
.
SYSTEM
or
\
prompt_message
.
role
==
PromptMessageRole
.
TOOL
or
\
prompt_message
.
role
==
PromptMessageRole
.
USER
:
if
len
(
params
[
'messages'
])
>
0
and
params
[
'messages'
][
-
1
][
'role'
]
==
'user'
:
params
[
'messages'
][
-
1
][
'content'
]
+=
"
\n\n
"
+
prompt_message
.
content
else
:
params
[
'messages'
]
.
append
({
'role'
:
'user'
,
'content'
:
prompt_message
.
content
})
else
:
params
[
'messages'
]
.
append
({
'role'
:
prompt_message
.
role
.
value
,
'content'
:
prompt_message
.
content
})
if
tools
and
len
(
tools
)
>
0
:
params
[
'tools'
]
=
[
{
'type'
:
'function'
,
'function'
:
helper
.
dump_model
(
tool
)
}
for
tool
in
tools
]
if
stream
:
response
=
client
.
sse_invoke
(
incremental
=
True
,
**
params
)
.
events
(
)
return
self
.
_handle_generate_stream_response
(
model
,
credentials_kwargs
,
response
,
prompt_messages
)
response
=
client
.
chat
.
completions
.
create
(
stream
=
stream
,
**
params
)
return
self
.
_handle_generate_stream_response
(
model
,
credentials_kwargs
,
tools
,
response
,
prompt_messages
)
response
=
client
.
invok
e
(
**
params
)
return
self
.
_handle_generate_response
(
model
,
credentials_kwargs
,
response
,
prompt_messages
)
response
=
client
.
chat
.
completions
.
creat
e
(
**
params
)
return
self
.
_handle_generate_response
(
model
,
credentials_kwargs
,
tools
,
response
,
prompt_messages
)
def
_handle_generate_response
(
self
,
model
:
str
,
credentials
:
dict
,
response
:
Dict
[
str
,
Any
],
tools
:
Optional
[
list
[
PromptMessageTool
]],
response
:
Completion
,
prompt_messages
:
list
[
PromptMessage
])
->
LLMResult
:
"""
Handle llm response
...
...
@@ -197,26 +247,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response
"""
data
=
response
[
"data"
]
text
=
''
for
res
in
data
[
"choices"
]:
text
+=
res
[
'content'
]
assistant_tool_calls
:
List
[
AssistantPromptMessage
.
ToolCall
]
=
[]
for
choice
in
response
.
choices
:
if
choice
.
message
.
tool_calls
:
for
tool_call
in
choice
.
message
.
tool_calls
:
if
tool_call
.
type
==
'function'
:
assistant_tool_calls
.
append
(
AssistantPromptMessage
.
ToolCall
(
id
=
tool_call
.
id
,
type
=
tool_call
.
type
,
function
=
AssistantPromptMessage
.
ToolCall
.
ToolCallFunction
(
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
)
)
)
text
+=
choice
.
message
.
content
or
''
token_usage
=
data
.
get
(
"usage"
)
if
token_usage
is
not
None
:
if
'prompt_tokens'
not
in
token_usage
:
token_usage
[
'prompt_tokens'
]
=
0
if
'completion_tokens'
not
in
token_usage
:
token_usage
[
'completion_tokens'
]
=
token_usage
[
'total_tokens'
]
prompt_usage
=
response
.
usage
.
prompt_tokens
completion_usage
=
response
.
usage
.
completion_tokens
# transform usage
usage
=
self
.
_calc_response_usage
(
model
,
credentials
,
token_usage
[
'prompt_tokens'
],
token_usage
[
'completion_tokens'
]
)
usage
=
self
.
_calc_response_usage
(
model
,
credentials
,
prompt_usage
,
completion_usage
)
# transform response
result
=
LLMResult
(
model
=
model
,
prompt_messages
=
prompt_messages
,
message
=
AssistantPromptMessage
(
content
=
text
),
message
=
AssistantPromptMessage
(
content
=
text
,
tool_calls
=
assistant_tool_calls
),
usage
=
usage
,
)
...
...
@@ -224,7 +287,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
def
_handle_generate_stream_response
(
self
,
model
:
str
,
credentials
:
dict
,
responses
:
list
[
Generator
],
tools
:
Optional
[
list
[
PromptMessageTool
]],
responses
:
Generator
[
ChatCompletionChunk
,
None
,
None
],
prompt_messages
:
list
[
PromptMessage
])
->
Generator
:
"""
Handle llm stream response
...
...
@@ -234,39 +298,64 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
for
index
,
event
in
enumerate
(
responses
):
if
event
.
event
==
"add"
:
full_assistant_content
=
''
for
chunk
in
responses
:
if
len
(
chunk
.
choices
)
==
0
:
continue
delta
=
chunk
.
choices
[
0
]
if
delta
.
finish_reason
is
None
and
(
delta
.
delta
.
content
is
None
or
delta
.
delta
.
content
==
''
):
continue
assistant_tool_calls
:
List
[
AssistantPromptMessage
.
ToolCall
]
=
[]
for
tool_call
in
delta
.
delta
.
tool_calls
or
[]:
if
tool_call
.
type
==
'function'
:
assistant_tool_calls
.
append
(
AssistantPromptMessage
.
ToolCall
(
id
=
tool_call
.
id
,
type
=
tool_call
.
type
,
function
=
AssistantPromptMessage
.
ToolCall
.
ToolCallFunction
(
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
)
)
)
# transform assistant message to prompt message
assistant_prompt_message
=
AssistantPromptMessage
(
content
=
delta
.
delta
.
content
if
delta
.
delta
.
content
else
''
,
tool_calls
=
assistant_tool_calls
)
full_assistant_content
+=
delta
.
delta
.
content
if
delta
.
delta
.
content
else
''
if
delta
.
finish_reason
is
not
None
and
chunk
.
usage
is
not
None
:
completion_tokens
=
chunk
.
usage
.
completion_tokens
prompt_tokens
=
chunk
.
usage
.
prompt_tokens
# transform usage
usage
=
self
.
_calc_response_usage
(
model
,
credentials
,
prompt_tokens
,
completion_tokens
)
yield
LLMResultChunk
(
model
=
chunk
.
model
,
prompt_messages
=
prompt_messages
,
model
=
model
,
system_fingerprint
=
''
,
delta
=
LLMResultChunkDelta
(
index
=
index
,
message
=
AssistantPromptMessage
(
content
=
event
.
data
)
index
=
delta
.
index
,
message
=
assistant_prompt_message
,
finish_reason
=
delta
.
finish_reason
,
usage
=
usage
)
)
elif
event
.
event
==
"error"
or
event
.
event
==
"interrupted"
:
raise
ValueError
(
f
"{event.data}"
)
elif
event
.
event
==
"finish"
:
meta
=
json
.
loads
(
event
.
meta
)
token_usage
=
meta
[
'usage'
]
if
token_usage
is
not
None
:
if
'prompt_tokens'
not
in
token_usage
:
token_usage
[
'prompt_tokens'
]
=
0
if
'completion_tokens'
not
in
token_usage
:
token_usage
[
'completion_tokens'
]
=
token_usage
[
'total_tokens'
]
usage
=
self
.
_calc_response_usage
(
model
,
credentials
,
token_usage
[
'prompt_tokens'
],
token_usage
[
'completion_tokens'
])
else
:
yield
LLMResultChunk
(
model
=
model
,
model
=
chunk
.
model
,
prompt_messages
=
prompt_messages
,
system_fingerprint
=
''
,
delta
=
LLMResultChunkDelta
(
index
=
index
,
message
=
AssistantPromptMessage
(
content
=
event
.
data
),
finish_reason
=
'finish'
,
usage
=
usage
index
=
delta
.
index
,
message
=
assistant_prompt_message
,
)
)
...
...
@@ -291,11 +380,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
raise
ValueError
(
f
"Got unknown type {message}"
)
return
message_text
def
_convert_messages_to_prompt
(
self
,
messages
:
List
[
PromptMessage
])
->
str
:
"""
Format a list of messages into a full prompt for the Anthropic model
def
_convert_messages_to_prompt
(
self
,
messages
:
List
[
PromptMessage
],
tools
:
Optional
[
list
[
PromptMessageTool
]]
=
None
)
->
str
:
"""
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
...
...
@@ -306,5 +394,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
for
message
in
messages
)
if
tools
and
len
(
tools
)
>
0
:
text
+=
"
\n\n
Tools:"
for
tool
in
tools
:
text
+=
f
"
\n
{tool.json()}"
# trim off the trailing ' ' that might come from the "Assistant: "
return
text
.
rstrip
()
return
text
.
rstrip
()
\ No newline at end of file
api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
View file @
b921c556
...
...
@@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType
from
core.model_runtime.entities.text_embedding_entities
import
EmbeddingUsage
,
TextEmbeddingResult
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.__base.text_embedding_model
import
TextEmbeddingModel
from
core.model_runtime.model_providers.zhipuai.
_client
import
ZhipuModelAP
I
from
core.model_runtime.model_providers.zhipuai.
zhipuai_sdk._client
import
ZhipuA
I
from
core.model_runtime.model_providers.zhipuai._common
import
_CommonZhipuaiAI
from
langchain.schema.language_model
import
_get_token_ids_default_method
...
...
@@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
:return: embeddings result
"""
credentials_kwargs
=
self
.
_to_credential_kwargs
(
credentials
)
client
=
Zhipu
ModelAP
I
(
client
=
Zhipu
A
I
(
api_key
=
credentials_kwargs
[
'api_key'
]
)
...
...
@@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
try
:
# transform credentials to kwargs for model instance
credentials_kwargs
=
self
.
_to_credential_kwargs
(
credentials
)
client
=
Zhipu
ModelAP
I
(
client
=
Zhipu
A
I
(
api_key
=
credentials_kwargs
[
'api_key'
]
)
...
...
@@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
def
embed_documents
(
self
,
model
:
str
,
client
:
Zhipu
ModelAP
I
,
texts
:
List
[
str
])
->
Tuple
[
List
[
List
[
float
]],
int
]:
def
embed_documents
(
self
,
model
:
str
,
client
:
Zhipu
A
I
,
texts
:
List
[
str
])
->
Tuple
[
List
[
List
[
float
]],
int
]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
...
...
@@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
Returns:
List of embeddings, one for each text.
"""
embeddings
=
[]
for
text
in
texts
:
response
=
client
.
invoke
(
model
=
model
,
prompt
=
text
)
data
=
response
[
"data"
]
embeddings
.
append
(
data
.
get
(
'embedding'
))
embedding_used_tokens
=
0
embedding_used_tokens
=
data
.
get
(
'usage'
)
for
text
in
texts
:
response
=
client
.
embeddings
.
create
(
model
=
model
,
input
=
text
)
data
=
response
.
data
[
0
]
embeddings
.
append
(
data
.
embedding
)
embedding_used_tokens
+=
response
.
usage
.
total_tokens
return
[
list
(
map
(
float
,
e
))
for
e
in
embeddings
],
embedding_used_tokens
[
'total_tokens'
]
if
embedding_used_tokens
else
0
return
[
list
(
map
(
float
,
e
))
for
e
in
embeddings
],
embedding_used_tokens
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
"""Call out to ZhipuAI's embedding endpoint.
...
...
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py
0 → 100644
View file @
b921c556
from
._client
import
ZhipuAI
from
.core._errors
import
(
ZhipuAIError
,
APIStatusError
,
APIRequestFailedError
,
APIAuthenticationError
,
APIReachLimitError
,
APIInternalError
,
APIServerFlowExceedError
,
APIResponseError
,
APIResponseValidationError
,
APITimeoutError
,
)
from
.__version__
import
__version__
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py
0 → 100644
View file @
b921c556
__version__
=
'v2.0.1'
\ No newline at end of file
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Union
,
Mapping
from
typing_extensions
import
override
from
.core
import
_jwt_token
from
.core._errors
import
ZhipuAIError
from
.core._http_client
import
HttpClient
,
ZHIPUAI_DEFAULT_MAX_RETRIES
from
.core._base_type
import
NotGiven
,
NOT_GIVEN
from
.
import
api_resource
import
os
import
httpx
from
httpx
import
Timeout
class
ZhipuAI
(
HttpClient
):
chat
:
api_resource
.
chat
api_key
:
str
def
__init__
(
self
,
*
,
api_key
:
str
|
None
=
None
,
base_url
:
str
|
httpx
.
URL
|
None
=
None
,
timeout
:
Union
[
float
,
Timeout
,
None
,
NotGiven
]
=
NOT_GIVEN
,
max_retries
:
int
=
ZHIPUAI_DEFAULT_MAX_RETRIES
,
http_client
:
httpx
.
Client
|
None
=
None
,
custom_headers
:
Mapping
[
str
,
str
]
|
None
=
None
)
->
None
:
# if api_key is None:
# api_key = os.environ.get("ZHIPUAI_API_KEY")
if
api_key
is
None
:
raise
ZhipuAIError
(
"未提供api_key,请通过参数或环境变量提供"
)
self
.
api_key
=
api_key
if
base_url
is
None
:
base_url
=
os
.
environ
.
get
(
"ZHIPUAI_BASE_URL"
)
if
base_url
is
None
:
base_url
=
f
"https://open.bigmodel.cn/api/paas/v4"
from
.__version__
import
__version__
super
()
.
__init__
(
version
=
__version__
,
base_url
=
base_url
,
timeout
=
timeout
,
custom_httpx_client
=
http_client
,
custom_headers
=
custom_headers
,
)
self
.
chat
=
api_resource
.
chat
.
Chat
(
self
)
self
.
images
=
api_resource
.
images
.
Images
(
self
)
self
.
embeddings
=
api_resource
.
embeddings
.
Embeddings
(
self
)
self
.
files
=
api_resource
.
files
.
Files
(
self
)
self
.
fine_tuning
=
api_resource
.
fine_tuning
.
FineTuning
(
self
)
@
property
@
override
def
_auth_headers
(
self
)
->
dict
[
str
,
str
]:
api_key
=
self
.
api_key
return
{
"Authorization"
:
f
"{_jwt_token.generate_token(api_key)}"
}
def
__del__
(
self
)
->
None
:
if
(
not
hasattr
(
self
,
"_has_custom_http_client"
)
or
not
hasattr
(
self
,
"close"
)
or
not
hasattr
(
self
,
"_client"
)):
# if the '__init__' method raised an error, self would not have client attr
return
if
self
.
_has_custom_http_client
:
return
self
.
close
()
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py
0 → 100644
View file @
b921c556
from
.chat
import
chat
from
.images
import
Images
from
.embeddings
import
Embeddings
from
.files
import
Files
from
.fine_tuning
import
fine_tuning
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py
0 → 100644
View file @
b921c556
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Union
,
List
,
Optional
,
TYPE_CHECKING
import
httpx
from
typing_extensions
import
Literal
from
...core._base_api
import
BaseAPI
from
...core._base_type
import
NotGiven
,
NOT_GIVEN
,
Headers
from
...core._http_client
import
make_user_request_input
from
...types.chat.async_chat_completion
import
AsyncTaskStatus
,
AsyncCompletion
if
TYPE_CHECKING
:
from
..._client
import
ZhipuAI
class
AsyncCompletions
(
BaseAPI
):
def
__init__
(
self
,
client
:
"ZhipuAI"
)
->
None
:
super
()
.
__init__
(
client
)
def
create
(
self
,
*
,
model
:
str
,
request_id
:
Optional
[
str
]
|
NotGiven
=
NOT_GIVEN
,
do_sample
:
Optional
[
Literal
[
False
]]
|
Literal
[
True
]
|
NotGiven
=
NOT_GIVEN
,
temperature
:
Optional
[
float
]
|
NotGiven
=
NOT_GIVEN
,
top_p
:
Optional
[
float
]
|
NotGiven
=
NOT_GIVEN
,
max_tokens
:
int
|
NotGiven
=
NOT_GIVEN
,
seed
:
int
|
NotGiven
=
NOT_GIVEN
,
messages
:
Union
[
str
,
List
[
str
],
List
[
int
],
List
[
List
[
int
]],
None
],
stop
:
Optional
[
Union
[
str
,
List
[
str
],
None
]]
|
NotGiven
=
NOT_GIVEN
,
sensitive_word_check
:
Optional
[
object
]
|
NotGiven
=
NOT_GIVEN
,
tools
:
Optional
[
object
]
|
NotGiven
=
NOT_GIVEN
,
tool_choice
:
str
|
NotGiven
=
NOT_GIVEN
,
extra_headers
:
Headers
|
None
=
None
,
disable_strict_validation
:
Optional
[
bool
]
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
AsyncTaskStatus
:
_cast_type
=
AsyncTaskStatus
if
disable_strict_validation
:
_cast_type
=
object
return
self
.
_post
(
"/async/chat/completions"
,
body
=
{
"model"
:
model
,
"request_id"
:
request_id
,
"temperature"
:
temperature
,
"top_p"
:
top_p
,
"do_sample"
:
do_sample
,
"max_tokens"
:
max_tokens
,
"seed"
:
seed
,
"messages"
:
messages
,
"stop"
:
stop
,
"sensitive_word_check"
:
sensitive_word_check
,
"tools"
:
tools
,
"tool_choice"
:
tool_choice
,
},
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
),
cast_type
=
_cast_type
,
enable_stream
=
False
,
)
def
retrieve_completion_result
(
self
,
id
:
str
,
extra_headers
:
Headers
|
None
=
None
,
disable_strict_validation
:
Optional
[
bool
]
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
Union
[
AsyncCompletion
,
AsyncTaskStatus
]:
_cast_type
=
Union
[
AsyncCompletion
,
AsyncTaskStatus
]
if
disable_strict_validation
:
_cast_type
=
object
return
self
.
_get
(
path
=
f
"/async-result/{id}"
,
cast_type
=
_cast_type
,
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
)
)
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py
0 → 100644
View file @
b921c556
from
typing
import
TYPE_CHECKING
from
.completions
import
Completions
from
.async_completions
import
AsyncCompletions
from
...core._base_api
import
BaseAPI
if
TYPE_CHECKING
:
from
..._client
import
ZhipuAI
class
Chat
(
BaseAPI
):
completions
:
Completions
def
__init__
(
self
,
client
:
"ZhipuAI"
)
->
None
:
super
()
.
__init__
(
client
)
self
.
completions
=
Completions
(
client
)
self
.
asyncCompletions
=
AsyncCompletions
(
client
)
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Union
,
List
,
Optional
,
TYPE_CHECKING
import
httpx
from
typing_extensions
import
Literal
from
...core._base_api
import
BaseAPI
from
...core._base_type
import
NotGiven
,
NOT_GIVEN
,
Headers
from
...core._http_client
import
make_user_request_input
from
...core._sse_client
import
StreamResponse
from
...types.chat.chat_completion
import
Completion
from
...types.chat.chat_completion_chunk
import
ChatCompletionChunk
if
TYPE_CHECKING
:
from
..._client
import
ZhipuAI
class
Completions
(
BaseAPI
):
def
__init__
(
self
,
client
:
"ZhipuAI"
)
->
None
:
super
()
.
__init__
(
client
)
def
create
(
self
,
*
,
model
:
str
,
request_id
:
Optional
[
str
]
|
NotGiven
=
NOT_GIVEN
,
do_sample
:
Optional
[
Literal
[
False
]]
|
Literal
[
True
]
|
NotGiven
=
NOT_GIVEN
,
stream
:
Optional
[
Literal
[
False
]]
|
Literal
[
True
]
|
NotGiven
=
NOT_GIVEN
,
temperature
:
Optional
[
float
]
|
NotGiven
=
NOT_GIVEN
,
top_p
:
Optional
[
float
]
|
NotGiven
=
NOT_GIVEN
,
max_tokens
:
int
|
NotGiven
=
NOT_GIVEN
,
seed
:
int
|
NotGiven
=
NOT_GIVEN
,
messages
:
Union
[
str
,
List
[
str
],
List
[
int
],
object
,
None
],
stop
:
Optional
[
Union
[
str
,
List
[
str
],
None
]]
|
NotGiven
=
NOT_GIVEN
,
sensitive_word_check
:
Optional
[
object
]
|
NotGiven
=
NOT_GIVEN
,
tools
:
Optional
[
object
]
|
NotGiven
=
NOT_GIVEN
,
tool_choice
:
str
|
NotGiven
=
NOT_GIVEN
,
extra_headers
:
Headers
|
None
=
None
,
disable_strict_validation
:
Optional
[
bool
]
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
Completion
|
StreamResponse
[
ChatCompletionChunk
]:
_cast_type
=
Completion
_stream_cls
=
StreamResponse
[
ChatCompletionChunk
]
if
disable_strict_validation
:
_cast_type
=
object
_stream_cls
=
StreamResponse
[
object
]
return
self
.
_post
(
"/chat/completions"
,
body
=
{
"model"
:
model
,
"request_id"
:
request_id
,
"temperature"
:
temperature
,
"top_p"
:
top_p
,
"do_sample"
:
do_sample
,
"max_tokens"
:
max_tokens
,
"seed"
:
seed
,
"messages"
:
messages
,
"stop"
:
stop
,
"sensitive_word_check"
:
sensitive_word_check
,
"stream"
:
stream
,
"tools"
:
tools
,
"tool_choice"
:
tool_choice
,
},
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
),
cast_type
=
_cast_type
,
enable_stream
=
stream
or
False
,
stream_cls
=
_stream_cls
,
)
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Union
,
List
,
Optional
,
TYPE_CHECKING
import
httpx
from
..core._base_api
import
BaseAPI
from
..core._base_type
import
NotGiven
,
NOT_GIVEN
,
Headers
from
..core._http_client
import
make_user_request_input
from
..types.embeddings
import
EmbeddingsResponded
if
TYPE_CHECKING
:
from
.._client
import
ZhipuAI
class
Embeddings
(
BaseAPI
):
def
__init__
(
self
,
client
:
"ZhipuAI"
)
->
None
:
super
()
.
__init__
(
client
)
def
create
(
self
,
*
,
input
:
Union
[
str
,
List
[
str
],
List
[
int
],
List
[
List
[
int
]]],
model
:
Union
[
str
],
encoding_format
:
str
|
NotGiven
=
NOT_GIVEN
,
user
:
str
|
NotGiven
=
NOT_GIVEN
,
sensitive_word_check
:
Optional
[
object
]
|
NotGiven
=
NOT_GIVEN
,
extra_headers
:
Headers
|
None
=
None
,
disable_strict_validation
:
Optional
[
bool
]
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
EmbeddingsResponded
:
_cast_type
=
EmbeddingsResponded
if
disable_strict_validation
:
_cast_type
=
object
return
self
.
_post
(
"/embeddings"
,
body
=
{
"input"
:
input
,
"model"
:
model
,
"encoding_format"
:
encoding_format
,
"user"
:
user
,
"sensitive_word_check"
:
sensitive_word_check
,
},
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
),
cast_type
=
_cast_type
,
enable_stream
=
False
,
)
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
httpx
from
..core._base_api
import
BaseAPI
from
..core._base_type
import
NOT_GIVEN
,
Body
,
Query
,
Headers
,
NotGiven
,
FileTypes
from
..core._files
import
is_file_content
from
..core._http_client
import
(
make_user_request_input
,
)
from
..types.file_object
import
FileObject
,
ListOfFileObject
if
TYPE_CHECKING
:
from
.._client
import
ZhipuAI
__all__
=
[
"Files"
]
class
Files
(
BaseAPI
):
def
__init__
(
self
,
client
:
"ZhipuAI"
)
->
None
:
super
()
.
__init__
(
client
)
def
create
(
self
,
*
,
file
:
FileTypes
,
purpose
:
str
,
extra_headers
:
Headers
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
FileObject
:
if
not
is_file_content
(
file
):
prefix
=
f
"Expected file input `{file!r}`"
raise
RuntimeError
(
f
"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead."
)
from
None
files
=
[(
"file"
,
file
)]
extra_headers
=
{
"Content-Type"
:
"multipart/form-data"
,
**
(
extra_headers
or
{})}
return
self
.
_post
(
"/files"
,
body
=
{
"purpose"
:
purpose
,
},
files
=
files
,
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
),
cast_type
=
FileObject
,
)
def
list
(
self
,
*
,
purpose
:
str
|
NotGiven
=
NOT_GIVEN
,
limit
:
int
|
NotGiven
=
NOT_GIVEN
,
after
:
str
|
NotGiven
=
NOT_GIVEN
,
order
:
str
|
NotGiven
=
NOT_GIVEN
,
extra_headers
:
Headers
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
ListOfFileObject
:
return
self
.
_get
(
"/files"
,
cast_type
=
ListOfFileObject
,
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
,
query
=
{
"purpose"
:
purpose
,
"limit"
:
limit
,
"after"
:
after
,
"order"
:
order
,
},
),
)
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py
0 → 100644
View file @
b921c556
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py
0 → 100644
View file @
b921c556
from
typing
import
TYPE_CHECKING
from
.jobs
import
Jobs
from
...core._base_api
import
BaseAPI
if
TYPE_CHECKING
:
from
..._client
import
ZhipuAI
class
FineTuning
(
BaseAPI
):
jobs
:
Jobs
def
__init__
(
self
,
client
:
"ZhipuAI"
)
->
None
:
super
()
.
__init__
(
client
)
self
.
jobs
=
Jobs
(
client
)
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Optional
,
TYPE_CHECKING
import
httpx
from
...core._base_api
import
BaseAPI
from
...core._base_type
import
NOT_GIVEN
,
Headers
,
NotGiven
from
...core._http_client
import
(
make_user_request_input
,
)
from
...types.fine_tuning
import
(
FineTuningJob
,
job_create_params
,
ListOfFineTuningJob
,
FineTuningJobEvent
,
)
if
TYPE_CHECKING
:
from
..._client
import
ZhipuAI
__all__
=
[
"Jobs"
]
class
Jobs
(
BaseAPI
):
def
__init__
(
self
,
client
:
"ZhipuAI"
)
->
None
:
super
()
.
__init__
(
client
)
def
create
(
self
,
*
,
model
:
str
,
training_file
:
str
,
hyperparameters
:
job_create_params
.
Hyperparameters
|
NotGiven
=
NOT_GIVEN
,
suffix
:
Optional
[
str
]
|
NotGiven
=
NOT_GIVEN
,
request_id
:
Optional
[
str
]
|
NotGiven
=
NOT_GIVEN
,
validation_file
:
Optional
[
str
]
|
NotGiven
=
NOT_GIVEN
,
extra_headers
:
Headers
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
FineTuningJob
:
return
self
.
_post
(
"/fine_tuning/jobs"
,
body
=
{
"model"
:
model
,
"training_file"
:
training_file
,
"hyperparameters"
:
hyperparameters
,
"suffix"
:
suffix
,
"validation_file"
:
validation_file
,
"request_id"
:
request_id
,
},
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
),
cast_type
=
FineTuningJob
,
)
def
retrieve
(
self
,
fine_tuning_job_id
:
str
,
*
,
extra_headers
:
Headers
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
FineTuningJob
:
return
self
.
_get
(
f
"/fine_tuning/jobs/{fine_tuning_job_id}"
,
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
),
cast_type
=
FineTuningJob
,
)
def
list
(
self
,
*
,
after
:
str
|
NotGiven
=
NOT_GIVEN
,
limit
:
int
|
NotGiven
=
NOT_GIVEN
,
extra_headers
:
Headers
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
ListOfFineTuningJob
:
return
self
.
_get
(
"/fine_tuning/jobs"
,
cast_type
=
ListOfFineTuningJob
,
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
,
query
=
{
"after"
:
after
,
"limit"
:
limit
,
},
),
)
def
list_events
(
self
,
fine_tuning_job_id
:
str
,
*
,
after
:
str
|
NotGiven
=
NOT_GIVEN
,
limit
:
int
|
NotGiven
=
NOT_GIVEN
,
extra_headers
:
Headers
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
FineTuningJobEvent
:
return
self
.
_get
(
f
"/fine_tuning/jobs/{fine_tuning_job_id}/events"
,
cast_type
=
FineTuningJobEvent
,
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
,
query
=
{
"after"
:
after
,
"limit"
:
limit
,
},
),
)
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Union
,
List
,
Optional
,
TYPE_CHECKING
import
httpx
from
..core._base_api
import
BaseAPI
from
..core._base_type
import
NotGiven
,
NOT_GIVEN
,
Headers
from
..core._http_client
import
make_user_request_input
from
..types.image
import
ImagesResponded
if
TYPE_CHECKING
:
from
.._client
import
ZhipuAI
class
Images
(
BaseAPI
):
def
__init__
(
self
,
client
:
"ZhipuAI"
)
->
None
:
super
()
.
__init__
(
client
)
def
generations
(
self
,
*
,
prompt
:
str
,
model
:
str
|
NotGiven
=
NOT_GIVEN
,
n
:
Optional
[
int
]
|
NotGiven
=
NOT_GIVEN
,
quality
:
Optional
[
str
]
|
NotGiven
=
NOT_GIVEN
,
response_format
:
Optional
[
str
]
|
NotGiven
=
NOT_GIVEN
,
size
:
Optional
[
str
]
|
NotGiven
=
NOT_GIVEN
,
style
:
Optional
[
str
]
|
NotGiven
=
NOT_GIVEN
,
user
:
str
|
NotGiven
=
NOT_GIVEN
,
extra_headers
:
Headers
|
None
=
None
,
disable_strict_validation
:
Optional
[
bool
]
|
None
=
None
,
timeout
:
float
|
httpx
.
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
)
->
ImagesResponded
:
_cast_type
=
ImagesResponded
if
disable_strict_validation
:
_cast_type
=
object
return
self
.
_post
(
"/images/generations"
,
body
=
{
"prompt"
:
prompt
,
"model"
:
model
,
"n"
:
n
,
"quality"
:
quality
,
"response_format"
:
response_format
,
"size"
:
size
,
"style"
:
style
,
"user"
:
user
,
},
options
=
make_user_request_input
(
extra_headers
=
extra_headers
,
timeout
=
timeout
),
cast_type
=
_cast_type
,
enable_stream
=
False
,
)
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py
0 → 100644
View file @
b921c556
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
.._client
import
ZhipuAI
class
BaseAPI
:
_client
:
ZhipuAI
def
__init__
(
self
,
client
:
ZhipuAI
)
->
None
:
self
.
_client
=
client
self
.
_delete
=
client
.
delete
self
.
_get
=
client
.
get
self
.
_post
=
client
.
post
self
.
_put
=
client
.
put
self
.
_patch
=
client
.
patch
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
os
import
PathLike
from
typing
import
(
TYPE_CHECKING
,
Type
,
Union
,
Mapping
,
TypeVar
,
IO
,
Tuple
,
Sequence
,
Any
,
List
,
)
import
pydantic
from
typing_extensions
import
(
Literal
,
override
,
)
Query
=
Mapping
[
str
,
object
]
Body
=
object
AnyMapping
=
Mapping
[
str
,
object
]
PrimitiveData
=
Union
[
str
,
int
,
float
,
bool
,
None
]
Data
=
Union
[
PrimitiveData
,
List
[
Any
],
Tuple
[
Any
],
"Mapping[str, Any]"
]
ModelT
=
TypeVar
(
"ModelT"
,
bound
=
pydantic
.
BaseModel
)
_T
=
TypeVar
(
"_T"
)
if
TYPE_CHECKING
:
NoneType
:
Type
[
None
]
else
:
NoneType
=
type
(
None
)
# Sentinel class used until PEP 0661 is accepted
class
NotGiven
(
pydantic
.
BaseModel
):
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""
def
__bool__
(
self
)
->
Literal
[
False
]:
return
False
@
override
def
__repr__
(
self
)
->
str
:
return
"NOT_GIVEN"
NotGivenOr
=
Union
[
_T
,
NotGiven
]
NOT_GIVEN
=
NotGiven
()
class
Omit
(
pydantic
.
BaseModel
):
"""In certain situations you need to be able to represent a case where a default value has
to be explicitly removed and `None` is not an appropriate substitute, for example:
```py
# as the default `Content-Type` header is `application/json` that will be sent
client.post('/upload/files', files={'file': b'my raw file content'})
# you can't explicitly override the header as it has to be dynamically generated
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={'Content-Type': 'multipart/form-data'})
# instead you can remove the default `application/json` header by passing Omit
client.post(..., headers={'Content-Type': Omit()})
```
"""
def
__bool__
(
self
)
->
Literal
[
False
]:
return
False
Headers
=
Mapping
[
str
,
Union
[
str
,
Omit
]]
ResponseT
=
TypeVar
(
"ResponseT"
,
bound
=
"Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]"
,
)
# for user input files
if
TYPE_CHECKING
:
FileContent
=
Union
[
IO
[
bytes
],
bytes
,
PathLike
[
str
]]
else
:
FileContent
=
Union
[
IO
[
bytes
],
bytes
,
PathLike
]
FileTypes
=
Union
[
FileContent
,
# file content
Tuple
[
str
,
FileContent
],
# (filename, file)
Tuple
[
str
,
FileContent
,
str
],
# (filename, file , content_type)
Tuple
[
str
,
FileContent
,
str
,
Mapping
[
str
,
str
]],
# (filename, file , content_type, headers)
]
RequestFiles
=
Union
[
Mapping
[
str
,
FileTypes
],
Sequence
[
Tuple
[
str
,
FileTypes
]]]
# for httpx client supported files
HttpxFileContent
=
Union
[
bytes
,
IO
[
bytes
]]
HttpxFileTypes
=
Union
[
FileContent
,
# file content
Tuple
[
str
,
HttpxFileContent
],
# (filename, file)
Tuple
[
str
,
HttpxFileContent
,
str
],
# (filename, file , content_type)
Tuple
[
str
,
HttpxFileContent
,
str
,
Mapping
[
str
,
str
]],
# (filename, file , content_type, headers)
]
HttpxRequestFiles
=
Union
[
Mapping
[
str
,
HttpxFileTypes
],
Sequence
[
Tuple
[
str
,
HttpxFileTypes
]]]
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
import
httpx
__all__
=
[
"ZhipuAIError"
,
"APIStatusError"
,
"APIRequestFailedError"
,
"APIAuthenticationError"
,
"APIReachLimitError"
,
"APIInternalError"
,
"APIServerFlowExceedError"
,
"APIResponseError"
,
"APIResponseValidationError"
,
"APITimeoutError"
,
]
class
ZhipuAIError
(
Exception
):
def
__init__
(
self
,
message
:
str
,
)
->
None
:
super
()
.
__init__
(
message
)
class
APIStatusError
(
Exception
):
response
:
httpx
.
Response
status_code
:
int
def
__init__
(
self
,
message
:
str
,
*
,
response
:
httpx
.
Response
)
->
None
:
super
()
.
__init__
(
message
)
self
.
response
=
response
self
.
status_code
=
response
.
status_code
class
APIRequestFailedError
(
APIStatusError
):
...
class
APIAuthenticationError
(
APIStatusError
):
...
class
APIReachLimitError
(
APIStatusError
):
...
class
APIInternalError
(
APIStatusError
):
...
class
APIServerFlowExceedError
(
APIStatusError
):
...
class
APIResponseError
(
Exception
):
message
:
str
request
:
httpx
.
Request
json_data
:
object
def
__init__
(
self
,
message
:
str
,
request
:
httpx
.
Request
,
json_data
:
object
):
self
.
message
=
message
self
.
request
=
request
self
.
json_data
=
json_data
super
()
.
__init__
(
message
)
class
APIResponseValidationError
(
APIResponseError
):
status_code
:
int
response
:
httpx
.
Response
def
__init__
(
self
,
response
:
httpx
.
Response
,
json_data
:
object
|
None
,
*
,
message
:
str
|
None
=
None
)
->
None
:
super
()
.
__init__
(
message
=
message
or
"Data returned by API invalid for expected schema."
,
request
=
response
.
request
,
json_data
=
json_data
)
self
.
response
=
response
self
.
status_code
=
response
.
status_code
class
APITimeoutError
(
Exception
):
request
:
httpx
.
Request
def
__init__
(
self
,
request
:
httpx
.
Request
):
self
.
request
=
request
super
()
.
__init__
(
"Request Timeout"
)
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
import
io
import
os
from
pathlib
import
Path
from
typing
import
Mapping
,
Sequence
from
._base_type
import
(
FileTypes
,
HttpxFileTypes
,
HttpxRequestFiles
,
RequestFiles
,
)
def
is_file_content
(
obj
:
object
)
->
bool
:
return
isinstance
(
obj
,
(
bytes
,
tuple
,
io
.
IOBase
,
os
.
PathLike
))
def
_transform_file
(
file
:
FileTypes
)
->
HttpxFileTypes
:
if
is_file_content
(
file
):
if
isinstance
(
file
,
os
.
PathLike
):
path
=
Path
(
file
)
return
path
.
name
,
path
.
read_bytes
()
else
:
return
file
if
isinstance
(
file
,
tuple
):
if
isinstance
(
file
[
1
],
os
.
PathLike
):
return
(
file
[
0
],
Path
(
file
[
1
])
.
read_bytes
(),
*
file
[
2
:])
else
:
return
(
file
[
0
],
file
[
1
],
*
file
[
2
:])
else
:
raise
TypeError
(
f
"Unexpected input file with type {type(file)},Expected FileContent type or tuple type"
)
def
make_httpx_files
(
files
:
RequestFiles
|
None
)
->
HttpxRequestFiles
|
None
:
if
files
is
None
:
return
None
if
isinstance
(
files
,
Mapping
):
files
=
{
key
:
_transform_file
(
file
)
for
key
,
file
in
files
.
items
()}
elif
isinstance
(
files
,
Sequence
):
files
=
[(
key
,
_transform_file
(
file
))
for
key
,
file
in
files
]
else
:
raise
TypeError
(
f
"Unexpected input file with type {type(files)}, excepted Mapping or Sequence"
)
return
files
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py
0 → 100644
View file @
b921c556
# -*- coding:utf-8 -*-
from
__future__
import
annotations
import
inspect
from
typing
import
(
Any
,
Type
,
Union
,
cast
,
Mapping
,
)
import
httpx
import
pydantic
from
httpx
import
URL
,
Timeout
from
.
import
_errors
from
._base_type
import
NotGiven
,
ResponseT
,
Body
,
Headers
,
NOT_GIVEN
,
RequestFiles
,
Query
,
Data
from
._errors
import
APIResponseValidationError
,
APIStatusError
,
APITimeoutError
from
._files
import
make_httpx_files
from
._request_opt
import
ClientRequestParam
,
UserRequestInput
from
._response
import
HttpResponse
from
._sse_client
import
StreamResponse
from
._utils
import
flatten
headers
=
{
"Accept"
:
"application/json"
,
"Content-Type"
:
"application/json; charset=UTF-8"
,
}
def
_merge_map
(
map1
:
Mapping
,
map2
:
Mapping
)
->
Mapping
:
merged
=
{
**
map1
,
**
map2
}
return
{
key
:
val
for
key
,
val
in
merged
.
items
()
if
val
is
not
None
}
from
httpx._config
import
DEFAULT_TIMEOUT_CONFIG
as
HTTPX_DEFAULT_TIMEOUT
ZHIPUAI_DEFAULT_TIMEOUT
=
httpx
.
Timeout
(
timeout
=
300.0
,
connect
=
8.0
)
ZHIPUAI_DEFAULT_MAX_RETRIES
=
3
ZHIPUAI_DEFAULT_LIMITS
=
httpx
.
Limits
(
max_connections
=
50
,
max_keepalive_connections
=
10
)
class
HttpClient
:
_client
:
httpx
.
Client
_version
:
str
_base_url
:
URL
timeout
:
Union
[
float
,
Timeout
,
None
]
_limits
:
httpx
.
Limits
_has_custom_http_client
:
bool
_default_stream_cls
:
type
[
StreamResponse
[
Any
]]
|
None
=
None
def
__init__
(
self
,
*
,
version
:
str
,
base_url
:
URL
,
timeout
:
Union
[
float
,
Timeout
,
None
],
custom_httpx_client
:
httpx
.
Client
|
None
=
None
,
custom_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
)
->
None
:
if
timeout
is
None
or
isinstance
(
timeout
,
NotGiven
):
if
custom_httpx_client
and
custom_httpx_client
.
timeout
!=
HTTPX_DEFAULT_TIMEOUT
:
timeout
=
custom_httpx_client
.
timeout
else
:
timeout
=
ZHIPUAI_DEFAULT_TIMEOUT
self
.
timeout
=
cast
(
Timeout
,
timeout
)
self
.
_has_custom_http_client
=
bool
(
custom_httpx_client
)
self
.
_client
=
custom_httpx_client
or
httpx
.
Client
(
base_url
=
base_url
,
timeout
=
self
.
timeout
,
limits
=
ZHIPUAI_DEFAULT_LIMITS
,
)
self
.
_version
=
version
url
=
URL
(
url
=
base_url
)
if
not
url
.
raw_path
.
endswith
(
b
"/"
):
url
=
url
.
copy_with
(
raw_path
=
url
.
raw_path
+
b
"/"
)
self
.
_base_url
=
url
self
.
_custom_headers
=
custom_headers
or
{}
def
_prepare_url
(
self
,
url
:
str
)
->
URL
:
sub_url
=
URL
(
url
)
if
sub_url
.
is_relative_url
:
request_raw_url
=
self
.
_base_url
.
raw_path
+
sub_url
.
raw_path
.
lstrip
(
b
"/"
)
return
self
.
_base_url
.
copy_with
(
raw_path
=
request_raw_url
)
return
sub_url
@
property
def
_default_headers
(
self
):
return
\
{
"Accept"
:
"application/json"
,
"Content-Type"
:
"application/json; charset=UTF-8"
,
"ZhipuAI-SDK-Ver"
:
self
.
_version
,
"source_type"
:
"zhipu-sdk-python"
,
"x-request-sdk"
:
"zhipu-sdk-python"
,
**
self
.
_auth_headers
,
**
self
.
_custom_headers
,
}
@
property
def
_auth_headers
(
self
):
return
{}
def
_prepare_headers
(
self
,
request_param
:
ClientRequestParam
)
->
httpx
.
Headers
:
custom_headers
=
request_param
.
headers
or
{}
headers_dict
=
_merge_map
(
self
.
_default_headers
,
custom_headers
)
httpx_headers
=
httpx
.
Headers
(
headers_dict
)
return
httpx_headers
def
_prepare_request
(
self
,
request_param
:
ClientRequestParam
)
->
httpx
.
Request
:
kwargs
:
dict
[
str
,
Any
]
=
{}
json_data
=
request_param
.
json_data
headers
=
self
.
_prepare_headers
(
request_param
)
url
=
self
.
_prepare_url
(
request_param
.
url
)
json_data
=
request_param
.
json_data
if
headers
.
get
(
"Content-Type"
)
==
"multipart/form-data"
:
headers
.
pop
(
"Content-Type"
)
if
json_data
:
kwargs
[
"data"
]
=
self
.
_make_multipartform
(
json_data
)
return
self
.
_client
.
build_request
(
headers
=
headers
,
timeout
=
self
.
timeout
if
isinstance
(
request_param
.
timeout
,
NotGiven
)
else
request_param
.
timeout
,
method
=
request_param
.
method
,
url
=
url
,
json
=
json_data
,
files
=
request_param
.
files
,
params
=
request_param
.
params
,
**
kwargs
,
)
def
_object_to_formfata
(
self
,
key
:
str
,
value
:
Data
|
Mapping
[
object
,
object
])
->
list
[
tuple
[
str
,
str
]]:
items
=
[]
if
isinstance
(
value
,
Mapping
):
for
k
,
v
in
value
.
items
():
items
.
extend
(
self
.
_object_to_formfata
(
f
"{key}[{k}]"
,
v
))
return
items
if
isinstance
(
value
,
(
list
,
tuple
)):
for
v
in
value
:
items
.
extend
(
self
.
_object_to_formfata
(
key
+
"[]"
,
v
))
return
items
def
_primitive_value_to_str
(
val
)
->
str
:
# copied from httpx
if
val
is
True
:
return
"true"
elif
val
is
False
:
return
"false"
elif
val
is
None
:
return
""
return
str
(
val
)
str_data
=
_primitive_value_to_str
(
value
)
if
not
str_data
:
return
[]
return
[(
key
,
str_data
)]
def
_make_multipartform
(
self
,
data
:
Mapping
[
object
,
object
])
->
dict
[
str
,
object
]:
items
=
flatten
([
self
.
_object_to_formfata
(
k
,
v
)
for
k
,
v
in
data
.
items
()])
serialized
:
dict
[
str
,
object
]
=
{}
for
key
,
value
in
items
:
if
key
in
serialized
:
raise
ValueError
(
f
"存在重复的键: {key};"
)
serialized
[
key
]
=
value
return
serialized
def
_parse_response
(
self
,
*
,
cast_type
:
Type
[
ResponseT
],
response
:
httpx
.
Response
,
enable_stream
:
bool
,
request_param
:
ClientRequestParam
,
stream_cls
:
type
[
StreamResponse
[
Any
]]
|
None
=
None
,
)
->
HttpResponse
:
http_response
=
HttpResponse
(
raw_response
=
response
,
cast_type
=
cast_type
,
client
=
self
,
enable_stream
=
enable_stream
,
stream_cls
=
stream_cls
)
return
http_response
.
parse
()
def
_process_response_data
(
self
,
*
,
data
:
object
,
cast_type
:
type
[
ResponseT
],
response
:
httpx
.
Response
,
)
->
ResponseT
:
if
data
is
None
:
return
cast
(
ResponseT
,
None
)
try
:
if
inspect
.
isclass
(
cast_type
)
and
issubclass
(
cast_type
,
pydantic
.
BaseModel
):
return
cast
(
ResponseT
,
cast_type
.
validate
(
data
))
return
cast
(
ResponseT
,
pydantic
.
TypeAdapter
(
cast_type
)
.
validate_python
(
data
))
except
pydantic
.
ValidationError
as
err
:
raise
APIResponseValidationError
(
response
=
response
,
json_data
=
data
)
from
err
def
is_closed
(
self
)
->
bool
:
return
self
.
_client
.
is_closed
def
close
(
self
):
self
.
_client
.
close
()
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
close
()
def
request
(
self
,
*
,
cast_type
:
Type
[
ResponseT
],
params
:
ClientRequestParam
,
enable_stream
:
bool
=
False
,
stream_cls
:
type
[
StreamResponse
[
Any
]]
|
None
=
None
,
)
->
ResponseT
|
StreamResponse
:
request
=
self
.
_prepare_request
(
params
)
try
:
response
=
self
.
_client
.
send
(
request
,
stream
=
enable_stream
,
)
response
.
raise_for_status
()
except
httpx
.
TimeoutException
as
err
:
raise
APITimeoutError
(
request
=
request
)
from
err
except
httpx
.
HTTPStatusError
as
err
:
err
.
response
.
read
()
# raise err
raise
self
.
_make_status_error
(
err
.
response
)
from
None
except
Exception
as
err
:
raise
err
return
self
.
_parse_response
(
cast_type
=
cast_type
,
request_param
=
params
,
response
=
response
,
enable_stream
=
enable_stream
,
stream_cls
=
stream_cls
,
)
def
get
(
self
,
path
:
str
,
*
,
cast_type
:
Type
[
ResponseT
],
options
:
UserRequestInput
=
{},
enable_stream
:
bool
=
False
,
)
->
ResponseT
|
StreamResponse
:
opts
=
ClientRequestParam
.
construct
(
method
=
"get"
,
url
=
path
,
**
options
)
return
self
.
request
(
cast_type
=
cast_type
,
params
=
opts
,
enable_stream
=
enable_stream
)
def
post
(
self
,
path
:
str
,
*
,
body
:
Body
|
None
=
None
,
cast_type
:
Type
[
ResponseT
],
options
:
UserRequestInput
=
{},
files
:
RequestFiles
|
None
=
None
,
enable_stream
:
bool
=
False
,
stream_cls
:
type
[
StreamResponse
[
Any
]]
|
None
=
None
,
)
->
ResponseT
|
StreamResponse
:
opts
=
ClientRequestParam
.
construct
(
method
=
"post"
,
json_data
=
body
,
files
=
make_httpx_files
(
files
),
url
=
path
,
**
options
)
return
self
.
request
(
cast_type
=
cast_type
,
params
=
opts
,
enable_stream
=
enable_stream
,
stream_cls
=
stream_cls
)
def
patch
(
self
,
path
:
str
,
*
,
body
:
Body
|
None
=
None
,
cast_type
:
Type
[
ResponseT
],
options
:
UserRequestInput
=
{},
)
->
ResponseT
:
opts
=
ClientRequestParam
.
construct
(
method
=
"patch"
,
url
=
path
,
json_data
=
body
,
**
options
)
return
self
.
request
(
cast_type
=
cast_type
,
params
=
opts
,
)
def
put
(
self
,
path
:
str
,
*
,
body
:
Body
|
None
=
None
,
cast_type
:
Type
[
ResponseT
],
options
:
UserRequestInput
=
{},
files
:
RequestFiles
|
None
=
None
,
)
->
ResponseT
|
StreamResponse
:
opts
=
ClientRequestParam
.
construct
(
method
=
"put"
,
url
=
path
,
json_data
=
body
,
files
=
make_httpx_files
(
files
),
**
options
)
return
self
.
request
(
cast_type
=
cast_type
,
params
=
opts
,
)
def
delete
(
self
,
path
:
str
,
*
,
body
:
Body
|
None
=
None
,
cast_type
:
Type
[
ResponseT
],
options
:
UserRequestInput
=
{},
)
->
ResponseT
|
StreamResponse
:
opts
=
ClientRequestParam
.
construct
(
method
=
"delete"
,
url
=
path
,
json_data
=
body
,
**
options
)
return
self
.
request
(
cast_type
=
cast_type
,
params
=
opts
,
)
def
_make_status_error
(
self
,
response
)
->
APIStatusError
:
response_text
=
response
.
text
.
strip
()
status_code
=
response
.
status_code
error_msg
=
f
"Error code: {status_code}, with error text {response_text}"
if
status_code
==
400
:
return
_errors
.
APIRequestFailedError
(
message
=
error_msg
,
response
=
response
)
elif
status_code
==
401
:
return
_errors
.
APIAuthenticationError
(
message
=
error_msg
,
response
=
response
)
elif
status_code
==
429
:
return
_errors
.
APIReachLimitError
(
message
=
error_msg
,
response
=
response
)
elif
status_code
==
500
:
return
_errors
.
APIInternalError
(
message
=
error_msg
,
response
=
response
)
elif
status_code
==
503
:
return
_errors
.
APIServerFlowExceedError
(
message
=
error_msg
,
response
=
response
)
return
APIStatusError
(
message
=
error_msg
,
response
=
response
)
def
make_user_request_input
(
max_retries
:
int
|
None
=
None
,
timeout
:
float
|
Timeout
|
None
|
NotGiven
=
NOT_GIVEN
,
extra_headers
:
Headers
=
None
,
query
:
Query
|
None
=
None
,
)
->
UserRequestInput
:
options
:
UserRequestInput
=
{}
if
extra_headers
is
not
None
:
options
[
"headers"
]
=
extra_headers
if
max_retries
is
not
None
:
options
[
"max_retries"
]
=
max_retries
if
not
isinstance
(
timeout
,
NotGiven
):
options
[
'timeout'
]
=
timeout
if
query
is
not
None
:
options
[
"params"
]
=
query
return
options
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py
0 → 100644
View file @
b921c556
# -*- coding:utf-8 -*-
import
time
import
cachetools.func
import
jwt
API_TOKEN_TTL_SECONDS
=
3
*
60
CACHE_TTL_SECONDS
=
API_TOKEN_TTL_SECONDS
-
30
@
cachetools
.
func
.
ttl_cache
(
maxsize
=
10
,
ttl
=
CACHE_TTL_SECONDS
)
def
generate_token
(
apikey
:
str
):
try
:
api_key
,
secret
=
apikey
.
split
(
"."
)
except
Exception
as
e
:
raise
Exception
(
"invalid api_key"
,
e
)
payload
=
{
"api_key"
:
api_key
,
"exp"
:
int
(
round
(
time
.
time
()
*
1000
))
+
API_TOKEN_TTL_SECONDS
*
1000
,
"timestamp"
:
int
(
round
(
time
.
time
()
*
1000
)),
}
ret
=
jwt
.
encode
(
payload
,
secret
,
algorithm
=
"HS256"
,
headers
=
{
"alg"
:
"HS256"
,
"sign_type"
:
"SIGN"
},
)
return
ret
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Union
,
Any
,
cast
import
pydantic.generics
from
httpx
import
Timeout
from
pydantic
import
ConfigDict
from
typing_extensions
import
(
Unpack
,
ClassVar
,
TypedDict
)
from
._base_type
import
Body
,
NotGiven
,
Headers
,
HttpxRequestFiles
,
Query
from
._utils
import
remove_notgiven_indict
class
UserRequestInput
(
TypedDict
,
total
=
False
):
max_retries
:
int
timeout
:
float
|
Timeout
|
None
headers
:
Headers
params
:
Query
|
None
class
ClientRequestParam
():
method
:
str
url
:
str
max_retries
:
Union
[
int
,
NotGiven
]
=
NotGiven
()
timeout
:
Union
[
float
,
NotGiven
]
=
NotGiven
()
headers
:
Union
[
Headers
,
NotGiven
]
=
NotGiven
()
json_data
:
Union
[
Body
,
None
]
=
None
files
:
Union
[
HttpxRequestFiles
,
None
]
=
None
params
:
Query
=
{}
model_config
:
ClassVar
[
ConfigDict
]
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
def
get_max_retries
(
self
,
max_retries
)
->
int
:
if
isinstance
(
self
.
max_retries
,
NotGiven
):
return
max_retries
return
self
.
max_retries
@
classmethod
def
construct
(
# type: ignore
cls
,
_fields_set
:
set
[
str
]
|
None
=
None
,
**
values
:
Unpack
[
UserRequestInput
],
)
->
ClientRequestParam
:
kwargs
:
dict
[
str
,
Any
]
=
{
key
:
remove_notgiven_indict
(
value
)
for
key
,
value
in
values
.
items
()
}
client
=
cls
()
client
.
__dict__
.
update
(
kwargs
)
return
client
model_construct
=
construct
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
import
datetime
from
typing
import
TypeVar
,
Generic
,
cast
,
Any
,
TYPE_CHECKING
import
httpx
import
pydantic
from
typing_extensions
import
ParamSpec
,
get_origin
,
get_args
from
._base_type
import
NoneType
from
._sse_client
import
StreamResponse
if
TYPE_CHECKING
:
from
._http_client
import
HttpClient
P
=
ParamSpec
(
"P"
)
R
=
TypeVar
(
"R"
)
class
HttpResponse
(
Generic
[
R
]):
_cast_type
:
type
[
R
]
_client
:
"HttpClient"
_parsed
:
R
|
None
_enable_stream
:
bool
_stream_cls
:
type
[
StreamResponse
[
Any
]]
http_response
:
httpx
.
Response
def
__init__
(
self
,
*
,
raw_response
:
httpx
.
Response
,
cast_type
:
type
[
R
],
client
:
"HttpClient"
,
enable_stream
:
bool
=
False
,
stream_cls
:
type
[
StreamResponse
[
Any
]]
|
None
=
None
,
)
->
None
:
self
.
_cast_type
=
cast_type
self
.
_client
=
client
self
.
_parsed
=
None
self
.
_stream_cls
=
stream_cls
self
.
_enable_stream
=
enable_stream
self
.
http_response
=
raw_response
def
parse
(
self
)
->
R
:
self
.
_parsed
=
self
.
_parse
()
return
self
.
_parsed
def
_parse
(
self
)
->
R
:
if
self
.
_enable_stream
:
self
.
_parsed
=
cast
(
R
,
self
.
_stream_cls
(
cast_type
=
cast
(
type
,
get_args
(
self
.
_stream_cls
)[
0
]),
response
=
self
.
http_response
,
client
=
self
.
_client
)
)
return
self
.
_parsed
cast_type
=
self
.
_cast_type
if
cast_type
is
NoneType
:
return
cast
(
R
,
None
)
http_response
=
self
.
http_response
if
cast_type
==
str
:
return
cast
(
R
,
http_response
.
text
)
content_type
,
*
_
=
http_response
.
headers
.
get
(
"content-type"
,
"application/json"
)
.
split
(
";"
)
origin
=
get_origin
(
cast_type
)
or
cast_type
if
content_type
!=
"application/json"
:
if
issubclass
(
origin
,
pydantic
.
BaseModel
):
data
=
http_response
.
json
()
return
self
.
_client
.
_process_response_data
(
data
=
data
,
cast_type
=
cast_type
,
# type: ignore
response
=
http_response
,
)
return
http_response
.
text
data
=
http_response
.
json
()
return
self
.
_client
.
_process_response_data
(
data
=
data
,
cast_type
=
cast_type
,
# type: ignore
response
=
http_response
,
)
@
property
def
headers
(
self
)
->
httpx
.
Headers
:
return
self
.
http_response
.
headers
@
property
def
http_request
(
self
)
->
httpx
.
Request
:
return
self
.
http_response
.
request
@
property
def
status_code
(
self
)
->
int
:
return
self
.
http_response
.
status_code
@
property
def
url
(
self
)
->
httpx
.
URL
:
return
self
.
http_response
.
url
@
property
def
method
(
self
)
->
str
:
return
self
.
http_request
.
method
@
property
def
content
(
self
)
->
bytes
:
return
self
.
http_response
.
content
@
property
def
text
(
self
)
->
str
:
return
self
.
http_response
.
text
@
property
def
http_version
(
self
)
->
str
:
return
self
.
http_response
.
http_version
@
property
def
elapsed
(
self
)
->
datetime
.
timedelta
:
return
self
.
http_response
.
elapsed
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py
0 → 100644
View file @
b921c556
# -*- coding:utf-8 -*-
from
__future__
import
annotations
import
json
from
typing
import
Generic
,
Iterator
,
TYPE_CHECKING
,
Mapping
import
httpx
from
._base_type
import
ResponseT
from
._errors
import
APIResponseError
_FIELD_SEPARATOR
=
":"
if
TYPE_CHECKING
:
from
._http_client
import
HttpClient
class
StreamResponse
(
Generic
[
ResponseT
]):
response
:
httpx
.
Response
_cast_type
:
type
[
ResponseT
]
def
__init__
(
self
,
*
,
cast_type
:
type
[
ResponseT
],
response
:
httpx
.
Response
,
client
:
HttpClient
,
)
->
None
:
self
.
response
=
response
self
.
_cast_type
=
cast_type
self
.
_data_process_func
=
client
.
_process_response_data
self
.
_stream_chunks
=
self
.
__stream__
()
def
__next__
(
self
)
->
ResponseT
:
return
self
.
_stream_chunks
.
__next__
()
def
__iter__
(
self
)
->
Iterator
[
ResponseT
]:
for
item
in
self
.
_stream_chunks
:
yield
item
def
__stream__
(
self
)
->
Iterator
[
ResponseT
]:
sse_line_parser
=
SSELineParser
()
iterator
=
sse_line_parser
.
iter_lines
(
self
.
response
.
iter_lines
())
for
sse
in
iterator
:
if
sse
.
data
.
startswith
(
"[DONE]"
):
break
if
sse
.
event
is
None
:
data
=
sse
.
json_data
()
if
isinstance
(
data
,
Mapping
)
and
data
.
get
(
"error"
):
raise
APIResponseError
(
message
=
"An error occurred during streaming"
,
request
=
self
.
response
.
request
,
json_data
=
data
[
"error"
],
)
yield
self
.
_data_process_func
(
data
=
data
,
cast_type
=
self
.
_cast_type
,
response
=
self
.
response
)
for
sse
in
iterator
:
pass
class
Event
(
object
):
def
__init__
(
self
,
event
:
str
|
None
=
None
,
data
:
str
|
None
=
None
,
id
:
str
|
None
=
None
,
retry
:
int
|
None
=
None
):
self
.
_event
=
event
self
.
_data
=
data
self
.
_id
=
id
self
.
_retry
=
retry
def
__repr__
(
self
):
data_len
=
len
(
self
.
_data
)
if
self
.
_data
else
0
return
f
"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
@
property
def
event
(
self
):
return
self
.
_event
@
property
def
data
(
self
):
return
self
.
_data
def
json_data
(
self
):
return
json
.
loads
(
self
.
_data
)
@
property
def
id
(
self
):
return
self
.
_id
@
property
def
retry
(
self
):
return
self
.
_retry
class
SSELineParser
:
_data
:
list
[
str
]
_event
:
str
|
None
_retry
:
int
|
None
_id
:
str
|
None
def
__init__
(
self
):
self
.
_event
=
None
self
.
_data
=
[]
self
.
_id
=
None
self
.
_retry
=
None
def
iter_lines
(
self
,
lines
:
Iterator
[
str
])
->
Iterator
[
Event
]:
for
line
in
lines
:
line
=
line
.
rstrip
(
'
\n
'
)
if
not
line
:
if
self
.
_event
is
None
and
\
not
self
.
_data
and
\
self
.
_id
is
None
and
\
self
.
_retry
is
None
:
continue
sse_event
=
Event
(
event
=
self
.
_event
,
data
=
'
\n
'
.
join
(
self
.
_data
),
id
=
self
.
_id
,
retry
=
self
.
_retry
)
self
.
_event
=
None
self
.
_data
=
[]
self
.
_id
=
None
self
.
_retry
=
None
yield
sse_event
self
.
decode_line
(
line
)
def
decode_line
(
self
,
line
:
str
):
if
line
.
startswith
(
":"
)
or
not
line
:
return
field
,
_p
,
value
=
line
.
partition
(
":"
)
if
value
.
startswith
(
' '
):
value
=
value
[
1
:]
if
field
==
"data"
:
self
.
_data
.
append
(
value
)
elif
field
==
"event"
:
self
.
_event
=
value
elif
field
==
"retry"
:
try
:
self
.
_retry
=
int
(
value
)
except
(
TypeError
,
ValueError
):
pass
return
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Mapping
,
Iterable
,
TypeVar
from
._base_type
import
NotGiven
def
remove_notgiven_indict
(
obj
):
if
obj
is
None
or
(
not
isinstance
(
obj
,
Mapping
)):
return
obj
return
{
key
:
value
for
key
,
value
in
obj
.
items
()
if
not
isinstance
(
value
,
NotGiven
)}
_T
=
TypeVar
(
"_T"
)
def
flatten
(
t
:
Iterable
[
Iterable
[
_T
]])
->
list
[
_T
]:
return
[
item
for
sublist
in
t
for
item
in
sublist
]
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py
0 → 100644
View file @
b921c556
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py
0 → 100644
View file @
b921c556
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py
0 → 100644
View file @
b921c556
from
typing
import
List
,
Optional
from
pydantic
import
BaseModel
from
.chat_completion
import
CompletionChoice
,
CompletionUsage
__all__
=
[
"AsyncTaskStatus"
]
class
AsyncTaskStatus
(
BaseModel
):
id
:
Optional
[
str
]
=
None
request_id
:
Optional
[
str
]
=
None
model
:
Optional
[
str
]
=
None
task_status
:
Optional
[
str
]
=
None
class
AsyncCompletion
(
BaseModel
):
id
:
Optional
[
str
]
=
None
request_id
:
Optional
[
str
]
=
None
model
:
Optional
[
str
]
=
None
task_status
:
str
choices
:
List
[
CompletionChoice
]
usage
:
CompletionUsage
\ No newline at end of file
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py
0 → 100644
View file @
b921c556
from
typing
import
List
,
Optional
from
pydantic
import
BaseModel
__all__
=
[
"Completion"
,
"CompletionUsage"
]
class
Function
(
BaseModel
):
arguments
:
str
name
:
str
class
CompletionMessageToolCall
(
BaseModel
):
id
:
str
function
:
Function
type
:
str
class
CompletionMessage
(
BaseModel
):
content
:
Optional
[
str
]
=
None
role
:
str
tool_calls
:
Optional
[
List
[
CompletionMessageToolCall
]]
=
None
class
CompletionUsage
(
BaseModel
):
prompt_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
class
CompletionChoice
(
BaseModel
):
index
:
int
finish_reason
:
str
message
:
CompletionMessage
class
Completion
(
BaseModel
):
model
:
Optional
[
str
]
=
None
created
:
Optional
[
int
]
=
None
choices
:
List
[
CompletionChoice
]
request_id
:
Optional
[
str
]
=
None
id
:
Optional
[
str
]
=
None
usage
:
CompletionUsage
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py
0 → 100644
View file @
b921c556
from
typing
import
List
,
Optional
from
pydantic
import
BaseModel
__all__
=
[
"ChatCompletionChunk"
,
"Choice"
,
"ChoiceDelta"
,
"ChoiceDeltaFunctionCall"
,
"ChoiceDeltaToolCall"
,
"ChoiceDeltaToolCallFunction"
,
]
class
ChoiceDeltaFunctionCall
(
BaseModel
):
arguments
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
class
ChoiceDeltaToolCallFunction
(
BaseModel
):
arguments
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
class
ChoiceDeltaToolCall
(
BaseModel
):
index
:
int
id
:
Optional
[
str
]
=
None
function
:
Optional
[
ChoiceDeltaToolCallFunction
]
=
None
type
:
Optional
[
str
]
=
None
class
ChoiceDelta
(
BaseModel
):
content
:
Optional
[
str
]
=
None
role
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
ChoiceDeltaToolCall
]]
=
None
class
Choice
(
BaseModel
):
delta
:
ChoiceDelta
finish_reason
:
Optional
[
str
]
=
None
index
:
int
class
CompletionUsage
(
BaseModel
):
prompt_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
class
ChatCompletionChunk
(
BaseModel
):
id
:
Optional
[
str
]
=
None
choices
:
List
[
Choice
]
created
:
Optional
[
int
]
=
None
model
:
Optional
[
str
]
=
None
usage
:
Optional
[
CompletionUsage
]
=
None
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py
0 → 100644
View file @
b921c556
from
typing
import
Optional
from
typing_extensions
import
TypedDict
class
Reference
(
TypedDict
,
total
=
False
):
enable
:
Optional
[
bool
]
search_query
:
Optional
[
str
]
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Optional
,
List
from
pydantic
import
BaseModel
from
.chat.chat_completion
import
CompletionUsage
__all__
=
[
"Embedding"
,
"EmbeddingsResponded"
]
class
Embedding
(
BaseModel
):
object
:
str
index
:
Optional
[
int
]
=
None
embedding
:
List
[
float
]
class
EmbeddingsResponded
(
BaseModel
):
object
:
str
data
:
List
[
Embedding
]
model
:
str
usage
:
CompletionUsage
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py
0 → 100644
View file @
b921c556
from
typing
import
Optional
,
List
from
pydantic
import
BaseModel
__all__
=
[
"FileObject"
]
class
FileObject
(
BaseModel
):
id
:
Optional
[
str
]
=
None
bytes
:
Optional
[
int
]
=
None
created_at
:
Optional
[
int
]
=
None
filename
:
Optional
[
str
]
=
None
object
:
Optional
[
str
]
=
None
purpose
:
Optional
[
str
]
=
None
status
:
Optional
[
str
]
=
None
status_details
:
Optional
[
str
]
=
None
class
ListOfFileObject
(
BaseModel
):
object
:
Optional
[
str
]
=
None
data
:
List
[
FileObject
]
has_more
:
Optional
[
bool
]
=
None
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
.fine_tuning_job
import
FineTuningJob
as
FineTuningJob
from
.fine_tuning_job
import
ListOfFineTuningJob
as
ListOfFineTuningJob
from
.fine_tuning_job_event
import
FineTuningJobEvent
as
FineTuningJobEvent
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py
0 → 100644
View file @
b921c556
from
typing
import
List
,
Union
,
Optional
from
typing_extensions
import
Literal
from
pydantic
import
BaseModel
__all__
=
[
"FineTuningJob"
,
"Error"
,
"Hyperparameters"
,
"ListOfFineTuningJob"
]
class
Error
(
BaseModel
):
code
:
str
message
:
str
param
:
Optional
[
str
]
=
None
class
Hyperparameters
(
BaseModel
):
n_epochs
:
Union
[
str
,
int
,
None
]
=
None
class
FineTuningJob
(
BaseModel
):
id
:
Optional
[
str
]
=
None
request_id
:
Optional
[
str
]
=
None
created_at
:
Optional
[
int
]
=
None
error
:
Optional
[
Error
]
=
None
fine_tuned_model
:
Optional
[
str
]
=
None
finished_at
:
Optional
[
int
]
=
None
hyperparameters
:
Optional
[
Hyperparameters
]
=
None
model
:
Optional
[
str
]
=
None
object
:
Optional
[
str
]
=
None
result_files
:
List
[
str
]
status
:
str
trained_tokens
:
Optional
[
int
]
=
None
training_file
:
str
validation_file
:
Optional
[
str
]
=
None
class
ListOfFineTuningJob
(
BaseModel
):
object
:
Optional
[
str
]
=
None
data
:
List
[
FineTuningJob
]
has_more
:
Optional
[
bool
]
=
None
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py
0 → 100644
View file @
b921c556
from
typing
import
List
,
Union
,
Optional
from
typing_extensions
import
Literal
from
pydantic
import
BaseModel
__all__
=
[
"FineTuningJobEvent"
,
"Metric"
,
"JobEvent"
]
class
Metric
(
BaseModel
):
epoch
:
Optional
[
Union
[
str
,
int
,
float
]]
=
None
current_steps
:
Optional
[
int
]
=
None
total_steps
:
Optional
[
int
]
=
None
elapsed_time
:
Optional
[
str
]
=
None
remaining_time
:
Optional
[
str
]
=
None
trained_tokens
:
Optional
[
int
]
=
None
loss
:
Optional
[
Union
[
str
,
int
,
float
]]
=
None
eval_loss
:
Optional
[
Union
[
str
,
int
,
float
]]
=
None
acc
:
Optional
[
Union
[
str
,
int
,
float
]]
=
None
eval_acc
:
Optional
[
Union
[
str
,
int
,
float
]]
=
None
learning_rate
:
Optional
[
Union
[
str
,
int
,
float
]]
=
None
class
JobEvent
(
BaseModel
):
object
:
Optional
[
str
]
=
None
id
:
Optional
[
str
]
=
None
type
:
Optional
[
str
]
=
None
created_at
:
Optional
[
int
]
=
None
level
:
Optional
[
str
]
=
None
message
:
Optional
[
str
]
=
None
data
:
Optional
[
Metric
]
=
None
class
FineTuningJobEvent
(
BaseModel
):
object
:
Optional
[
str
]
=
None
data
:
List
[
JobEvent
]
has_more
:
Optional
[
bool
]
=
None
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Union
from
typing_extensions
import
Literal
,
TypedDict
__all__
=
[
"Hyperparameters"
]
class
Hyperparameters
(
TypedDict
,
total
=
False
):
batch_size
:
Union
[
Literal
[
"auto"
],
int
]
learning_rate_multiplier
:
Union
[
Literal
[
"auto"
],
float
]
n_epochs
:
Union
[
Literal
[
"auto"
],
int
]
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py
0 → 100644
View file @
b921c556
from
__future__
import
annotations
from
typing
import
Optional
,
List
from
pydantic
import
BaseModel
__all__
=
[
"GeneratedImage"
,
"ImagesResponded"
]
class
GeneratedImage
(
BaseModel
):
b64_json
:
Optional
[
str
]
=
None
url
:
Optional
[
str
]
=
None
revised_prompt
:
Optional
[
str
]
=
None
class
ImagesResponded
(
BaseModel
):
created
:
int
data
:
List
[
GeneratedImage
]
api/tests/integration_tests/model_runtime/zhipuai/test_llm.py
View file @
b921c556
...
...
@@ -3,7 +3,8 @@ from typing import Generator
import
pytest
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.entities.message_entities
import
AssistantPromptMessage
,
SystemPromptMessage
,
UserPromptMessage
from
core.model_runtime.entities.message_entities
import
(
AssistantPromptMessage
,
SystemPromptMessage
,
UserPromptMessage
,
PromptMessageTool
)
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.zhipuai.llm.llm
import
ZhipuAILargeLanguageModel
...
...
@@ -102,3 +103,48 @@ def test_get_num_tokens():
)
assert
num_tokens
==
14
def
test_get_tools_num_tokens
():
model
=
ZhipuAILargeLanguageModel
()
num_tokens
=
model
.
get_num_tokens
(
model
=
'tools'
,
credentials
=
{
'api_key'
:
os
.
environ
.
get
(
'ZHIPUAI_API_KEY'
)
},
tools
=
[
PromptMessageTool
(
name
=
'get_current_weather'
,
description
=
'Get the current weather in a given location'
,
parameters
=
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
,
"description"
:
"The city and state e.g. San Francisco, CA"
},
"unit"
:
{
"type"
:
"string"
,
"enum"
:
[
"c"
,
"f"
]
}
},
"required"
:
[
"location"
]
}
)
],
prompt_messages
=
[
SystemPromptMessage
(
content
=
'You are a helpful AI assistant.'
,
),
UserPromptMessage
(
content
=
'Hello World!'
)
]
)
assert
num_tokens
==
108
\ No newline at end of file
api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py
View file @
b921c556
...
...
@@ -42,7 +42,7 @@ def test_invoke_model():
assert
isinstance
(
result
,
TextEmbeddingResult
)
assert
len
(
result
.
embeddings
)
==
2
assert
result
.
usage
.
total_tokens
==
2
assert
result
.
usage
.
total_tokens
>
0
def
test_get_num_tokens
():
...
...
web/app/components/app/chat/answer/index.tsx
View file @
b921c556
...
...
@@ -229,7 +229,7 @@ const Answer: FC<IAnswerProps> = ({
<
Thought
thought=
{
item
}
allToolIcons=
{
allToolIcons
||
{}
}
isFinished=
{
!!
item
.
observation
}
isFinished=
{
!!
item
.
observation
||
!
isResponsing
}
/>
)
}
...
...
web/app/components/app/configuration/index.tsx
View file @
b921c556
...
...
@@ -43,7 +43,7 @@ import { fetchDatasets } from '@/service/datasets'
import
{
useProviderContext
}
from
'@/context/provider-context'
import
{
AgentStrategy
,
AppType
,
ModelModeType
,
RETRIEVE_TYPE
,
Resolution
,
TransferMethod
}
from
'@/types/app'
import
{
PromptMode
}
from
'@/models/debug'
import
{
ANNOTATION_DEFAULT
,
DEFAULT_AGENT_SETTING
,
DEFAULT_CHAT_PROMPT_CONFIG
,
DEFAULT_COMPLETION_PROMPT_CONFIG
}
from
'@/config'
import
{
ANNOTATION_DEFAULT
,
DEFAULT_AGENT_SETTING
,
DEFAULT_CHAT_PROMPT_CONFIG
,
DEFAULT_COMPLETION_PROMPT_CONFIG
,
supportFunctionCallModels
}
from
'@/config'
import
SelectDataSet
from
'@/app/components/app/configuration/dataset-config/select-dataset'
import
I18n
from
'@/context/i18n'
import
{
useModalContext
}
from
'@/context/modal-context'
...
...
@@ -163,8 +163,7 @@ const Configuration: FC = () => {
doSetModelConfig
(
newModelConfig
)
}
const
isOpenAI
=
modelConfig
.
provider
===
'openai'
const
isFunctionCall
=
isOpenAI
&&
modelConfig
.
mode
===
ModelModeType
.
chat
const
isFunctionCall
=
(
isOpenAI
&&
modelConfig
.
mode
===
ModelModeType
.
chat
)
||
supportFunctionCallModels
.
includes
(
modelConfig
.
model_id
)
const
[
collectionList
,
setCollectionList
]
=
useState
<
Collection
[]
>
([])
useEffect
(()
=>
{
...
...
web/config/index.ts
View file @
b921c556
...
...
@@ -160,6 +160,8 @@ export const DEFAULT_AGENT_SETTING = {
tools
:
[],
}
export
const
supportFunctionCallModels
=
[
'glm-3-turbo'
,
'glm-4'
]
export
const
DEFAULT_AGENT_PROMPT
=
{
chat
:
`Respond to the human as helpfully and accurately as possible.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment