Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
1d4f019d
Unverified
Commit
1d4f019d
authored
Oct 10, 2023
by
takatost
Committed by
GitHub
Oct 10, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add baichuan llm support (#1294)
Co-authored-by:
zxhlyh
<
jasonapring2015@outlook.com
>
parent
677aacc8
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
745 additions
and
1 deletion
+745
-1
model_provider_factory.py
api/core/model_providers/model_provider_factory.py
+3
-0
baichuan_model.py
api/core/model_providers/models/llm/baichuan_model.py
+61
-0
baichuan_provider.py
api/core/model_providers/providers/baichuan_provider.py
+167
-0
_providers.json
api/core/model_providers/rules/_providers.json
+2
-1
baichuan.json
api/core/model_providers/rules/baichuan.json
+15
-0
baichuan_llm.py
api/core/third_party/langchain/llms/baichuan_llm.py
+315
-0
.env.example
api/tests/integration_tests/.env.example
+4
-0
test_baichuan_model.py
...tests/integration_tests/models/llm/test_baichuan_model.py
+81
-0
test_baichuan_provider.py
...ests/unit_tests/model_providers/test_baichuan_provider.py
+97
-0
No files found.
api/core/model_providers/model_provider_factory.py
View file @
1d4f019d
...
@@ -51,6 +51,9 @@ class ModelProviderFactory:
...
@@ -51,6 +51,9 @@ class ModelProviderFactory:
elif
provider_name
==
'chatglm'
:
elif
provider_name
==
'chatglm'
:
from
core.model_providers.providers.chatglm_provider
import
ChatGLMProvider
from
core.model_providers.providers.chatglm_provider
import
ChatGLMProvider
return
ChatGLMProvider
return
ChatGLMProvider
elif
provider_name
==
'baichuan'
:
from
core.model_providers.providers.baichuan_provider
import
BaichuanProvider
return
BaichuanProvider
elif
provider_name
==
'azure_openai'
:
elif
provider_name
==
'azure_openai'
:
from
core.model_providers.providers.azure_openai_provider
import
AzureOpenAIProvider
from
core.model_providers.providers.azure_openai_provider
import
AzureOpenAIProvider
return
AzureOpenAIProvider
return
AzureOpenAIProvider
...
...
api/core/model_providers/models/llm/baichuan_model.py
0 → 100644
View file @
1d4f019d
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.baichuan_llm
import
BaichuanChatLLM
class
BaichuanModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
CHAT
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
return
BaichuanChatLLM
(
streaming
=
self
.
streaming
,
callbacks
=
self
.
callbacks
,
**
self
.
credentials
,
**
provider_model_kwargs
)
def
_run
(
self
,
messages
:
List
[
PromptMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
->
LLMResult
:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens_from_messages
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
for
k
,
v
in
provider_model_kwargs
.
items
():
if
hasattr
(
self
.
client
,
k
):
setattr
(
self
.
client
,
k
,
v
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Baichuan: {str(ex)}"
)
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/providers/baichuan_provider.py
0 → 100644
View file @
1d4f019d
import
json
from
json
import
JSONDecodeError
from
typing
import
Type
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
from
core.model_providers.models.llm.baichuan_model
import
BaichuanModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.third_party.langchain.llms.baichuan_llm
import
BaichuanChatLLM
from
models.provider
import
ProviderType
class
BaichuanProvider
(
BaseModelProvider
):
@
property
def
provider_name
(
self
):
"""
Returns the name of a provider.
"""
return
'baichuan'
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
{
'id'
:
'baichuan2-53b'
,
'name'
:
'Baichuan2-53B'
,
}
]
else
:
return
[]
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
"""
Returns the model class.
:param model_type:
:return:
"""
if
model_type
==
ModelType
.
TEXT_GENERATION
:
model_class
=
BaichuanModel
else
:
raise
NotImplementedError
return
model_class
def
get_model_parameter_rules
(
self
,
model_name
:
str
,
model_type
:
ModelType
)
->
ModelKwargsRules
:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.3
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
0.99
,
default
=
0.85
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
enabled
=
False
),
)
@
classmethod
def
is_provider_credentials_valid_or_raise
(
cls
,
credentials
:
dict
):
"""
Validates the given credentials.
"""
if
'api_key'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Baichuan api_key must be provided.'
)
if
'secret_key'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Baichuan secret_key must be provided.'
)
try
:
credential_kwargs
=
{
'api_key'
:
credentials
[
'api_key'
],
'secret_key'
:
credentials
[
'secret_key'
],
}
llm
=
BaichuanChatLLM
(
temperature
=
0
,
**
credential_kwargs
)
llm
([
HumanMessage
(
content
=
'ping'
)])
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
@
classmethod
def
encrypt_provider_credentials
(
cls
,
tenant_id
:
str
,
credentials
:
dict
)
->
dict
:
credentials
[
'api_key'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'api_key'
])
credentials
[
'secret_key'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'secret_key'
])
return
credentials
def
get_provider_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
dict
:
if
self
.
provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
:
try
:
credentials
=
json
.
loads
(
self
.
provider
.
encrypted_config
)
except
JSONDecodeError
:
credentials
=
{
'api_key'
:
None
,
'secret_key'
:
None
,
}
if
credentials
[
'api_key'
]:
credentials
[
'api_key'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
credentials
[
'api_key'
]
)
if
obfuscated
:
credentials
[
'api_key'
]
=
encrypter
.
obfuscated_token
(
credentials
[
'api_key'
])
if
credentials
[
'secret_key'
]:
credentials
[
'secret_key'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
credentials
[
'secret_key'
]
)
if
obfuscated
:
credentials
[
'secret_key'
]
=
encrypter
.
obfuscated_token
(
credentials
[
'secret_key'
])
return
credentials
else
:
return
{}
def
should_deduct_quota
(
self
):
return
True
@
classmethod
def
is_model_credentials_valid_or_raise
(
cls
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@
classmethod
def
encrypt_model_credentials
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
)
->
dict
:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return
{}
def
get_model_credentials
(
self
,
model_name
:
str
,
model_type
:
ModelType
,
obfuscated
:
bool
=
False
)
->
dict
:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return
self
.
get_provider_credentials
(
obfuscated
)
api/core/model_providers/rules/_providers.json
View file @
1d4f019d
...
@@ -7,10 +7,11 @@
...
@@ -7,10 +7,11 @@
"spark"
,
"spark"
,
"wenxin"
,
"wenxin"
,
"zhipuai"
,
"zhipuai"
,
"baichuan"
,
"chatglm"
,
"chatglm"
,
"replicate"
,
"replicate"
,
"huggingface_hub"
,
"huggingface_hub"
,
"xinference"
,
"xinference"
,
"openllm"
,
"openllm"
,
"localai"
"localai"
]
]
\ No newline at end of file
api/core/model_providers/rules/baichuan.json
0 → 100644
View file @
1d4f019d
{
"support_provider_types"
:
[
"custom"
],
"system_config"
:
null
,
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"baichuan2-53b"
:
{
"prompt"
:
"0.01"
,
"completion"
:
"0.01"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
}
}
}
\ No newline at end of file
api/core/third_party/langchain/llms/baichuan_llm.py
0 → 100644
View file @
1d4f019d
This diff is collapsed.
Click to expand it.
api/tests/integration_tests/.env.example
View file @
1d4f019d
...
@@ -35,6 +35,10 @@ WENXIN_SECRET_KEY=
...
@@ -35,6 +35,10 @@ WENXIN_SECRET_KEY=
# ZhipuAI Credentials
# ZhipuAI Credentials
ZHIPUAI_API_KEY=
ZHIPUAI_API_KEY=
# Baichuan Credentials
BAICHUAN_API_KEY=
BAICHUAN_SECRET_KEY=
# ChatGLM Credentials
# ChatGLM Credentials
CHATGLM_API_BASE=
CHATGLM_API_BASE=
...
...
api/tests/integration_tests/models/llm/test_baichuan_model.py
0 → 100644
View file @
1d4f019d
import
json
import
os
from
unittest.mock
import
patch
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.model_params
import
ModelKwargs
from
core.model_providers.models.llm.baichuan_model
import
BaichuanModel
from
core.model_providers.providers.baichuan_provider
import
BaichuanProvider
from
models.provider
import
Provider
,
ProviderType
def
get_mock_provider
(
valid_api_key
,
valid_secret_key
):
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'baichuan'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
({
'api_key'
:
valid_api_key
,
'secret_key'
:
valid_secret_key
,
}),
is_valid
=
True
,
)
def
get_mock_model
(
model_name
:
str
,
streaming
:
bool
=
False
):
model_kwargs
=
ModelKwargs
(
temperature
=
0.01
,
)
valid_api_key
=
os
.
environ
[
'BAICHUAN_API_KEY'
]
valid_secret_key
=
os
.
environ
[
'BAICHUAN_SECRET_KEY'
]
model_provider
=
BaichuanProvider
(
provider
=
get_mock_provider
(
valid_api_key
,
valid_secret_key
))
return
BaichuanModel
(
model_provider
=
model_provider
,
name
=
model_name
,
model_kwargs
=
model_kwargs
,
streaming
=
streaming
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_get_num_tokens
(
mock_decrypt
):
model
=
get_mock_model
(
'baichuan2-53b'
)
rst
=
model
.
get_num_tokens
([
PromptMessage
(
type
=
MessageType
.
SYSTEM
,
content
=
'you are a kindness Assistant.'
),
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Who is your manufacturer?'
)
])
assert
rst
>
0
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'baichuan2-53b'
)
messages
=
[
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Are you Human? you MUST only answer `y` or `n`?'
)
]
rst
=
model
.
run
(
messages
,
)
assert
len
(
rst
.
content
)
>
0
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_stream_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'baichuan2-53b'
,
streaming
=
True
)
messages
=
[
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Are you Human? you MUST only answer `y` or `n`?'
)
]
rst
=
model
.
run
(
messages
)
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_baichuan_provider.py
0 → 100644
View file @
1d4f019d
import
pytest
from
unittest.mock
import
patch
import
json
from
langchain.schema
import
ChatResult
,
ChatGeneration
,
AIMessage
from
core.model_providers.providers.baichuan_provider
import
BaichuanProvider
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
models.provider
import
ProviderType
,
Provider
PROVIDER_NAME
=
'baichuan'
MODEL_PROVIDER_CLASS
=
BaichuanProvider
VALIDATE_CREDENTIAL
=
{
'api_key'
:
'valid_key'
,
'secret_key'
:
'valid_key'
,
}
def
encrypt_side_effect
(
tenant_id
,
encrypt_key
):
return
f
'encrypted_{encrypt_key}'
def
decrypt_side_effect
(
tenant_id
,
encrypted_key
):
return
encrypted_key
.
replace
(
'encrypted_'
,
''
)
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'core.third_party.langchain.llms.baichuan_llm.BaichuanChatLLM._generate'
,
return_value
=
ChatResult
(
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
'abc'
))]))
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
def
test_is_provider_credentials_valid_or_raise_invalid
():
# raise CredentialsValidateFailedError if api_key is not in credentials
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
({})
credential
=
VALIDATE_CREDENTIAL
.
copy
()
credential
[
'api_key'
]
=
'invalid_key'
credential
[
'secret_key'
]
=
'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
credential
)
@
patch
(
'core.helper.encrypter.encrypt_token'
,
side_effect
=
encrypt_side_effect
)
def
test_encrypt_credentials
(
mock_encrypt
):
result
=
MODEL_PROVIDER_CLASS
.
encrypt_provider_credentials
(
'tenant_id'
,
VALIDATE_CREDENTIAL
.
copy
())
assert
result
[
'api_key'
]
==
f
'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
assert
result
[
'secret_key'
]
==
f
'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_credentials_custom
(
mock_decrypt
):
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'api_key'
]
=
'encrypted_'
+
encrypted_credential
[
'api_key'
]
encrypted_credential
[
'secret_key'
]
=
'encrypted_'
+
encrypted_credential
[
'secret_key'
]
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
(
encrypted_credential
),
is_valid
=
True
,
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_provider_credentials
()
assert
result
[
'api_key'
]
==
'valid_key'
assert
result
[
'secret_key'
]
==
'valid_key'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_credentials_obfuscated
(
mock_decrypt
):
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'api_key'
]
=
'encrypted_'
+
encrypted_credential
[
'api_key'
]
encrypted_credential
[
'secret_key'
]
=
'encrypted_'
+
encrypted_credential
[
'secret_key'
]
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
(
encrypted_credential
),
is_valid
=
True
,
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_provider_credentials
(
obfuscated
=
True
)
middle_token
=
result
[
'api_key'
][
6
:
-
2
]
secret_key_middle_token
=
result
[
'secret_key'
][
6
:
-
2
]
assert
len
(
middle_token
)
==
max
(
len
(
VALIDATE_CREDENTIAL
[
'api_key'
])
-
8
,
0
)
assert
len
(
secret_key_middle_token
)
==
max
(
len
(
VALIDATE_CREDENTIAL
[
'secret_key'
])
-
8
,
0
)
assert
all
(
char
==
'*'
for
char
in
middle_token
)
assert
all
(
char
==
'*'
for
char
in
secret_key_middle_token
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment