Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
827c97f0
Unverified
Commit
827c97f0
authored
Sep 18, 2023
by
takatost
Committed by
GitHub
Sep 18, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add zhipuai (#1188)
parent
c8bd76cd
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
1089 additions
and
113 deletions
+1089
-113
model_providers.py
api/controllers/console/workspace/model_providers.py
+8
-2
llm_callback_handler.py
api/core/callback_handler/llm_callback_handler.py
+12
-1
sensitive_word_avoidance_chain.py
api/core/chain/sensitive_word_avoidance_chain.py
+7
-6
completion.py
api/core/completion.py
+51
-30
moderation.py
api/core/helper/moderation.py
+19
-17
model_provider_factory.py
api/core/model_providers/model_provider_factory.py
+3
-0
zhipuai_embedding.py
...ore/model_providers/models/embedding/zhipuai_embedding.py
+22
-0
model_params.py
api/core/model_providers/models/entity/model_params.py
+1
-0
zhipuai_model.py
api/core/model_providers/models/llm/zhipuai_model.py
+61
-0
openai_moderation.py
...re/model_providers/models/moderation/openai_moderation.py
+10
-6
anthropic_provider.py
api/core/model_providers/providers/anthropic_provider.py
+3
-3
azure_openai_provider.py
api/core/model_providers/providers/azure_openai_provider.py
+5
-5
chatglm_provider.py
api/core/model_providers/providers/chatglm_provider.py
+3
-3
hosted.py
api/core/model_providers/providers/hosted.py
+18
-0
huggingface_hub_provider.py
...ore/model_providers/providers/huggingface_hub_provider.py
+3
-3
localai_provider.py
api/core/model_providers/providers/localai_provider.py
+3
-3
minimax_provider.py
api/core/model_providers/providers/minimax_provider.py
+3
-3
openai_provider.py
api/core/model_providers/providers/openai_provider.py
+5
-5
openllm_provider.py
api/core/model_providers/providers/openllm_provider.py
+5
-5
replicate_provider.py
api/core/model_providers/providers/replicate_provider.py
+2
-0
spark_provider.py
api/core/model_providers/providers/spark_provider.py
+2
-2
tongyi_provider.py
api/core/model_providers/providers/tongyi_provider.py
+2
-2
wenxin_provider.py
api/core/model_providers/providers/wenxin_provider.py
+2
-2
xinference_provider.py
api/core/model_providers/providers/xinference_provider.py
+11
-11
zhipuai_provider.py
api/core/model_providers/providers/zhipuai_provider.py
+176
-0
_providers.json
api/core/model_providers/rules/_providers.json
+1
-0
zhipuai.json
api/core/model_providers/rules/zhipuai.json
+44
-0
zhipuai_embedding.py
...ore/third_party/langchain/embeddings/zhipuai_embedding.py
+64
-0
zhipuai_llm.py
api/core/third_party/langchain/llms/zhipuai_llm.py
+315
-0
requirements.txt
api/requirements.txt
+2
-1
provider_service.py
api/services/provider_service.py
+5
-2
.env.example
api/tests/integration_tests/.env.example
+3
-0
test_zhipuai_embedding.py
...egration_tests/models/embedding/test_zhipuai_embedding.py
+50
-0
test_zhipuai_model.py
api/tests/integration_tests/models/llm/test_zhipuai_model.py
+79
-0
test_spark_provider.py
api/tests/unit_tests/model_providers/test_spark_provider.py
+1
-1
test_zhipuai_provider.py
...tests/unit_tests/model_providers/test_zhipuai_provider.py
+88
-0
No files found.
api/controllers/console/workspace/model_providers.py
View file @
827c97f0
...
...
@@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
'enabled'
:
v
.
enabled
,
'min'
:
v
.
min
,
'max'
:
v
.
max
,
'default'
:
v
.
default
'default'
:
v
.
default
,
'precision'
:
v
.
precision
}
for
k
,
v
in
vars
(
parameter_rules
)
.
items
()
}
...
...
@@ -290,10 +291,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@
login_required
@
account_initialization_required
def
get
(
self
,
provider_name
:
str
):
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'token'
,
type
=
str
,
required
=
False
,
nullable
=
True
,
location
=
'args'
)
args
=
parser
.
parse_args
()
provider_service
=
ProviderService
()
result
=
provider_service
.
free_quota_qualification_verify
(
tenant_id
=
current_user
.
current_tenant_id
,
provider_name
=
provider_name
provider_name
=
provider_name
,
token
=
args
[
'token'
]
)
return
result
...
...
api/core/callback_handler/llm_callback_handler.py
View file @
827c97f0
...
...
@@ -63,7 +63,18 @@ class LLMCallbackHandler(BaseCallbackHandler):
self
.
conversation_message_task
.
append_message_text
(
response
.
generations
[
0
][
0
]
.
text
)
self
.
llm_message
.
completion
=
response
.
generations
[
0
][
0
]
.
text
self
.
llm_message
.
completion_tokens
=
self
.
model_instance
.
get_num_tokens
([
PromptMessage
(
content
=
self
.
llm_message
.
completion
)])
if
response
.
llm_output
and
'token_usage'
in
response
.
llm_output
:
if
'prompt_tokens'
in
response
.
llm_output
[
'token_usage'
]:
self
.
llm_message
.
prompt_tokens
=
response
.
llm_output
[
'token_usage'
][
'prompt_tokens'
]
if
'completion_tokens'
in
response
.
llm_output
[
'token_usage'
]:
self
.
llm_message
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
else
:
self
.
llm_message
.
completion_tokens
=
self
.
model_instance
.
get_num_tokens
(
[
PromptMessage
(
content
=
self
.
llm_message
.
completion
)])
else
:
self
.
llm_message
.
completion_tokens
=
self
.
model_instance
.
get_num_tokens
(
[
PromptMessage
(
content
=
self
.
llm_message
.
completion
)])
self
.
conversation_message_task
.
save_message
(
self
.
llm_message
)
...
...
api/core/chain/sensitive_word_avoidance_chain.py
View file @
827c97f0
...
...
@@ -2,13 +2,8 @@ import enum
import
logging
from
typing
import
List
,
Dict
,
Optional
,
Any
import
openai
from
flask
import
current_app
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
openai
import
InvalidRequestError
from
openai.error
import
APIConnectionError
,
APIError
,
ServiceUnavailableError
,
Timeout
,
RateLimitError
,
\
AuthenticationError
,
OpenAIError
from
pydantic
import
BaseModel
from
core.model_providers.error
import
LLMBadRequestError
...
...
@@ -86,6 +81,12 @@ class SensitiveWordAvoidanceChain(Chain):
result
=
self
.
_check_moderation
(
text
)
if
not
result
:
raise
LLMBadRequest
Error
(
self
.
sensitive_word_avoidance_rule
.
canned_response
)
raise
SensitiveWordAvoidance
Error
(
self
.
sensitive_word_avoidance_rule
.
canned_response
)
return
{
self
.
output_key
:
text
}
class
SensitiveWordAvoidanceError
(
Exception
):
def
__init__
(
self
,
message
):
super
()
.
__init__
(
message
)
self
.
message
=
message
api/core/completion.py
View file @
827c97f0
...
...
@@ -7,6 +7,7 @@ from requests.exceptions import ChunkedEncodingError
from
core.agent.agent_executor
import
AgentExecuteResult
,
PlanningStrategy
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceError
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.model_providers.error
import
LLMBadRequestError
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
...
...
@@ -76,28 +77,53 @@ class Completion:
app_model_config
=
app_model_config
)
# parse sensitive_word_avoidance_chain
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
(
final_model_instance
,
[
chain_callback
])
if
sensitive_word_avoidance_chain
:
query
=
sensitive_word_avoidance_chain
.
run
(
query
)
# get agent executor
agent_executor
=
orchestrator_rule_parser
.
to_agent_executor
(
conversation_message_task
=
conversation_message_task
,
memory
=
memory
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
chain_callback
=
chain_callback
)
# run agent executor
agent_execute_result
=
None
if
agent_executor
:
should_use_agent
=
agent_executor
.
should_use_agent
(
query
)
if
should_use_agent
:
agent_execute_result
=
agent_executor
.
run
(
query
)
# run the final llm
try
:
# parse sensitive_word_avoidance_chain
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
(
final_model_instance
,
[
chain_callback
])
if
sensitive_word_avoidance_chain
:
try
:
query
=
sensitive_word_avoidance_chain
.
run
(
query
)
except
SensitiveWordAvoidanceError
as
ex
:
cls
.
run_final_llm
(
model_instance
=
final_model_instance
,
mode
=
app
.
mode
,
app_model_config
=
app_model_config
,
query
=
query
,
inputs
=
inputs
,
agent_execute_result
=
None
,
conversation_message_task
=
conversation_message_task
,
memory
=
memory
,
fake_response
=
ex
.
message
)
return
# get agent executor
agent_executor
=
orchestrator_rule_parser
.
to_agent_executor
(
conversation_message_task
=
conversation_message_task
,
memory
=
memory
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
chain_callback
=
chain_callback
,
retriever_from
=
retriever_from
)
# run agent executor
agent_execute_result
=
None
if
agent_executor
:
should_use_agent
=
agent_executor
.
should_use_agent
(
query
)
if
should_use_agent
:
agent_execute_result
=
agent_executor
.
run
(
query
)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response
=
None
if
not
app_model_config
.
pre_prompt
and
agent_execute_result
and
agent_execute_result
.
output
\
and
agent_execute_result
.
strategy
not
in
[
PlanningStrategy
.
ROUTER
,
PlanningStrategy
.
REACT_ROUTER
]:
fake_response
=
agent_execute_result
.
output
# run the final llm
cls
.
run_final_llm
(
model_instance
=
final_model_instance
,
mode
=
app
.
mode
,
...
...
@@ -106,7 +132,8 @@ class Completion:
inputs
=
inputs
,
agent_execute_result
=
agent_execute_result
,
conversation_message_task
=
conversation_message_task
,
memory
=
memory
memory
=
memory
,
fake_response
=
fake_response
)
except
ConversationTaskStoppedException
:
return
...
...
@@ -121,14 +148,8 @@ class Completion:
inputs
:
dict
,
agent_execute_result
:
Optional
[
AgentExecuteResult
],
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
ReadOnlyConversationTokenDBBufferSharedMemory
]):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response
=
None
if
not
app_model_config
.
pre_prompt
and
agent_execute_result
and
agent_execute_result
.
output
\
and
agent_execute_result
.
strategy
not
in
[
PlanningStrategy
.
ROUTER
,
PlanningStrategy
.
REACT_ROUTER
]:
fake_response
=
agent_execute_result
.
output
memory
:
Optional
[
ReadOnlyConversationTokenDBBufferSharedMemory
],
fake_response
:
Optional
[
str
]):
# get llm prompt
prompt_messages
,
stop_words
=
model_instance
.
get_prompt
(
mode
=
mode
,
...
...
api/core/helper/moderation.py
View file @
827c97f0
import
logging
import
openai
from
flask
import
current_app
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.model_providers.providers.hosted
import
hosted_config
,
hosted_model_providers
from
models.provider
import
ProviderType
def
check_moderation
(
model_provider
:
BaseModelProvider
,
text
:
str
)
->
bool
:
if
current_app
.
config
[
'HOSTED_MODERATION_ENABLED'
]
and
current_app
.
config
[
'HOSTED_MODERATION_PROVIDERS'
]:
moderation_providers
=
current_app
.
config
[
'HOSTED_MODERATION_PROVIDERS'
]
.
split
(
','
)
if
hosted_config
.
moderation
.
enabled
is
True
and
hosted_model_providers
.
openai
:
if
model_provider
.
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
\
and
model_provider
.
provider_name
in
moderation_
providers
:
and
model_provider
.
provider_name
in
hosted_config
.
moderation
.
providers
:
# 2000 text per chunk
length
=
2000
chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
try
:
moderation_result
=
openai
.
Moderation
.
create
(
input
=
chunks
,
api_key
=
current_app
.
config
[
'HOSTED_OPENAI_API_KEY'
])
except
Exception
as
ex
:
logging
.
exception
(
ex
)
raise
LLMBadRequestError
(
'Rate limit exceeded, please try again later.'
)
for
result
in
moderation_result
.
results
:
if
result
[
'flagged'
]
is
True
:
return
False
text_chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
max_text_chunks
=
32
chunks
=
[
text_chunks
[
i
:
i
+
max_text_chunks
]
for
i
in
range
(
0
,
len
(
text_chunks
),
max_text_chunks
)]
for
text_chunk
in
chunks
:
try
:
moderation_result
=
openai
.
Moderation
.
create
(
input
=
text_chunk
,
api_key
=
hosted_model_providers
.
openai
.
api_key
)
except
Exception
as
ex
:
logging
.
exception
(
ex
)
raise
LLMBadRequestError
(
'Rate limit exceeded, please try again later.'
)
for
result
in
moderation_result
.
results
:
if
result
[
'flagged'
]
is
True
:
return
False
return
True
api/core/model_providers/model_provider_factory.py
View file @
827c97f0
...
...
@@ -45,6 +45,9 @@ class ModelProviderFactory:
elif
provider_name
==
'wenxin'
:
from
core.model_providers.providers.wenxin_provider
import
WenxinProvider
return
WenxinProvider
elif
provider_name
==
'zhipuai'
:
from
core.model_providers.providers.zhipuai_provider
import
ZhipuAIProvider
return
ZhipuAIProvider
elif
provider_name
==
'chatglm'
:
from
core.model_providers.providers.chatglm_provider
import
ChatGLMProvider
return
ChatGLMProvider
...
...
api/core/model_providers/models/embedding/zhipuai_embedding.py
0 → 100644
View file @
827c97f0
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.model_providers.models.embedding.base
import
BaseEmbedding
from
core.third_party.langchain.embeddings.zhipuai_embedding
import
ZhipuAIEmbeddings
class
ZhipuAIEmbedding
(
BaseEmbedding
):
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
name
:
str
):
credentials
=
model_provider
.
get_model_credentials
(
model_name
=
name
,
model_type
=
self
.
type
)
client
=
ZhipuAIEmbeddings
(
model
=
name
,
**
credentials
,
)
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"ZhipuAI embedding: {str(ex)}"
)
api/core/model_providers/models/entity/model_params.py
View file @
827c97f0
...
...
@@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
max
:
Optional
[
T
]
=
None
default
:
Optional
[
T
]
=
None
alias
:
Optional
[
str
]
=
None
precision
:
Optional
[
int
]
=
None
class
ModelKwargsRules
(
BaseModel
):
...
...
api/core/model_providers/models/llm/zhipuai_model.py
0 → 100644
View file @
827c97f0
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.zhipuai_llm
import
ZhipuAIChatLLM
class
ZhipuAIModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
CHAT
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
return
ZhipuAIChatLLM
(
streaming
=
self
.
streaming
,
callbacks
=
self
.
callbacks
,
**
self
.
credentials
,
**
provider_model_kwargs
)
def
_run
(
self
,
messages
:
List
[
PromptMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
->
LLMResult
:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens_from_messages
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
for
k
,
v
in
provider_model_kwargs
.
items
():
if
hasattr
(
self
.
client
,
k
):
setattr
(
self
.
client
,
k
,
v
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"ZhipuAI: {str(ex)}"
)
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/models/moderation/openai_moderation.py
View file @
827c97f0
...
...
@@ -23,14 +23,18 @@ class OpenAIModeration(BaseModeration):
# 2000 text per chunk
length
=
2000
chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
text_
chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
m
oderation_result
=
self
.
_client
.
create
(
input
=
chunks
,
api_key
=
credentials
[
'openai_api_key'
])
m
ax_text_chunks
=
32
chunks
=
[
text_chunks
[
i
:
i
+
max_text_chunks
]
for
i
in
range
(
0
,
len
(
text_chunks
),
max_text_chunks
)]
for
result
in
moderation_result
.
results
:
if
result
[
'flagged'
]
is
True
:
return
False
for
text_chunk
in
chunks
:
moderation_result
=
self
.
_client
.
create
(
input
=
text_chunk
,
api_key
=
credentials
[
'openai_api_key'
])
for
result
in
moderation_result
.
results
:
if
result
[
'flagged'
]
is
True
:
return
False
return
True
...
...
api/core/model_providers/providers/anthropic_provider.py
View file @
827c97f0
...
...
@@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
alias
=
"max_tokens_to_sample"
,
min
=
10
,
max
=
100000
,
default
=
256
),
max_tokens
=
KwargRule
[
int
](
alias
=
"max_tokens_to_sample"
,
min
=
10
,
max
=
100000
,
default
=
256
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/azure_openai_provider.py
View file @
827c97f0
...
...
@@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
model_credentials
=
self
.
get_model_credentials
(
model_name
,
model_type
)
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
base_model_max_tokens
.
get
(
model_credentials
[
'base_model_name'
],
4097
),
default
=
16
),
),
default
=
16
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/chatglm_provider.py
View file @
827c97f0
...
...
@@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
}
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_token'
,
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
2048
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_token'
,
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
2048
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/hosted.py
View file @
827c97f0
...
...
@@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
hosted_model_providers
=
HostedModelProviders
()
class
HostedModerationConfig
(
BaseModel
):
enabled
:
bool
=
False
providers
:
list
[
str
]
=
[]
class
HostedConfig
(
BaseModel
):
moderation
=
HostedModerationConfig
()
hosted_config
=
HostedConfig
()
def
init_app
(
app
:
Flask
):
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
langchain
.
verbose
=
True
...
...
@@ -78,3 +90,9 @@ def init_app(app: Flask):
paid_min_quantity
=
app
.
config
.
get
(
"HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"
),
paid_max_quantity
=
app
.
config
.
get
(
"HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"
),
)
if
app
.
config
.
get
(
"HOSTED_MODERATION_ENABLED"
)
and
app
.
config
.
get
(
"HOSTED_MODERATION_PROVIDERS"
):
hosted_config
.
moderation
=
HostedModerationConfig
(
enabled
=
app
.
config
.
get
(
"HOSTED_MODERATION_ENABLED"
),
providers
=
app
.
config
.
get
(
"HOSTED_MODERATION_PROVIDERS"
)
.
split
(
','
)
)
api/core/model_providers/providers/huggingface_hub_provider.py
View file @
827c97f0
...
...
@@ -47,11 +47,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
0.99
,
default
=
0.7
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
0.99
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
200
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
200
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/localai_provider.py
View file @
827c97f0
...
...
@@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
0.7
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4097
,
default
=
16
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
0.7
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4097
,
default
=
16
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/minimax_provider.py
View file @
827c97f0
...
...
@@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
}
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.9
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.95
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.9
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
,
6144
),
default
=
1024
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
,
6144
),
default
=
1024
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/openai_provider.py
View file @
827c97f0
...
...
@@ -133,11 +133,11 @@ class OpenAIProvider(BaseModelProvider):
}
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
,
4097
),
default
=
16
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
,
4097
),
default
=
16
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/openllm_provider.py
View file @
827c97f0
...
...
@@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
128
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
128
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/replicate_provider.py
View file @
827c97f0
...
...
@@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
min
=
float
(
value
.
get
(
'minimum'
))
if
value
.
get
(
'minimum'
)
is
not
None
else
None
,
max
=
float
(
value
.
get
(
'maximum'
))
if
value
.
get
(
'maximum'
)
is
not
None
else
None
,
default
=
float
(
value
.
get
(
'default'
))
if
value
.
get
(
'default'
)
is
not
None
else
None
,
precision
=
2
)
if
key
==
'temperature'
:
model_kwargs_rules
.
temperature
=
kwarg_rule
...
...
@@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
min
=
int
(
value
.
get
(
'minimum'
))
if
value
.
get
(
'minimum'
)
is
not
None
else
1
,
max
=
int
(
value
.
get
(
'maximum'
))
if
value
.
get
(
'maximum'
)
is
not
None
else
8000
,
default
=
int
(
value
.
get
(
'default'
))
if
value
.
get
(
'default'
)
is
not
None
else
500
,
precision
=
0
)
return
model_kwargs_rules
...
...
api/core/model_providers/providers/spark_provider.py
View file @
827c97f0
...
...
@@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.5
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.5
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4096
,
default
=
2048
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4096
,
default
=
2048
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/tongyi_provider.py
View file @
827c97f0
...
...
@@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
enabled
=
False
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.8
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
1024
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
1024
,
precision
=
0
),
)
@
classmethod
...
...
api/core/model_providers/providers/wenxin_provider.py
View file @
827c97f0
...
...
@@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider):
"""
if
model_name
in
[
'ernie-bot'
,
'ernie-bot-turbo'
]:
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
enabled
=
False
),
...
...
api/core/model_providers/providers/xinference_provider.py
View file @
827c97f0
...
...
@@ -53,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
credentials
=
self
.
get_model_credentials
(
model_name
,
model_type
)
if
credentials
[
'model_format'
]
==
"ggmlv3"
and
credentials
[
"model_handle_type"
]
==
"chatglm"
:
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
,
precision
=
0
),
)
elif
credentials
[
'model_format'
]
==
"ggmlv3"
:
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
,
precision
=
0
),
)
else
:
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
,
precision
=
0
),
)
...
...
api/core/model_providers/providers/zhipuai_provider.py
0 → 100644
View file @
827c97f0
import
json
from
json
import
JSONDecodeError
from
typing
import
Type
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.embedding.zhipuai_embedding
import
ZhipuAIEmbedding
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
from
core.model_providers.models.llm.zhipuai_model
import
ZhipuAIModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.third_party.langchain.llms.zhipuai_llm
import
ZhipuAIChatLLM
from
models.provider
import
ProviderType
,
ProviderQuotaType
class
ZhipuAIProvider
(
BaseModelProvider
):
@
property
def
provider_name
(
self
):
"""
Returns the name of a provider.
"""
return
'zhipuai'
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
{
'id'
:
'chatglm_pro'
,
'name'
:
'chatglm_pro'
,
},
{
'id'
:
'chatglm_std'
,
'name'
:
'chatglm_std'
,
},
{
'id'
:
'chatglm_lite'
,
'name'
:
'chatglm_lite'
,
},
{
'id'
:
'chatglm_lite_32k'
,
'name'
:
'chatglm_lite_32k'
,
}
]
elif
model_type
==
ModelType
.
EMBEDDINGS
:
return
[
{
'id'
:
'text_embedding'
,
'name'
:
'text_embedding'
,
}
]
else
:
return
[]
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
"""
Returns the model class.
:param model_type:
:return:
"""
if
model_type
==
ModelType
.
TEXT_GENERATION
:
model_class
=
ZhipuAIModel
elif
model_type
==
ModelType
.
EMBEDDINGS
:
model_class
=
ZhipuAIEmbedding
else
:
raise
NotImplementedError
return
model_class
def
get_model_parameter_rules
(
self
,
model_name
:
str
,
model_type
:
ModelType
)
->
ModelKwargsRules
:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.1
,
max
=
0.9
,
default
=
0.8
,
precision
=
1
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
enabled
=
False
),
)
@
classmethod
def
is_provider_credentials_valid_or_raise
(
cls
,
credentials
:
dict
):
"""
Validates the given credentials.
"""
if
'api_key'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'ZhipuAI api_key must be provided.'
)
try
:
credential_kwargs
=
{
'api_key'
:
credentials
[
'api_key'
]
}
llm
=
ZhipuAIChatLLM
(
temperature
=
0.01
,
**
credential_kwargs
)
llm
([
HumanMessage
(
content
=
'ping'
)])
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
@
classmethod
def
encrypt_provider_credentials
(
cls
,
tenant_id
:
str
,
credentials
:
dict
)
->
dict
:
credentials
[
'api_key'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'api_key'
])
return
credentials
def
get_provider_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
dict
:
if
self
.
provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
\
or
(
self
.
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
and
self
.
provider
.
quota_type
==
ProviderQuotaType
.
FREE
.
value
):
try
:
credentials
=
json
.
loads
(
self
.
provider
.
encrypted_config
)
except
JSONDecodeError
:
credentials
=
{
'api_key'
:
None
,
}
if
credentials
[
'api_key'
]:
credentials
[
'api_key'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
credentials
[
'api_key'
]
)
if
obfuscated
:
credentials
[
'api_key'
]
=
encrypter
.
obfuscated_token
(
credentials
[
'api_key'
])
return
credentials
else
:
return
{}
def
should_deduct_quota
(
self
):
return
True
@
classmethod
def
is_model_credentials_valid_or_raise
(
cls
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@
classmethod
def
encrypt_model_credentials
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
)
->
dict
:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return
{}
def
get_model_credentials
(
self
,
model_name
:
str
,
model_type
:
ModelType
,
obfuscated
:
bool
=
False
)
->
dict
:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return
self
.
get_provider_credentials
(
obfuscated
)
api/core/model_providers/rules/_providers.json
View file @
827c97f0
...
...
@@ -6,6 +6,7 @@
"tongyi"
,
"spark"
,
"wenxin"
,
"zhipuai"
,
"chatglm"
,
"replicate"
,
"huggingface_hub"
,
...
...
api/core/model_providers/rules/zhipuai.json
0 → 100644
View file @
827c97f0
{
"support_provider_types"
:
[
"system"
,
"custom"
],
"system_config"
:
{
"supported_quota_types"
:
[
"free"
],
"quota_unit"
:
"tokens"
},
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"chatglm_pro"
:
{
"prompt"
:
"0.01"
,
"completion"
:
"0.01"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"chatglm_std"
:
{
"prompt"
:
"0.005"
,
"completion"
:
"0.005"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"chatglm_lite"
:
{
"prompt"
:
"0.002"
,
"completion"
:
"0.002"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"chatglm_lite_32k"
:
{
"prompt"
:
"0.0004"
,
"completion"
:
"0.0004"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"text_embedding"
:
{
"completion"
:
"0"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
}
}
}
\ No newline at end of file
api/core/third_party/langchain/embeddings/zhipuai_embedding.py
0 → 100644
View file @
827c97f0
"""Wrapper around ZhipuAI embedding models."""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
pydantic
import
BaseModel
,
Extra
,
root_validator
from
langchain.embeddings.base
import
Embeddings
from
langchain.utils
import
get_from_dict_or_env
from
core.third_party.langchain.llms.zhipuai_llm
import
ZhipuModelAPI
class
ZhipuAIEmbeddings
(
BaseModel
,
Embeddings
):
"""Wrapper around ZhipuAI embedding models.
1024 dimensions.
"""
client
:
Any
#: :meta private:
model
:
str
"""Model name to use."""
base_url
:
str
=
"https://open.bigmodel.cn/api/paas/v3/model-api"
api_key
:
Optional
[
str
]
=
None
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
values
[
"api_key"
]
=
get_from_dict_or_env
(
values
,
"api_key"
,
"ZHIPUAI_API_KEY"
)
values
[
'client'
]
=
ZhipuModelAPI
(
api_key
=
values
[
'api_key'
],
base_url
=
values
[
'base_url'
])
return
values
def
embed_documents
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings
=
[]
for
text
in
texts
:
response
=
self
.
client
.
invoke
(
model
=
self
.
model
,
prompt
=
text
)
data
=
response
[
"data"
]
embeddings
.
append
(
data
.
get
(
'embedding'
))
return
[
list
(
map
(
float
,
e
))
for
e
in
embeddings
]
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return
self
.
embed_documents
([
text
])[
0
]
api/core/third_party/langchain/llms/zhipuai_llm.py
0 → 100644
View file @
827c97f0
"""Wrapper around ZhipuAI APIs."""
from
__future__
import
annotations
import
json
import
logging
import
posixpath
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Iterator
,
Sequence
,
)
import
zhipuai
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.schema
import
BaseMessage
,
ChatMessage
,
HumanMessage
,
AIMessage
,
SystemMessage
from
langchain.schema.messages
import
AIMessageChunk
from
langchain.schema.output
import
ChatResult
,
ChatGenerationChunk
,
ChatGeneration
from
pydantic
import
Extra
,
root_validator
,
BaseModel
from
langchain.callbacks.manager
import
(
CallbackManagerForLLMRun
,
)
from
langchain.utils
import
get_from_dict_or_env
from
zhipuai.model_api.api
import
InvokeType
from
zhipuai.utils
import
jwt_token
from
zhipuai.utils.http_client
import
post
,
stream
from
zhipuai.utils.sse_client
import
SSEClient
logger
=
logging
.
getLogger
(
__name__
)
class
ZhipuModelAPI
(
BaseModel
):
base_url
:
str
api_key
:
str
api_timeout_seconds
=
60
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
def
invoke
(
self
,
**
kwargs
):
url
=
self
.
_build_api_url
(
kwargs
,
InvokeType
.
SYNC
)
response
=
post
(
url
,
self
.
_generate_token
(),
kwargs
,
self
.
api_timeout_seconds
)
if
not
response
[
'success'
]:
raise
ValueError
(
f
"Error Code: {response['code']}, Message: {response['msg']} "
)
return
response
def
sse_invoke
(
self
,
**
kwargs
):
url
=
self
.
_build_api_url
(
kwargs
,
InvokeType
.
SSE
)
data
=
stream
(
url
,
self
.
_generate_token
(),
kwargs
,
self
.
api_timeout_seconds
)
return
SSEClient
(
data
)
def
_build_api_url
(
self
,
kwargs
,
*
path
):
if
kwargs
:
if
"model"
not
in
kwargs
:
raise
Exception
(
"model param missed"
)
model
=
kwargs
.
pop
(
"model"
)
else
:
model
=
"-"
return
posixpath
.
join
(
self
.
base_url
,
model
,
*
path
)
def
_generate_token
(
self
):
if
not
self
.
api_key
:
raise
Exception
(
"api_key not provided, you could provide it."
)
try
:
return
jwt_token
.
generate_token
(
self
.
api_key
)
except
Exception
:
raise
ValueError
(
f
"Your api_key is invalid, please check it."
)
class
ZhipuAIChatLLM
(
BaseChatModel
):
"""Wrapper around ZhipuAI large language models.
To use, you should pass the api_key as a named parameter to the constructor.
Example:
.. code-block:: python
from core.third_party.langchain.llms.zhipuai import ZhipuAI
model = ZhipuAI(model="<model_name>", api_key="my-api-key")
"""
@
property
def
lc_secrets
(
self
)
->
Dict
[
str
,
str
]:
return
{
"api_key"
:
"API_KEY"
}
@
property
def
lc_serializable
(
self
)
->
bool
:
return
True
client
:
Any
=
None
#: :meta private:
model
:
str
=
"chatglm_lite"
"""Model name to use."""
temperature
:
float
=
0.95
"""A non-negative float that tunes the degree of randomness in generation."""
top_p
:
float
=
0.7
"""Total probability mass of tokens to consider at each step."""
streaming
:
bool
=
False
"""Whether to stream the response or return it all at once."""
api_key
:
Optional
[
str
]
=
None
base_url
:
str
=
"https://open.bigmodel.cn/api/paas/v3/model-api"
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
values
[
"api_key"
]
=
get_from_dict_or_env
(
values
,
"api_key"
,
"ZHIPUAI_API_KEY"
)
if
'test'
in
values
[
'base_url'
]:
values
[
'model'
]
=
'chatglm_130b_test'
values
[
'client'
]
=
ZhipuModelAPI
(
api_key
=
values
[
'api_key'
],
base_url
=
values
[
'base_url'
])
return
values
@
property
def
_default_params
(
self
)
->
Dict
[
str
,
Any
]:
"""Get the default parameters for calling OpenAI API."""
return
{
"model"
:
self
.
model
,
"temperature"
:
self
.
temperature
,
"top_p"
:
self
.
top_p
}
@
property
def
_identifying_params
(
self
)
->
Dict
[
str
,
Any
]:
"""Get the identifying parameters."""
return
self
.
_default_params
@
property
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
return
"zhipuai"
def
_convert_message_to_dict
(
self
,
message
:
BaseMessage
)
->
dict
:
if
isinstance
(
message
,
ChatMessage
):
message_dict
=
{
"role"
:
message
.
role
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
HumanMessage
):
message_dict
=
{
"role"
:
"user"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
AIMessage
):
message_dict
=
{
"role"
:
"assistant"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
SystemMessage
):
message_dict
=
{
"role"
:
"user"
,
"content"
:
message
.
content
}
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
return
message_dict
def
_convert_dict_to_message
(
self
,
_dict
:
Dict
[
str
,
Any
])
->
BaseMessage
:
role
=
_dict
[
"role"
]
if
role
==
"user"
:
return
HumanMessage
(
content
=
_dict
[
"content"
])
elif
role
==
"assistant"
:
return
AIMessage
(
content
=
_dict
[
"content"
])
elif
role
==
"system"
:
return
SystemMessage
(
content
=
_dict
[
"content"
])
else
:
return
ChatMessage
(
content
=
_dict
[
"content"
],
role
=
role
)
def
_create_message_dicts
(
self
,
messages
:
List
[
BaseMessage
]
)
->
List
[
Dict
[
str
,
Any
]]:
dict_messages
=
[]
for
m
in
messages
:
message
=
self
.
_convert_message_to_dict
(
m
)
if
dict_messages
:
previous_message
=
dict_messages
[
-
1
]
if
previous_message
[
'role'
]
==
message
[
'role'
]:
dict_messages
[
-
1
][
'content'
]
+=
f
"
\n
{message['content']}"
else
:
dict_messages
.
append
(
message
)
else
:
dict_messages
.
append
(
message
)
return
dict_messages
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
if
self
.
streaming
:
generation
:
Optional
[
ChatGenerationChunk
]
=
None
llm_output
:
Optional
[
Dict
]
=
None
for
chunk
in
self
.
_stream
(
messages
=
messages
,
stop
=
stop
,
run_manager
=
run_manager
,
**
kwargs
):
if
chunk
.
generation_info
is
not
None
\
and
'token_usage'
in
chunk
.
generation_info
:
llm_output
=
{
"token_usage"
:
chunk
.
generation_info
[
'token_usage'
],
"model_name"
:
self
.
model
}
continue
if
generation
is
None
:
generation
=
chunk
else
:
generation
+=
chunk
assert
generation
is
not
None
return
ChatResult
(
generations
=
[
generation
],
llm_output
=
llm_output
)
else
:
message_dicts
=
self
.
_create_message_dicts
(
messages
)
request
=
self
.
_default_params
request
[
"prompt"
]
=
message_dicts
request
.
update
(
kwargs
)
response
=
self
.
client
.
invoke
(
**
request
)
return
self
.
_create_chat_result
(
response
)
def
_stream
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
Iterator
[
ChatGenerationChunk
]:
message_dicts
=
self
.
_create_message_dicts
(
messages
)
request
=
self
.
_default_params
request
[
"prompt"
]
=
message_dicts
request
.
update
(
kwargs
)
for
event
in
self
.
client
.
sse_invoke
(
incremental
=
True
,
**
request
)
.
events
():
if
event
.
event
==
"add"
:
yield
ChatGenerationChunk
(
message
=
AIMessageChunk
(
content
=
event
.
data
))
if
run_manager
:
run_manager
.
on_llm_new_token
(
event
.
data
)
elif
event
.
event
==
"error"
or
event
.
event
==
"interrupted"
:
raise
ValueError
(
f
"{event.data}"
)
elif
event
.
event
==
"finish"
:
meta
=
json
.
loads
(
event
.
meta
)
token_usage
=
meta
[
'usage'
]
if
token_usage
is
not
None
:
if
'prompt_tokens'
not
in
token_usage
:
token_usage
[
'prompt_tokens'
]
=
0
if
'completion_tokens'
not
in
token_usage
:
token_usage
[
'completion_tokens'
]
=
token_usage
[
'total_tokens'
]
yield
ChatGenerationChunk
(
message
=
AIMessageChunk
(
content
=
event
.
data
),
generation_info
=
dict
({
'token_usage'
:
token_usage
})
)
def
_create_chat_result
(
self
,
response
:
Dict
[
str
,
Any
])
->
ChatResult
:
data
=
response
[
"data"
]
generations
=
[]
for
res
in
data
[
"choices"
]:
message
=
self
.
_convert_dict_to_message
(
res
)
gen
=
ChatGeneration
(
message
=
message
)
generations
.
append
(
gen
)
token_usage
=
data
.
get
(
"usage"
)
if
token_usage
is
not
None
:
if
'prompt_tokens'
not
in
token_usage
:
token_usage
[
'prompt_tokens'
]
=
0
if
'completion_tokens'
not
in
token_usage
:
token_usage
[
'completion_tokens'
]
=
token_usage
[
'total_tokens'
]
llm_output
=
{
"token_usage"
:
token_usage
,
"model_name"
:
self
.
model
}
return
ChatResult
(
generations
=
generations
,
llm_output
=
llm_output
)
# def get_token_ids(self, text: str) -> List[int]:
# """Return the ordered ids of the tokens in a text.
#
# Args:
# text: The string input to tokenize.
#
# Returns:
# A list of ids corresponding to the tokens in the text, in order they occur
# in the text.
# """
# from core.third_party.transformers.Token import ChatGLMTokenizer
#
# tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
# return tokenizer.encode(text)
def
get_num_tokens_from_messages
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return
sum
([
self
.
get_num_tokens
(
m
.
content
)
for
m
in
messages
])
def
_combine_llm_outputs
(
self
,
llm_outputs
:
List
[
Optional
[
dict
]])
->
dict
:
overall_token_usage
:
dict
=
{}
for
output
in
llm_outputs
:
if
output
is
None
:
# Happens in streaming
continue
token_usage
=
output
[
"token_usage"
]
for
k
,
v
in
token_usage
.
items
():
if
k
in
overall_token_usage
:
overall_token_usage
[
k
]
+=
v
else
:
overall_token_usage
[
k
]
=
v
return
{
"token_usage"
:
overall_token_usage
,
"model_name"
:
self
.
model
}
api/requirements.txt
View file @
827c97f0
...
...
@@ -50,4 +50,5 @@ transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.4.2
safetensors==0.3.2
\ No newline at end of file
safetensors==0.3.2
zhipuai==1.0.7
api/services/provider_service.py
View file @
827c97f0
...
...
@@ -548,7 +548,7 @@ class ProviderService:
'result'
:
'success'
}
def
free_quota_qualification_verify
(
self
,
tenant_id
:
str
,
provider_name
:
str
):
def
free_quota_qualification_verify
(
self
,
tenant_id
:
str
,
provider_name
:
str
,
token
:
Optional
[
str
]
):
api_key
=
os
.
environ
.
get
(
"FREE_QUOTA_APPLY_API_KEY"
)
api_base_url
=
os
.
environ
.
get
(
"FREE_QUOTA_APPLY_BASE_URL"
)
api_url
=
api_base_url
+
'/api/v1/providers/qualification-verify'
...
...
@@ -557,8 +557,11 @@ class ProviderService:
'Content-Type'
:
'application/json'
,
'Authorization'
:
f
"Bearer {api_key}"
}
json_data
=
{
'workspace_id'
:
tenant_id
,
'provider_name'
:
provider_name
}
if
token
:
json_data
[
'token'
]
=
token
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
{
'workspace_id'
:
tenant_id
,
'provider_name'
:
provider_name
}
)
json
=
json_data
)
if
not
response
.
ok
:
logging
.
error
(
f
"Request FREE QUOTA APPLY SERVER Error: {response.status_code} "
)
raise
ValueError
(
f
"Error: {response.status_code} "
)
...
...
api/tests/integration_tests/.env.example
View file @
827c97f0
...
...
@@ -31,6 +31,9 @@ TONGYI_DASHSCOPE_API_KEY=
WENXIN_API_KEY=
WENXIN_SECRET_KEY=
# ZhipuAI Credentials
ZHIPUAI_API_KEY=
# ChatGLM Credentials
CHATGLM_API_BASE=
...
...
api/tests/integration_tests/models/embedding/test_zhipuai_embedding.py
0 → 100644
View file @
827c97f0
import
json
import
os
from
unittest.mock
import
patch
from
core.model_providers.models.embedding.zhipuai_embedding
import
ZhipuAIEmbedding
from
core.model_providers.providers.zhipuai_provider
import
ZhipuAIProvider
from
models.provider
import
Provider
,
ProviderType
def
get_mock_provider
(
valid_api_key
):
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'zhipuai'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
({
'api_key'
:
valid_api_key
}),
is_valid
=
True
,
)
def
get_mock_embedding_model
():
model_name
=
'text_embedding'
valid_api_key
=
os
.
environ
[
'ZHIPUAI_API_KEY'
]
provider
=
ZhipuAIProvider
(
provider
=
get_mock_provider
(
valid_api_key
))
return
ZhipuAIEmbedding
(
model_provider
=
provider
,
name
=
model_name
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_embedding
(
mock_decrypt
):
embedding_model
=
get_mock_embedding_model
()
rst
=
embedding_model
.
client
.
embed_query
(
'test'
)
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
1024
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_doc_embedding
(
mock_decrypt
):
embedding_model
=
get_mock_embedding_model
()
rst
=
embedding_model
.
client
.
embed_documents
([
'test'
,
'test2'
])
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
[
0
])
==
1024
api/tests/integration_tests/models/llm/test_zhipuai_model.py
0 → 100644
View file @
827c97f0
import
json
import
os
from
unittest.mock
import
patch
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.model_params
import
ModelKwargs
from
core.model_providers.models.llm.zhipuai_model
import
ZhipuAIModel
from
core.model_providers.providers.zhipuai_provider
import
ZhipuAIProvider
from
models.provider
import
Provider
,
ProviderType
def
get_mock_provider
(
valid_api_key
):
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'zhipuai'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
({
'api_key'
:
valid_api_key
}),
is_valid
=
True
,
)
def
get_mock_model
(
model_name
:
str
,
streaming
:
bool
=
False
):
model_kwargs
=
ModelKwargs
(
temperature
=
0.01
,
)
valid_api_key
=
os
.
environ
[
'ZHIPUAI_API_KEY'
]
model_provider
=
ZhipuAIProvider
(
provider
=
get_mock_provider
(
valid_api_key
))
return
ZhipuAIModel
(
model_provider
=
model_provider
,
name
=
model_name
,
model_kwargs
=
model_kwargs
,
streaming
=
streaming
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_get_num_tokens
(
mock_decrypt
):
model
=
get_mock_model
(
'chatglm_lite'
)
rst
=
model
.
get_num_tokens
([
PromptMessage
(
type
=
MessageType
.
SYSTEM
,
content
=
'you are a kindness Assistant.'
),
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Who is your manufacturer?'
)
])
assert
rst
>
0
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'chatglm_lite'
)
messages
=
[
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Are you Human? you MUST only answer `y` or `n`?'
)
]
rst
=
model
.
run
(
messages
,
)
assert
len
(
rst
.
content
)
>
0
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_stream_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'chatglm_lite'
,
streaming
=
True
)
messages
=
[
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Are you Human? you MUST only answer `y` or `n`?'
)
]
rst
=
model
.
run
(
messages
)
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_spark_provider.py
View file @
827c97f0
...
...
@@ -39,7 +39,7 @@ def test_is_provider_credentials_valid_or_raise_invalid():
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
({})
credential
=
VALIDATE_CREDENTIAL
.
copy
()
credential
[
'api_key'
]
=
'invalid_key'
del
credential
[
'api_key'
]
# raise CredentialsValidateFailedError if api_key is invalid
with
pytest
.
raises
(
CredentialsValidateFailedError
):
...
...
api/tests/unit_tests/model_providers/test_zhipuai_provider.py
0 → 100644
View file @
827c97f0
import
pytest
from
unittest.mock
import
patch
import
json
from
langchain.schema
import
ChatResult
,
ChatGeneration
,
AIMessage
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.zhipuai_provider
import
ZhipuAIProvider
from
models.provider
import
ProviderType
,
Provider
PROVIDER_NAME
=
'zhipuai'
MODEL_PROVIDER_CLASS
=
ZhipuAIProvider
VALIDATE_CREDENTIAL
=
{
'api_key'
:
'valid_key'
,
}
def
encrypt_side_effect
(
tenant_id
,
encrypt_key
):
return
f
'encrypted_{encrypt_key}'
def
decrypt_side_effect
(
tenant_id
,
encrypted_key
):
return
encrypted_key
.
replace
(
'encrypted_'
,
''
)
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'core.third_party.langchain.llms.zhipuai_llm.ZhipuAIChatLLM._generate'
,
return_value
=
ChatResult
(
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
'abc'
))]))
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
def
test_is_provider_credentials_valid_or_raise_invalid
():
# raise CredentialsValidateFailedError if api_key is not in credentials
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
({})
credential
=
VALIDATE_CREDENTIAL
.
copy
()
credential
[
'api_key'
]
=
'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
credential
)
@
patch
(
'core.helper.encrypter.encrypt_token'
,
side_effect
=
encrypt_side_effect
)
def
test_encrypt_credentials
(
mock_encrypt
):
result
=
MODEL_PROVIDER_CLASS
.
encrypt_provider_credentials
(
'tenant_id'
,
VALIDATE_CREDENTIAL
.
copy
())
assert
result
[
'api_key'
]
==
f
'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_credentials_custom
(
mock_decrypt
):
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'api_key'
]
=
'encrypted_'
+
encrypted_credential
[
'api_key'
]
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
(
encrypted_credential
),
is_valid
=
True
,
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_provider_credentials
()
assert
result
[
'api_key'
]
==
'valid_key'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_credentials_obfuscated
(
mock_decrypt
):
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'api_key'
]
=
'encrypted_'
+
encrypted_credential
[
'api_key'
]
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
(
encrypted_credential
),
is_valid
=
True
,
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_provider_credentials
(
obfuscated
=
True
)
middle_token
=
result
[
'api_key'
][
6
:
-
2
]
assert
len
(
middle_token
)
==
max
(
len
(
VALIDATE_CREDENTIAL
[
'api_key'
])
-
8
,
0
)
assert
all
(
char
==
'*'
for
char
in
middle_token
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment