Commit dd500e3b authored by John Wang's avatar John Wang

fix: providers list include system token

parent 05493c35
...@@ -51,7 +51,8 @@ class ProviderListApi(Resource): ...@@ -51,7 +51,8 @@ class ProviderListApi(Resource):
'quota_used': p.quota_used 'quota_used': p.quota_used
} if p.provider_type == ProviderType.SYSTEM.value else {}), } if p.provider_type == ProviderType.SYSTEM.value else {}),
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant, 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
ProviderName(p.provider_name)) ProviderName(p.provider_name), only_custom=True)
if p.provider_type == ProviderType.CUSTOM.value else None
} }
for p in providers for p in providers
] ]
......
...@@ -32,12 +32,12 @@ class AnthropicProvider(BaseProvider): ...@@ -32,12 +32,12 @@ class AnthropicProvider(BaseProvider):
def get_provider_name(self): def get_provider_name(self):
return ProviderName.ANTHROPIC return ProviderName.ANTHROPIC
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the provider configs. Returns the provider configs.
""" """
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key(only_custom=only_custom)
except: except:
config = { config = {
'anthropic_api_key': '' 'anthropic_api_key': ''
......
...@@ -52,12 +52,12 @@ class AzureProvider(BaseProvider): ...@@ -52,12 +52,12 @@ class AzureProvider(BaseProvider):
def get_provider_name(self): def get_provider_name(self):
return ProviderName.AZURE_OPENAI return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the provider configs. Returns the provider configs.
""" """
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key(only_custom=only_custom)
except: except:
config = { config = {
'openai_api_type': 'azure', 'openai_api_type': 'azure',
......
...@@ -14,13 +14,13 @@ class BaseProvider(ABC): ...@@ -14,13 +14,13 @@ class BaseProvider(ABC):
def __init__(self, tenant_id: str): def __init__(self, tenant_id: str):
self.tenant_id = tenant_id self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]: def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the decrypted API key for the given tenant_id and provider_name. Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError. If the provider is not found or not valid, raises a ProviderTokenNotInitError.
""" """
provider = self.get_provider(prefer_custom) provider = self.get_provider(only_custom)
if not provider: if not provider:
raise ProviderTokenNotInitError( raise ProviderTokenNotInitError(
f"No valid {llm_constant.models[model_id]} model provider credentials found. " f"No valid {llm_constant.models[model_id]} model provider credentials found. "
...@@ -41,19 +41,19 @@ class BaseProvider(ABC): ...@@ -41,19 +41,19 @@ class BaseProvider(ABC):
else: else:
return self.get_decrypted_token(provider.encrypted_config) return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, prefer_custom: bool) -> Optional[Provider]: def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
""" """
Returns the Provider instance for the given tenant_id and provider_name. Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
""" """
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom) return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod @classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[ def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
Provider]: Provider]:
""" """
Returns the Provider instance for the given tenant_id and provider_name. Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. If both CUSTOM and System providers exist.
""" """
query = db.session.query(Provider).filter( query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id Provider.tenant_id == tenant_id
...@@ -62,23 +62,18 @@ class BaseProvider(ABC): ...@@ -62,23 +62,18 @@ class BaseProvider(ABC):
if provider_name: if provider_name:
query = query.filter(Provider.provider_name == provider_name) query = query.filter(Provider.provider_name == provider_name)
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() if only_custom:
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
custom_provider = None providers = query.order_by(Provider.provider_type.asc()).all()
system_provider = None
for provider in providers: for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
custom_provider = provider return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
system_provider = provider return provider
if custom_provider: return None
return custom_provider
elif system_provider:
return system_provider
else:
return None
def get_hosted_credentials(self) -> Union[str | dict]: def get_hosted_credentials(self) -> Union[str | dict]:
raise ProviderTokenNotInitError( raise ProviderTokenNotInitError(
...@@ -86,12 +81,12 @@ class BaseProvider(ABC): ...@@ -86,12 +81,12 @@ class BaseProvider(ABC):
f"Please go to Settings -> Model Provider to complete your provider credentials." f"Please go to Settings -> Model Provider to complete your provider credentials."
) )
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the provider configs. Returns the provider configs.
""" """
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key(only_custom=only_custom)
except: except:
config = '' config = ''
......
...@@ -31,11 +31,11 @@ class LLMProviderService: ...@@ -31,11 +31,11 @@ class LLMProviderService:
def get_credentials(self, model_id: Optional[str] = None) -> dict: def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id) return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated) return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]: def get_provider_db_record(self) -> Optional[Provider]:
return self.provider.get_provider(prefer_custom) return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]): def config_validate(self, config: Union[dict | str]):
""" """
......
...@@ -41,9 +41,9 @@ class ProviderService: ...@@ -41,9 +41,9 @@ class ProviderService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def get_obfuscated_api_key(tenant, provider_name: ProviderName): def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value) llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.get_provider_configs(obfuscated=True) return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
@staticmethod @staticmethod
def get_token_type(tenant, provider_name: ProviderName): def get_token_type(tenant, provider_name: ProviderName):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment