Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
7c9b585a
Unverified
Commit
7c9b585a
authored
Oct 18, 2023
by
takatost
Committed by
GitHub
Oct 18, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: support weixin ernie-bot-4 and chat mode (#1375)
parent
c039f4af
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
174 additions
and
77 deletions
+174
-77
wenxin_model.py
api/core/model_providers/models/llm/wenxin_model.py
+14
-5
wenxin_provider.py
api/core/model_providers/providers/wenxin_provider.py
+13
-5
wenxin.json
api/core/model_providers/rules/wenxin.json
+6
-0
wenxin.py
api/core/third_party/langchain/llms/wenxin.py
+135
-63
test_wenxin_model.py
api/tests/integration_tests/models/llm/test_wenxin_model.py
+2
-3
test_wenxin_provider.py
api/tests/unit_tests/model_providers/test_wenxin_provider.py
+4
-1
No files found.
api/core/model_providers/models/llm/wenxin_model.py
View file @
7c9b585a
...
...
@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.wenxin
import
Wenxin
class
WenxinModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
C
OMPLETION
model_mode
:
ModelMode
=
ModelMode
.
C
HAT
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
# TODO load price_config from configs(db)
return
Wenxin
(
model
=
self
.
name
,
streaming
=
self
.
streaming
,
...
...
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
generate_kwargs
=
{
'stop'
:
stop
,
'callbacks'
:
callbacks
,
'messages'
:
[
prompts
]}
if
'functions'
in
kwargs
:
generate_kwargs
[
'functions'
]
=
kwargs
[
'functions'
]
return
self
.
_client
.
generate
(
**
generate_kwargs
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
...
...
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
return
max
(
self
.
_client
.
get_num_tokens
_from_messages
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
...
...
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Wenxin: {str(ex)}"
)
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/providers/wenxin_provider.py
View file @
7c9b585a
...
...
@@ -2,6 +2,8 @@ import json
from
json
import
JSONDecodeError
from
typing
import
Type
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
,
ModelMode
...
...
@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
{
'id'
:
'ernie-bot-4'
,
'name'
:
'ERNIE-Bot-4'
,
'mode'
:
ModelMode
.
CHAT
.
value
,
},
{
'id'
:
'ernie-bot'
,
'name'
:
'ERNIE-Bot'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
},
{
'id'
:
'ernie-bot-turbo'
,
'name'
:
'ERNIE-Bot-turbo'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
},
{
'id'
:
'bloomz-7b'
,
'name'
:
'BLOOMZ-7B'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
}
]
else
:
...
...
@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return:
"""
model_max_tokens
=
{
'ernie-bot-4'
:
4800
,
'ernie-bot'
:
4800
,
'ernie-bot-turbo'
:
11200
,
}
if
model_name
in
[
'ernie-bot'
,
'ernie-bot-turbo'
]:
if
model_name
in
[
'ernie-bot
-4'
,
'ernie-bot
'
,
'ernie-bot-turbo'
]:
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
...
...
@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**
credential_kwargs
)
llm
(
"ping"
)
llm
(
[
HumanMessage
(
content
=
'ping'
)]
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
...
api/core/model_providers/rules/wenxin.json
View file @
7c9b585a
...
...
@@ -5,6 +5,12 @@
"system_config"
:
null
,
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"ernie-bot-4"
:
{
"prompt"
:
"0"
,
"completion"
:
"0"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"ernie-bot"
:
{
"prompt"
:
"0.012"
,
"completion"
:
"0.012"
,
...
...
api/core/third_party/langchain/llms/wenxin.py
View file @
7c9b585a
...
...
@@ -8,12 +8,15 @@ from typing import (
Any
,
Dict
,
List
,
Optional
,
Iterator
,
Optional
,
Iterator
,
Tuple
,
)
import
requests
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.llms.utils
import
enforce_stop_tokens
from
langchain.schema.output
import
GenerationChunk
from
langchain.schema
import
BaseMessage
,
ChatMessage
,
HumanMessage
,
AIMessage
,
SystemMessage
from
langchain.schema.messages
import
AIMessageChunk
from
langchain.schema.output
import
GenerationChunk
,
ChatResult
,
ChatGenerationChunk
,
ChatGeneration
from
pydantic
import
BaseModel
,
Extra
,
Field
,
PrivateAttr
,
root_validator
from
langchain.callbacks.manager
import
(
...
...
@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
raise
ValueError
(
f
"Wenxin Model name is required"
)
model_url_map
=
{
'ernie-bot-4'
:
'completions_pro'
,
'ernie-bot'
:
'completions'
,
'ernie-bot-turbo'
:
'eb-instant'
,
'bloomz-7b'
:
'bloomz_7b1'
,
...
...
@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):
access_token
=
self
.
get_access_token
()
api_url
=
f
"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
del
request
[
'model'
]
headers
=
{
"Content-Type"
:
"application/json"
}
response
=
requests
.
post
(
api_url
,
...
...
@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
f
"Wenxin API {json_response['error_code']}"
f
" error: {json_response['error_msg']}"
)
return
json_response
[
"result"
]
return
json_response
else
:
return
response
class
Wenxin
(
LLM
):
"""Wrapper around Wenxin large language models.
To use, you should have the environment variable
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
or pass them as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.wenxin import Wenxin
wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
secret_key="my-group-id")
"""
class
Wenxin
(
BaseChatModel
):
"""Wrapper around Wenxin large language models."""
@
property
def
lc_secrets
(
self
)
->
Dict
[
str
,
str
]:
return
{
"api_key"
:
"API_KEY"
,
"secret_key"
:
"SECRET_KEY"
}
@
property
def
lc_serializable
(
self
)
->
bool
:
return
True
_client
:
_WenxinEndpointClient
=
PrivateAttr
()
model
:
str
=
"ernie-bot"
...
...
@@ -161,64 +165,89 @@ class Wenxin(LLM):
secret_key
=
self
.
secret_key
,
)
def
_call
(
def
_convert_message_to_dict
(
self
,
message
:
BaseMessage
)
->
dict
:
if
isinstance
(
message
,
ChatMessage
):
message_dict
=
{
"role"
:
message
.
role
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
HumanMessage
):
message_dict
=
{
"role"
:
"user"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
AIMessage
):
message_dict
=
{
"role"
:
"assistant"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
SystemMessage
):
message_dict
=
{
"role"
:
"system"
,
"content"
:
message
.
content
}
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
return
message_dict
def
_create_message_dicts
(
self
,
messages
:
List
[
BaseMessage
]
)
->
Tuple
[
List
[
Dict
[
str
,
Any
]],
str
]:
dict_messages
=
[]
system
=
None
for
m
in
messages
:
message
=
self
.
_convert_message_to_dict
(
m
)
if
message
[
'role'
]
==
'system'
:
if
not
system
:
system
=
message
[
'content'
]
else
:
system
+=
f
"
\n
{message['content']}"
continue
if
dict_messages
:
previous_message
=
dict_messages
[
-
1
]
if
previous_message
[
'role'
]
==
message
[
'role'
]:
dict_messages
[
-
1
][
'content'
]
+=
f
"
\n
{message['content']}"
else
:
dict_messages
.
append
(
message
)
else
:
dict_messages
.
append
(
message
)
return
dict_messages
,
system
def
_generate
(
self
,
prompt
:
str
,
messages
:
List
[
BaseMessage
]
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
r"""Call out to Wenxin's completion endpoint to chat
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = wenxin("Tell me a joke.")
"""
)
->
ChatResult
:
if
self
.
streaming
:
completion
=
""
generation
:
Optional
[
ChatGenerationChunk
]
=
None
llm_output
:
Optional
[
Dict
]
=
None
for
chunk
in
self
.
_stream
(
prompt
=
prompt
,
stop
=
stop
,
run_manager
=
run_manager
,
**
kwargs
messages
=
messages
,
stop
=
stop
,
run_manager
=
run_manager
,
**
kwargs
):
completion
+=
chunk
.
text
if
chunk
.
generation_info
is
not
None
\
and
'token_usage'
in
chunk
.
generation_info
:
llm_output
=
{
"token_usage"
:
chunk
.
generation_info
[
'token_usage'
],
"model_name"
:
self
.
model
}
if
generation
is
None
:
generation
=
chunk
else
:
generation
+=
chunk
assert
generation
is
not
None
return
ChatResult
(
generations
=
[
generation
],
llm_output
=
llm_output
)
else
:
message_dicts
,
system
=
self
.
_create_message_dicts
(
messages
)
request
=
self
.
_default_params
request
[
"messages"
]
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
request
[
"messages"
]
=
message_dicts
if
system
:
request
[
"system"
]
=
system
request
.
update
(
kwargs
)
completion
=
self
.
_client
.
post
(
request
)
if
stop
is
not
None
:
completion
=
enforce_stop_tokens
(
completion
,
stop
)
return
completion
response
=
self
.
_client
.
post
(
request
)
return
self
.
_create_chat_result
(
response
)
def
_stream
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
Iterator
[
GenerationChunk
]:
r"""Call wenxin completion_stream and return the resulting generator.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Wenxin.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = wenxin.stream(prompt)
for token in generator:
yield token
"""
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
Iterator
[
ChatGenerationChunk
]:
message_dicts
,
system
=
self
.
_create_message_dicts
(
messages
)
request
=
self
.
_default_params
request
[
"messages"
]
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
request
[
"messages"
]
=
message_dicts
if
system
:
request
[
"system"
]
=
system
request
.
update
(
kwargs
)
for
token
in
self
.
_client
.
post
(
request
)
.
iter_lines
():
...
...
@@ -228,12 +257,18 @@ class Wenxin(LLM):
if
token
.
startswith
(
'data:'
):
completion
=
json
.
loads
(
token
[
5
:])
yield
GenerationChunk
(
text
=
completion
[
'result'
])
if
run_manager
:
run_manager
.
on_llm_new_token
(
completion
[
'result'
])
chunk_dict
=
{
'message'
:
AIMessageChunk
(
content
=
completion
[
'result'
]),
}
if
completion
[
'is_end'
]:
break
token_usage
=
completion
[
'usage'
]
token_usage
[
'completion_tokens'
]
=
token_usage
[
'total_tokens'
]
-
token_usage
[
'prompt_tokens'
]
chunk_dict
[
'generation_info'
]
=
dict
({
'token_usage'
:
token_usage
})
yield
ChatGenerationChunk
(
**
chunk_dict
)
if
run_manager
:
run_manager
.
on_llm_new_token
(
completion
[
'result'
])
else
:
try
:
json_response
=
json
.
loads
(
token
)
...
...
@@ -245,3 +280,40 @@ class Wenxin(LLM):
f
" error: {json_response['error_msg']}, "
f
"please confirm if the model you have chosen is already paid for."
)
def
_create_chat_result
(
self
,
response
:
Dict
[
str
,
Any
])
->
ChatResult
:
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
response
[
'result'
]),
)]
token_usage
=
response
.
get
(
"usage"
)
token_usage
[
'completion_tokens'
]
=
token_usage
[
'total_tokens'
]
-
token_usage
[
'prompt_tokens'
]
llm_output
=
{
"token_usage"
:
token_usage
,
"model_name"
:
self
.
model
}
return
ChatResult
(
generations
=
generations
,
llm_output
=
llm_output
)
def
get_num_tokens_from_messages
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return
sum
([
self
.
get_num_tokens
(
m
.
content
)
for
m
in
messages
])
def
_combine_llm_outputs
(
self
,
llm_outputs
:
List
[
Optional
[
dict
]])
->
dict
:
overall_token_usage
:
dict
=
{}
for
output
in
llm_outputs
:
if
output
is
None
:
# Happens in streaming
continue
token_usage
=
output
[
"token_usage"
]
for
k
,
v
in
token_usage
.
items
():
if
k
in
overall_token_usage
:
overall_token_usage
[
k
]
+=
v
else
:
overall_token_usage
[
k
]
=
v
return
{
"token_usage"
:
overall_token_usage
,
"model_name"
:
self
.
model
}
api/tests/integration_tests/models/llm/test_wenxin_model.py
View file @
7c9b585a
...
...
@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'ernie-bot'
)
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
messages
=
[
PromptMessage
(
type
=
MessageType
.
USER
,
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
rst
=
model
.
run
(
messages
,
stop
=
[
'
\n
Human:'
],
messages
)
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_wenxin_provider.py
View file @
7c9b585a
...
...
@@ -2,6 +2,8 @@ import pytest
from
unittest.mock
import
patch
import
json
from
langchain.schema
import
AIMessage
,
ChatGeneration
,
ChatResult
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.wenxin_provider
import
WenxinProvider
from
models.provider
import
ProviderType
,
Provider
...
...
@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'core.third_party.langchain.llms.wenxin.Wenxin._call'
,
return_value
=
"abc"
)
mocker
.
patch
(
'core.third_party.langchain.llms.wenxin.Wenxin._generate'
,
return_value
=
ChatResult
(
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
'abc'
))]))
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment