Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
0e627c92
Unverified
Commit
0e627c92
authored
Nov 24, 2023
by
takatost
Committed by
GitHub
Nov 24, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: xinference rerank model support (#1615)
parent
ea35f1dc
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
215 additions
and
6 deletions
+215
-6
model_providers.py
api/controllers/console/workspace/model_providers.py
+3
-3
xinference_reranking.py
.../model_providers/models/reranking/xinference_reranking.py
+58
-0
xinference_provider.py
api/core/model_providers/providers/xinference_provider.py
+8
-0
xinference.json
api/core/model_providers/rules/xinference.json
+2
-1
requirements.txt
api/requirements.txt
+1
-1
.env.example
api/tests/integration_tests/.env.example
+4
-1
__init__.py
api/tests/integration_tests/models/reranking/__init__.py
+0
-0
test_cohere_reranking.py
...tegration_tests/models/reranking/test_cohere_reranking.py
+61
-0
test_xinference_reranking.py
...ation_tests/models/reranking/test_xinference_reranking.py
+78
-0
No files found.
api/controllers/console/workspace/model_providers.py
View file @
0e627c92
...
...
@@ -115,7 +115,7 @@ class ModelProviderModelValidateApi(Resource):
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'model_name'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'model_type'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
choices
=
[
'text-generation'
,
'embeddings'
,
'speech2text'
],
location
=
'json'
)
choices
=
[
'text-generation'
,
'embeddings'
,
'speech2text'
,
'reranking'
],
location
=
'json'
)
parser
.
add_argument
(
'config'
,
type
=
dict
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
...
...
@@ -155,7 +155,7 @@ class ModelProviderModelUpdateApi(Resource):
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'model_name'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'model_type'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
choices
=
[
'text-generation'
,
'embeddings'
,
'speech2text'
],
location
=
'json'
)
choices
=
[
'text-generation'
,
'embeddings'
,
'speech2text'
,
'reranking'
],
location
=
'json'
)
parser
.
add_argument
(
'config'
,
type
=
dict
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
...
...
@@ -184,7 +184,7 @@ class ModelProviderModelUpdateApi(Resource):
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'model_name'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'args'
)
parser
.
add_argument
(
'model_type'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
choices
=
[
'text-generation'
,
'embeddings'
,
'speech2text'
],
location
=
'args'
)
choices
=
[
'text-generation'
,
'embeddings'
,
'speech2text'
,
'reranking'
],
location
=
'args'
)
args
=
parser
.
parse_args
()
provider_service
=
ProviderService
()
...
...
api/core/model_providers/models/reranking/xinference_reranking.py
0 → 100644
View file @
0e627c92
import
logging
from
typing
import
Optional
,
List
from
langchain.schema
import
Document
from
xinference_client.client.restful.restful_client
import
Client
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.reranking.base
import
BaseReranking
from
core.model_providers.providers.base
import
BaseModelProvider
class
XinferenceReranking
(
BaseReranking
):
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
name
:
str
):
self
.
credentials
=
model_provider
.
get_model_credentials
(
model_name
=
name
,
model_type
=
self
.
type
)
client
=
Client
(
self
.
credentials
[
'server_url'
])
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
rerank
(
self
,
query
:
str
,
documents
:
List
[
Document
],
score_threshold
:
Optional
[
float
],
top_k
:
Optional
[
int
])
->
Optional
[
List
[
Document
]]:
docs
=
[]
doc_id
=
[]
for
document
in
documents
:
if
document
.
metadata
[
'doc_id'
]
not
in
doc_id
:
doc_id
.
append
(
document
.
metadata
[
'doc_id'
])
docs
.
append
(
document
.
page_content
)
model
=
self
.
client
.
get_model
(
self
.
credentials
[
'model_uid'
])
response
=
model
.
rerank
(
query
=
query
,
documents
=
docs
,
top_n
=
top_k
)
rerank_documents
=
[]
for
idx
,
result
in
enumerate
(
response
[
'results'
]):
# format document
index
=
result
[
'index'
]
rerank_document
=
Document
(
page_content
=
result
[
'document'
],
metadata
=
{
"doc_id"
:
documents
[
index
]
.
metadata
[
'doc_id'
],
"doc_hash"
:
documents
[
index
]
.
metadata
[
'doc_hash'
],
"document_id"
:
documents
[
index
]
.
metadata
[
'document_id'
],
"dataset_id"
:
documents
[
index
]
.
metadata
[
'dataset_id'
],
'score'
:
result
[
'relevance_score'
]
}
)
# score threshold check
if
score_threshold
is
not
None
:
if
result
.
relevance_score
>=
score_threshold
:
rerank_documents
.
append
(
rerank_document
)
else
:
rerank_documents
.
append
(
rerank_document
)
return
rerank_documents
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Xinference rerank: {str(ex)}"
)
api/core/model_providers/providers/xinference_provider.py
View file @
0e627c92
...
...
@@ -2,11 +2,13 @@ import json
from
typing
import
Type
import
requests
from
xinference_client.client.restful.restful_client
import
Client
from
core.helper
import
encrypter
from
core.model_providers.models.embedding.xinference_embedding
import
XinferenceEmbedding
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
,
ModelMode
from
core.model_providers.models.llm.xinference_model
import
XinferenceModel
from
core.model_providers.models.reranking.xinference_reranking
import
XinferenceReranking
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.models.base
import
BaseProviderModel
...
...
@@ -40,6 +42,8 @@ class XinferenceProvider(BaseModelProvider):
model_class
=
XinferenceModel
elif
model_type
==
ModelType
.
EMBEDDINGS
:
model_class
=
XinferenceEmbedding
elif
model_type
==
ModelType
.
RERANKING
:
model_class
=
XinferenceReranking
else
:
raise
NotImplementedError
...
...
@@ -113,6 +117,10 @@ class XinferenceProvider(BaseModelProvider):
)
embedding
.
embed_query
(
"ping"
)
elif
model_type
==
ModelType
.
RERANKING
:
rerank_client
=
Client
(
credential_kwargs
[
'server_url'
])
model
=
rerank_client
.
get_model
(
credential_kwargs
[
'model_uid'
])
model
.
rerank
(
query
=
"ping"
,
documents
=
[
"ping"
,
"pong"
],
top_n
=
2
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
...
api/core/model_providers/rules/xinference.json
View file @
0e627c92
...
...
@@ -6,6 +6,7 @@
"model_flexibility"
:
"configurable"
,
"supported_model_types"
:
[
"text-generation"
,
"embeddings"
"embeddings"
,
"reranking"
]
}
\ No newline at end of file
api/requirements.txt
View file @
0e627c92
...
...
@@ -48,7 +48,7 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference-client~=0.
5
.4
xinference-client~=0.
6
.4
safetensors==0.3.2
zhipuai==1.0.7
werkzeug==2.3.7
...
...
api/tests/integration_tests/.env.example
View file @
0e627c92
...
...
@@ -50,4 +50,7 @@ XINFERENCE_MODEL_UID=
OPENLLM_SERVER_URL=
# LocalAI Credentials
LOCALAI_SERVER_URL=
\ No newline at end of file
LOCALAI_SERVER_URL=
# Cohere Credentials
COHERE_API_KEY=
\ No newline at end of file
api/tests/integration_tests/models/reranking/__init__.py
0 → 100644
View file @
0e627c92
api/tests/integration_tests/models/reranking/test_cohere_reranking.py
0 → 100644
View file @
0e627c92
import
json
import
os
from
unittest.mock
import
patch
from
langchain.schema
import
Document
from
core.model_providers.models.reranking.cohere_reranking
import
CohereReranking
from
core.model_providers.providers.cohere_provider
import
CohereProvider
from
models.provider
import
Provider
,
ProviderType
def
get_mock_provider
(
valid_api_key
):
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'cohere'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
({
'api_key'
:
valid_api_key
}),
is_valid
=
True
,
)
def
get_mock_model
():
valid_api_key
=
os
.
environ
[
'COHERE_API_KEY'
]
provider
=
CohereProvider
(
provider
=
get_mock_provider
(
valid_api_key
))
return
CohereReranking
(
model_provider
=
provider
,
name
=
'rerank-english-v2.0'
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
):
model
=
get_mock_model
()
docs
=
[]
docs
.
append
(
Document
(
page_content
=
'bye'
,
metadata
=
{
"doc_id"
:
'a'
,
"doc_hash"
:
'doc_hash'
,
"document_id"
:
'document_id'
,
"dataset_id"
:
'dataset_id'
,
}
))
docs
.
append
(
Document
(
page_content
=
'hello'
,
metadata
=
{
"doc_id"
:
'b'
,
"doc_hash"
:
'doc_hash'
,
"document_id"
:
'document_id'
,
"dataset_id"
:
'dataset_id'
,
}
))
rst
=
model
.
rerank
(
'hello'
,
docs
,
None
,
2
)
assert
rst
[
0
]
.
page_content
==
'hello'
api/tests/integration_tests/models/reranking/test_xinference_reranking.py
0 → 100644
View file @
0e627c92
import
json
import
os
from
unittest.mock
import
patch
,
MagicMock
from
langchain.schema
import
Document
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.models.reranking.xinference_reranking
import
XinferenceReranking
from
core.model_providers.providers.xinference_provider
import
XinferenceProvider
from
models.provider
import
Provider
,
ProviderType
,
ProviderModel
def
get_mock_provider
(
valid_server_url
,
valid_model_uid
):
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'xinference'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
json
.
dumps
({
'server_url'
:
valid_server_url
,
'model_uid'
:
valid_model_uid
}),
is_valid
=
True
,
)
def
get_mock_model
(
mocker
):
valid_server_url
=
os
.
environ
[
'XINFERENCE_SERVER_URL'
]
valid_model_uid
=
os
.
environ
[
'XINFERENCE_MODEL_UID'
]
model_name
=
'bge-reranker-base'
provider
=
XinferenceProvider
(
provider
=
get_mock_provider
(
valid_server_url
,
valid_model_uid
))
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
provider_name
=
'xinference'
,
model_name
=
model_name
,
model_type
=
ModelType
.
RERANKING
.
value
,
encrypted_config
=
json
.
dumps
({
'server_url'
:
valid_server_url
,
'model_uid'
:
valid_model_uid
}),
is_valid
=
True
,
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
return
XinferenceReranking
(
model_provider
=
provider
,
name
=
model_name
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
,
mocker
):
model
=
get_mock_model
(
mocker
)
docs
=
[]
docs
.
append
(
Document
(
page_content
=
'bye'
,
metadata
=
{
"doc_id"
:
'a'
,
"doc_hash"
:
'doc_hash'
,
"document_id"
:
'document_id'
,
"dataset_id"
:
'dataset_id'
,
}
))
docs
.
append
(
Document
(
page_content
=
'hello'
,
metadata
=
{
"doc_id"
:
'b'
,
"doc_hash"
:
'doc_hash'
,
"document_id"
:
'document_id'
,
"dataset_id"
:
'dataset_id'
,
}
))
rst
=
model
.
rerank
(
'hello'
,
docs
,
None
,
2
)
assert
rst
[
0
]
.
page_content
==
'hello'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment