Unverified Commit 7c9b585a authored by takatost's avatar takatost Committed by GitHub

feat: support weixin ernie-bot-4 and chat mode (#1375)

parent c039f4af
...@@ -6,17 +6,16 @@ from langchain.schema import LLMResult ...@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin from core.third_party.langchain.llms.wenxin import Wenxin
class WenxinModel(BaseLLM): class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin( return Wenxin(
model=self.name, model=self.name,
streaming=self.streaming, streaming=self.streaming,
...@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM): ...@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return: :return:
""" """
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int: def get_num_tokens(self, messages: List[PromptMessage]) -> int:
""" """
...@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM): ...@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return: :return:
""" """
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
...@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM): ...@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}") return LLMBadRequestError(f"Wenxin: {str(ex)}")
@property
def support_streaming(self):
return True
...@@ -2,6 +2,8 @@ import json ...@@ -2,6 +2,8 @@ import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import Type from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
...@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider): ...@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION: if model_type == ModelType.TEXT_GENERATION:
return [ return [
{
'id': 'ernie-bot-4',
'name': 'ERNIE-Bot-4',
'mode': ModelMode.CHAT.value,
},
{ {
'id': 'ernie-bot', 'id': 'ernie-bot',
'name': 'ERNIE-Bot', 'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value, 'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'ernie-bot-turbo', 'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo', 'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value, 'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'bloomz-7b', 'id': 'bloomz-7b',
'name': 'BLOOMZ-7B', 'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value, 'mode': ModelMode.CHAT.value,
} }
] ]
else: else:
...@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider): ...@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return: :return:
""" """
model_max_tokens = { model_max_tokens = {
'ernie-bot-4': 4800,
'ernie-bot': 4800, 'ernie-bot': 4800,
'ernie-bot-turbo': 11200, 'ernie-bot-turbo': 11200,
} }
if model_name in ['ernie-bot', 'ernie-bot-turbo']: if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules( return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2), top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
...@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider): ...@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**credential_kwargs **credential_kwargs
) )
llm("ping") llm([HumanMessage(content='ping')])
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
......
...@@ -5,6 +5,12 @@ ...@@ -5,6 +5,12 @@
"system_config": null, "system_config": null,
"model_flexibility": "fixed", "model_flexibility": "fixed",
"price_config": { "price_config": {
"ernie-bot-4": {
"prompt": "0",
"completion": "0",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot": { "ernie-bot": {
"prompt": "0.012", "prompt": "0.012",
"completion": "0.012", "completion": "0.012",
......
...@@ -8,12 +8,15 @@ from typing import ( ...@@ -8,12 +8,15 @@ from typing import (
Any, Any,
Dict, Dict,
List, List,
Optional, Iterator, Optional, Iterator, Tuple,
) )
import requests import requests
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.schema.output import GenerationChunk from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
...@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel): ...@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
raise ValueError(f"Wenxin Model name is required") raise ValueError(f"Wenxin Model name is required")
model_url_map = { model_url_map = {
'ernie-bot-4': 'completions_pro',
'ernie-bot': 'completions', 'ernie-bot': 'completions',
'ernie-bot-turbo': 'eb-instant', 'ernie-bot-turbo': 'eb-instant',
'bloomz-7b': 'bloomz_7b1', 'bloomz-7b': 'bloomz_7b1',
...@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel): ...@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):
access_token = self.get_access_token() access_token = self.get_access_token()
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}" api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
del request['model']
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
response = requests.post(api_url, response = requests.post(api_url,
...@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel): ...@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
f"Wenxin API {json_response['error_code']}" f"Wenxin API {json_response['error_code']}"
f" error: {json_response['error_msg']}" f" error: {json_response['error_msg']}"
) )
return json_response["result"] return json_response
else: else:
return response return response
class Wenxin(LLM): class Wenxin(BaseChatModel):
"""Wrapper around Wenxin large language models. """Wrapper around Wenxin large language models."""
To use, you should have the environment variable
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key, @property
or pass them as a named parameter to the constructor. def lc_secrets(self) -> Dict[str, str]:
Example: return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
.. code-block:: python
from langchain.llms.wenxin import Wenxin @property
wenxin = Wenxin(model="<model_name>", api_key="my-api-key", def lc_serializable(self) -> bool:
secret_key="my-group-id") return True
"""
_client: _WenxinEndpointClient = PrivateAttr() _client: _WenxinEndpointClient = PrivateAttr()
model: str = "ernie-bot" model: str = "ernie-bot"
...@@ -161,64 +165,89 @@ class Wenxin(LLM): ...@@ -161,64 +165,89 @@ class Wenxin(LLM):
secret_key=self.secret_key, secret_key=self.secret_key,
) )
def _call( def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> Tuple[List[Dict[str, Any]], str]:
dict_messages = []
system = None
for m in messages:
message = self._convert_message_to_dict(m)
if message['role'] == 'system':
if not system:
system = message['content']
else:
system += f"\n{message['content']}"
continue
if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)
return dict_messages, system
def _generate(
self, self,
prompt: str, messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> ChatResult:
r"""Call out to Wenxin's completion endpoint to chat
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = wenxin("Tell me a joke.")
"""
if self.streaming: if self.streaming:
completion = "" generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream( for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs messages=messages, stop=stop, run_manager=run_manager, **kwargs
): ):
completion += chunk.text if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else: else:
message_dicts, system = self._create_message_dicts(messages)
request = self._default_params request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}] request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs) request.update(kwargs)
completion = self._client.post(request) response = self._client.post(request)
return self._create_chat_result(response)
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
def _stream( def _stream(
self, self,
prompt: str, messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
r"""Call wenxin completion_stream and return the resulting generator. message_dicts, system = self._create_message_dicts(messages)
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Wenxin.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = wenxin.stream(prompt)
for token in generator:
yield token
"""
request = self._default_params request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}] request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs) request.update(kwargs)
for token in self._client.post(request).iter_lines(): for token in self._client.post(request).iter_lines():
...@@ -228,12 +257,18 @@ class Wenxin(LLM): ...@@ -228,12 +257,18 @@ class Wenxin(LLM):
if token.startswith('data:'): if token.startswith('data:'):
completion = json.loads(token[5:]) completion = json.loads(token[5:])
yield GenerationChunk(text=completion['result']) chunk_dict = {
if run_manager: 'message': AIMessageChunk(content=completion['result']),
run_manager.on_llm_new_token(completion['result']) }
if completion['is_end']: if completion['is_end']:
break token_usage = completion['usage']
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
chunk_dict['generation_info'] = dict({'token_usage': token_usage})
yield ChatGenerationChunk(**chunk_dict)
if run_manager:
run_manager.on_llm_new_token(completion['result'])
else: else:
try: try:
json_response = json.loads(token) json_response = json.loads(token)
...@@ -245,3 +280,40 @@ class Wenxin(LLM): ...@@ -245,3 +280,40 @@ class Wenxin(LLM):
f" error: {json_response['error_msg']}, " f" error: {json_response['error_msg']}, "
f"please confirm if the model you have chosen is already paid for." f"please confirm if the model you have chosen is already paid for."
) )
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
generations = [ChatGeneration(
message=AIMessage(content=response['result']),
)]
token_usage = response.get("usage")
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}
...@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker): ...@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('ernie-bot') model = get_mock_model('ernie-bot')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')] messages = [PromptMessage(type=MessageType.USER, content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run( rst = model.run(
messages, messages
stop=['\nHuman:'],
) )
assert len(rst.content) > 0 assert len(rst.content) > 0
...@@ -2,6 +2,8 @@ import pytest ...@@ -2,6 +2,8 @@ import pytest
from unittest.mock import patch from unittest.mock import patch
import json import json
from langchain.schema import AIMessage, ChatGeneration, ChatResult
from core.model_providers.providers.base import CredentialsValidateFailedError from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.wenxin_provider import WenxinProvider from core.model_providers.providers.wenxin_provider import WenxinProvider
from models.provider import ProviderType, Provider from models.provider import ProviderType, Provider
...@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key): ...@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_provider_credentials_valid_or_raise_valid(mocker): def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc") mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment