Commit 41ab8713 authored by John Wang's avatar John Wang

feat: fulfill few comments

parent 1c114eae
...@@ -14,8 +14,15 @@ from models.provider import Provider, ProviderType, ProviderModel ...@@ -14,8 +14,15 @@ from models.provider import Provider, ProviderType, ProviderModel
class BaseModelProvider(BaseModel, ABC): class BaseModelProvider(BaseModel, ABC):
"""
This is the base class for all model providers.
"""
provider: Provider provider: Provider
"""
The corresponding Provider database table object, which exists in two types of providers: SYSTEM and CUSTOM.
System providers have three types of quota: TRAIL, FREE (third-party free trial), and PAID.
"""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
...@@ -26,7 +33,7 @@ class BaseModelProvider(BaseModel, ABC): ...@@ -26,7 +33,7 @@ class BaseModelProvider(BaseModel, ABC):
@abstractmethod @abstractmethod
def provider_name(self): def provider_name(self):
""" """
Returns the name of a provider. Returns the unique name of a provider.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -39,9 +46,14 @@ class BaseModelProvider(BaseModel, ABC): ...@@ -39,9 +46,14 @@ class BaseModelProvider(BaseModel, ABC):
def get_supported_model_list(self, model_type: ModelType) -> list[dict]: def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
""" """
get supported model object list for use. get supported model object list for use.
If the provider rule is for fixed models, return the fixed model list provided by `_get_fixed_model_list`,
otherwise get the list of user-configured models under the provider.
:param model_type: :param model_type: The type of model to get the list for.
:return: :type model_type: ModelType
:return: A list of dictionaries representing the supported models.
:rtype: list[dict]
""" """
rules = self.get_rules() rules = self.get_rules()
if 'custom' not in rules['support_provider_types']: if 'custom' not in rules['support_provider_types']:
...@@ -69,20 +81,26 @@ class BaseModelProvider(BaseModel, ABC): ...@@ -69,20 +81,26 @@ class BaseModelProvider(BaseModel, ABC):
@abstractmethod @abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
""" """
get supported model object list for use. Get a list of supported fixed model objects for use.
:param model_type: :param model_type: The type of model to get the list for.
:return: :type model_type: ModelType
:return: A list of dictionaries representing the supported models.
:rtype: list[dict]
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_model_class(self, model_type: ModelType) -> Type: def get_model_class(self, model_type: ModelType) -> Type:
""" """
get specific model class. Get the specific model class for the given model type.
:param model_type: :param model_type: The type of model to get the class for.
:return: :type model_type: ModelType
:return: The class object for the specified model type.
:rtype: Type
""" """
raise NotImplementedError raise NotImplementedError
...@@ -90,9 +108,14 @@ class BaseModelProvider(BaseModel, ABC): ...@@ -90,9 +108,14 @@ class BaseModelProvider(BaseModel, ABC):
@abstractmethod @abstractmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict): def is_provider_credentials_valid_or_raise(cls, credentials: dict):
""" """
check provider credentials valid. Check if the given credentials are valid for this provider.
:param credentials: :param credentials: A dictionary of credentials to check.
:type credentials: dict
:raises: CredentialsValidateFailedError if the credentials are invalid.
:return: None
""" """
raise NotImplementedError raise NotImplementedError
...@@ -111,10 +134,13 @@ class BaseModelProvider(BaseModel, ABC): ...@@ -111,10 +134,13 @@ class BaseModelProvider(BaseModel, ABC):
@abstractmethod @abstractmethod
def get_provider_credentials(self, obfuscated: bool = False) -> dict: def get_provider_credentials(self, obfuscated: bool = False) -> dict:
""" """
get credentials for llm use. Get the credentials for this provider.
:param obfuscated: :param obfuscated: Whether to obfuscate the credentials or not.
:return: :type obfuscated: bool
:return: A dictionary of credentials for this provider.
:rtype: dict
""" """
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment