Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
7c9b585a
Unverified
Commit
7c9b585a
authored
Oct 18, 2023
by
takatost
Committed by
GitHub
Oct 18, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: support weixin ernie-bot-4 and chat mode (#1375)
parent
c039f4af
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
174 additions
and
77 deletions
+174
-77
wenxin_model.py
api/core/model_providers/models/llm/wenxin_model.py
+14
-5
wenxin_provider.py
api/core/model_providers/providers/wenxin_provider.py
+13
-5
wenxin.json
api/core/model_providers/rules/wenxin.json
+6
-0
wenxin.py
api/core/third_party/langchain/llms/wenxin.py
+135
-63
test_wenxin_model.py
api/tests/integration_tests/models/llm/test_wenxin_model.py
+2
-3
test_wenxin_provider.py
api/tests/unit_tests/model_providers/test_wenxin_provider.py
+4
-1
No files found.
api/core/model_providers/models/llm/wenxin_model.py
View file @
7c9b585a
...
@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
...
@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.wenxin
import
Wenxin
from
core.third_party.langchain.llms.wenxin
import
Wenxin
class
WenxinModel
(
BaseLLM
):
class
WenxinModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
C
OMPLETION
model_mode
:
ModelMode
=
ModelMode
.
C
HAT
def
_init_client
(
self
)
->
Any
:
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
# TODO load price_config from configs(db)
return
Wenxin
(
return
Wenxin
(
model
=
self
.
name
,
model
=
self
.
name
,
streaming
=
self
.
streaming
,
streaming
=
self
.
streaming
,
...
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
...
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return:
:return:
"""
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
generate_kwargs
=
{
'stop'
:
stop
,
'callbacks'
:
callbacks
,
'messages'
:
[
prompts
]}
if
'functions'
in
kwargs
:
generate_kwargs
[
'functions'
]
=
kwargs
[
'functions'
]
return
self
.
_client
.
generate
(
**
generate_kwargs
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
"""
...
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
...
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return:
:return:
"""
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
return
max
(
self
.
_client
.
get_num_tokens
_from_messages
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
model_kwargs
)
...
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
...
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Wenxin: {str(ex)}"
)
return
LLMBadRequestError
(
f
"Wenxin: {str(ex)}"
)
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/providers/wenxin_provider.py
View file @
7c9b585a
...
@@ -2,6 +2,8 @@ import json
...
@@ -2,6 +2,8 @@ import json
from
json
import
JSONDecodeError
from
json
import
JSONDecodeError
from
typing
import
Type
from
typing
import
Type
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
,
ModelMode
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
,
ModelMode
...
@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
...
@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
return
[
return
[
{
'id'
:
'ernie-bot-4'
,
'name'
:
'ERNIE-Bot-4'
,
'mode'
:
ModelMode
.
CHAT
.
value
,
},
{
{
'id'
:
'ernie-bot'
,
'id'
:
'ernie-bot'
,
'name'
:
'ERNIE-Bot'
,
'name'
:
'ERNIE-Bot'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
},
},
{
{
'id'
:
'ernie-bot-turbo'
,
'id'
:
'ernie-bot-turbo'
,
'name'
:
'ERNIE-Bot-turbo'
,
'name'
:
'ERNIE-Bot-turbo'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
},
},
{
{
'id'
:
'bloomz-7b'
,
'id'
:
'bloomz-7b'
,
'name'
:
'BLOOMZ-7B'
,
'name'
:
'BLOOMZ-7B'
,
'mode'
:
ModelMode
.
C
OMPLETION
.
value
,
'mode'
:
ModelMode
.
C
HAT
.
value
,
}
}
]
]
else
:
else
:
...
@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
...
@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return:
:return:
"""
"""
model_max_tokens
=
{
model_max_tokens
=
{
'ernie-bot-4'
:
4800
,
'ernie-bot'
:
4800
,
'ernie-bot'
:
4800
,
'ernie-bot-turbo'
:
11200
,
'ernie-bot-turbo'
:
11200
,
}
}
if
model_name
in
[
'ernie-bot'
,
'ernie-bot-turbo'
]:
if
model_name
in
[
'ernie-bot
-4'
,
'ernie-bot
'
,
'ernie-bot-turbo'
]:
return
ModelKwargsRules
(
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.95
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0.01
,
max
=
1
,
default
=
0.8
,
precision
=
2
),
...
@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
...
@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**
credential_kwargs
**
credential_kwargs
)
)
llm
(
"ping"
)
llm
(
[
HumanMessage
(
content
=
'ping'
)]
)
except
Exception
as
ex
:
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
...
api/core/model_providers/rules/wenxin.json
View file @
7c9b585a
...
@@ -5,6 +5,12 @@
...
@@ -5,6 +5,12 @@
"system_config"
:
null
,
"system_config"
:
null
,
"model_flexibility"
:
"fixed"
,
"model_flexibility"
:
"fixed"
,
"price_config"
:
{
"price_config"
:
{
"ernie-bot-4"
:
{
"prompt"
:
"0"
,
"completion"
:
"0"
,
"unit"
:
"0.001"
,
"currency"
:
"RMB"
},
"ernie-bot"
:
{
"ernie-bot"
:
{
"prompt"
:
"0.012"
,
"prompt"
:
"0.012"
,
"completion"
:
"0.012"
,
"completion"
:
"0.012"
,
...
...
api/core/third_party/langchain/llms/wenxin.py
View file @
7c9b585a
...
@@ -8,12 +8,15 @@ from typing import (
...
@@ -8,12 +8,15 @@ from typing import (
Any
,
Any
,
Dict
,
Dict
,
List
,
List
,
Optional
,
Iterator
,
Optional
,
Iterator
,
Tuple
,
)
)
import
requests
import
requests
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.llms.utils
import
enforce_stop_tokens
from
langchain.llms.utils
import
enforce_stop_tokens
from
langchain.schema.output
import
GenerationChunk
from
langchain.schema
import
BaseMessage
,
ChatMessage
,
HumanMessage
,
AIMessage
,
SystemMessage
from
langchain.schema.messages
import
AIMessageChunk
from
langchain.schema.output
import
GenerationChunk
,
ChatResult
,
ChatGenerationChunk
,
ChatGeneration
from
pydantic
import
BaseModel
,
Extra
,
Field
,
PrivateAttr
,
root_validator
from
pydantic
import
BaseModel
,
Extra
,
Field
,
PrivateAttr
,
root_validator
from
langchain.callbacks.manager
import
(
from
langchain.callbacks.manager
import
(
...
@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
...
@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
raise
ValueError
(
f
"Wenxin Model name is required"
)
raise
ValueError
(
f
"Wenxin Model name is required"
)
model_url_map
=
{
model_url_map
=
{
'ernie-bot-4'
:
'completions_pro'
,
'ernie-bot'
:
'completions'
,
'ernie-bot'
:
'completions'
,
'ernie-bot-turbo'
:
'eb-instant'
,
'ernie-bot-turbo'
:
'eb-instant'
,
'bloomz-7b'
:
'bloomz_7b1'
,
'bloomz-7b'
:
'bloomz_7b1'
,
...
@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):
...
@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):
access_token
=
self
.
get_access_token
()
access_token
=
self
.
get_access_token
()
api_url
=
f
"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
api_url
=
f
"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
del
request
[
'model'
]
headers
=
{
"Content-Type"
:
"application/json"
}
headers
=
{
"Content-Type"
:
"application/json"
}
response
=
requests
.
post
(
api_url
,
response
=
requests
.
post
(
api_url
,
...
@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
...
@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
f
"Wenxin API {json_response['error_code']}"
f
"Wenxin API {json_response['error_code']}"
f
" error: {json_response['error_msg']}"
f
" error: {json_response['error_msg']}"
)
)
return
json_response
[
"result"
]
return
json_response
else
:
else
:
return
response
return
response
class
Wenxin
(
LLM
):
class
Wenxin
(
BaseChatModel
):
"""Wrapper around Wenxin large language models.
"""Wrapper around Wenxin large language models."""
To use, you should have the environment variable
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
@
property
or pass them as a named parameter to the constructor.
def
lc_secrets
(
self
)
->
Dict
[
str
,
str
]:
Example:
return
{
"api_key"
:
"API_KEY"
,
"secret_key"
:
"SECRET_KEY"
}
.. code-block:: python
from langchain.llms.wenxin import Wenxin
@
property
wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
def
lc_serializable
(
self
)
->
bool
:
secret_key="my-group-id")
return
True
"""
_client
:
_WenxinEndpointClient
=
PrivateAttr
()
_client
:
_WenxinEndpointClient
=
PrivateAttr
()
model
:
str
=
"ernie-bot"
model
:
str
=
"ernie-bot"
...
@@ -161,64 +165,89 @@ class Wenxin(LLM):
...
@@ -161,64 +165,89 @@ class Wenxin(LLM):
secret_key
=
self
.
secret_key
,
secret_key
=
self
.
secret_key
,
)
)
def
_call
(
def
_convert_message_to_dict
(
self
,
message
:
BaseMessage
)
->
dict
:
if
isinstance
(
message
,
ChatMessage
):
message_dict
=
{
"role"
:
message
.
role
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
HumanMessage
):
message_dict
=
{
"role"
:
"user"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
AIMessage
):
message_dict
=
{
"role"
:
"assistant"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
SystemMessage
):
message_dict
=
{
"role"
:
"system"
,
"content"
:
message
.
content
}
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
return
message_dict
def
_create_message_dicts
(
self
,
messages
:
List
[
BaseMessage
]
)
->
Tuple
[
List
[
Dict
[
str
,
Any
]],
str
]:
dict_messages
=
[]
system
=
None
for
m
in
messages
:
message
=
self
.
_convert_message_to_dict
(
m
)
if
message
[
'role'
]
==
'system'
:
if
not
system
:
system
=
message
[
'content'
]
else
:
system
+=
f
"
\n
{message['content']}"
continue
if
dict_messages
:
previous_message
=
dict_messages
[
-
1
]
if
previous_message
[
'role'
]
==
message
[
'role'
]:
dict_messages
[
-
1
][
'content'
]
+=
f
"
\n
{message['content']}"
else
:
dict_messages
.
append
(
message
)
else
:
dict_messages
.
append
(
message
)
return
dict_messages
,
system
def
_generate
(
self
,
self
,
prompt
:
str
,
messages
:
List
[
BaseMessage
]
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
str
:
)
->
ChatResult
:
r"""Call out to Wenxin's completion endpoint to chat
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = wenxin("Tell me a joke.")
"""
if
self
.
streaming
:
if
self
.
streaming
:
completion
=
""
generation
:
Optional
[
ChatGenerationChunk
]
=
None
llm_output
:
Optional
[
Dict
]
=
None
for
chunk
in
self
.
_stream
(
for
chunk
in
self
.
_stream
(
prompt
=
prompt
,
stop
=
stop
,
run_manager
=
run_manager
,
**
kwargs
messages
=
messages
,
stop
=
stop
,
run_manager
=
run_manager
,
**
kwargs
):
):
completion
+=
chunk
.
text
if
chunk
.
generation_info
is
not
None
\
and
'token_usage'
in
chunk
.
generation_info
:
llm_output
=
{
"token_usage"
:
chunk
.
generation_info
[
'token_usage'
],
"model_name"
:
self
.
model
}
if
generation
is
None
:
generation
=
chunk
else
:
generation
+=
chunk
assert
generation
is
not
None
return
ChatResult
(
generations
=
[
generation
],
llm_output
=
llm_output
)
else
:
else
:
message_dicts
,
system
=
self
.
_create_message_dicts
(
messages
)
request
=
self
.
_default_params
request
=
self
.
_default_params
request
[
"messages"
]
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
request
[
"messages"
]
=
message_dicts
if
system
:
request
[
"system"
]
=
system
request
.
update
(
kwargs
)
request
.
update
(
kwargs
)
completion
=
self
.
_client
.
post
(
request
)
response
=
self
.
_client
.
post
(
request
)
return
self
.
_create_chat_result
(
response
)
if
stop
is
not
None
:
completion
=
enforce_stop_tokens
(
completion
,
stop
)
return
completion
def
_stream
(
def
_stream
(
self
,
self
,
prompt
:
str
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
Iterator
[
GenerationChunk
]:
)
->
Iterator
[
ChatGenerationChunk
]:
r"""Call wenxin completion_stream and return the resulting generator.
message_dicts
,
system
=
self
.
_create_message_dicts
(
messages
)
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Wenxin.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = wenxin.stream(prompt)
for token in generator:
yield token
"""
request
=
self
.
_default_params
request
=
self
.
_default_params
request
[
"messages"
]
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
request
[
"messages"
]
=
message_dicts
if
system
:
request
[
"system"
]
=
system
request
.
update
(
kwargs
)
request
.
update
(
kwargs
)
for
token
in
self
.
_client
.
post
(
request
)
.
iter_lines
():
for
token
in
self
.
_client
.
post
(
request
)
.
iter_lines
():
...
@@ -228,12 +257,18 @@ class Wenxin(LLM):
...
@@ -228,12 +257,18 @@ class Wenxin(LLM):
if
token
.
startswith
(
'data:'
):
if
token
.
startswith
(
'data:'
):
completion
=
json
.
loads
(
token
[
5
:])
completion
=
json
.
loads
(
token
[
5
:])
yield
GenerationChunk
(
text
=
completion
[
'result'
])
chunk_dict
=
{
if
run_manager
:
'message'
:
AIMessageChunk
(
content
=
completion
[
'result'
]),
run_manager
.
on_llm_new_token
(
completion
[
'result'
])
}
if
completion
[
'is_end'
]:
if
completion
[
'is_end'
]:
break
token_usage
=
completion
[
'usage'
]
token_usage
[
'completion_tokens'
]
=
token_usage
[
'total_tokens'
]
-
token_usage
[
'prompt_tokens'
]
chunk_dict
[
'generation_info'
]
=
dict
({
'token_usage'
:
token_usage
})
yield
ChatGenerationChunk
(
**
chunk_dict
)
if
run_manager
:
run_manager
.
on_llm_new_token
(
completion
[
'result'
])
else
:
else
:
try
:
try
:
json_response
=
json
.
loads
(
token
)
json_response
=
json
.
loads
(
token
)
...
@@ -245,3 +280,40 @@ class Wenxin(LLM):
...
@@ -245,3 +280,40 @@ class Wenxin(LLM):
f
" error: {json_response['error_msg']}, "
f
" error: {json_response['error_msg']}, "
f
"please confirm if the model you have chosen is already paid for."
f
"please confirm if the model you have chosen is already paid for."
)
)
def
_create_chat_result
(
self
,
response
:
Dict
[
str
,
Any
])
->
ChatResult
:
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
response
[
'result'
]),
)]
token_usage
=
response
.
get
(
"usage"
)
token_usage
[
'completion_tokens'
]
=
token_usage
[
'total_tokens'
]
-
token_usage
[
'prompt_tokens'
]
llm_output
=
{
"token_usage"
:
token_usage
,
"model_name"
:
self
.
model
}
return
ChatResult
(
generations
=
generations
,
llm_output
=
llm_output
)
def
get_num_tokens_from_messages
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return
sum
([
self
.
get_num_tokens
(
m
.
content
)
for
m
in
messages
])
def
_combine_llm_outputs
(
self
,
llm_outputs
:
List
[
Optional
[
dict
]])
->
dict
:
overall_token_usage
:
dict
=
{}
for
output
in
llm_outputs
:
if
output
is
None
:
# Happens in streaming
continue
token_usage
=
output
[
"token_usage"
]
for
k
,
v
in
token_usage
.
items
():
if
k
in
overall_token_usage
:
overall_token_usage
[
k
]
+=
v
else
:
overall_token_usage
[
k
]
=
v
return
{
"token_usage"
:
overall_token_usage
,
"model_name"
:
self
.
model
}
api/tests/integration_tests/models/llm/test_wenxin_model.py
View file @
7c9b585a
...
@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
...
@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'ernie-bot'
)
model
=
get_mock_model
(
'ernie-bot'
)
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
messages
=
[
PromptMessage
(
type
=
MessageType
.
USER
,
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
rst
=
model
.
run
(
rst
=
model
.
run
(
messages
,
messages
stop
=
[
'
\n
Human:'
],
)
)
assert
len
(
rst
.
content
)
>
0
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_wenxin_provider.py
View file @
7c9b585a
...
@@ -2,6 +2,8 @@ import pytest
...
@@ -2,6 +2,8 @@ import pytest
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
json
import
json
from
langchain.schema
import
AIMessage
,
ChatGeneration
,
ChatResult
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.wenxin_provider
import
WenxinProvider
from
core.model_providers.providers.wenxin_provider
import
WenxinProvider
from
models.provider
import
ProviderType
,
Provider
from
models.provider
import
ProviderType
,
Provider
...
@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
...
@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'core.third_party.langchain.llms.wenxin.Wenxin._call'
,
return_value
=
"abc"
)
mocker
.
patch
(
'core.third_party.langchain.llms.wenxin.Wenxin._generate'
,
return_value
=
ChatResult
(
generations
=
[
ChatGeneration
(
message
=
AIMessage
(
content
=
'abc'
))]))
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment