Unverified Commit e0a48c49 authored by takatost's avatar takatost Committed by GitHub

fix: xinference chat support (#939)

parent f53242c0
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Xinference
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
class XinferenceModel(BaseLLM):
......@@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM):
def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
client = Xinference(
**self.credentials,
client = XinferenceLLM(
server_url=self.credentials['server_url'],
model_uid=self.credentials['model_uid'],
)
client.callbacks = self.callbacks
......
import json
from typing import Type
from langchain.llms import Xinference
import requests
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
......@@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType
......@@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
:param model_type:
:return:
"""
credentials = self.get_model_credentials(model_name, model_type)
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
......@@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'],
}
llm = Xinference(
llm = XinferenceLLM(
**credential_kwargs
)
llm("ping", generate_config={'max_tokens': 10})
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
......@@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
:param credentials:
:return:
"""
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
......@@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
return credentials
@classmethod
def _get_extra_credentials(self, credentials: dict) -> dict:
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
response = requests.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get the model description, detail: {response.json()['detail']}"
)
desc = response.json()
extra_credentials = {
'model_format': desc['model_format'],
}
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
extra_credentials['model_handle_type'] = 'chatglm'
elif "generate" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'generate'
elif "chat" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'chat'
else:
raise NotImplementedError(f"Model handle type not supported.")
return extra_credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
......
from typing import Optional, List, Any, Union, Generator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import RESTfulChatglmCppChatModelHandle, \
RESTfulChatModelHandle, RESTfulGenerateModelHandle
class XinferenceLLM(Xinference):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the xinference model and return the output.
Args:
prompt: The prompt to use for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Returns:
The generated string by the model.
"""
model = self.client.get_model(self.model_uid)
if isinstance(model, RESTfulChatModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
if stop:
generate_config["stop"] = stop
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.chat(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulGenerateModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
if stop:
generate_config["stop"] = stop
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.generate(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
completion = combined_text_output
else:
completion = model.chat(prompt=prompt, generate_config=generate_config)
completion = completion["choices"][0]["text"]
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
) -> Generator[str, None, None]:
"""
Args:
prompt: The prompt to use for generation.
model: The model used for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Yields:
A string token.
"""
if isinstance(model, RESTfulGenerateModelHandle):
streaming_response = model.generate(
prompt=prompt, generate_config=generate_config
)
else:
streaming_response = model.chat(
prompt=prompt, generate_config=generate_config
)
for chunk in streaming_response:
if isinstance(chunk, dict):
choices = chunk.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
log_probs = choice.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token
......@@ -4,7 +4,6 @@ import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.replicate_provider import ReplicateProvider
from core.model_providers.providers.xinference_provider import XinferenceProvider
from models.provider import ProviderType, Provider, ProviderModel
......@@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.xinference.Xinference._call',
mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call',
return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
......@@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid():
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
def test_encrypt_model_credentials(mock_encrypt, mocker):
api_key = 'http://127.0.0.1:9997/'
mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials',
return_value={
'model_handle_type': 'generate',
'model_format': 'ggmlv3'
})
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment