Unverified Commit c8bd76cd authored by takatost's avatar takatost Committed by GitHub

fix: inference embedding validate (#1187)

parent ec5f585d
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
from typing import Type from typing import Type
import requests import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
...@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider): ...@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'], 'model_uid': credentials['model_uid'],
} }
llm = XinferenceLLM( if model_type == ModelType.TEXT_GENERATION:
**credential_kwargs llm = XinferenceLLM(
) **credential_kwargs
)
llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = XinferenceEmbeddings(
**credential_kwargs
)
llm("ping") embedding.embed_query("ping")
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
...@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider): ...@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
:param credentials: :param credentials:
:return: :return:
""" """
extra_credentials = cls._get_extra_credentials(credentials) if model_type == ModelType.TEXT_GENERATION:
credentials.update(extra_credentials) extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
......
...@@ -19,7 +19,7 @@ pytest~=7.3.1 ...@@ -19,7 +19,7 @@ pytest~=7.3.1
pytest-mock~=3.11.1 pytest-mock~=3.11.1
tiktoken==0.3.3 tiktoken==0.3.3
Authlib==1.2.0 Authlib==1.2.0
boto3~=1.26.123 boto3==1.28.17
tenacity==8.2.2 tenacity==8.2.2
cachetools~=5.3.0 cachetools~=5.3.0
weaviate-client~=3.21.0 weaviate-client~=3.21.0
...@@ -49,5 +49,5 @@ huggingface_hub~=0.16.4 ...@@ -49,5 +49,5 @@ huggingface_hub~=0.16.4
transformers~=4.31.0 transformers~=4.31.0
stripe~=5.5.0 stripe~=5.5.0
pandas==1.5.3 pandas==1.5.3
xinference==0.2.1 xinference==0.4.2
safetensors==0.3.2 safetensors==0.3.2
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment