Unverified Commit 7c9b585a authored by takatost's avatar takatost Committed by GitHub

feat: support weixin ernie-bot-4 and chat mode (#1375)

parent c039f4af
......@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin
class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
model=self.name,
streaming=self.streaming,
......@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
......@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
......@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")
@property
def support_streaming(self):
return True
......@@ -2,6 +2,8 @@ import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
......@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'ernie-bot-4',
'name': 'ERNIE-Bot-4',
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
}
]
else:
......@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'ernie-bot-4': 4800,
'ernie-bot': 4800,
'ernie-bot-turbo': 11200,
}
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
......@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**credential_kwargs
)
llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
......
......@@ -5,6 +5,12 @@
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"ernie-bot-4": {
"prompt": "0",
"completion": "0",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",
......
......@@ -8,12 +8,15 @@ from typing import (
Any,
Dict,
List,
Optional, Iterator,
Optional, Iterator, Tuple,
)
import requests
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema.output import GenerationChunk
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
from langchain.callbacks.manager import (
......@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
raise ValueError(f"Wenxin Model name is required")
model_url_map = {
'ernie-bot-4': 'completions_pro',
'ernie-bot': 'completions',
'ernie-bot-turbo': 'eb-instant',
'bloomz-7b': 'bloomz_7b1',
......@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):
access_token = self.get_access_token()
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
del request['model']
headers = {"Content-Type": "application/json"}
response = requests.post(api_url,
......@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
f"Wenxin API {json_response['error_code']}"
f" error: {json_response['error_msg']}"
)
return json_response["result"]
return json_response
else:
return response
class Wenxin(LLM):
"""Wrapper around Wenxin large language models.
To use, you should have the environment variable
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
or pass them as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.wenxin import Wenxin
wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
secret_key="my-group-id")
"""
class Wenxin(BaseChatModel):
"""Wrapper around Wenxin large language models."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
@property
def lc_serializable(self) -> bool:
return True
_client: _WenxinEndpointClient = PrivateAttr()
model: str = "ernie-bot"
......@@ -161,64 +165,89 @@ class Wenxin(LLM):
secret_key=self.secret_key,
)
def _call(
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> Tuple[List[Dict[str, Any]], str]:
dict_messages = []
system = None
for m in messages:
message = self._convert_message_to_dict(m)
if message['role'] == 'system':
if not system:
system = message['content']
else:
system += f"\n{message['content']}"
continue
if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)
return dict_messages, system
def _generate(
self,
prompt: str,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""Call out to Wenxin's completion endpoint to chat
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = wenxin("Tell me a joke.")
"""
) -> ChatResult:
if self.streaming:
completion = ""
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts, system = self._create_message_dicts(messages)
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs)
completion = self._client.post(request)
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
response = self._client.post(request)
return self._create_chat_result(response)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
r"""Call wenxin completion_stream and return the resulting generator.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Wenxin.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = wenxin.stream(prompt)
for token in generator:
yield token
"""
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, system = self._create_message_dicts(messages)
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs)
for token in self._client.post(request).iter_lines():
......@@ -228,12 +257,18 @@ class Wenxin(LLM):
if token.startswith('data:'):
completion = json.loads(token[5:])
yield GenerationChunk(text=completion['result'])
if run_manager:
run_manager.on_llm_new_token(completion['result'])
chunk_dict = {
'message': AIMessageChunk(content=completion['result']),
}
if completion['is_end']:
break
token_usage = completion['usage']
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
chunk_dict['generation_info'] = dict({'token_usage': token_usage})
yield ChatGenerationChunk(**chunk_dict)
if run_manager:
run_manager.on_llm_new_token(completion['result'])
else:
try:
json_response = json.loads(token)
......@@ -245,3 +280,40 @@ class Wenxin(LLM):
f" error: {json_response['error_msg']}, "
f"please confirm if the model you have chosen is already paid for."
)
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
generations = [ChatGeneration(
message=AIMessage(content=response['result']),
)]
token_usage = response.get("usage")
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}
......@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('ernie-bot')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
messages = [PromptMessage(type=MessageType.USER, content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run(
messages,
stop=['\nHuman:'],
messages
)
assert len(rst.content) > 0
......@@ -2,6 +2,8 @@ import pytest
from unittest.mock import patch
import json
from langchain.schema import AIMessage, ChatGeneration, ChatResult
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.wenxin_provider import WenxinProvider
from models.provider import ProviderType, Provider
......@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc")
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment