Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
3efaa713
Unverified
Commit
3efaa713
authored
Oct 13, 2023
by
takatost
Committed by
GitHub
Oct 13, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: use xinference client instead of xinference (#1339)
parent
9822f687
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
14 deletions
+83
-14
xinference_embedding.py
.../model_providers/models/embedding/xinference_embedding.py
+1
-2
xinference_provider.py
api/core/model_providers/providers/xinference_provider.py
+1
-1
xinference_embedding.py
.../third_party/langchain/embeddings/xinference_embedding.py
+38
-5
xinference_llm.py
api/core/third_party/langchain/llms/xinference_llm.py
+42
-5
requirements.txt
api/requirements.txt
+1
-1
No files found.
api/core/model_providers/models/embedding/xinference_embedding.py
View file @
3efaa713
from
core.third_party.langchain.embeddings.xinference_embedding
import
XinferenceEmbedding
as
XinferenceEmbeddings
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.model_providers.models.embedding.base
import
BaseEmbedding
from
core.third_party.langchain.embeddings.xinference_embedding
import
XinferenceEmbeddings
class
XinferenceEmbedding
(
BaseEmbedding
):
...
...
api/core/model_providers/providers/xinference_provider.py
View file @
3efaa713
...
...
@@ -2,7 +2,6 @@ import json
from
typing
import
Type
import
requests
from
langchain.embeddings
import
XinferenceEmbeddings
from
core.helper
import
encrypter
from
core.model_providers.models.embedding.xinference_embedding
import
XinferenceEmbedding
...
...
@@ -11,6 +10,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.models.base
import
BaseProviderModel
from
core.third_party.langchain.embeddings.xinference_embedding
import
XinferenceEmbeddings
from
core.third_party.langchain.llms.xinference_llm
import
XinferenceLLM
from
models.provider
import
ProviderType
...
...
api/core/third_party/langchain/embeddings/xinference_embedding.py
View file @
3efaa713
from
typing
import
List
from
typing
import
List
,
Optional
,
Any
import
numpy
as
np
from
langchain.embeddings
import
XinferenceEmbeddings
from
langchain.embeddings.base
import
Embeddings
from
xinference_client.client.restful.restful_client
import
Client
class
XinferenceEmbedding
(
XinferenceEmbeddings
):
class
XinferenceEmbeddings
(
Embeddings
):
client
:
Any
server_url
:
Optional
[
str
]
"""URL of the xinference server"""
model_uid
:
Optional
[
str
]
"""UID of the launched model"""
def
__init__
(
self
,
server_url
:
Optional
[
str
]
=
None
,
model_uid
:
Optional
[
str
]
=
None
):
super
()
.
__init__
()
if
server_url
is
None
:
raise
ValueError
(
"Please provide server URL"
)
if
model_uid
is
None
:
raise
ValueError
(
"Please provide the model UID"
)
self
.
server_url
=
server_url
self
.
model_uid
=
model_uid
self
.
client
=
Client
(
server_url
)
def
embed_documents
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
vectors
=
super
()
.
embed_documents
(
texts
)
model
=
self
.
client
.
get_model
(
self
.
model_uid
)
embeddings
=
[
model
.
create_embedding
(
text
)[
"data"
][
0
][
"embedding"
]
for
text
in
texts
]
vectors
=
[
list
(
map
(
float
,
e
))
for
e
in
embeddings
]
normalized_vectors
=
[(
vector
/
np
.
linalg
.
norm
(
vector
))
.
tolist
()
for
vector
in
vectors
]
return
normalized_vectors
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
vector
=
super
()
.
embed_query
(
text
)
model
=
self
.
client
.
get_model
(
self
.
model_uid
)
embedding_res
=
model
.
create_embedding
(
text
)
embedding
=
embedding_res
[
"data"
][
0
][
"embedding"
]
vector
=
list
(
map
(
float
,
embedding
))
normalized_vector
=
(
vector
/
np
.
linalg
.
norm
(
vector
))
.
tolist
()
return
normalized_vector
api/core/third_party/langchain/llms/xinference_llm.py
View file @
3efaa713
from
typing
import
Optional
,
List
,
Any
,
Union
,
Generator
from
typing
import
Optional
,
List
,
Any
,
Union
,
Generator
,
Mapping
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.llms
import
Xinference
from
langchain.llms
.base
import
LLM
from
langchain.llms.utils
import
enforce_stop_tokens
from
xinference
.
client
import
(
from
xinference
_client.client.restful.restful_
client
import
(
RESTfulChatglmCppChatModelHandle
,
RESTfulChatModelHandle
,
RESTfulGenerateModelHandle
,
RESTfulGenerateModelHandle
,
Client
,
)
class
XinferenceLLM
(
Xinference
):
class
XinferenceLLM
(
LLM
):
client
:
Any
server_url
:
Optional
[
str
]
"""URL of the xinference server"""
model_uid
:
Optional
[
str
]
"""UID of the launched model"""
def
__init__
(
self
,
server_url
:
Optional
[
str
]
=
None
,
model_uid
:
Optional
[
str
]
=
None
):
super
()
.
__init__
(
**
{
"server_url"
:
server_url
,
"model_uid"
:
model_uid
,
}
)
if
self
.
server_url
is
None
:
raise
ValueError
(
"Please provide server URL"
)
if
self
.
model_uid
is
None
:
raise
ValueError
(
"Please provide the model UID"
)
self
.
client
=
Client
(
server_url
)
@
property
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
return
"xinference"
@
property
def
_identifying_params
(
self
)
->
Mapping
[
str
,
Any
]:
"""Get the identifying parameters."""
return
{
**
{
"server_url"
:
self
.
server_url
},
**
{
"model_uid"
:
self
.
model_uid
},
}
def
_call
(
self
,
prompt
:
str
,
...
...
api/requirements.txt
View file @
3efaa713
...
...
@@ -49,7 +49,7 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference
==0.5
.2
xinference
-client~=0.1
.2
safetensors==0.3.2
zhipuai==1.0.7
werkzeug==2.3.7
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment