Unverified Commit 2851a9f0 authored by takatost's avatar takatost Committed by GitHub

feat: optimize minimax llm call (#1312)

parent c536f85b
import decimal
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.llms import Minimax
from langchain.schema import LLMResult from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
class MinimaxModel(BaseLLM): class MinimaxModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Minimax( return MinimaxChatLLM(
model=self.name, model=self.name,
model_kwargs={ streaming=self.streaming,
'stream': False
},
callbacks=self.callbacks, callbacks=self.callbacks,
**self.credentials, **self.credentials,
**provider_model_kwargs **provider_model_kwargs
...@@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM): ...@@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
:return: :return:
""" """
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts), 0)
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'
...@@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM): ...@@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
return LLMBadRequestError(f"Minimax: {str(ex)}") return LLMBadRequestError(f"Minimax: {str(ex)}")
else: else:
return ex return ex
@property
def support_streaming(self):
return True
...@@ -2,7 +2,7 @@ import json ...@@ -2,7 +2,7 @@ import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import Type from typing import Type
from langchain.llms import Minimax from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
...@@ -10,6 +10,7 @@ from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbed ...@@ -10,6 +10,7 @@ from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbed
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.minimax_model import MinimaxModel from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
from models.provider import ProviderType, ProviderQuotaType from models.provider import ProviderType, ProviderQuotaType
...@@ -98,14 +99,14 @@ class MinimaxProvider(BaseModelProvider): ...@@ -98,14 +99,14 @@ class MinimaxProvider(BaseModelProvider):
'minimax_api_key': credentials['minimax_api_key'], 'minimax_api_key': credentials['minimax_api_key'],
} }
llm = Minimax( llm = MinimaxChatLLM(
model='abab5.5-chat', model='abab5.5-chat',
max_tokens=10, max_tokens=10,
temperature=0.01, temperature=0.01,
**credential_kwargs **credential_kwargs
) )
llm("ping") llm([HumanMessage(content='ping')])
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment