Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
4ab4bcc0
Unverified
Commit
4ab4bcc0
authored
Oct 10, 2023
by
takatost
Committed by
GitHub
Oct 10, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: support openllm embedding (#1293)
parent
1d4f019d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
172 additions
and
13 deletions
+172
-13
openllm_embedding.py
...ore/model_providers/models/embedding/openllm_embedding.py
+22
-0
xinference_embedding.py
.../model_providers/models/embedding/xinference_embedding.py
+1
-5
openllm_provider.py
api/core/model_providers/providers/openllm_provider.py
+19
-8
openllm_embedding.py
...ore/third_party/langchain/embeddings/openllm_embedding.py
+67
-0
test_openllm_embedding.py
...egration_tests/models/embedding/test_openllm_embedding.py
+63
-0
No files found.
api/core/model_providers/models/embedding/openllm_embedding.py
0 → 100644
View file @
4ab4bcc0
from
core.third_party.langchain.embeddings.openllm_embedding
import
OpenLLMEmbeddings
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.model_providers.models.embedding.base
import
BaseEmbedding
class
OpenLLMEmbedding
(
BaseEmbedding
):
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
name
:
str
):
credentials
=
model_provider
.
get_model_credentials
(
model_name
=
name
,
model_type
=
self
.
type
)
client
=
OpenLLMEmbeddings
(
server_url
=
credentials
[
'server_url'
]
)
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"OpenLLM embedding: {str(ex)}"
)
api/core/model_providers/models/embedding/xinference_embedding.py
View file @
4ab4bcc0
from
core.third_party.langchain.embeddings.xinference_embedding
import
XinferenceEmbedding
as
XinferenceEmbeddings
from
replicate.exceptions
import
ModelError
,
ReplicateError
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
...
...
@@ -21,7 +20,4 @@ class XinferenceEmbedding(BaseEmbedding):
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
if
isinstance
(
ex
,
(
ModelError
,
ReplicateError
)):
return
LLMBadRequestError
(
f
"Xinference embedding: {str(ex)}"
)
else
:
return
ex
api/core/model_providers/providers/openllm_provider.py
View file @
4ab4bcc0
...
...
@@ -2,11 +2,13 @@ import json
from
typing
import
Type
from
core.helper
import
encrypter
from
core.model_providers.models.embedding.openllm_embedding
import
OpenLLMEmbedding
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
from
core.model_providers.models.llm.openllm_model
import
OpenLLMModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.models.base
import
BaseProviderModel
from
core.third_party.langchain.embeddings.openllm_embedding
import
OpenLLMEmbeddings
from
core.third_party.langchain.llms.openllm
import
OpenLLM
from
models.provider
import
ProviderType
...
...
@@ -31,6 +33,8 @@ class OpenLLMProvider(BaseModelProvider):
"""
if
model_type
==
ModelType
.
TEXT_GENERATION
:
model_class
=
OpenLLMModel
elif
model_type
==
ModelType
.
EMBEDDINGS
:
model_class
=
OpenLLMEmbedding
else
:
raise
NotImplementedError
...
...
@@ -69,6 +73,7 @@ class OpenLLMProvider(BaseModelProvider):
'server_url'
:
credentials
[
'server_url'
]
}
if
model_type
==
ModelType
.
TEXT_GENERATION
:
llm
=
OpenLLM
(
llm_kwargs
=
{
'max_new_tokens'
:
10
...
...
@@ -77,6 +82,12 @@ class OpenLLMProvider(BaseModelProvider):
)
llm
(
"ping"
)
elif
model_type
==
ModelType
.
EMBEDDINGS
:
embedding
=
OpenLLMEmbeddings
(
**
credential_kwargs
)
embedding
.
embed_query
(
"ping"
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
...
...
api/core/third_party/langchain/embeddings/openllm_embedding.py
0 → 100644
View file @
4ab4bcc0
"""Wrapper around OpenLLM embedding models."""
from
typing
import
Any
,
List
,
Optional
import
requests
from
pydantic
import
BaseModel
,
Extra
from
langchain.embeddings.base
import
Embeddings
class
OpenLLMEmbeddings
(
BaseModel
,
Embeddings
):
"""Wrapper around OpenLLM embedding models.
"""
client
:
Any
#: :meta private:
server_url
:
Optional
[
str
]
=
None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
def
embed_documents
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Call out to OpenLLM's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings
=
[]
for
text
in
texts
:
result
=
self
.
invoke_embedding
(
text
=
text
)
embeddings
.
append
(
result
)
return
[
list
(
map
(
float
,
e
))
for
e
in
embeddings
]
def
invoke_embedding
(
self
,
text
):
params
=
[
text
]
headers
=
{
"Content-Type"
:
"application/json"
}
response
=
requests
.
post
(
f
'{self.server_url}/v1/embeddings'
,
headers
=
headers
,
json
=
params
)
if
not
response
.
ok
:
raise
ValueError
(
f
"OpenLLM HTTP {response.status_code} error: {response.text}"
)
json_response
=
response
.
json
()
return
json_response
[
0
][
"embeddings"
][
0
]
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
"""Call out to OpenLLM's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return
self
.
embed_documents
([
text
])[
0
]
api/tests/integration_tests/models/embedding/test_openllm_embedding.py
0 → 100644
View file @
4ab4bcc0
import
json
import
os
from
unittest.mock
import
patch
,
MagicMock
from
core.model_providers.models.embedding.openllm_embedding
import
OpenLLMEmbedding
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.providers.openllm_provider
import
OpenLLMProvider
from
models.provider
import
Provider
,
ProviderType
,
ProviderModel
def
get_mock_provider
():
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'openllm'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
''
,
is_valid
=
True
,
)
def
get_mock_embedding_model
(
mocker
):
model_name
=
'facebook/opt-125m'
server_url
=
os
.
environ
[
'OPENLLM_SERVER_URL'
]
model_provider
=
OpenLLMProvider
(
provider
=
get_mock_provider
())
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
provider_name
=
'openllm'
,
model_name
=
model_name
,
model_type
=
ModelType
.
EMBEDDINGS
.
value
,
encrypted_config
=
json
.
dumps
({
'server_url'
:
server_url
}),
is_valid
=
True
,
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
return
OpenLLMEmbedding
(
model_provider
=
model_provider
,
name
=
model_name
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_embed_documents
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
mocker
)
rst
=
embedding_model
.
client
.
embed_documents
([
'test'
,
'test1'
])
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
2
assert
len
(
rst
[
0
])
>
0
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_embed_query
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
mocker
)
rst
=
embedding_model
.
client
.
embed_query
(
'test'
)
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
>
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment