Unverified Commit 1d4f019d authored by takatost's avatar takatost Committed by GitHub

feat: add baichuan llm support (#1294)

Co-authored-by: 's avatarzxhlyh <jasonapring2015@outlook.com>
parent 677aacc8
...@@ -51,6 +51,9 @@ class ModelProviderFactory: ...@@ -51,6 +51,9 @@ class ModelProviderFactory:
elif provider_name == 'chatglm': elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider return ChatGLMProvider
elif provider_name == 'baichuan':
from core.model_providers.providers.baichuan_provider import BaichuanProvider
return BaichuanProvider
elif provider_name == 'azure_openai': elif provider_name == 'azure_openai':
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
return AzureOpenAIProvider return AzureOpenAIProvider
......
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
class BaichuanModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return BaichuanChatLLM(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Baichuan: {str(ex)}")
@property
def support_streaming(self):
return True
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
from models.provider import ProviderType
class BaichuanProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'baichuan'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'baichuan2-53b',
'name': 'Baichuan2-53B',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = BaichuanModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Baichuan api_key must be provided.')
if 'secret_key' not in credentials:
raise CredentialsValidateFailedError('Baichuan secret_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
'secret_key': credentials['secret_key'],
}
llm = BaichuanChatLLM(
temperature=0,
**credential_kwargs
)
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
'secret_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
if credentials['secret_key']:
credentials['secret_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['secret_key']
)
if obfuscated:
credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)
...@@ -7,10 +7,11 @@ ...@@ -7,10 +7,11 @@
"spark", "spark",
"wenxin", "wenxin",
"zhipuai", "zhipuai",
"baichuan",
"chatglm", "chatglm",
"replicate", "replicate",
"huggingface_hub", "huggingface_hub",
"xinference", "xinference",
"openllm", "openllm",
"localai" "localai"
] ]
\ No newline at end of file
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"baichuan2-53b": {
"prompt": "0.01",
"completion": "0.01",
"unit": "0.001",
"currency": "RMB"
}
}
}
\ No newline at end of file
"""Wrapper around Baichuan APIs."""
from __future__ import annotations
import hashlib
import json
import logging
import time
from typing import (
Any,
Dict,
List,
Optional, Iterator,
)
import requests
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import Extra, root_validator, BaseModel
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class BaichuanModelAPI(BaseModel):
api_key: str
secret_key: str
base_url: str = "https://api.baichuan-ai.com/v1"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def do_request(self, model: str, messages: list[dict], parameters: dict, **kwargs: Any):
stream = 'stream' in kwargs and kwargs['stream']
url = self.base_url + ("/stream/chat" if stream else "/chat")
data = {
"model": model,
"messages": messages,
"parameters": parameters
}
json_data = json.dumps(data)
time_stamp = int(time.time())
signature = self._calculate_md5(self.secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
response = requests.post(url, data=json_data, headers=headers, stream=stream, timeout=(5, 60))
if not response.ok:
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
if not stream:
json_response = response.json()
if json_response['code'] != 0:
raise ValueError(
f"API {json_response['code']}"
f" error: {json_response['msg']}"
)
return json_response
else:
return response
def _calculate_md5(self, input_string):
md5 = hashlib.md5()
md5.update(input_string.encode('utf-8'))
encrypted = md5.hexdigest()
return encrypted
class BaichuanChatLLM(BaseChatModel):
"""Wrapper around Baichuan large language models.
To use, you should pass the api_key as a named parameter to the constructor.
Example:
.. code-block:: python
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
model = BaichuanChatLLM(model="<model_name>", api_key="my-api-key", secret_key="my-secret-key")
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any = None #: :meta private:
model: str = "Baichuan2-53B"
"""Model name to use."""
temperature: float = 0.3
"""A non-negative float that tunes the degree of randomness in generation."""
top_p: float = 0.85
"""Total probability mass of tokens to consider at each step."""
streaming: bool = False
"""Whether to stream the response or return it all at once."""
api_key: Optional[str] = None
secret_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "BAICHUAN_API_KEY"
)
values["secret_key"] = get_from_dict_or_env(
values, "secret_key", "BAICHUAN_SECRET_KEY"
)
values['client'] = BaichuanModelAPI(
api_key=values['api_key'],
secret_key=values['secret_key']
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model,
"parameters": {
"temperature": self.temperature,
"top_p": self.top_p
}
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "baichuan"
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
dict_messages = []
for m in messages:
message = self._convert_message_to_dict(m)
if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)
return dict_messages
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if generation is None:
generation = chunk
else:
generation += chunk
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts = self._create_message_dicts(messages)
params = self._default_params
params["messages"] = message_dicts
params.update(kwargs)
response = self.client.do_request(**params)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages)
params = self._default_params
params["messages"] = message_dicts
params.update(kwargs)
for event in self.client.do_request(stream=True, **params).iter_lines():
if event:
event = event.decode("utf-8")
meta = json.loads(event)
if meta['code'] != 0:
raise ValueError(
f"API {meta['code']}"
f" error: {meta['msg']}"
)
content = meta['data']['messages'][0]['content']
chunk_kwargs = {
'message': AIMessageChunk(content=content),
}
if 'usage' in meta:
token_usage = meta['usage']
overall_token_usage = {
'prompt_tokens': token_usage.get('prompt_tokens', 0),
'completion_tokens': token_usage.get('answer_tokens', 0),
'total_tokens': token_usage.get('total_tokens', 0)
}
chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage}
yield ChatGenerationChunk(**chunk_kwargs)
if run_manager:
run_manager.on_llm_new_token(content)
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
data = response["data"]
generations = []
for res in data["messages"]:
message = self._convert_dict_to_message(res)
gen = ChatGeneration(
message=message
)
generations.append(gen)
usage = response.get("usage")
token_usage = {
'prompt_tokens': usage.get('prompt_tokens', 0),
'completion_tokens': usage.get('answer_tokens', 0),
'total_tokens': usage.get('total_tokens', 0)
}
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
return {"token_usage": token_usage, "model_name": self.model}
...@@ -35,6 +35,10 @@ WENXIN_SECRET_KEY= ...@@ -35,6 +35,10 @@ WENXIN_SECRET_KEY=
# ZhipuAI Credentials # ZhipuAI Credentials
ZHIPUAI_API_KEY= ZHIPUAI_API_KEY=
# Baichuan Credentials
BAICHUAN_API_KEY=
BAICHUAN_SECRET_KEY=
# ChatGLM Credentials # ChatGLM Credentials
CHATGLM_API_BASE= CHATGLM_API_BASE=
......
import json
import os
from unittest.mock import patch
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.baichuan_provider import BaichuanProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key, valid_secret_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='baichuan',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key,
'secret_key': valid_secret_key,
}),
is_valid=True,
)
def get_mock_model(model_name: str, streaming: bool = False):
model_kwargs = ModelKwargs(
temperature=0.01,
)
valid_api_key = os.environ['BAICHUAN_API_KEY']
valid_secret_key = os.environ['BAICHUAN_SECRET_KEY']
model_provider = BaichuanProvider(provider=get_mock_provider(valid_api_key, valid_secret_key))
return BaichuanModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('baichuan2-53b')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('baichuan2-53b')
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
)
assert len(rst.content) > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_stream_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('baichuan2-53b', streaming=True)
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages
)
assert len(rst.content) > 0
import pytest
from unittest.mock import patch
import json
from langchain.schema import ChatResult, ChatGeneration, AIMessage
from core.model_providers.providers.baichuan_provider import BaichuanProvider
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'baichuan'
MODEL_PROVIDER_CLASS = BaichuanProvider
VALIDATE_CREDENTIAL = {
'api_key': 'valid_key',
'secret_key': 'valid_key',
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.baichuan_llm.BaichuanChatLLM._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
credential['secret_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_key'] == 'valid_key'
assert result['secret_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_key'][6:-2]
secret_key_middle_token = result['secret_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
assert len(secret_key_middle_token) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0)
assert all(char == '*' for char in middle_token)
assert all(char == '*' for char in secret_key_middle_token)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment