Unverified Commit d38eac95 authored by takatost's avatar takatost Committed by GitHub

fix: wenxin model name invalid when llm call (#1248)

parent 9dbb8acd
...@@ -18,6 +18,7 @@ class WenxinModel(BaseLLM): ...@@ -18,6 +18,7 @@ class WenxinModel(BaseLLM):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db) # TODO load price_config from configs(db)
return Wenxin( return Wenxin(
model=self.name,
streaming=self.streaming, streaming=self.streaming,
callbacks=self.callbacks, callbacks=self.callbacks,
**self.credentials, **self.credentials,
......
...@@ -61,13 +61,18 @@ class WenxinProvider(BaseModelProvider): ...@@ -61,13 +61,18 @@ class WenxinProvider(BaseModelProvider):
:param model_type: :param model_type:
:return: :return:
""" """
model_max_tokens = {
'ernie-bot': 4800,
'ernie-bot-turbo': 11200,
}
if model_name in ['ernie-bot', 'ernie-bot-turbo']: if model_name in ['ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules( return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2), top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
presence_penalty=KwargRule[float](enabled=False), presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False), max_tokens=KwargRule[int](enabled=False, max=model_max_tokens.get(model_name)),
) )
else: else:
return ModelKwargsRules( return ModelKwargsRules(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment