Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
7c9b585a
Unverified
Commit
7c9b585a
authored
Oct 18, 2023
by
takatost
Committed by
GitHub
Oct 18, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: support weixin ernie-bot-4 and chat mode (#1375)
parent
c039f4af
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
174 additions
and
77 deletions
+174
-77
wenxin_model.py
api/core/model_providers/models/llm/wenxin_model.py
+14
-5
wenxin_provider.py
api/core/model_providers/providers/wenxin_provider.py
+13
-5
wenxin.json
api/core/model_providers/rules/wenxin.json
+6
-0
wenxin.py
api/core/third_party/langchain/llms/wenxin.py
+135
-63
test_wenxin_model.py
api/tests/integration_tests/models/llm/test_wenxin_model.py
+2
-3
test_wenxin_provider.py
api/tests/unit_tests/model_providers/test_wenxin_provider.py
+4
-1
No files found.
api/core/model_providers/models/llm/wenxin_model.py
View file @
7c9b585a
...
@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
...
@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.wenxin
import
Wenxin
from
core.third_party.langchain.llms.wenxin
import
Wenxin
class
WenxinModel
(
BaseLLM
):
class
WenxinModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
C
OMPLETION
model_mode
:
ModelMode
=
ModelMode
.
C
HAT
def
_init_client
(
self
)
->
Any
:
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
# TODO load price_config from configs(db)
return
Wenxin
(
return
Wenxin
(
model
=
self
.
name
,
model
=
self
.
name
,
streaming
=
self
.
streaming
,
streaming
=
self
.
streaming
,
...
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
...
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return:
:return:
"""
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
generate_kwargs
=
{
'stop'
:
stop
,
'callbacks'
:
callbacks
,
'messages'
:
[
prompts
]}
if
'functions'
in
kwargs
:
generate_kwargs
[
'functions'
]
=
kwargs
[
'functions'
]
return
self
.
_client
.
generate
(
**
generate_kwargs
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
"""
...
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
...
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return:
:return:
"""
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
return
max
(
self
.
_client
.
get_num_tokens
_from_messages
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
...
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
...
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Wenxin: {str(ex)}"
)
return
LLMBadRequestError
(
f
"Wenxin: {str(ex)}"
)
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/providers/wenxin_provider.py
View file @
7c9b585a
...
@@ -2,6 +2,8 @@ import json
...
@@ -2,6 +2,8 @@ import json
from
json
import
JSONDecodeError
from
json
import
JSONDecodeError
from
typing
import
Type
from
typing
import
Type
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
,
ModelMode
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
,
ModelMode
...
@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
...
@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
return
[
{
'id'
:
'ernie-bot-4'
,
'name'
:
'ERNIE-Bot-4'
,
'mode'
:
ModelMode
.
CHAT
.
value
,
},
{
{
'id'
:
'ernie-bot'
,
'id'
:
'ernie-bot'
,
'name'
:
'ERNIE-Bot'
,
'name'
:
'ERNIE-Bot'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
},
},
{
{
'id'
:
'ernie-bot-turbo'
,
'id'
:
'ernie-bot-turbo'
,
'name'
:
'ERNIE-Bot-turbo'
,
'name'
:
'ERNIE-Bot-turbo'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
},
},
{
{
'id'
:
'bloomz-7b'
,
'id'
:
'bloomz-7b'
,
'name'
:
'BLOOMZ-7B'
,
'name'
:
'BLOOMZ-7B'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
}
}
]
]
else
:
else
:
...
@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
...
@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return:
:return:
"""
"""
model_max_tokens
=
{
model_max_tokens
=
{
'ernie-bot-4'
:
4800
,
'ernie-bot'
:
4800
,
'ernie-bot'
:
4800
,
'ernie-bot-turbo'
:
11200
,
'ernie-bot-turbo'
:
11200
,
}
}
if
model_name
in
[
'ernie-bot'
,
'ernie-bot-turbo'
]:
if
model_name
in
[
'ernie-bot
-4'
,
'ernie-bot
'
,
'ernie-bot-turbo'
]:
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
...
@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
...
@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**
credential_kwargs
**
credential_kwargs
)
)
llm
(
"ping"
)
llm
(
[
HumanMessage
(
content
=
'ping'
)]
)
except
Exception
as
ex
:
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
...
api/core/model_providers/rules/wenxin.json
View file @
7c9b585a
...
@@ -5,6 +5,12 @@
...
@@ -5,6 +5,12 @@
"system_config"
:
null
,
"system_config"
:
null
,
"model_flexibility"
:
"fixed"
,
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"price_config"
:
{
"ernie-bot-4"
:
{
"prompt"
:
"0"
,
"completion"
:
"0"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"ernie-bot"
:
{
"ernie-bot"
:
{
"prompt"
:
"0.012"
,
"prompt"
:
"0.012"
,
"completion"
:
"0.012"
,
"completion"
:
"0.012"
,
...
...
api/core/third_party/langchain/llms/wenxin.py
View file @
7c9b585a
This diff is collapsed.
Click to expand it.
api/tests/integration_tests/models/llm/test_wenxin_model.py
View file @
7c9b585a
...
@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
...
@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'ernie-bot'
)
model
=
get_mock_model
(
'ernie-bot'
)
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
messages
=
[
PromptMessage
(
type
=
MessageType
.
USER
,
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
rst
=
model
.
run
(
rst
=
model
.
run
(
messages
,
messages
stop
=
[
'
\n
Human:'
],
)
)
assert
len
(
rst
.
content
)
>
0
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_wenxin_provider.py
View file @
7c9b585a
...
@@ -2,6 +2,8 @@ import pytest
...
@@ -2,6 +2,8 @@ import pytest
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
json
import
json
from
langchain.schema
import
AIMessage
,
ChatGeneration
,
ChatResult
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.wenxin_provider
import
WenxinProvider
from
core.model_providers.providers.wenxin_provider
import
WenxinProvider
from
models.provider
import
ProviderType
,
Provider
from
models.provider
import
ProviderType
,
Provider
...
@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
...
@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'core.third_party.langchain.llms.wenxin.Wenxin._call'
,
return_value
=
"abc"
)
mocker
.
patch
(
'core.third_party.langchain.llms.wenxin.Wenxin._generate'
,
return_value
=
ChatResult
(
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
'abc'
))]))
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment