Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
e0a48c49
Unverified
Commit
e0a48c49
authored
Aug 21, 2023
by
takatost
Committed by
GitHub
Aug 21, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: xinference chat support (#939)
parent
f53242c0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
204 additions
and
16 deletions
+204
-16
xinference_model.py
api/core/model_providers/models/llm/xinference_model.py
+4
-3
xinference_provider.py
api/core/model_providers/providers/xinference_provider.py
+59
-10
xinference_llm.py
api/core/third_party/langchain/llms/xinference_llm.py
+132
-0
test_xinference_provider.py
...ts/unit_tests/model_providers/test_xinference_provider.py
+9
-3
No files found.
api/core/model_providers/models/llm/xinference_model.py
View file @
e0a48c49
from
typing
import
List
,
Optional
,
Any
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
Xinference
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.xinference_llm
import
XinferenceLLM
class
XinferenceModel
(
BaseLLM
):
class
XinferenceModel
(
BaseLLM
):
...
@@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM):
...
@@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM):
def
_init_client
(
self
)
->
Any
:
def
_init_client
(
self
)
->
Any
:
self
.
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
self
.
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
client
=
Xinference
(
client
=
XinferenceLLM
(
**
self
.
credentials
,
server_url
=
self
.
credentials
[
'server_url'
],
model_uid
=
self
.
credentials
[
'model_uid'
],
)
)
client
.
callbacks
=
self
.
callbacks
client
.
callbacks
=
self
.
callbacks
...
...
api/core/model_providers/providers/xinference_provider.py
View file @
e0a48c49
import
json
import
json
from
typing
import
Type
from
typing
import
Type
from
langchain.llms
import
Xinference
import
requests
from
xinference.client
import
RESTfulGenerateModelHandle
,
RESTfulChatModelHandle
,
RESTfulChatglmCppChatModelHandle
from
core.helper
import
encrypter
from
core.helper
import
encrypter
from
core.model_providers.models.embedding.xinference_embedding
import
XinferenceEmbedding
from
core.model_providers.models.embedding.xinference_embedding
import
XinferenceEmbedding
...
@@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
...
@@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.base
import
BaseProviderModel
from
core.third_party.langchain.llms.xinference_llm
import
XinferenceLLM
from
models.provider
import
ProviderType
from
models.provider
import
ProviderType
...
@@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
...
@@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
:param model_type:
:param model_type:
:return:
:return:
"""
"""
return
ModelKwargsRules
(
credentials
=
self
.
get_model_credentials
(
model_name
,
model_type
)
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
if
credentials
[
'model_format'
]
==
"ggmlv3"
and
credentials
[
"model_handle_type"
]
==
"chatglm"
:
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
return
ModelKwargsRules
(
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
)
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
),
)
elif
credentials
[
'model_format'
]
==
"ggmlv3"
:
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
256
),
)
else
:
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0.01
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
256
),
)
@
classmethod
@
classmethod
def
is_model_credentials_valid_or_raise
(
cls
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
):
def
is_model_credentials_valid_or_raise
(
cls
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
):
...
@@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
...
@@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
'model_uid'
:
credentials
[
'model_uid'
],
'model_uid'
:
credentials
[
'model_uid'
],
}
}
llm
=
Xinference
(
llm
=
Xinference
LLM
(
**
credential_kwargs
**
credential_kwargs
)
)
llm
(
"ping"
,
generate_config
=
{
'max_tokens'
:
10
}
)
llm
(
"ping"
)
except
Exception
as
ex
:
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
@@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
...
@@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
:param credentials:
:param credentials:
:return:
:return:
"""
"""
extra_credentials
=
cls
.
_get_extra_credentials
(
credentials
)
credentials
.
update
(
extra_credentials
)
credentials
[
'server_url'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'server_url'
])
credentials
[
'server_url'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'server_url'
])
return
credentials
return
credentials
def
get_model_credentials
(
self
,
model_name
:
str
,
model_type
:
ModelType
,
obfuscated
:
bool
=
False
)
->
dict
:
def
get_model_credentials
(
self
,
model_name
:
str
,
model_type
:
ModelType
,
obfuscated
:
bool
=
False
)
->
dict
:
...
@@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
...
@@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
return
credentials
return
credentials
@
classmethod
def
_get_extra_credentials
(
self
,
credentials
:
dict
)
->
dict
:
url
=
f
"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
response
=
requests
.
get
(
url
)
if
response
.
status_code
!=
200
:
raise
RuntimeError
(
f
"Failed to get the model description, detail: {response.json()['detail']}"
)
desc
=
response
.
json
()
extra_credentials
=
{
'model_format'
:
desc
[
'model_format'
],
}
if
desc
[
"model_format"
]
==
"ggmlv3"
and
"chatglm"
in
desc
[
"model_name"
]:
extra_credentials
[
'model_handle_type'
]
=
'chatglm'
elif
"generate"
in
desc
[
"model_ability"
]:
extra_credentials
[
'model_handle_type'
]
=
'generate'
elif
"chat"
in
desc
[
"model_ability"
]:
extra_credentials
[
'model_handle_type'
]
=
'chat'
else
:
raise
NotImplementedError
(
f
"Model handle type not supported."
)
return
extra_credentials
@
classmethod
@
classmethod
def
is_provider_credentials_valid_or_raise
(
cls
,
credentials
:
dict
):
def
is_provider_credentials_valid_or_raise
(
cls
,
credentials
:
dict
):
return
return
...
...
api/core/third_party/langchain/llms/xinference_llm.py
0 → 100644
View file @
e0a48c49
from
typing
import
Optional
,
List
,
Any
,
Union
,
Generator
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.llms
import
Xinference
from
langchain.llms.utils
import
enforce_stop_tokens
from
xinference.client
import
RESTfulChatglmCppChatModelHandle
,
\
RESTfulChatModelHandle
,
RESTfulGenerateModelHandle
class
XinferenceLLM
(
Xinference
):
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
"""Call the xinference model and return the output.
Args:
prompt: The prompt to use for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Returns:
The generated string by the model.
"""
model
=
self
.
client
.
get_model
(
self
.
model_uid
)
if
isinstance
(
model
,
RESTfulChatModelHandle
):
generate_config
:
"LlamaCppGenerateConfig"
=
kwargs
.
get
(
"generate_config"
,
{})
if
stop
:
generate_config
[
"stop"
]
=
stop
if
generate_config
and
generate_config
.
get
(
"stream"
):
combined_text_output
=
""
for
token
in
self
.
_stream_generate
(
model
=
model
,
prompt
=
prompt
,
run_manager
=
run_manager
,
generate_config
=
generate_config
,
):
combined_text_output
+=
token
return
combined_text_output
else
:
completion
=
model
.
chat
(
prompt
=
prompt
,
generate_config
=
generate_config
)
return
completion
[
"choices"
][
0
][
"text"
]
elif
isinstance
(
model
,
RESTfulGenerateModelHandle
):
generate_config
:
"LlamaCppGenerateConfig"
=
kwargs
.
get
(
"generate_config"
,
{})
if
stop
:
generate_config
[
"stop"
]
=
stop
if
generate_config
and
generate_config
.
get
(
"stream"
):
combined_text_output
=
""
for
token
in
self
.
_stream_generate
(
model
=
model
,
prompt
=
prompt
,
run_manager
=
run_manager
,
generate_config
=
generate_config
,
):
combined_text_output
+=
token
return
combined_text_output
else
:
completion
=
model
.
generate
(
prompt
=
prompt
,
generate_config
=
generate_config
)
return
completion
[
"choices"
][
0
][
"text"
]
elif
isinstance
(
model
,
RESTfulChatglmCppChatModelHandle
):
generate_config
:
"ChatglmCppGenerateConfig"
=
kwargs
.
get
(
"generate_config"
,
{})
if
generate_config
and
generate_config
.
get
(
"stream"
):
combined_text_output
=
""
for
token
in
self
.
_stream_generate
(
model
=
model
,
prompt
=
prompt
,
run_manager
=
run_manager
,
generate_config
=
generate_config
,
):
combined_text_output
+=
token
completion
=
combined_text_output
else
:
completion
=
model
.
chat
(
prompt
=
prompt
,
generate_config
=
generate_config
)
completion
=
completion
[
"choices"
][
0
][
"text"
]
if
stop
is
not
None
:
completion
=
enforce_stop_tokens
(
completion
,
stop
)
return
completion
def
_stream_generate
(
self
,
model
:
Union
[
"RESTfulGenerateModelHandle"
,
"RESTfulChatModelHandle"
,
"RESTfulChatglmCppChatModelHandle"
],
prompt
:
str
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
generate_config
:
Optional
[
Union
[
"LlamaCppGenerateConfig"
,
"PytorchGenerateConfig"
,
"ChatglmCppGenerateConfig"
]]
=
None
,
)
->
Generator
[
str
,
None
,
None
]:
"""
Args:
prompt: The prompt to use for generation.
model: The model used for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Yields:
A string token.
"""
if
isinstance
(
model
,
RESTfulGenerateModelHandle
):
streaming_response
=
model
.
generate
(
prompt
=
prompt
,
generate_config
=
generate_config
)
else
:
streaming_response
=
model
.
chat
(
prompt
=
prompt
,
generate_config
=
generate_config
)
for
chunk
in
streaming_response
:
if
isinstance
(
chunk
,
dict
):
choices
=
chunk
.
get
(
"choices"
,
[])
if
choices
:
choice
=
choices
[
0
]
if
isinstance
(
choice
,
dict
):
token
=
choice
.
get
(
"text"
,
""
)
log_probs
=
choice
.
get
(
"logprobs"
)
if
run_manager
:
run_manager
.
on_llm_new_token
(
token
=
token
,
verbose
=
self
.
verbose
,
log_probs
=
log_probs
)
yield
token
api/tests/unit_tests/model_providers/test_xinference_provider.py
View file @
e0a48c49
...
@@ -4,7 +4,6 @@ import json
...
@@ -4,7 +4,6 @@ import json
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.replicate_provider
import
ReplicateProvider
from
core.model_providers.providers.xinference_provider
import
XinferenceProvider
from
core.model_providers.providers.xinference_provider
import
XinferenceProvider
from
models.provider
import
ProviderType
,
Provider
,
ProviderModel
from
models.provider
import
ProviderType
,
Provider
,
ProviderModel
...
@@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
...
@@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def
test_is_credentials_valid_or_raise_valid
(
mocker
):
def
test_is_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'
langchain.llms.xinference.Xinference
._call'
,
mocker
.
patch
(
'
core.third_party.langchain.llms.xinference_llm.XinferenceLLM
._call'
,
return_value
=
"abc"
)
return_value
=
"abc"
)
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
...
@@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid():
...
@@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid():
@
patch
(
'core.helper.encrypter.encrypt_token'
,
side_effect
=
encrypt_side_effect
)
@
patch
(
'core.helper.encrypter.encrypt_token'
,
side_effect
=
encrypt_side_effect
)
def
test_encrypt_model_credentials
(
mock_encrypt
):
def
test_encrypt_model_credentials
(
mock_encrypt
,
mocker
):
api_key
=
'http://127.0.0.1:9997/'
api_key
=
'http://127.0.0.1:9997/'
mocker
.
patch
(
'core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials'
,
return_value
=
{
'model_handle_type'
:
'generate'
,
'model_format'
:
'ggmlv3'
})
result
=
MODEL_PROVIDER_CLASS
.
encrypt_model_credentials
(
result
=
MODEL_PROVIDER_CLASS
.
encrypt_model_credentials
(
tenant_id
=
'tenant_id'
,
tenant_id
=
'tenant_id'
,
model_name
=
'test_model_name'
,
model_name
=
'test_model_name'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment