Unverified Commit 7c9b585a authored by takatost's avatar takatost Committed by GitHub

feat: support weixin ernie-bot-4 and chat mode (#1375)

parent c039f4af
...@@ -6,17 +6,16 @@ from langchain.schema import LLMResult ...@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin from core.third_party.langchain.llms.wenxin import Wenxin
class WenxinModel(BaseLLM): class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin( return Wenxin(
model=self.name, model=self.name,
streaming=self.streaming, streaming=self.streaming,
...@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM): ...@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return: :return:
""" """
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int: def get_num_tokens(self, messages: List[PromptMessage]) -> int:
""" """
...@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM): ...@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return: :return:
""" """
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
...@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM): ...@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}") return LLMBadRequestError(f"Wenxin: {str(ex)}")
@property
def support_streaming(self):
return True
...@@ -2,6 +2,8 @@ import json ...@@ -2,6 +2,8 @@ import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import Type from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
...@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider): ...@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION: if model_type == ModelType.TEXT_GENERATION:
return [ return [
{
'id': 'ernie-bot-4',
'name': 'ERNIE-Bot-4',
'mode': ModelMode.CHAT.value,
},
{ {
'id': 'ernie-bot', 'id': 'ernie-bot',
'name': 'ERNIE-Bot', 'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value, 'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'ernie-bot-turbo', 'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo', 'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value, 'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'bloomz-7b', 'id': 'bloomz-7b',
'name': 'BLOOMZ-7B', 'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value, 'mode': ModelMode.CHAT.value,
} }
] ]
else: else:
...@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider): ...@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return: :return:
""" """
model_max_tokens = { model_max_tokens = {
'ernie-bot-4': 4800,
'ernie-bot': 4800, 'ernie-bot': 4800,
'ernie-bot-turbo': 11200, 'ernie-bot-turbo': 11200,
} }
if model_name in ['ernie-bot', 'ernie-bot-turbo']: if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules( return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2), top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
...@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider): ...@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**credential_kwargs **credential_kwargs
) )
llm("ping") llm([HumanMessage(content='ping')])
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
......
...@@ -5,6 +5,12 @@ ...@@ -5,6 +5,12 @@
"system_config": null, "system_config": null,
"model_flexibility": "fixed", "model_flexibility": "fixed",
"price_config": { "price_config": {
"ernie-bot-4": {
"prompt": "0",
"completion": "0",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot": { "ernie-bot": {
"prompt": "0.012", "prompt": "0.012",
"completion": "0.012", "completion": "0.012",
......
...@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker): ...@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('ernie-bot') model = get_mock_model('ernie-bot')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')] messages = [PromptMessage(type=MessageType.USER, content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run( rst = model.run(
messages, messages
stop=['\nHuman:'],
) )
assert len(rst.content) > 0 assert len(rst.content) > 0
...@@ -2,6 +2,8 @@ import pytest ...@@ -2,6 +2,8 @@ import pytest
from unittest.mock import patch from unittest.mock import patch
import json import json
from langchain.schema import AIMessage, ChatGeneration, ChatResult
from core.model_providers.providers.base import CredentialsValidateFailedError from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.wenxin_provider import WenxinProvider from core.model_providers.providers.wenxin_provider import WenxinProvider
from models.provider import ProviderType, Provider from models.provider import ProviderType, Provider
...@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key): ...@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_provider_credentials_valid_or_raise_valid(mocker): def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc") mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment