Unverified Commit e0a48c49 authored by takatost's avatar takatost Committed by GitHub

fix: xinference chat support (#939)

parent f53242c0
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.llms import Xinference
from langchain.schema import LLMResult from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
class XinferenceModel(BaseLLM): class XinferenceModel(BaseLLM):
...@@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM): ...@@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM):
def _init_client(self) -> Any: def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
client = Xinference( client = XinferenceLLM(
**self.credentials, server_url=self.credentials['server_url'],
model_uid=self.credentials['model_uid'],
) )
client.callbacks = self.callbacks client.callbacks = self.callbacks
......
import json import json
from typing import Type from typing import Type
from langchain.llms import Xinference import requests
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
...@@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel ...@@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType from models.provider import ProviderType
...@@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider): ...@@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
:param model_type: :param model_type:
:return: :return:
""" """
return ModelKwargsRules( credentials = self.get_model_credentials(model_name, model_type)
temperature=KwargRule[float](min=0, max=2, default=1), if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
top_p=KwargRule[float](min=0, max=1, default=0.7), return ModelKwargsRules(
presence_penalty=KwargRule[float](min=-2, max=2, default=0), temperature=KwargRule[float](min=0.01, max=2, default=1),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0), top_p=KwargRule[float](min=0, max=1, default=0.7),
max_tokens=KwargRule[int](min=10, max=4000, default=256), presence_penalty=KwargRule[float](enabled=False),
) frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
)
@classmethod @classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
...@@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider): ...@@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'], 'model_uid': credentials['model_uid'],
} }
llm = Xinference( llm = XinferenceLLM(
**credential_kwargs **credential_kwargs
) )
llm("ping", generate_config={'max_tokens': 10}) llm("ping")
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
...@@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider): ...@@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
:param credentials: :param credentials:
:return: :return:
""" """
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
...@@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider): ...@@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
return credentials return credentials
@classmethod
def _get_extra_credentials(self, credentials: dict) -> dict:
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
response = requests.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get the model description, detail: {response.json()['detail']}"
)
desc = response.json()
extra_credentials = {
'model_format': desc['model_format'],
}
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
extra_credentials['model_handle_type'] = 'chatglm'
elif "generate" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'generate'
elif "chat" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'chat'
else:
raise NotImplementedError(f"Model handle type not supported.")
return extra_credentials
@classmethod @classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict): def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return return
......
from typing import Optional, List, Any, Union, Generator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import RESTfulChatglmCppChatModelHandle, \
RESTfulChatModelHandle, RESTfulGenerateModelHandle
class XinferenceLLM(Xinference):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the xinference model and return the output.
Args:
prompt: The prompt to use for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Returns:
The generated string by the model.
"""
model = self.client.get_model(self.model_uid)
if isinstance(model, RESTfulChatModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
if stop:
generate_config["stop"] = stop
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.chat(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulGenerateModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
if stop:
generate_config["stop"] = stop
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.generate(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
completion = combined_text_output
else:
completion = model.chat(prompt=prompt, generate_config=generate_config)
completion = completion["choices"][0]["text"]
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
) -> Generator[str, None, None]:
"""
Args:
prompt: The prompt to use for generation.
model: The model used for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Yields:
A string token.
"""
if isinstance(model, RESTfulGenerateModelHandle):
streaming_response = model.generate(
prompt=prompt, generate_config=generate_config
)
else:
streaming_response = model.chat(
prompt=prompt, generate_config=generate_config
)
for chunk in streaming_response:
if isinstance(chunk, dict):
choices = chunk.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
log_probs = choice.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token
...@@ -4,7 +4,6 @@ import json ...@@ -4,7 +4,6 @@ import json
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.replicate_provider import ReplicateProvider
from core.model_providers.providers.xinference_provider import XinferenceProvider from core.model_providers.providers.xinference_provider import XinferenceProvider
from models.provider import ProviderType, Provider, ProviderModel from models.provider import ProviderType, Provider, ProviderModel
...@@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key): ...@@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_credentials_valid_or_raise_valid(mocker): def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.xinference.Xinference._call', mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call',
return_value="abc") return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
...@@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid(): ...@@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid():
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt): def test_encrypt_model_credentials(mock_encrypt, mocker):
api_key = 'http://127.0.0.1:9997/' api_key = 'http://127.0.0.1:9997/'
mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials',
return_value={
'model_handle_type': 'generate',
'model_format': 'ggmlv3'
})
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id', tenant_id='tenant_id',
model_name='test_model_name', model_name='test_model_name',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment