Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
827c97f0
Unverified
Commit
827c97f0
authored
Sep 18, 2023
by
takatost
Committed by
GitHub
Sep 18, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add zhipuai (#1188)
parent
c8bd76cd
Changes
36
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
1089 additions
and
113 deletions
+1089
-113
model_providers.py
api/controllers/console/workspace/model_providers.py
+8
-2
llm_callback_handler.py
api/core/callback_handler/llm_callback_handler.py
+12
-1
sensitive_word_avoidance_chain.py
api/core/chain/sensitive_word_avoidance_chain.py
+7
-6
completion.py
api/core/completion.py
+51
-30
moderation.py
api/core/helper/moderation.py
+19
-17
model_provider_factory.py
api/core/model_providers/model_provider_factory.py
+3
-0
zhipuai_embedding.py
...ore/model_providers/models/embedding/zhipuai_embedding.py
+22
-0
model_params.py
api/core/model_providers/models/entity/model_params.py
+1
-0
zhipuai_model.py
api/core/model_providers/models/llm/zhipuai_model.py
+61
-0
openai_moderation.py
...re/model_providers/models/moderation/openai_moderation.py
+10
-6
anthropic_provider.py
api/core/model_providers/providers/anthropic_provider.py
+3
-3
azure_openai_provider.py
api/core/model_providers/providers/azure_openai_provider.py
+5
-5
chatglm_provider.py
api/core/model_providers/providers/chatglm_provider.py
+3
-3
hosted.py
api/core/model_providers/providers/hosted.py
+18
-0
huggingface_hub_provider.py
...ore/model_providers/providers/huggingface_hub_provider.py
+3
-3
localai_provider.py
api/core/model_providers/providers/localai_provider.py
+3
-3
minimax_provider.py
api/core/model_providers/providers/minimax_provider.py
+3
-3
openai_provider.py
api/core/model_providers/providers/openai_provider.py
+5
-5
openllm_provider.py
api/core/model_providers/providers/openllm_provider.py
+5
-5
replicate_provider.py
api/core/model_providers/providers/replicate_provider.py
+2
-0
spark_provider.py
api/core/model_providers/providers/spark_provider.py
+2
-2
tongyi_provider.py
api/core/model_providers/providers/tongyi_provider.py
+2
-2
wenxin_provider.py
api/core/model_providers/providers/wenxin_provider.py
+2
-2
xinference_provider.py
api/core/model_providers/providers/xinference_provider.py
+11
-11
zhipuai_provider.py
api/core/model_providers/providers/zhipuai_provider.py
+176
-0
_providers.json
api/core/model_providers/rules/_providers.json
+1
-0
zhipuai.json
api/core/model_providers/rules/zhipuai.json
+44
-0
zhipuai_embedding.py
...ore/third_party/langchain/embeddings/zhipuai_embedding.py
+64
-0
zhipuai_llm.py
api/core/third_party/langchain/llms/zhipuai_llm.py
+315
-0
requirements.txt
api/requirements.txt
+2
-1
provider_service.py
api/services/provider_service.py
+5
-2
.env.example
api/tests/integration_tests/.env.example
+3
-0
test_zhipuai_embedding.py
...egration_tests/models/embedding/test_zhipuai_embedding.py
+50
-0
test_zhipuai_model.py
api/tests/integration_tests/models/llm/test_zhipuai_model.py
+79
-0
test_spark_provider.py
api/tests/unit_tests/model_providers/test_spark_provider.py
+1
-1
test_zhipuai_provider.py
...tests/unit_tests/model_providers/test_zhipuai_provider.py
+88
-0
No files found.
api/controllers/console/workspace/model_providers.py
View file @
827c97f0
...
@@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
...
@@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
'enabled'
:
v
.
enabled
,
'enabled'
:
v
.
enabled
,
'min'
:
v
.
min
,
'min'
:
v
.
min
,
'max'
:
v
.
max
,
'max'
:
v
.
max
,
'default'
:
v
.
default
'default'
:
v
.
default
,
'precision'
:
v
.
precision
}
}
for
k
,
v
in
vars
(
parameter_rules
)
.
items
()
for
k
,
v
in
vars
(
parameter_rules
)
.
items
()
}
}
...
@@ -290,10 +291,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
...
@@ -290,10 +291,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@
login_required
@
login_required
@
account_initialization_required
@
account_initialization_required
def
get
(
self
,
provider_name
:
str
):
def
get
(
self
,
provider_name
:
str
):
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'token'
,
type
=
str
,
required
=
False
,
nullable
=
True
,
location
=
'args'
)
args
=
parser
.
parse_args
()
provider_service
=
ProviderService
()
provider_service
=
ProviderService
()
result
=
provider_service
.
free_quota_qualification_verify
(
result
=
provider_service
.
free_quota_qualification_verify
(
tenant_id
=
current_user
.
current_tenant_id
,
tenant_id
=
current_user
.
current_tenant_id
,
provider_name
=
provider_name
provider_name
=
provider_name
,
token
=
args
[
'token'
]
)
)
return
result
return
result
...
...
api/core/callback_handler/llm_callback_handler.py
View file @
827c97f0
...
@@ -63,7 +63,18 @@ class LLMCallbackHandler(BaseCallbackHandler):
...
@@ -63,7 +63,18 @@ class LLMCallbackHandler(BaseCallbackHandler):
self
.
conversation_message_task
.
append_message_text
(
response
.
generations
[
0
][
0
]
.
text
)
self
.
conversation_message_task
.
append_message_text
(
response
.
generations
[
0
][
0
]
.
text
)
self
.
llm_message
.
completion
=
response
.
generations
[
0
][
0
]
.
text
self
.
llm_message
.
completion
=
response
.
generations
[
0
][
0
]
.
text
self
.
llm_message
.
completion_tokens
=
self
.
model_instance
.
get_num_tokens
([
PromptMessage
(
content
=
self
.
llm_message
.
completion
)])
if
response
.
llm_output
and
'token_usage'
in
response
.
llm_output
:
if
'prompt_tokens'
in
response
.
llm_output
[
'token_usage'
]:
self
.
llm_message
.
prompt_tokens
=
response
.
llm_output
[
'token_usage'
][
'prompt_tokens'
]
if
'completion_tokens'
in
response
.
llm_output
[
'token_usage'
]:
self
.
llm_message
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
else
:
self
.
llm_message
.
completion_tokens
=
self
.
model_instance
.
get_num_tokens
(
[
PromptMessage
(
content
=
self
.
llm_message
.
completion
)])
else
:
self
.
llm_message
.
completion_tokens
=
self
.
model_instance
.
get_num_tokens
(
[
PromptMessage
(
content
=
self
.
llm_message
.
completion
)])
self
.
conversation_message_task
.
save_message
(
self
.
llm_message
)
self
.
conversation_message_task
.
save_message
(
self
.
llm_message
)
...
...
api/core/chain/sensitive_word_avoidance_chain.py
View file @
827c97f0
...
@@ -2,13 +2,8 @@ import enum
...
@@ -2,13 +2,8 @@ import enum
import
logging
import
logging
from
typing
import
List
,
Dict
,
Optional
,
Any
from
typing
import
List
,
Dict
,
Optional
,
Any
import
openai
from
flask
import
current_app
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
openai
import
InvalidRequestError
from
openai.error
import
APIConnectionError
,
APIError
,
ServiceUnavailableError
,
Timeout
,
RateLimitError
,
\
AuthenticationError
,
OpenAIError
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.error
import
LLMBadRequestError
...
@@ -86,6 +81,12 @@ class SensitiveWordAvoidanceChain(Chain):
...
@@ -86,6 +81,12 @@ class SensitiveWordAvoidanceChain(Chain):
result
=
self
.
_check_moderation
(
text
)
result
=
self
.
_check_moderation
(
text
)
if
not
result
:
if
not
result
:
raise
LLMBadRequest
Error
(
self
.
sensitive_word_avoidance_rule
.
canned_response
)
raise
SensitiveWordAvoidance
Error
(
self
.
sensitive_word_avoidance_rule
.
canned_response
)
return
{
self
.
output_key
:
text
}
return
{
self
.
output_key
:
text
}
class
SensitiveWordAvoidanceError
(
Exception
):
def
__init__
(
self
,
message
):
super
()
.
__init__
(
message
)
self
.
message
=
message
api/core/completion.py
View file @
827c97f0
...
@@ -7,6 +7,7 @@ from requests.exceptions import ChunkedEncodingError
...
@@ -7,6 +7,7 @@ from requests.exceptions import ChunkedEncodingError
from
core.agent.agent_executor
import
AgentExecuteResult
,
PlanningStrategy
from
core.agent.agent_executor
import
AgentExecuteResult
,
PlanningStrategy
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceError
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.error
import
LLMBadRequestError
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
...
@@ -76,28 +77,53 @@ class Completion:
...
@@ -76,28 +77,53 @@ class Completion:
app_model_config
=
app_model_config
app_model_config
=
app_model_config
)
)
# parse sensitive_word_avoidance_chain
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
(
final_model_instance
,
[
chain_callback
])
if
sensitive_word_avoidance_chain
:
query
=
sensitive_word_avoidance_chain
.
run
(
query
)
# get agent executor
agent_executor
=
orchestrator_rule_parser
.
to_agent_executor
(
conversation_message_task
=
conversation_message_task
,
memory
=
memory
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
chain_callback
=
chain_callback
)
# run agent executor
agent_execute_result
=
None
if
agent_executor
:
should_use_agent
=
agent_executor
.
should_use_agent
(
query
)
if
should_use_agent
:
agent_execute_result
=
agent_executor
.
run
(
query
)
# run the final llm
try
:
try
:
# parse sensitive_word_avoidance_chain
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
(
final_model_instance
,
[
chain_callback
])
if
sensitive_word_avoidance_chain
:
try
:
query
=
sensitive_word_avoidance_chain
.
run
(
query
)
except
SensitiveWordAvoidanceError
as
ex
:
cls
.
run_final_llm
(
model_instance
=
final_model_instance
,
mode
=
app
.
mode
,
app_model_config
=
app_model_config
,
query
=
query
,
inputs
=
inputs
,
agent_execute_result
=
None
,
conversation_message_task
=
conversation_message_task
,
memory
=
memory
,
fake_response
=
ex
.
message
)
return
# get agent executor
agent_executor
=
orchestrator_rule_parser
.
to_agent_executor
(
conversation_message_task
=
conversation_message_task
,
memory
=
memory
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
chain_callback
=
chain_callback
,
retriever_from
=
retriever_from
)
# run agent executor
agent_execute_result
=
None
if
agent_executor
:
should_use_agent
=
agent_executor
.
should_use_agent
(
query
)
if
should_use_agent
:
agent_execute_result
=
agent_executor
.
run
(
query
)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response
=
None
if
not
app_model_config
.
pre_prompt
and
agent_execute_result
and
agent_execute_result
.
output
\
and
agent_execute_result
.
strategy
not
in
[
PlanningStrategy
.
ROUTER
,
PlanningStrategy
.
REACT_ROUTER
]:
fake_response
=
agent_execute_result
.
output
# run the final llm
cls
.
run_final_llm
(
cls
.
run_final_llm
(
model_instance
=
final_model_instance
,
model_instance
=
final_model_instance
,
mode
=
app
.
mode
,
mode
=
app
.
mode
,
...
@@ -106,7 +132,8 @@ class Completion:
...
@@ -106,7 +132,8 @@ class Completion:
inputs
=
inputs
,
inputs
=
inputs
,
agent_execute_result
=
agent_execute_result
,
agent_execute_result
=
agent_execute_result
,
conversation_message_task
=
conversation_message_task
,
conversation_message_task
=
conversation_message_task
,
memory
=
memory
memory
=
memory
,
fake_response
=
fake_response
)
)
except
ConversationTaskStoppedException
:
except
ConversationTaskStoppedException
:
return
return
...
@@ -121,14 +148,8 @@ class Completion:
...
@@ -121,14 +148,8 @@ class Completion:
inputs
:
dict
,
inputs
:
dict
,
agent_execute_result
:
Optional
[
AgentExecuteResult
],
agent_execute_result
:
Optional
[
AgentExecuteResult
],
conversation_message_task
:
ConversationMessageTask
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
ReadOnlyConversationTokenDBBufferSharedMemory
]):
memory
:
Optional
[
ReadOnlyConversationTokenDBBufferSharedMemory
],
# When no extra pre prompt is specified,
fake_response
:
Optional
[
str
]):
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response
=
None
if
not
app_model_config
.
pre_prompt
and
agent_execute_result
and
agent_execute_result
.
output
\
and
agent_execute_result
.
strategy
not
in
[
PlanningStrategy
.
ROUTER
,
PlanningStrategy
.
REACT_ROUTER
]:
fake_response
=
agent_execute_result
.
output
# get llm prompt
# get llm prompt
prompt_messages
,
stop_words
=
model_instance
.
get_prompt
(
prompt_messages
,
stop_words
=
model_instance
.
get_prompt
(
mode
=
mode
,
mode
=
mode
,
...
...
api/core/helper/moderation.py
View file @
827c97f0
import
logging
import
logging
import
openai
import
openai
from
flask
import
current_app
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.model_providers.providers.hosted
import
hosted_config
,
hosted_model_providers
from
models.provider
import
ProviderType
from
models.provider
import
ProviderType
def
check_moderation
(
model_provider
:
BaseModelProvider
,
text
:
str
)
->
bool
:
def
check_moderation
(
model_provider
:
BaseModelProvider
,
text
:
str
)
->
bool
:
if
current_app
.
config
[
'HOSTED_MODERATION_ENABLED'
]
and
current_app
.
config
[
'HOSTED_MODERATION_PROVIDERS'
]:
if
hosted_config
.
moderation
.
enabled
is
True
and
hosted_model_providers
.
openai
:
moderation_providers
=
current_app
.
config
[
'HOSTED_MODERATION_PROVIDERS'
]
.
split
(
','
)
if
model_provider
.
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
\
if
model_provider
.
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
\
and
model_provider
.
provider_name
in
moderation_
providers
:
and
model_provider
.
provider_name
in
hosted_config
.
moderation
.
providers
:
# 2000 text per chunk
# 2000 text per chunk
length
=
2000
length
=
2000
chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
text_chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
try
:
max_text_chunks
=
32
moderation_result
=
openai
.
Moderation
.
create
(
input
=
chunks
,
chunks
=
[
text_chunks
[
i
:
i
+
max_text_chunks
]
for
i
in
range
(
0
,
len
(
text_chunks
),
max_text_chunks
)]
api_key
=
current_app
.
config
[
'HOSTED_OPENAI_API_KEY'
])
except
Exception
as
ex
:
for
text_chunk
in
chunks
:
logging
.
exception
(
ex
)
try
:
raise
LLMBadRequestError
(
'Rate limit exceeded, please try again later.'
)
moderation_result
=
openai
.
Moderation
.
create
(
input
=
text_chunk
,
api_key
=
hosted_model_providers
.
openai
.
api_key
)
for
result
in
moderation_result
.
results
:
except
Exception
as
ex
:
if
result
[
'flagged'
]
is
True
:
logging
.
exception
(
ex
)
return
False
raise
LLMBadRequestError
(
'Rate limit exceeded, please try again later.'
)
for
result
in
moderation_result
.
results
:
if
result
[
'flagged'
]
is
True
:
return
False
return
True
return
True
api/core/model_providers/model_provider_factory.py
View file @
827c97f0
...
@@ -45,6 +45,9 @@ class ModelProviderFactory:
...
@@ -45,6 +45,9 @@ class ModelProviderFactory:
elif
provider_name
==
'wenxin'
:
elif
provider_name
==
'wenxin'
:
from
core.model_providers.providers.wenxin_provider
import
WenxinProvider
from
core.model_providers.providers.wenxin_provider
import
WenxinProvider
return
WenxinProvider
return
WenxinProvider
elif
provider_name
==
'zhipuai'
:
from
core.model_providers.providers.zhipuai_provider
import
ZhipuAIProvider
return
ZhipuAIProvider
elif
provider_name
==
'chatglm'
:
elif
provider_name
==
'chatglm'
:
from
core.model_providers.providers.chatglm_provider
import
ChatGLMProvider
from
core.model_providers.providers.chatglm_provider
import
ChatGLMProvider
return
ChatGLMProvider
return
ChatGLMProvider
...
...
api/core/model_providers/models/embedding/zhipuai_embedding.py
0 → 100644
View file @
827c97f0
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.model_providers.models.embedding.base
import
BaseEmbedding
from
core.third_party.langchain.embeddings.zhipuai_embedding
import
ZhipuAIEmbeddings
class
ZhipuAIEmbedding
(
BaseEmbedding
):
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
name
:
str
):
credentials
=
model_provider
.
get_model_credentials
(
model_name
=
name
,
model_type
=
self
.
type
)
client
=
ZhipuAIEmbeddings
(
model
=
name
,
**
credentials
,
)
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"ZhipuAI embedding: {str(ex)}"
)
api/core/model_providers/models/entity/model_params.py
View file @
827c97f0
...
@@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
...
@@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
max
:
Optional
[
T
]
=
None
max
:
Optional
[
T
]
=
None
default
:
Optional
[
T
]
=
None
default
:
Optional
[
T
]
=
None
alias
:
Optional
[
str
]
=
None
alias
:
Optional
[
str
]
=
None
precision
:
Optional
[
int
]
=
None
class
ModelKwargsRules
(
BaseModel
):
class
ModelKwargsRules
(
BaseModel
):
...
...
api/core/model_providers/models/llm/zhipuai_model.py
0 → 100644
View file @
827c97f0
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.zhipuai_llm
import
ZhipuAIChatLLM
class
ZhipuAIModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
CHAT
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
return
ZhipuAIChatLLM
(
streaming
=
self
.
streaming
,
callbacks
=
self
.
callbacks
,
**
self
.
credentials
,
**
provider_model_kwargs
)
def
_run
(
self
,
messages
:
List
[
PromptMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
->
LLMResult
:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens_from_messages
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
for
k
,
v
in
provider_model_kwargs
.
items
():
if
hasattr
(
self
.
client
,
k
):
setattr
(
self
.
client
,
k
,
v
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"ZhipuAI: {str(ex)}"
)
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/models/moderation/openai_moderation.py
View file @
827c97f0
...
@@ -23,14 +23,18 @@ class OpenAIModeration(BaseModeration):
...
@@ -23,14 +23,18 @@ class OpenAIModeration(BaseModeration):
# 2000 text per chunk
# 2000 text per chunk
length
=
2000
length
=
2000
chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
text_
chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
m
oderation_result
=
self
.
_client
.
create
(
input
=
chunks
,
m
ax_text_chunks
=
32
api_key
=
credentials
[
'openai_api_key'
])
chunks
=
[
text_chunks
[
i
:
i
+
max_text_chunks
]
for
i
in
range
(
0
,
len
(
text_chunks
),
max_text_chunks
)]
for
result
in
moderation_result
.
results
:
for
text_chunk
in
chunks
:
if
result
[
'flagged'
]
is
True
:
moderation_result
=
self
.
_client
.
create
(
input
=
text_chunk
,
return
False
api_key
=
credentials
[
'openai_api_key'
])
for
result
in
moderation_result
.
results
:
if
result
[
'flagged'
]
is
True
:
return
False
return
True
return
True
...
...
api/core/model_providers/providers/anthropic_provider.py
View file @
827c97f0
...
@@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
...
@@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
:return:
:return:
"""
"""
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
alias
=
"max_tokens_to_sample"
,
min
=
10
,
max
=
100000
,
default
=
256
),
max_tokens
=
KwargRule
[
int
](
alias
=
"max_tokens_to_sample"
,
min
=
10
,
max
=
100000
,
default
=
256
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/azure_openai_provider.py
View file @
827c97f0
...
@@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
...
@@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
model_credentials
=
self
.
get_model_credentials
(
model_name
,
model_type
)
model_credentials
=
self
.
get_model_credentials
(
model_name
,
model_type
)
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
base_model_max_tokens
.
get
(
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
base_model_max_tokens
.
get
(
model_credentials
[
'base_model_name'
],
model_credentials
[
'base_model_name'
],
4097
4097
),
default
=
16
),
),
default
=
16
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/chatglm_provider.py
View file @
827c97f0
...
@@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
...
@@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
}
}
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_token'
,
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
2048
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_token'
,
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
2048
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/hosted.py
View file @
827c97f0
...
@@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
...
@@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
hosted_model_providers
=
HostedModelProviders
()
hosted_model_providers
=
HostedModelProviders
()
class
HostedModerationConfig
(
BaseModel
):
enabled
:
bool
=
False
providers
:
list
[
str
]
=
[]
class
HostedConfig
(
BaseModel
):
moderation
=
HostedModerationConfig
()
hosted_config
=
HostedConfig
()
def
init_app
(
app
:
Flask
):
def
init_app
(
app
:
Flask
):
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
langchain
.
verbose
=
True
langchain
.
verbose
=
True
...
@@ -78,3 +90,9 @@ def init_app(app: Flask):
...
@@ -78,3 +90,9 @@ def init_app(app: Flask):
paid_min_quantity
=
app
.
config
.
get
(
"HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"
),
paid_min_quantity
=
app
.
config
.
get
(
"HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"
),
paid_max_quantity
=
app
.
config
.
get
(
"HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"
),
paid_max_quantity
=
app
.
config
.
get
(
"HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"
),
)
)
if
app
.
config
.
get
(
"HOSTED_MODERATION_ENABLED"
)
and
app
.
config
.
get
(
"HOSTED_MODERATION_PROVIDERS"
):
hosted_config
.
moderation
=
HostedModerationConfig
(
enabled
=
app
.
config
.
get
(
"HOSTED_MODERATION_ENABLED"
),
providers
=
app
.
config
.
get
(
"HOSTED_MODERATION_PROVIDERS"
)
.
split
(
','
)
)
api/core/model_providers/providers/huggingface_hub_provider.py
View file @
827c97f0
...
@@ -47,11 +47,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
...
@@ -47,11 +47,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
:return:
:return:
"""
"""
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
0.99
,
default
=
0.7
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
0.99
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
200
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
200
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/localai_provider.py
View file @
827c97f0
...
@@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
...
@@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
:return:
:return:
"""
"""
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
0.7
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
0.7
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4097
,
default
=
16
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4097
,
default
=
16
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/minimax_provider.py
View file @
827c97f0
...
@@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
...
@@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
}
}
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.9
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.9
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.95
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
,
6144
),
default
=
1024
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
,
6144
),
default
=
1024
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/openai_provider.py
View file @
827c97f0
...
@@ -133,11 +133,11 @@ class OpenAIProvider(BaseModelProvider):
...
@@ -133,11 +133,11 @@ class OpenAIProvider(BaseModelProvider):
}
}
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
1
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
,
4097
),
default
=
16
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
,
4097
),
default
=
16
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/openllm_provider.py
View file @
827c97f0
...
@@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
...
@@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return:
:return:
"""
"""
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
128
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
128
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/replicate_provider.py
View file @
827c97f0
...
@@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
...
@@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
min
=
float
(
value
.
get
(
'minimum'
))
if
value
.
get
(
'minimum'
)
is
not
None
else
None
,
min
=
float
(
value
.
get
(
'minimum'
))
if
value
.
get
(
'minimum'
)
is
not
None
else
None
,
max
=
float
(
value
.
get
(
'maximum'
))
if
value
.
get
(
'maximum'
)
is
not
None
else
None
,
max
=
float
(
value
.
get
(
'maximum'
))
if
value
.
get
(
'maximum'
)
is
not
None
else
None
,
default
=
float
(
value
.
get
(
'default'
))
if
value
.
get
(
'default'
)
is
not
None
else
None
,
default
=
float
(
value
.
get
(
'default'
))
if
value
.
get
(
'default'
)
is
not
None
else
None
,
precision
=
2
)
)
if
key
==
'temperature'
:
if
key
==
'temperature'
:
model_kwargs_rules
.
temperature
=
kwarg_rule
model_kwargs_rules
.
temperature
=
kwarg_rule
...
@@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
...
@@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
min
=
int
(
value
.
get
(
'minimum'
))
if
value
.
get
(
'minimum'
)
is
not
None
else
1
,
min
=
int
(
value
.
get
(
'minimum'
))
if
value
.
get
(
'minimum'
)
is
not
None
else
1
,
max
=
int
(
value
.
get
(
'maximum'
))
if
value
.
get
(
'maximum'
)
is
not
None
else
8000
,
max
=
int
(
value
.
get
(
'maximum'
))
if
value
.
get
(
'maximum'
)
is
not
None
else
8000
,
default
=
int
(
value
.
get
(
'default'
))
if
value
.
get
(
'default'
)
is
not
None
else
500
,
default
=
int
(
value
.
get
(
'default'
))
if
value
.
get
(
'default'
)
is
not
None
else
500
,
precision
=
0
)
)
return
model_kwargs_rules
return
model_kwargs_rules
...
...
api/core/model_providers/providers/spark_provider.py
View file @
827c97f0
...
@@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
...
@@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
:return:
:return:
"""
"""
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.5
),
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.5
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
enabled
=
False
),
top_p
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4096
,
default
=
2048
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4096
,
default
=
2048
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/tongyi_provider.py
View file @
827c97f0
...
@@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
...
@@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
enabled
=
False
),
temperature
=
KwargRule
[
float
](
enabled
=
False
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.8
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
1024
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
model_max_tokens
.
get
(
model_name
),
default
=
1024
,
precision
=
0
),
)
)
@
classmethod
@
classmethod
...
...
api/core/model_providers/providers/wenxin_provider.py
View file @
827c97f0
...
@@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider):
...
@@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider):
"""
"""
if
model_name
in
[
'ernie-bot'
,
'ernie-bot-turbo'
]:
if
model_name
in
[
'ernie-bot'
,
'ernie-bot-turbo'
]:
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
enabled
=
False
),
...
...
api/core/model_providers/providers/xinference_provider.py
View file @
827c97f0
...
@@ -53,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
...
@@ -53,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
credentials
=
self
.
get_model_credentials
(
model_name
,
model_type
)
credentials
=
self
.
get_model_credentials
(
model_name
,
model_type
)
if
credentials
[
'model_format'
]
==
"ggmlv3"
and
credentials
[
"model_handle_type"
]
==
"chatglm"
:
if
credentials
[
'model_format'
]
==
"ggmlv3"
and
credentials
[
"model_handle_type"
]
==
"chatglm"
:
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
,
precision
=
0
),
)
)
elif
credentials
[
'model_format'
]
==
"ggmlv3"
:
elif
credentials
[
'model_format'
]
==
"ggmlv3"
:
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
,
precision
=
2
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
,
precision
=
0
),
)
)
else
:
else
:
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
,
precision
=
0
),
)
)
...
...
api/core/model_providers/providers/zhipuai_provider.py
0 → 100644
View file @
827c97f0
import
json
from
json
import
JSONDecodeError
from
typing
import
Type
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.embedding.zhipuai_embedding
import
ZhipuAIEmbedding
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
from
core.model_providers.models.llm.zhipuai_model
import
ZhipuAIModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.third_party.langchain.llms.zhipuai_llm
import
ZhipuAIChatLLM
from
models.provider
import
ProviderType
,
ProviderQuotaType
class
ZhipuAIProvider
(
BaseModelProvider
):
@
property
def
provider_name
(
self
):
"""
Returns the name of a provider.
"""
return
'zhipuai'
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
{
'id'
:
'chatglm_pro'
,
'name'
:
'chatglm_pro'
,
},
{
'id'
:
'chatglm_std'
,
'name'
:
'chatglm_std'
,
},
{
'id'
:
'chatglm_lite'
,
'name'
:
'chatglm_lite'
,
},
{
'id'
:
'chatglm_lite_32k'
,
'name'
:
'chatglm_lite_32k'
,
}
]
elif
model_type
==
ModelType
.
EMBEDDINGS
:
return
[
{
'id'
:
'text_embedding'
,
'name'
:
'text_embedding'
,
}
]
else
:
return
[]
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
"""
Returns the model class.
:param model_type:
:return:
"""
if
model_type
==
ModelType
.
TEXT_GENERATION
:
model_class
=
ZhipuAIModel
elif
model_type
==
ModelType
.
EMBEDDINGS
:
model_class
=
ZhipuAIEmbedding
else
:
raise
NotImplementedError
return
model_class
def
get_model_parameter_rules
(
self
,
model_name
:
str
,
model_type
:
ModelType
)
->
ModelKwargsRules
:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.1
,
max
=
0.9
,
default
=
0.8
,
precision
=
1
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
enabled
=
False
),
)
@
classmethod
def
is_provider_credentials_valid_or_raise
(
cls
,
credentials
:
dict
):
"""
Validates the given credentials.
"""
if
'api_key'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'ZhipuAI api_key must be provided.'
)
try
:
credential_kwargs
=
{
'api_key'
:
credentials
[
'api_key'
]
}
llm
=
ZhipuAIChatLLM
(
temperature
=
0.01
,
**
credential_kwargs
)
llm
([
HumanMessage
(
content
=
'ping'
)])
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
@
classmethod
def
encrypt_provider_credentials
(
cls
,
tenant_id
:
str
,
credentials
:
dict
)
->
dict
:
credentials
[
'api_key'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'api_key'
])
return
credentials
def
get_provider_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
dict
:
if
self
.
provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
\
or
(
self
.
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
and
self
.
provider
.
quota_type
==
ProviderQuotaType
.
FREE
.
value
):
try
:
credentials
=
json
.
loads
(
self
.
provider
.
encrypted_config
)
except
JSONDecodeError
:
credentials
=
{
'api_key'
:
None
,
}
if
credentials
[
'api_key'
]:
credentials
[
'api_key'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
credentials
[
'api_key'
]
)
if
obfuscated
:
credentials
[
'api_key'
]
=
encrypter
.
obfuscated_token
(
credentials
[
'api_key'
])
return
credentials
else
:
return
{}
def
should_deduct_quota
(
self
):
return
True
@
classmethod
def
is_model_credentials_valid_or_raise
(
cls
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@
classmethod
def
encrypt_model_credentials
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
)
->
dict
:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return
{}
def
get_model_credentials
(
self
,
model_name
:
str
,
model_type
:
ModelType
,
obfuscated
:
bool
=
False
)
->
dict
:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return
self
.
get_provider_credentials
(
obfuscated
)
api/core/model_providers/rules/_providers.json
View file @
827c97f0
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
"tongyi"
,
"tongyi"
,
"spark"
,
"spark"
,
"wenxin"
,
"wenxin"
,
"zhipuai"
,
"chatglm"
,
"chatglm"
,
"replicate"
,
"replicate"
,
"huggingface_hub"
,
"huggingface_hub"
,
...
...
api/core/model_providers/rules/zhipuai.json
0 → 100644
View file @
827c97f0
{
"support_provider_types"
:
[
"system"
,
"custom"
],
"system_config"
:
{
"supported_quota_types"
:
[
"free"
],
"quota_unit"
:
"tokens"
},
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"chatglm_pro"
:
{
"prompt"
:
"0.01"
,
"completion"
:
"0.01"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"chatglm_std"
:
{
"prompt"
:
"0.005"
,
"completion"
:
"0.005"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"chatglm_lite"
:
{
"prompt"
:
"0.002"
,
"completion"
:
"0.002"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"chatglm_lite_32k"
:
{
"prompt"
:
"0.0004"
,
"completion"
:
"0.0004"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"text_embedding"
:
{
"completion"
:
"0"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
}
}
}
\ No newline at end of file
api/core/third_party/langchain/embeddings/zhipuai_embedding.py
0 → 100644
View file @
827c97f0
"""Wrapper around ZhipuAI embedding models."""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
pydantic
import
BaseModel
,
Extra
,
root_validator
from
langchain.embeddings.base
import
Embeddings
from
langchain.utils
import
get_from_dict_or_env
from
core.third_party.langchain.llms.zhipuai_llm
import
ZhipuModelAPI
class
ZhipuAIEmbeddings
(
BaseModel
,
Embeddings
):
"""Wrapper around ZhipuAI embedding models.
1024 dimensions.
"""
client
:
Any
#: :meta private:
model
:
str
"""Model name to use."""
base_url
:
str
=
"https://open.bigmodel.cn/api/paas/v3/model-api"
api_key
:
Optional
[
str
]
=
None
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
values
[
"api_key"
]
=
get_from_dict_or_env
(
values
,
"api_key"
,
"ZHIPUAI_API_KEY"
)
values
[
'client'
]
=
ZhipuModelAPI
(
api_key
=
values
[
'api_key'
],
base_url
=
values
[
'base_url'
])
return
values
def
embed_documents
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings
=
[]
for
text
in
texts
:
response
=
self
.
client
.
invoke
(
model
=
self
.
model
,
prompt
=
text
)
data
=
response
[
"data"
]
embeddings
.
append
(
data
.
get
(
'embedding'
))
return
[
list
(
map
(
float
,
e
))
for
e
in
embeddings
]
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return
self
.
embed_documents
([
text
])[
0
]
api/core/third_party/langchain/llms/zhipuai_llm.py
0 → 100644
View file @
827c97f0
This diff is collapsed.
Click to expand it.
api/requirements.txt
View file @
827c97f0
...
@@ -50,4 +50,5 @@ transformers~=4.31.0
...
@@ -50,4 +50,5 @@ transformers~=4.31.0
stripe~=5.5.0
stripe~=5.5.0
pandas==1.5.3
pandas==1.5.3
xinference==0.4.2
xinference==0.4.2
safetensors==0.3.2
safetensors==0.3.2
\ No newline at end of file
zhipuai==1.0.7
api/services/provider_service.py
View file @
827c97f0
...
@@ -548,7 +548,7 @@ class ProviderService:
...
@@ -548,7 +548,7 @@ class ProviderService:
'result'
:
'success'
'result'
:
'success'
}
}
def
free_quota_qualification_verify
(
self
,
tenant_id
:
str
,
provider_name
:
str
):
def
free_quota_qualification_verify
(
self
,
tenant_id
:
str
,
provider_name
:
str
,
token
:
Optional
[
str
]
):
api_key
=
os
.
environ
.
get
(
"FREE_QUOTA_APPLY_API_KEY"
)
api_key
=
os
.
environ
.
get
(
"FREE_QUOTA_APPLY_API_KEY"
)
api_base_url
=
os
.
environ
.
get
(
"FREE_QUOTA_APPLY_BASE_URL"
)
api_base_url
=
os
.
environ
.
get
(
"FREE_QUOTA_APPLY_BASE_URL"
)
api_url
=
api_base_url
+
'/api/v1/providers/qualification-verify'
api_url
=
api_base_url
+
'/api/v1/providers/qualification-verify'
...
@@ -557,8 +557,11 @@ class ProviderService:
...
@@ -557,8 +557,11 @@ class ProviderService:
'Content-Type'
:
'application/json'
,
'Content-Type'
:
'application/json'
,
'Authorization'
:
f
"Bearer {api_key}"
'Authorization'
:
f
"Bearer {api_key}"
}
}
json_data
=
{
'workspace_id'
:
tenant_id
,
'provider_name'
:
provider_name
}
if
token
:
json_data
[
'token'
]
=
token
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
{
'workspace_id'
:
tenant_id
,
'provider_name'
:
provider_name
}
)
json
=
json_data
)
if
not
response
.
ok
:
if
not
response
.
ok
:
logging
.
error
(
f
"Request FREE QUOTA APPLY SERVER Error: {response.status_code} "
)
logging
.
error
(
f
"Request FREE QUOTA APPLY SERVER Error: {response.status_code} "
)
raise
ValueError
(
f
"Error: {response.status_code} "
)
raise
ValueError
(
f
"Error: {response.status_code} "
)
...
...
api/tests/integration_tests/.env.example
View file @
827c97f0
...
@@ -31,6 +31,9 @@ TONGYI_DASHSCOPE_API_KEY=
...
@@ -31,6 +31,9 @@ TONGYI_DASHSCOPE_API_KEY=
WENXIN_API_KEY=
WENXIN_API_KEY=
WENXIN_SECRET_KEY=
WENXIN_SECRET_KEY=
# ZhipuAI Credentials
ZHIPUAI_API_KEY=
# ChatGLM Credentials
# ChatGLM Credentials
CHATGLM_API_BASE=
CHATGLM_API_BASE=
...
...
api/tests/integration_tests/models/embedding/test_zhipuai_embedding.py
0 → 100644
View file @
827c97f0
import
json
import
os
from
unittest.mock
import
patch
from
core.model_providers.models.embedding.zhipuai_embedding
import
ZhipuAIEmbedding
from
core.model_providers.providers.zhipuai_provider
import
ZhipuAIProvider
from
models.provider
import
Provider
,
ProviderType
def
get_mock_provider
(
valid_api_key
):
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'zhipuai'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
({
'api_key'
:
valid_api_key
}),
is_valid
=
True
,
)
def
get_mock_embedding_model
():
model_name
=
'text_embedding'
valid_api_key
=
os
.
environ
[
'ZHIPUAI_API_KEY'
]
provider
=
ZhipuAIProvider
(
provider
=
get_mock_provider
(
valid_api_key
))
return
ZhipuAIEmbedding
(
model_provider
=
provider
,
name
=
model_name
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_embedding
(
mock_decrypt
):
embedding_model
=
get_mock_embedding_model
()
rst
=
embedding_model
.
client
.
embed_query
(
'test'
)
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
1024
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_doc_embedding
(
mock_decrypt
):
embedding_model
=
get_mock_embedding_model
()
rst
=
embedding_model
.
client
.
embed_documents
([
'test'
,
'test2'
])
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
[
0
])
==
1024
api/tests/integration_tests/models/llm/test_zhipuai_model.py
0 → 100644
View file @
827c97f0
import
json
import
os
from
unittest.mock
import
patch
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.model_params
import
ModelKwargs
from
core.model_providers.models.llm.zhipuai_model
import
ZhipuAIModel
from
core.model_providers.providers.zhipuai_provider
import
ZhipuAIProvider
from
models.provider
import
Provider
,
ProviderType
def
get_mock_provider
(
valid_api_key
):
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'zhipuai'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
({
'api_key'
:
valid_api_key
}),
is_valid
=
True
,
)
def
get_mock_model
(
model_name
:
str
,
streaming
:
bool
=
False
):
model_kwargs
=
ModelKwargs
(
temperature
=
0.01
,
)
valid_api_key
=
os
.
environ
[
'ZHIPUAI_API_KEY'
]
model_provider
=
ZhipuAIProvider
(
provider
=
get_mock_provider
(
valid_api_key
))
return
ZhipuAIModel
(
model_provider
=
model_provider
,
name
=
model_name
,
model_kwargs
=
model_kwargs
,
streaming
=
streaming
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_get_num_tokens
(
mock_decrypt
):
model
=
get_mock_model
(
'chatglm_lite'
)
rst
=
model
.
get_num_tokens
([
PromptMessage
(
type
=
MessageType
.
SYSTEM
,
content
=
'you are a kindness Assistant.'
),
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Who is your manufacturer?'
)
])
assert
rst
>
0
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'chatglm_lite'
)
messages
=
[
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Are you Human? you MUST only answer `y` or `n`?'
)
]
rst
=
model
.
run
(
messages
,
)
assert
len
(
rst
.
content
)
>
0
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_stream_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'chatglm_lite'
,
streaming
=
True
)
messages
=
[
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Are you Human? you MUST only answer `y` or `n`?'
)
]
rst
=
model
.
run
(
messages
)
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_spark_provider.py
View file @
827c97f0
...
@@ -39,7 +39,7 @@ def test_is_provider_credentials_valid_or_raise_invalid():
...
@@ -39,7 +39,7 @@ def test_is_provider_credentials_valid_or_raise_invalid():
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
({})
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
({})
credential
=
VALIDATE_CREDENTIAL
.
copy
()
credential
=
VALIDATE_CREDENTIAL
.
copy
()
credential
[
'api_key'
]
=
'invalid_key'
del
credential
[
'api_key'
]
# raise CredentialsValidateFailedError if api_key is invalid
# raise CredentialsValidateFailedError if api_key is invalid
with
pytest
.
raises
(
CredentialsValidateFailedError
):
with
pytest
.
raises
(
CredentialsValidateFailedError
):
...
...
api/tests/unit_tests/model_providers/test_zhipuai_provider.py
0 → 100644
View file @
827c97f0
import
pytest
from
unittest.mock
import
patch
import
json
from
langchain.schema
import
ChatResult
,
ChatGeneration
,
AIMessage
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.zhipuai_provider
import
ZhipuAIProvider
from
models.provider
import
ProviderType
,
Provider
PROVIDER_NAME
=
'zhipuai'
MODEL_PROVIDER_CLASS
=
ZhipuAIProvider
VALIDATE_CREDENTIAL
=
{
'api_key'
:
'valid_key'
,
}
def
encrypt_side_effect
(
tenant_id
,
encrypt_key
):
return
f
'encrypted_{encrypt_key}'
def
decrypt_side_effect
(
tenant_id
,
encrypted_key
):
return
encrypted_key
.
replace
(
'encrypted_'
,
''
)
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'core.third_party.langchain.llms.zhipuai_llm.ZhipuAIChatLLM._generate'
,
return_value
=
ChatResult
(
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
'abc'
))]))
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
def
test_is_provider_credentials_valid_or_raise_invalid
():
# raise CredentialsValidateFailedError if api_key is not in credentials
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
({})
credential
=
VALIDATE_CREDENTIAL
.
copy
()
credential
[
'api_key'
]
=
'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
credential
)
@
patch
(
'core.helper.encrypter.encrypt_token'
,
side_effect
=
encrypt_side_effect
)
def
test_encrypt_credentials
(
mock_encrypt
):
result
=
MODEL_PROVIDER_CLASS
.
encrypt_provider_credentials
(
'tenant_id'
,
VALIDATE_CREDENTIAL
.
copy
())
assert
result
[
'api_key'
]
==
f
'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_credentials_custom
(
mock_decrypt
):
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'api_key'
]
=
'encrypted_'
+
encrypted_credential
[
'api_key'
]
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
(
encrypted_credential
),
is_valid
=
True
,
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_provider_credentials
()
assert
result
[
'api_key'
]
==
'valid_key'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_credentials_obfuscated
(
mock_decrypt
):
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'api_key'
]
=
'encrypted_'
+
encrypted_credential
[
'api_key'
]
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
(
encrypted_credential
),
is_valid
=
True
,
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_provider_credentials
(
obfuscated
=
True
)
middle_token
=
result
[
'api_key'
][
6
:
-
2
]
assert
len
(
middle_token
)
==
max
(
len
(
VALIDATE_CREDENTIAL
[
'api_key'
])
-
8
,
0
)
assert
all
(
char
==
'*'
for
char
in
middle_token
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment