Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
25264e78
Unverified
Commit
25264e78
authored
Aug 20, 2023
by
takatost
Committed by
GitHub
Aug 20, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add xinference embedding model support (#930)
parent
18dd0d56
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
94 additions
and
0 deletions
+94
-0
xinference_embedding.py
.../model_providers/models/embedding/xinference_embedding.py
+26
-0
xinference_provider.py
api/core/model_providers/providers/xinference_provider.py
+3
-0
test_xinference_embedding.py
...ation_tests/models/embedding/test_xinference_embedding.py
+65
-0
No files found.
api/core/model_providers/models/embedding/xinference_embedding.py
0 → 100644
View file @
25264e78
from
langchain.embeddings
import
XinferenceEmbeddings
from
replicate.exceptions
import
ModelError
,
ReplicateError
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.model_providers.models.embedding.base
import
BaseEmbedding
class
XinferenceEmbedding
(
BaseEmbedding
):
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
name
:
str
):
credentials
=
model_provider
.
get_model_credentials
(
model_name
=
name
,
model_type
=
self
.
type
)
client
=
XinferenceEmbeddings
(
**
credentials
,
)
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
if
isinstance
(
ex
,
(
ModelError
,
ReplicateError
)):
return
LLMBadRequestError
(
f
"Xinference embedding: {str(ex)}"
)
else
:
return
ex
api/core/model_providers/providers/xinference_provider.py
View file @
25264e78
...
@@ -4,6 +4,7 @@ from typing import Type
...
@@ -4,6 +4,7 @@ from typing import Type
from
langchain.llms
import
Xinference
from
langchain.llms
import
Xinference
from
core.helper
import
encrypter
from
core.helper
import
encrypter
from
core.model_providers.models.embedding.xinference_embedding
import
XinferenceEmbedding
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
from
core.model_providers.models.llm.xinference_model
import
XinferenceModel
from
core.model_providers.models.llm.xinference_model
import
XinferenceModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
...
@@ -32,6 +33,8 @@ class XinferenceProvider(BaseModelProvider):
...
@@ -32,6 +33,8 @@ class XinferenceProvider(BaseModelProvider):
"""
"""
if
model_type
==
ModelType
.
TEXT_GENERATION
:
if
model_type
==
ModelType
.
TEXT_GENERATION
:
model_class
=
XinferenceModel
model_class
=
XinferenceModel
elif
model_type
==
ModelType
.
EMBEDDINGS
:
model_class
=
XinferenceEmbedding
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
...
api/tests/integration_tests/models/embedding/test_xinference_embedding.py
0 → 100644
View file @
25264e78
import
json
import
os
from
unittest.mock
import
patch
,
MagicMock
from
core.model_providers.models.embedding.xinference_embedding
import
XinferenceEmbedding
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.providers.xinference_provider
import
XinferenceProvider
from
models.provider
import
Provider
,
ProviderType
,
ProviderModel
def
get_mock_provider
():
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'xinference'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
''
,
is_valid
=
True
,
)
def
get_mock_embedding_model
(
mocker
):
model_name
=
'vicuna-v1.3'
server_url
=
os
.
environ
[
'XINFERENCE_SERVER_URL'
]
model_uid
=
os
.
environ
[
'XINFERENCE_MODEL_UID'
]
model_provider
=
XinferenceProvider
(
provider
=
get_mock_provider
())
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
provider_name
=
'xinference'
,
model_name
=
model_name
,
model_type
=
ModelType
.
EMBEDDINGS
.
value
,
encrypted_config
=
json
.
dumps
({
'server_url'
:
server_url
,
'model_uid'
:
model_uid
}),
is_valid
=
True
,
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
return
XinferenceEmbedding
(
model_provider
=
model_provider
,
name
=
model_name
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_embed_documents
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
mocker
)
rst
=
embedding_model
.
client
.
embed_documents
([
'test'
,
'test1'
])
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
2
assert
len
(
rst
[
0
])
==
4096
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_embed_query
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
mocker
)
rst
=
embedding_model
.
client
.
embed_query
(
'test'
)
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
4096
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment