Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
da3f10a5
Unverified
Commit
da3f10a5
authored
Aug 20, 2023
by
takatost
Committed by
GitHub
Aug 20, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: server xinference support (#927)
parent
8c991b5b
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
456 additions
and
17 deletions
+456
-17
model_provider_factory.py
api/core/model_providers/model_provider_factory.py
+3
-0
xinference_model.py
api/core/model_providers/models/llm/xinference_model.py
+69
-0
xinference_provider.py
api/core/model_providers/providers/xinference_provider.py
+141
-0
_providers.json
api/core/model_providers/rules/_providers.json
+2
-1
xinference.json
api/core/model_providers/rules/xinference.json
+7
-0
requirements.txt
api/requirements.txt
+2
-1
.env.example
api/tests/integration_tests/.env.example
+5
-1
test_anthropic_model.py
...ests/integration_tests/models/llm/test_anthropic_model.py
+3
-2
test_azure_openai_model.py
...s/integration_tests/models/llm/test_azure_openai_model.py
+2
-1
test_huggingface_hub_model.py
...ntegration_tests/models/llm/test_huggingface_hub_model.py
+4
-1
test_minimax_model.py
api/tests/integration_tests/models/llm/test_minimax_model.py
+3
-2
test_openai_model.py
api/tests/integration_tests/models/llm/test_openai_model.py
+6
-3
test_replicate_model.py
...ests/integration_tests/models/llm/test_replicate_model.py
+2
-0
test_spark_model.py
api/tests/integration_tests/models/llm/test_spark_model.py
+3
-2
test_tongyi_model.py
api/tests/integration_tests/models/llm/test_tongyi_model.py
+3
-1
test_wenxin_model.py
api/tests/integration_tests/models/llm/test_wenxin_model.py
+3
-2
test_xinference_model.py
...sts/integration_tests/models/llm/test_xinference_model.py
+74
-0
test_xinference_provider.py
...ts/unit_tests/model_providers/test_xinference_provider.py
+124
-0
No files found.
api/core/model_providers/model_provider_factory.py
View file @
da3f10a5
...
@@ -57,6 +57,9 @@ class ModelProviderFactory:
...
@@ -57,6 +57,9 @@ class ModelProviderFactory:
elif
provider_name
==
'huggingface_hub'
:
elif
provider_name
==
'huggingface_hub'
:
from
core.model_providers.providers.huggingface_hub_provider
import
HuggingfaceHubProvider
from
core.model_providers.providers.huggingface_hub_provider
import
HuggingfaceHubProvider
return
HuggingfaceHubProvider
return
HuggingfaceHubProvider
elif
provider_name
==
'xinference'
:
from
core.model_providers.providers.xinference_provider
import
XinferenceProvider
return
XinferenceProvider
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
...
api/core/model_providers/models/llm/xinference_model.py
0 → 100644
View file @
da3f10a5
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
Xinference
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
class
XinferenceModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
COMPLETION
def
_init_client
(
self
)
->
Any
:
self
.
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
client
=
Xinference
(
**
self
.
credentials
,
)
client
.
callbacks
=
self
.
callbacks
return
client
def
_run
(
self
,
messages
:
List
[
PromptMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
->
LLMResult
:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
(
[
prompts
],
stop
,
callbacks
,
generate_config
=
{
"stop"
:
stop
,
"stream"
:
self
.
streaming
,
**
self
.
provider_model_kwargs
,
}
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
pass
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Xinference: {str(ex)}"
)
@
classmethod
def
support_streaming
(
cls
):
return
True
api/core/model_providers/providers/xinference_provider.py
0 → 100644
View file @
da3f10a5
import
json
from
typing
import
Type
from
langchain.llms
import
Xinference
from
core.helper
import
encrypter
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
from
core.model_providers.models.llm.xinference_model
import
XinferenceModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.models.base
import
BaseProviderModel
from
models.provider
import
ProviderType
class
XinferenceProvider
(
BaseModelProvider
):
@
property
def
provider_name
(
self
):
"""
Returns the name of a provider.
"""
return
'xinference'
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
return
[]
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
"""
Returns the model class.
:param model_type:
:return:
"""
if
model_type
==
ModelType
.
TEXT_GENERATION
:
model_class
=
XinferenceModel
else
:
raise
NotImplementedError
return
model_class
def
get_model_parameter_rules
(
self
,
model_name
:
str
,
model_type
:
ModelType
)
->
ModelKwargsRules
:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_token'
,
min
=
10
,
max
=
4000
,
default
=
256
),
)
@
classmethod
def
is_model_credentials_valid_or_raise
(
cls
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if
'server_url'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Xinference Server URL must be provided.'
)
if
'model_uid'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Xinference Model UID must be provided.'
)
try
:
credential_kwargs
=
{
'server_url'
:
credentials
[
'server_url'
],
'model_uid'
:
credentials
[
'model_uid'
],
}
llm
=
Xinference
(
**
credential_kwargs
)
llm
(
"ping"
,
generate_config
=
{
'max_tokens'
:
10
})
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
@
classmethod
def
encrypt_model_credentials
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
)
->
dict
:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials
[
'server_url'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'server_url'
])
return
credentials
def
get_model_credentials
(
self
,
model_name
:
str
,
model_type
:
ModelType
,
obfuscated
:
bool
=
False
)
->
dict
:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if
self
.
provider
.
provider_type
!=
ProviderType
.
CUSTOM
.
value
:
raise
NotImplementedError
provider_model
=
self
.
_get_provider_model
(
model_name
,
model_type
)
if
not
provider_model
.
encrypted_config
:
return
{
'server_url'
:
None
,
'model_uid'
:
None
,
}
credentials
=
json
.
loads
(
provider_model
.
encrypted_config
)
if
credentials
[
'server_url'
]:
credentials
[
'server_url'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
credentials
[
'server_url'
]
)
if
obfuscated
:
credentials
[
'server_url'
]
=
encrypter
.
obfuscated_token
(
credentials
[
'server_url'
])
return
credentials
@
classmethod
def
is_provider_credentials_valid_or_raise
(
cls
,
credentials
:
dict
):
return
@
classmethod
def
encrypt_provider_credentials
(
cls
,
tenant_id
:
str
,
credentials
:
dict
)
->
dict
:
return
{}
def
get_provider_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
dict
:
return
{}
api/core/model_providers/rules/_providers.json
View file @
da3f10a5
...
@@ -8,5 +8,6 @@
...
@@ -8,5 +8,6 @@
"wenxin"
,
"wenxin"
,
"chatglm"
,
"chatglm"
,
"replicate"
,
"replicate"
,
"huggingface_hub"
"huggingface_hub"
,
"xinference"
]
]
\ No newline at end of file
api/core/model_providers/rules/xinference.json
0 → 100644
View file @
da3f10a5
{
"support_provider_types"
:
[
"custom"
],
"system_config"
:
null
,
"model_flexibility"
:
"configurable"
}
\ No newline at end of file
api/requirements.txt
View file @
da3f10a5
...
@@ -48,4 +48,5 @@ dashscope~=1.5.0
...
@@ -48,4 +48,5 @@ dashscope~=1.5.0
huggingface_hub~=0.16.4
huggingface_hub~=0.16.4
transformers~=4.31.0
transformers~=4.31.0
stripe~=5.5.0
stripe~=5.5.0
pandas==1.5.3
pandas==1.5.3
\ No newline at end of file
xinference==0.2.0
\ No newline at end of file
api/tests/integration_tests/.env.example
View file @
da3f10a5
...
@@ -32,4 +32,8 @@ WENXIN_API_KEY=
...
@@ -32,4 +32,8 @@ WENXIN_API_KEY=
WENXIN_SECRET_KEY=
WENXIN_SECRET_KEY=
# ChatGLM Credentials
# ChatGLM Credentials
CHATGLM_API_BASE=
CHATGLM_API_BASE=
\ No newline at end of file
# Xinference Credentials
XINFERENCE_SERVER_URL=
XINFERENCE_MODEL_UID=
\ No newline at end of file
api/tests/integration_tests/models/llm/test_anthropic_model.py
View file @
da3f10a5
...
@@ -50,7 +50,9 @@ def test_get_num_tokens(mock_decrypt):
...
@@ -50,7 +50,9 @@ def test_get_num_tokens(mock_decrypt):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
):
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'claude-2'
)
model
=
get_mock_model
(
'claude-2'
)
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: '
)]
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: '
)]
rst
=
model
.
run
(
rst
=
model
.
run
(
...
@@ -58,4 +60,3 @@ def test_run(mock_decrypt):
...
@@ -58,4 +60,3 @@ def test_run(mock_decrypt):
stop
=
[
'
\n
Human:'
],
stop
=
[
'
\n
Human:'
],
)
)
assert
len
(
rst
.
content
)
>
0
assert
len
(
rst
.
content
)
>
0
assert
rst
.
content
.
strip
()
==
'2'
api/tests/integration_tests/models/llm/test_azure_openai_model.py
View file @
da3f10a5
...
@@ -76,6 +76,8 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
...
@@ -76,6 +76,8 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
,
mocker
):
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
openai_model
=
get_mock_azure_openai_model
(
'gpt-35-turbo'
,
mocker
)
openai_model
=
get_mock_azure_openai_model
(
'gpt-35-turbo'
,
mocker
)
messages
=
[
PromptMessage
(
content
=
'Human: Are you Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)]
messages
=
[
PromptMessage
(
content
=
'Human: Are you Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)]
rst
=
openai_model
.
run
(
rst
=
openai_model
.
run
(
...
@@ -83,4 +85,3 @@ def test_run(mock_decrypt, mocker):
...
@@ -83,4 +85,3 @@ def test_run(mock_decrypt, mocker):
stop
=
[
'
\n
Human:'
],
stop
=
[
'
\n
Human:'
],
)
)
assert
len
(
rst
.
content
)
>
0
assert
len
(
rst
.
content
)
>
0
assert
rst
.
content
.
strip
()
==
'n'
api/tests/integration_tests/models/llm/test_huggingface_hub_model.py
View file @
da3f10a5
...
@@ -95,6 +95,8 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
...
@@ -95,6 +95,8 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_hosted_inference_api_run
(
mock_decrypt
,
mocker
):
def
test_hosted_inference_api_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
model
=
get_mock_model
(
'google/flan-t5-base'
,
'google/flan-t5-base'
,
'hosted_inference_api'
,
'hosted_inference_api'
,
...
@@ -111,6 +113,8 @@ def test_hosted_inference_api_run(mock_decrypt, mocker):
...
@@ -111,6 +113,8 @@ def test_hosted_inference_api_run(mock_decrypt, mocker):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_inference_endpoints_run
(
mock_decrypt
,
mocker
):
def
test_inference_endpoints_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
model
=
get_mock_model
(
''
,
''
,
'inference_endpoints'
,
'inference_endpoints'
,
...
@@ -121,4 +125,3 @@ def test_inference_endpoints_run(mock_decrypt, mocker):
...
@@ -121,4 +125,3 @@ def test_inference_endpoints_run(mock_decrypt, mocker):
[
PromptMessage
(
content
=
'Answer the following yes/no question. Can you write a whole Haiku in a single tweet?'
)],
[
PromptMessage
(
content
=
'Answer the following yes/no question. Can you write a whole Haiku in a single tweet?'
)],
)
)
assert
len
(
rst
.
content
)
>
0
assert
len
(
rst
.
content
)
>
0
assert
rst
.
content
.
strip
()
==
'no'
api/tests/integration_tests/models/llm/test_minimax_model.py
View file @
da3f10a5
...
@@ -54,11 +54,12 @@ def test_get_num_tokens(mock_decrypt):
...
@@ -54,11 +54,12 @@ def test_get_num_tokens(mock_decrypt):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
):
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'abab5.5-chat'
)
model
=
get_mock_model
(
'abab5.5-chat'
)
rst
=
model
.
run
(
rst
=
model
.
run
(
[
PromptMessage
(
content
=
'Human: Are you a real Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)],
[
PromptMessage
(
content
=
'Human: Are you a real Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)],
stop
=
[
'
\n
Human:'
],
stop
=
[
'
\n
Human:'
],
)
)
assert
len
(
rst
.
content
)
>
0
assert
len
(
rst
.
content
)
>
0
assert
rst
.
content
.
strip
()
==
'n'
api/tests/integration_tests/models/llm/test_openai_model.py
View file @
da3f10a5
...
@@ -58,7 +58,9 @@ def test_chat_get_num_tokens(mock_decrypt):
...
@@ -58,7 +58,9 @@ def test_chat_get_num_tokens(mock_decrypt):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
):
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
openai_model
=
get_mock_openai_model
(
'text-davinci-003'
)
openai_model
=
get_mock_openai_model
(
'text-davinci-003'
)
rst
=
openai_model
.
run
(
rst
=
openai_model
.
run
(
[
PromptMessage
(
content
=
'Human: Are you Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)],
[
PromptMessage
(
content
=
'Human: Are you Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)],
...
@@ -69,7 +71,9 @@ def test_run(mock_decrypt):
...
@@ -69,7 +71,9 @@ def test_run(mock_decrypt):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_chat_run
(
mock_decrypt
):
def
test_chat_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
openai_model
=
get_mock_openai_model
(
'gpt-3.5-turbo'
)
openai_model
=
get_mock_openai_model
(
'gpt-3.5-turbo'
)
messages
=
[
PromptMessage
(
content
=
'Human: Are you Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)]
messages
=
[
PromptMessage
(
content
=
'Human: Are you Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)]
rst
=
openai_model
.
run
(
rst
=
openai_model
.
run
(
...
@@ -77,4 +81,3 @@ def test_chat_run(mock_decrypt):
...
@@ -77,4 +81,3 @@ def test_chat_run(mock_decrypt):
stop
=
[
'
\n
Human:'
],
stop
=
[
'
\n
Human:'
],
)
)
assert
len
(
rst
.
content
)
>
0
assert
len
(
rst
.
content
)
>
0
assert
rst
.
content
.
strip
()
==
'n'
api/tests/integration_tests/models/llm/test_replicate_model.py
View file @
da3f10a5
...
@@ -65,6 +65,8 @@ def test_get_num_tokens(mock_decrypt, mocker):
...
@@ -65,6 +65,8 @@ def test_get_num_tokens(mock_decrypt, mocker):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
,
mocker
):
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'a16z-infra/llama-2-13b-chat'
,
'2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52'
,
mocker
)
model
=
get_mock_model
(
'a16z-infra/llama-2-13b-chat'
,
'2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52'
,
mocker
)
messages
=
[
PromptMessage
(
content
=
'Human: 1+1=?
\n
Answer: '
)]
messages
=
[
PromptMessage
(
content
=
'Human: 1+1=?
\n
Answer: '
)]
rst
=
model
.
run
(
rst
=
model
.
run
(
...
...
api/tests/integration_tests/models/llm/test_spark_model.py
View file @
da3f10a5
...
@@ -58,7 +58,9 @@ def test_get_num_tokens(mock_decrypt):
...
@@ -58,7 +58,9 @@ def test_get_num_tokens(mock_decrypt):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
):
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'spark'
)
model
=
get_mock_model
(
'spark'
)
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
rst
=
model
.
run
(
rst
=
model
.
run
(
...
@@ -66,4 +68,3 @@ def test_run(mock_decrypt):
...
@@ -66,4 +68,3 @@ def test_run(mock_decrypt):
stop
=
[
'
\n
Human:'
],
stop
=
[
'
\n
Human:'
],
)
)
assert
len
(
rst
.
content
)
>
0
assert
len
(
rst
.
content
)
>
0
assert
rst
.
content
.
strip
()
==
'2'
api/tests/integration_tests/models/llm/test_tongyi_model.py
View file @
da3f10a5
...
@@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
...
@@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
):
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'qwen-v1'
)
model
=
get_mock_model
(
'qwen-v1'
)
rst
=
model
.
run
(
rst
=
model
.
run
(
[
PromptMessage
(
content
=
'Human: Are you Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)],
[
PromptMessage
(
content
=
'Human: Are you Human? you MUST only answer `y` or `n`?
\n
Assistant: '
)],
...
...
api/tests/integration_tests/models/llm/test_wenxin_model.py
View file @
da3f10a5
...
@@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
...
@@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
):
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'ernie-bot'
)
model
=
get_mock_model
(
'ernie-bot'
)
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
messages
=
[
PromptMessage
(
content
=
'Human: 1 + 1=?
\n
Assistant: Integer answer is:'
)]
rst
=
model
.
run
(
rst
=
model
.
run
(
...
@@ -60,4 +62,3 @@ def test_run(mock_decrypt):
...
@@ -60,4 +62,3 @@ def test_run(mock_decrypt):
stop
=
[
'
\n
Human:'
],
stop
=
[
'
\n
Human:'
],
)
)
assert
len
(
rst
.
content
)
>
0
assert
len
(
rst
.
content
)
>
0
assert
rst
.
content
.
strip
()
==
'2'
api/tests/integration_tests/models/llm/test_xinference_model.py
0 → 100644
View file @
da3f10a5
import
json
import
os
from
unittest.mock
import
patch
,
MagicMock
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.model_params
import
ModelKwargs
,
ModelType
from
core.model_providers.models.llm.xinference_model
import
XinferenceModel
from
core.model_providers.providers.xinference_provider
import
XinferenceProvider
from
models.provider
import
Provider
,
ProviderType
,
ProviderModel
def
get_mock_provider
():
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'xinference'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
''
,
is_valid
=
True
,
)
def
get_mock_model
(
model_name
,
mocker
):
model_kwargs
=
ModelKwargs
(
max_tokens
=
10
,
temperature
=
0.01
)
server_url
=
os
.
environ
[
'XINFERENCE_SERVER_URL'
]
model_uid
=
os
.
environ
[
'XINFERENCE_MODEL_UID'
]
model_provider
=
XinferenceProvider
(
provider
=
get_mock_provider
())
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
provider_name
=
'xinference'
,
model_name
=
model_name
,
model_type
=
ModelType
.
TEXT_GENERATION
.
value
,
encrypted_config
=
json
.
dumps
({
'server_url'
:
server_url
,
'model_uid'
:
model_uid
}),
is_valid
=
True
,
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
return
XinferenceModel
(
model_provider
=
model_provider
,
name
=
model_name
,
model_kwargs
=
model_kwargs
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_num_tokens
(
mock_decrypt
,
mocker
):
model
=
get_mock_model
(
'llama-2-chat'
,
mocker
)
rst
=
model
.
get_num_tokens
([
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Who is your manufacturer?'
)
])
assert
rst
==
5
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'llama-2-chat'
,
mocker
)
messages
=
[
PromptMessage
(
content
=
'Human: 1+1=?
\n
Answer: '
)]
rst
=
model
.
run
(
messages
)
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_xinference_provider.py
0 → 100644
View file @
da3f10a5
import
pytest
from
unittest.mock
import
patch
,
MagicMock
import
json
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.replicate_provider
import
ReplicateProvider
from
core.model_providers.providers.xinference_provider
import
XinferenceProvider
from
models.provider
import
ProviderType
,
Provider
,
ProviderModel
PROVIDER_NAME
=
'xinference'
MODEL_PROVIDER_CLASS
=
XinferenceProvider
VALIDATE_CREDENTIAL
=
{
'model_uid'
:
'fake-model-uid'
,
'server_url'
:
'http://127.0.0.1:9997/'
}
def
encrypt_side_effect
(
tenant_id
,
encrypt_key
):
return
f
'encrypted_{encrypt_key}'
def
decrypt_side_effect
(
tenant_id
,
encrypted_key
):
return
encrypted_key
.
replace
(
'encrypted_'
,
''
)
def
test_is_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'langchain.llms.xinference.Xinference._call'
,
return_value
=
"abc"
)
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
model_name
=
'username/test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
credentials
=
VALIDATE_CREDENTIAL
.
copy
()
)
def
test_is_credentials_valid_or_raise_invalid
():
# raise CredentialsValidateFailedError if replicate_api_token is not in credentials
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
credentials
=
{}
)
# raise CredentialsValidateFailedError if replicate_api_token is invalid
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
credentials
=
{
'server_url'
:
'invalid'
})
@
patch
(
'core.helper.encrypter.encrypt_token'
,
side_effect
=
encrypt_side_effect
)
def
test_encrypt_model_credentials
(
mock_encrypt
):
api_key
=
'http://127.0.0.1:9997/'
result
=
MODEL_PROVIDER_CLASS
.
encrypt_model_credentials
(
tenant_id
=
'tenant_id'
,
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
credentials
=
VALIDATE_CREDENTIAL
.
copy
()
)
mock_encrypt
.
assert_called_with
(
'tenant_id'
,
api_key
)
assert
result
[
'server_url'
]
==
f
'encrypted_{api_key}'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_model_credentials_custom
(
mock_decrypt
,
mocker
):
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
None
,
is_valid
=
True
,
)
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'server_url'
]
=
'encrypted_'
+
encrypted_credential
[
'server_url'
]
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
encrypted_config
=
json
.
dumps
(
encrypted_credential
)
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_model_credentials
(
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
)
assert
result
[
'server_url'
]
==
'http://127.0.0.1:9997/'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_model_credentials_obfuscated
(
mock_decrypt
,
mocker
):
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
None
,
is_valid
=
True
,
)
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'server_url'
]
=
'encrypted_'
+
encrypted_credential
[
'server_url'
]
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
encrypted_config
=
json
.
dumps
(
encrypted_credential
)
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_model_credentials
(
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
obfuscated
=
True
)
middle_token
=
result
[
'server_url'
][
6
:
-
2
]
assert
len
(
middle_token
)
==
max
(
len
(
VALIDATE_CREDENTIAL
[
'server_url'
])
-
8
,
0
)
assert
all
(
char
==
'*'
for
char
in
middle_token
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment