Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
7c9b585a
Unverified
Commit
7c9b585a
authored
Oct 18, 2023
by
takatost
Committed by
GitHub
Oct 18, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: support weixin ernie-bot-4 and chat mode (#1375)
parent
c039f4af
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
174 additions
and
77 deletions
+174
-77
wenxin_model.py
api/core/model_providers/models/llm/wenxin_model.py
+14
-5
wenxin_provider.py
api/core/model_providers/providers/wenxin_provider.py
+13
-5
wenxin.json
api/core/model_providers/rules/wenxin.json
+6
-0
wenxin.py
api/core/third_party/langchain/llms/wenxin.py
+135
-63
test_wenxin_model.py
api/tests/integration_tests/models/llm/test_wenxin_model.py
+2
-3
test_wenxin_provider.py
api/tests/unit_tests/model_providers/test_wenxin_provider.py
+4
-1
No files found.
api/core/model_providers/models/llm/wenxin_model.py
View file @
7c9b585a
...
...
@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.wenxin
import
Wenxin
class
WenxinModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
C
OMPLETION
model_mode
:
ModelMode
=
ModelMode
.
C
HAT
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
# TODO load price_config from configs(db)
return
Wenxin
(
model
=
self
.
name
,
streaming
=
self
.
streaming
,
...
...
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
generate_kwargs
=
{
'stop'
:
stop
,
'callbacks'
:
callbacks
,
'messages'
:
[
prompts
]}
if
'functions'
in
kwargs
:
generate_kwargs
[
'functions'
]
=
kwargs
[
'functions'
]
return
self
.
_client
.
generate
(
**
generate_kwargs
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
...
...
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
return
max
(
self
.
_client
.
get_num_tokens
_from_messages
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
...
...
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Wenxin: {str(ex)}"
)
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/providers/wenxin_provider.py
View file @
7c9b585a
...
...
@@ -2,6 +2,8 @@ import json
from
json
import
JSONDecodeError
from
typing
import
Type
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
,
ModelMode
...
...
@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
{
'id'
:
'ernie-bot-4'
,
'name'
:
'ERNIE-Bot-4'
,
'mode'
:
ModelMode
.
CHAT
.
value
,
},
{
'id'
:
'ernie-bot'
,
'name'
:
'ERNIE-Bot'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
},
{
'id'
:
'ernie-bot-turbo'
,
'name'
:
'ERNIE-Bot-turbo'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
},
{
'id'
:
'bloomz-7b'
,
'name'
:
'BLOOMZ-7B'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
}
]
else
:
...
...
@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return:
"""
model_max_tokens
=
{
'ernie-bot-4'
:
4800
,
'ernie-bot'
:
4800
,
'ernie-bot-turbo'
:
11200
,
}
if
model_name
in
[
'ernie-bot'
,
'ernie-bot-turbo'
]:
if
model_name
in
[
'ernie-bot
-4'
,
'ernie-bot
'
,
'ernie-bot-turbo'
]:
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
...
...
@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**
credential_kwargs
)
llm
(
"ping"
)
llm
(
[
HumanMessage
(
content
=
'ping'
)]
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
...
api/core/model_providers/rules/wenxin.json
View file @
7c9b585a
...
...
@@ -5,6 +5,12 @@
"system_config"
:
null
,
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"ernie-bot-4"
:
{
"prompt"
:
"0"
,
"completion"
:
"0"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"ernie-bot"
:
{
"prompt"
:
"0.012"
,
"completion"
:
"0.012"
,
...
...
api/core/third_party/langchain/llms/wenxin.py
View file @
7c9b585a
This diff is collapsed.
Click to expand it.
api/tests/integration_tests/models/llm/test_wenxin_model.py
View file @
7c9b585a
...
...
@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'ernie-bot'
)
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
messages
=
[
PromptMessage
(
type
=
MessageType
.
USER
,
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
rst
=
model
.
run
(
messages
,
stop
=
[
'
\n
Human:'
],
messages
)
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_wenxin_provider.py
View file @
7c9b585a
...
...
@@ -2,6 +2,8 @@ import pytest
from
unittest.mock
import
patch
import
json
from
langchain.schema
import
AIMessage
,
ChatGeneration
,
ChatResult
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.wenxin_provider
import
WenxinProvider
from
models.provider
import
ProviderType
,
Provider
...
...
@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'core.third_party.langchain.llms.wenxin.Wenxin._call'
,
return_value
=
"abc"
)
mocker
.
patch
(
'core.third_party.langchain.llms.wenxin.Wenxin._generate'
,
return_value
=
ChatResult
(
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
'abc'
))]))
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment