Commit 898062c3 authored by takatost's avatar takatost

feat: add model mode in text generation model list api

parent 677aacc8
...@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage ...@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType from core.model_providers.models.llm.base import ModelType
...@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider): ...@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
{ {
'id': 'claude-instant-1', 'id': 'claude-instant-1',
'name': 'claude-instant-1', 'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'claude-2', 'id': 'claude-2',
'name': 'claude-2', 'name': 'claude-2',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider): ...@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -12,7 +12,7 @@ from core.helper import encrypter ...@@ -12,7 +12,7 @@ from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \ from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION AZURE_OPENAI_API_VERSION
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
} }
credentials = json.loads(provider_model.encrypted_config) credentials = json.loads(provider_model.encrypted_config)
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
if credentials['base_model_name'] in [ if credentials['base_model_name'] in [
'gpt-4', 'gpt-4',
'gpt-4-32k', 'gpt-4-32k',
...@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
return model_list return model_list
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name == 'text-davinci-003':
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION: if model_type == ModelType.TEXT_GENERATION:
models = [ models = [
{ {
'id': 'gpt-3.5-turbo', 'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo', 'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-3.5-turbo-16k', 'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k', 'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-4', 'id': 'gpt-4',
'name': 'gpt-4', 'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-4-32k', 'id': 'gpt-4-32k',
'name': 'gpt-4-32k', 'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{ {
'id': 'text-davinci-003', 'id': 'text-davinci-003',
'name': 'text-davinci-003', 'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
} }
] ]
......
...@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC): ...@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
ProviderModel.is_valid == True ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all() ).order_by(ProviderModel.created_at.asc()).all()
return [{ provider_model_list = []
'id': provider_model.model_name, for provider_model in provider_models:
'name': provider_model.model_name provider_model_dict = {
} for provider_model in provider_models] 'id': provider_model.model_name,
'name': provider_model.model_name
}
if model_type == ModelType.TEXT_GENERATION:
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
provider_model_list.append(provider_model_dict)
return provider_model_list
@abstractmethod @abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
...@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC): ...@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def _get_text_generation_model_mode(self, model_name) -> str:
"""
get text generation model mode.
:param model_name:
:return:
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_model_class(self, model_type: ModelType) -> Type: def get_model_class(self, model_type: ModelType) -> Type:
""" """
......
...@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM ...@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.chatglm_model import ChatGLMModel from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType from models.provider import ProviderType
...@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider): ...@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
{ {
'id': 'chatglm2-6b', 'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B', 'name': 'ChatGLM2-6B',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'chatglm-6b', 'id': 'chatglm-6b',
'name': 'ChatGLM-6B', 'name': 'ChatGLM-6B',
'mode': ModelMode.COMPLETION.value,
} }
] ]
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -5,7 +5,7 @@ import requests ...@@ -5,7 +5,7 @@ import requests
from huggingface_hub import HfApi from huggingface_hub import HfApi
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider): ...@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage ...@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
from core.model_providers.models.llm.localai_model import LocalAIModel from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider): ...@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
if credentials['completion_type'] == 'chat_completion':
return ModelMode.CHAT.value
else:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -7,7 +7,7 @@ from langchain.llms import Minimax ...@@ -7,7 +7,7 @@ from langchain.llms import Minimax
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.minimax_model import MinimaxModel from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType, ProviderQuotaType from models.provider import ProviderType, ProviderQuotaType
...@@ -28,10 +28,12 @@ class MinimaxProvider(BaseModelProvider): ...@@ -28,10 +28,12 @@ class MinimaxProvider(BaseModelProvider):
{ {
'id': 'abab5.5-chat', 'id': 'abab5.5-chat',
'name': 'abab5.5-chat', 'name': 'abab5.5-chat',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'abab5-chat', 'id': 'abab5-chat',
'name': 'abab5-chat', 'name': 'abab5-chat',
'mode': ModelMode.COMPLETION.value,
} }
] ]
elif model_type == ModelType.EMBEDDINGS: elif model_type == ModelType.EMBEDDINGS:
...@@ -44,6 +46,9 @@ class MinimaxProvider(BaseModelProvider): ...@@ -44,6 +46,9 @@ class MinimaxProvider(BaseModelProvider):
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature ...@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers from core.model_providers.providers.hosted import hosted_model_providers
...@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider): ...@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-3.5-turbo', 'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo', 'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider): ...@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-3.5-turbo-instruct', 'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct', 'name': 'GPT-3.5-Turbo-Instruct',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'gpt-3.5-turbo-16k', 'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k', 'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider): ...@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-4', 'id': 'gpt-4',
'name': 'gpt-4', 'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider): ...@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-4-32k', 'id': 'gpt-4-32k',
'name': 'gpt-4-32k', 'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider): ...@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'text-davinci-003', 'id': 'text-davinci-003',
'name': 'text-davinci-003', 'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
} }
] ]
...@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider): ...@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name in COMPLETION_MODELS:
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -2,7 +2,7 @@ import json ...@@ -2,7 +2,7 @@ import json
from typing import Type from typing import Type
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.openllm_model import OpenLLMModel from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -22,6 +22,9 @@ class OpenLLMProvider(BaseModelProvider): ...@@ -22,6 +22,9 @@ class OpenLLMProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -6,7 +6,8 @@ import replicate ...@@ -6,7 +6,8 @@ import replicate
from replicate.exceptions import ReplicateError from replicate.exceptions import ReplicateError
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
ModelMode
from core.model_providers.models.llm.replicate_model import ReplicateModel from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider): ...@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage ...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.spark_model import SparkModel from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark from core.third_party.langchain.llms.spark import ChatSpark
...@@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider): ...@@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider):
{ {
'id': 'spark', 'id': 'spark',
'name': 'Spark V1.5', 'name': 'Spark V1.5',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'spark-v2', 'id': 'spark-v2',
'name': 'Spark V2.0', 'name': 'Spark V2.0',
'mode': ModelMode.CHAT.value,
} }
] ]
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -4,7 +4,7 @@ from typing import Type ...@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.tongyi_model import TongyiModel from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
...@@ -26,15 +26,20 @@ class TongyiProvider(BaseModelProvider): ...@@ -26,15 +26,20 @@ class TongyiProvider(BaseModelProvider):
{ {
'id': 'qwen-v1', 'id': 'qwen-v1',
'name': 'qwen-v1', 'name': 'qwen-v1',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'qwen-plus-v1', 'id': 'qwen-plus-v1',
'name': 'qwen-plus-v1', 'name': 'qwen-plus-v1',
'mode': ModelMode.COMPLETION.value,
} }
] ]
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -4,7 +4,7 @@ from typing import Type ...@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.wenxin_model import WenxinModel from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin from core.third_party.langchain.llms.wenxin import Wenxin
...@@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider): ...@@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider):
{ {
'id': 'ernie-bot', 'id': 'ernie-bot',
'name': 'ERNIE-Bot', 'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'ernie-bot-turbo', 'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo', 'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'bloomz-7b', 'id': 'bloomz-7b',
'name': 'BLOOMZ-7B', 'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
} }
] ]
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -6,7 +6,7 @@ from langchain.embeddings import XinferenceEmbeddings ...@@ -6,7 +6,7 @@ from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.xinference_model import XinferenceModel from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider): ...@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage ...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
...@@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider): ...@@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider):
{ {
'id': 'chatglm_pro', 'id': 'chatglm_pro',
'name': 'chatglm_pro', 'name': 'chatglm_pro',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'chatglm_std', 'id': 'chatglm_std',
'name': 'chatglm_std', 'name': 'chatglm_std',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'chatglm_lite', 'id': 'chatglm_lite',
'name': 'chatglm_lite', 'name': 'chatglm_lite',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'chatglm_lite_32k', 'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k', 'name': 'chatglm_lite_32k',
'mode': ModelMode.CHAT.value,
} }
] ]
elif model_type == ModelType.EMBEDDINGS: elif model_type == ModelType.EMBEDDINGS:
...@@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider): ...@@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider):
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -482,6 +482,9 @@ class ProviderService: ...@@ -482,6 +482,9 @@ class ProviderService:
'features': [] 'features': []
} }
if 'mode' in model:
valid_model_dict['model_mode'] = model['mode']
if 'features' in model: if 'features' in model:
valid_model_dict['features'] = model['features'] valid_model_dict['features'] = model['features']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment