Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
1d4f019d
Unverified
Commit
1d4f019d
authored
Oct 10, 2023
by
takatost
Committed by
GitHub
Oct 10, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add baichuan llm support (#1294)
Co-authored-by:
zxhlyh
<
jasonapring2015@outlook.com
>
parent
677aacc8
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
745 additions
and
1 deletion
+745
-1
model_provider_factory.py
api/core/model_providers/model_provider_factory.py
+3
-0
baichuan_model.py
api/core/model_providers/models/llm/baichuan_model.py
+61
-0
baichuan_provider.py
api/core/model_providers/providers/baichuan_provider.py
+167
-0
_providers.json
api/core/model_providers/rules/_providers.json
+2
-1
baichuan.json
api/core/model_providers/rules/baichuan.json
+15
-0
baichuan_llm.py
api/core/third_party/langchain/llms/baichuan_llm.py
+315
-0
.env.example
api/tests/integration_tests/.env.example
+4
-0
test_baichuan_model.py
...tests/integration_tests/models/llm/test_baichuan_model.py
+81
-0
test_baichuan_provider.py
...ests/unit_tests/model_providers/test_baichuan_provider.py
+97
-0
No files found.
api/core/model_providers/model_provider_factory.py
View file @
1d4f019d
...
...
@@ -51,6 +51,9 @@ class ModelProviderFactory:
elif
provider_name
==
'chatglm'
:
from
core.model_providers.providers.chatglm_provider
import
ChatGLMProvider
return
ChatGLMProvider
elif
provider_name
==
'baichuan'
:
from
core.model_providers.providers.baichuan_provider
import
BaichuanProvider
return
BaichuanProvider
elif
provider_name
==
'azure_openai'
:
from
core.model_providers.providers.azure_openai_provider
import
AzureOpenAIProvider
return
AzureOpenAIProvider
...
...
api/core/model_providers/models/llm/baichuan_model.py
0 → 100644
View file @
1d4f019d
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.baichuan_llm
import
BaichuanChatLLM
class
BaichuanModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
CHAT
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
return
BaichuanChatLLM
(
streaming
=
self
.
streaming
,
callbacks
=
self
.
callbacks
,
**
self
.
credentials
,
**
provider_model_kwargs
)
def
_run
(
self
,
messages
:
List
[
PromptMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
->
LLMResult
:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens_from_messages
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
for
k
,
v
in
provider_model_kwargs
.
items
():
if
hasattr
(
self
.
client
,
k
):
setattr
(
self
.
client
,
k
,
v
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Baichuan: {str(ex)}"
)
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/providers/baichuan_provider.py
0 → 100644
View file @
1d4f019d
import
json
from
json
import
JSONDecodeError
from
typing
import
Type
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
from
core.model_providers.models.llm.baichuan_model
import
BaichuanModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.third_party.langchain.llms.baichuan_llm
import
BaichuanChatLLM
from
models.provider
import
ProviderType
class
BaichuanProvider
(
BaseModelProvider
):
@
property
def
provider_name
(
self
):
"""
Returns the name of a provider.
"""
return
'baichuan'
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
{
'id'
:
'baichuan2-53b'
,
'name'
:
'Baichuan2-53B'
,
}
]
else
:
return
[]
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
"""
Returns the model class.
:param model_type:
:return:
"""
if
model_type
==
ModelType
.
TEXT_GENERATION
:
model_class
=
BaichuanModel
else
:
raise
NotImplementedError
return
model_class
def
get_model_parameter_rules
(
self
,
model_name
:
str
,
model_type
:
ModelType
)
->
ModelKwargsRules
:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.3
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
0.99
,
default
=
0.85
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
enabled
=
False
),
)
@
classmethod
def
is_provider_credentials_valid_or_raise
(
cls
,
credentials
:
dict
):
"""
Validates the given credentials.
"""
if
'api_key'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Baichuan api_key must be provided.'
)
if
'secret_key'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Baichuan secret_key must be provided.'
)
try
:
credential_kwargs
=
{
'api_key'
:
credentials
[
'api_key'
],
'secret_key'
:
credentials
[
'secret_key'
],
}
llm
=
BaichuanChatLLM
(
temperature
=
0
,
**
credential_kwargs
)
llm
([
HumanMessage
(
content
=
'ping'
)])
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
@
classmethod
def
encrypt_provider_credentials
(
cls
,
tenant_id
:
str
,
credentials
:
dict
)
->
dict
:
credentials
[
'api_key'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'api_key'
])
credentials
[
'secret_key'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'secret_key'
])
return
credentials
def
get_provider_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
dict
:
if
self
.
provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
:
try
:
credentials
=
json
.
loads
(
self
.
provider
.
encrypted_config
)
except
JSONDecodeError
:
credentials
=
{
'api_key'
:
None
,
'secret_key'
:
None
,
}
if
credentials
[
'api_key'
]:
credentials
[
'api_key'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
credentials
[
'api_key'
]
)
if
obfuscated
:
credentials
[
'api_key'
]
=
encrypter
.
obfuscated_token
(
credentials
[
'api_key'
])
if
credentials
[
'secret_key'
]:
credentials
[
'secret_key'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
credentials
[
'secret_key'
]
)
if
obfuscated
:
credentials
[
'secret_key'
]
=
encrypter
.
obfuscated_token
(
credentials
[
'secret_key'
])
return
credentials
else
:
return
{}
def
should_deduct_quota
(
self
):
return
True
@
classmethod
def
is_model_credentials_valid_or_raise
(
cls
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@
classmethod
def
encrypt_model_credentials
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
)
->
dict
:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return
{}
def
get_model_credentials
(
self
,
model_name
:
str
,
model_type
:
ModelType
,
obfuscated
:
bool
=
False
)
->
dict
:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return
self
.
get_provider_credentials
(
obfuscated
)
api/core/model_providers/rules/_providers.json
View file @
1d4f019d
...
...
@@ -7,6 +7,7 @@
"spark"
,
"wenxin"
,
"zhipuai"
,
"baichuan"
,
"chatglm"
,
"replicate"
,
"huggingface_hub"
,
...
...
api/core/model_providers/rules/baichuan.json
0 → 100644
View file @
1d4f019d
{
"support_provider_types"
:
[
"custom"
],
"system_config"
:
null
,
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"baichuan2-53b"
:
{
"prompt"
:
"0.01"
,
"completion"
:
"0.01"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
}
}
}
\ No newline at end of file
api/core/third_party/langchain/llms/baichuan_llm.py
0 → 100644
View file @
1d4f019d
"""Wrapper around Baichuan APIs."""
from
__future__
import
annotations
import
hashlib
import
json
import
logging
import
time
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Iterator
,
)
import
requests
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.schema
import
BaseMessage
,
ChatMessage
,
HumanMessage
,
AIMessage
,
SystemMessage
from
langchain.schema.messages
import
AIMessageChunk
from
langchain.schema.output
import
ChatResult
,
ChatGenerationChunk
,
ChatGeneration
from
pydantic
import
Extra
,
root_validator
,
BaseModel
from
langchain.callbacks.manager
import
(
CallbackManagerForLLMRun
,
)
from
langchain.utils
import
get_from_dict_or_env
logger
=
logging
.
getLogger
(
__name__
)
class
BaichuanModelAPI
(
BaseModel
):
api_key
:
str
secret_key
:
str
base_url
:
str
=
"https://api.baichuan-ai.com/v1"
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
def
do_request
(
self
,
model
:
str
,
messages
:
list
[
dict
],
parameters
:
dict
,
**
kwargs
:
Any
):
stream
=
'stream'
in
kwargs
and
kwargs
[
'stream'
]
url
=
self
.
base_url
+
(
"/stream/chat"
if
stream
else
"/chat"
)
data
=
{
"model"
:
model
,
"messages"
:
messages
,
"parameters"
:
parameters
}
json_data
=
json
.
dumps
(
data
)
time_stamp
=
int
(
time
.
time
())
signature
=
self
.
_calculate_md5
(
self
.
secret_key
+
json_data
+
str
(
time_stamp
))
headers
=
{
"Content-Type"
:
"application/json"
,
"Authorization"
:
"Bearer "
+
self
.
api_key
,
"X-BC-Request-Id"
:
"your requestId"
,
"X-BC-Timestamp"
:
str
(
time_stamp
),
"X-BC-Signature"
:
signature
,
"X-BC-Sign-Algo"
:
"MD5"
,
}
response
=
requests
.
post
(
url
,
data
=
json_data
,
headers
=
headers
,
stream
=
stream
,
timeout
=
(
5
,
60
))
if
not
response
.
ok
:
raise
ValueError
(
f
"HTTP {response.status_code} error: {response.text}"
)
if
not
stream
:
json_response
=
response
.
json
()
if
json_response
[
'code'
]
!=
0
:
raise
ValueError
(
f
"API {json_response['code']}"
f
" error: {json_response['msg']}"
)
return
json_response
else
:
return
response
def
_calculate_md5
(
self
,
input_string
):
md5
=
hashlib
.
md5
()
md5
.
update
(
input_string
.
encode
(
'utf-8'
))
encrypted
=
md5
.
hexdigest
()
return
encrypted
class
BaichuanChatLLM
(
BaseChatModel
):
"""Wrapper around Baichuan large language models.
To use, you should pass the api_key as a named parameter to the constructor.
Example:
.. code-block:: python
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
model = BaichuanChatLLM(model="<model_name>", api_key="my-api-key", secret_key="my-secret-key")
"""
@
property
def
lc_secrets
(
self
)
->
Dict
[
str
,
str
]:
return
{
"api_key"
:
"API_KEY"
,
"secret_key"
:
"SECRET_KEY"
}
@
property
def
lc_serializable
(
self
)
->
bool
:
return
True
client
:
Any
=
None
#: :meta private:
model
:
str
=
"Baichuan2-53B"
"""Model name to use."""
temperature
:
float
=
0.3
"""A non-negative float that tunes the degree of randomness in generation."""
top_p
:
float
=
0.85
"""Total probability mass of tokens to consider at each step."""
streaming
:
bool
=
False
"""Whether to stream the response or return it all at once."""
api_key
:
Optional
[
str
]
=
None
secret_key
:
Optional
[
str
]
=
None
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
values
[
"api_key"
]
=
get_from_dict_or_env
(
values
,
"api_key"
,
"BAICHUAN_API_KEY"
)
values
[
"secret_key"
]
=
get_from_dict_or_env
(
values
,
"secret_key"
,
"BAICHUAN_SECRET_KEY"
)
values
[
'client'
]
=
BaichuanModelAPI
(
api_key
=
values
[
'api_key'
],
secret_key
=
values
[
'secret_key'
]
)
return
values
@
property
def
_default_params
(
self
)
->
Dict
[
str
,
Any
]:
"""Get the default parameters for calling OpenAI API."""
return
{
"model"
:
self
.
model
,
"parameters"
:
{
"temperature"
:
self
.
temperature
,
"top_p"
:
self
.
top_p
}
}
@
property
def
_identifying_params
(
self
)
->
Dict
[
str
,
Any
]:
"""Get the identifying parameters."""
return
self
.
_default_params
@
property
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
return
"baichuan"
def
_convert_message_to_dict
(
self
,
message
:
BaseMessage
)
->
dict
:
if
isinstance
(
message
,
ChatMessage
):
message_dict
=
{
"role"
:
message
.
role
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
HumanMessage
):
message_dict
=
{
"role"
:
"user"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
AIMessage
):
message_dict
=
{
"role"
:
"assistant"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
SystemMessage
):
message_dict
=
{
"role"
:
"user"
,
"content"
:
message
.
content
}
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
return
message_dict
def
_convert_dict_to_message
(
self
,
_dict
:
Dict
[
str
,
Any
])
->
BaseMessage
:
role
=
_dict
[
"role"
]
if
role
==
"user"
:
return
HumanMessage
(
content
=
_dict
[
"content"
])
elif
role
==
"assistant"
:
return
AIMessage
(
content
=
_dict
[
"content"
])
elif
role
==
"system"
:
return
SystemMessage
(
content
=
_dict
[
"content"
])
else
:
return
ChatMessage
(
content
=
_dict
[
"content"
],
role
=
role
)
def
_create_message_dicts
(
self
,
messages
:
List
[
BaseMessage
]
)
->
List
[
Dict
[
str
,
Any
]]:
dict_messages
=
[]
for
m
in
messages
:
message
=
self
.
_convert_message_to_dict
(
m
)
if
dict_messages
:
previous_message
=
dict_messages
[
-
1
]
if
previous_message
[
'role'
]
==
message
[
'role'
]:
dict_messages
[
-
1
][
'content'
]
+=
f
"
\n
{message['content']}"
else
:
dict_messages
.
append
(
message
)
else
:
dict_messages
.
append
(
message
)
return
dict_messages
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
if
self
.
streaming
:
generation
:
Optional
[
ChatGenerationChunk
]
=
None
llm_output
:
Optional
[
Dict
]
=
None
for
chunk
in
self
.
_stream
(
messages
=
messages
,
stop
=
stop
,
run_manager
=
run_manager
,
**
kwargs
):
if
generation
is
None
:
generation
=
chunk
else
:
generation
+=
chunk
if
chunk
.
generation_info
is
not
None
\
and
'token_usage'
in
chunk
.
generation_info
:
llm_output
=
{
"token_usage"
:
chunk
.
generation_info
[
'token_usage'
],
"model_name"
:
self
.
model
}
assert
generation
is
not
None
return
ChatResult
(
generations
=
[
generation
],
llm_output
=
llm_output
)
else
:
message_dicts
=
self
.
_create_message_dicts
(
messages
)
params
=
self
.
_default_params
params
[
"messages"
]
=
message_dicts
params
.
update
(
kwargs
)
response
=
self
.
client
.
do_request
(
**
params
)
return
self
.
_create_chat_result
(
response
)
def
_stream
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
Iterator
[
ChatGenerationChunk
]:
message_dicts
=
self
.
_create_message_dicts
(
messages
)
params
=
self
.
_default_params
params
[
"messages"
]
=
message_dicts
params
.
update
(
kwargs
)
for
event
in
self
.
client
.
do_request
(
stream
=
True
,
**
params
)
.
iter_lines
():
if
event
:
event
=
event
.
decode
(
"utf-8"
)
meta
=
json
.
loads
(
event
)
if
meta
[
'code'
]
!=
0
:
raise
ValueError
(
f
"API {meta['code']}"
f
" error: {meta['msg']}"
)
content
=
meta
[
'data'
][
'messages'
][
0
][
'content'
]
chunk_kwargs
=
{
'message'
:
AIMessageChunk
(
content
=
content
),
}
if
'usage'
in
meta
:
token_usage
=
meta
[
'usage'
]
overall_token_usage
=
{
'prompt_tokens'
:
token_usage
.
get
(
'prompt_tokens'
,
0
),
'completion_tokens'
:
token_usage
.
get
(
'answer_tokens'
,
0
),
'total_tokens'
:
token_usage
.
get
(
'total_tokens'
,
0
)
}
chunk_kwargs
[
'generation_info'
]
=
{
'token_usage'
:
overall_token_usage
}
yield
ChatGenerationChunk
(
**
chunk_kwargs
)
if
run_manager
:
run_manager
.
on_llm_new_token
(
content
)
def
_create_chat_result
(
self
,
response
:
Dict
[
str
,
Any
])
->
ChatResult
:
data
=
response
[
"data"
]
generations
=
[]
for
res
in
data
[
"messages"
]:
message
=
self
.
_convert_dict_to_message
(
res
)
gen
=
ChatGeneration
(
message
=
message
)
generations
.
append
(
gen
)
usage
=
response
.
get
(
"usage"
)
token_usage
=
{
'prompt_tokens'
:
usage
.
get
(
'prompt_tokens'
,
0
),
'completion_tokens'
:
usage
.
get
(
'answer_tokens'
,
0
),
'total_tokens'
:
usage
.
get
(
'total_tokens'
,
0
)
}
llm_output
=
{
"token_usage"
:
token_usage
,
"model_name"
:
self
.
model
}
return
ChatResult
(
generations
=
generations
,
llm_output
=
llm_output
)
def
get_num_tokens_from_messages
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return
sum
([
self
.
get_num_tokens
(
m
.
content
)
for
m
in
messages
])
def
_combine_llm_outputs
(
self
,
llm_outputs
:
List
[
Optional
[
dict
]])
->
dict
:
token_usage
:
dict
=
{}
for
output
in
llm_outputs
:
if
output
is
None
:
# Happens in streaming
continue
token_usage
=
output
[
"token_usage"
]
return
{
"token_usage"
:
token_usage
,
"model_name"
:
self
.
model
}
api/tests/integration_tests/.env.example
View file @
1d4f019d
...
...
@@ -35,6 +35,10 @@ WENXIN_SECRET_KEY=
# ZhipuAI Credentials
ZHIPUAI_API_KEY=
# Baichuan Credentials
BAICHUAN_API_KEY=
BAICHUAN_SECRET_KEY=
# ChatGLM Credentials
CHATGLM_API_BASE=
...
...
api/tests/integration_tests/models/llm/test_baichuan_model.py
0 → 100644
View file @
1d4f019d
import
json
import
os
from
unittest.mock
import
patch
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.model_params
import
ModelKwargs
from
core.model_providers.models.llm.baichuan_model
import
BaichuanModel
from
core.model_providers.providers.baichuan_provider
import
BaichuanProvider
from
models.provider
import
Provider
,
ProviderType
def
get_mock_provider
(
valid_api_key
,
valid_secret_key
):
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'baichuan'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
({
'api_key'
:
valid_api_key
,
'secret_key'
:
valid_secret_key
,
}),
is_valid
=
True
,
)
def
get_mock_model
(
model_name
:
str
,
streaming
:
bool
=
False
):
model_kwargs
=
ModelKwargs
(
temperature
=
0.01
,
)
valid_api_key
=
os
.
environ
[
'BAICHUAN_API_KEY'
]
valid_secret_key
=
os
.
environ
[
'BAICHUAN_SECRET_KEY'
]
model_provider
=
BaichuanProvider
(
provider
=
get_mock_provider
(
valid_api_key
,
valid_secret_key
))
return
BaichuanModel
(
model_provider
=
model_provider
,
name
=
model_name
,
model_kwargs
=
model_kwargs
,
streaming
=
streaming
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_get_num_tokens
(
mock_decrypt
):
model
=
get_mock_model
(
'baichuan2-53b'
)
rst
=
model
.
get_num_tokens
([
PromptMessage
(
type
=
MessageType
.
SYSTEM
,
content
=
'you are a kindness Assistant.'
),
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Who is your manufacturer?'
)
])
assert
rst
>
0
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'baichuan2-53b'
)
messages
=
[
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Are you Human? you MUST only answer `y` or `n`?'
)
]
rst
=
model
.
run
(
messages
,
)
assert
len
(
rst
.
content
)
>
0
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_stream_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'baichuan2-53b'
,
streaming
=
True
)
messages
=
[
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Are you Human? you MUST only answer `y` or `n`?'
)
]
rst
=
model
.
run
(
messages
)
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_baichuan_provider.py
0 → 100644
View file @
1d4f019d
import
pytest
from
unittest.mock
import
patch
import
json
from
langchain.schema
import
ChatResult
,
ChatGeneration
,
AIMessage
from
core.model_providers.providers.baichuan_provider
import
BaichuanProvider
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
models.provider
import
ProviderType
,
Provider
PROVIDER_NAME
=
'baichuan'
MODEL_PROVIDER_CLASS
=
BaichuanProvider
VALIDATE_CREDENTIAL
=
{
'api_key'
:
'valid_key'
,
'secret_key'
:
'valid_key'
,
}
def
encrypt_side_effect
(
tenant_id
,
encrypt_key
):
return
f
'encrypted_{encrypt_key}'
def
decrypt_side_effect
(
tenant_id
,
encrypted_key
):
return
encrypted_key
.
replace
(
'encrypted_'
,
''
)
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'core.third_party.langchain.llms.baichuan_llm.BaichuanChatLLM._generate'
,
return_value
=
ChatResult
(
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
'abc'
))]))
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
def
test_is_provider_credentials_valid_or_raise_invalid
():
# raise CredentialsValidateFailedError if api_key is not in credentials
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
({})
credential
=
VALIDATE_CREDENTIAL
.
copy
()
credential
[
'api_key'
]
=
'invalid_key'
credential
[
'secret_key'
]
=
'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
credential
)
@
patch
(
'core.helper.encrypter.encrypt_token'
,
side_effect
=
encrypt_side_effect
)
def
test_encrypt_credentials
(
mock_encrypt
):
result
=
MODEL_PROVIDER_CLASS
.
encrypt_provider_credentials
(
'tenant_id'
,
VALIDATE_CREDENTIAL
.
copy
())
assert
result
[
'api_key'
]
==
f
'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
assert
result
[
'secret_key'
]
==
f
'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_credentials_custom
(
mock_decrypt
):
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'api_key'
]
=
'encrypted_'
+
encrypted_credential
[
'api_key'
]
encrypted_credential
[
'secret_key'
]
=
'encrypted_'
+
encrypted_credential
[
'secret_key'
]
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
(
encrypted_credential
),
is_valid
=
True
,
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_provider_credentials
()
assert
result
[
'api_key'
]
==
'valid_key'
assert
result
[
'secret_key'
]
==
'valid_key'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_credentials_obfuscated
(
mock_decrypt
):
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'api_key'
]
=
'encrypted_'
+
encrypted_credential
[
'api_key'
]
encrypted_credential
[
'secret_key'
]
=
'encrypted_'
+
encrypted_credential
[
'secret_key'
]
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
(
encrypted_credential
),
is_valid
=
True
,
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_provider_credentials
(
obfuscated
=
True
)
middle_token
=
result
[
'api_key'
][
6
:
-
2
]
secret_key_middle_token
=
result
[
'secret_key'
][
6
:
-
2
]
assert
len
(
middle_token
)
==
max
(
len
(
VALIDATE_CREDENTIAL
[
'api_key'
])
-
8
,
0
)
assert
len
(
secret_key_middle_token
)
==
max
(
len
(
VALIDATE_CREDENTIAL
[
'secret_key'
])
-
8
,
0
)
assert
all
(
char
==
'*'
for
char
in
middle_token
)
assert
all
(
char
==
'*'
for
char
in
secret_key_middle_token
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment