Unverified Commit 296bf443 authored by takatost's avatar takatost Committed by GitHub

feat: reuse decoding_rsa_key & decoding_cipher_rsa & optimize construct (#1937)

parent af7be9bd
......@@ -520,7 +520,13 @@ class ProviderConfiguration(BaseModel):
provider_models.extend(
[
ModelWithProviderEntity(
**m.dict(),
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
......@@ -569,7 +575,13 @@ class ProviderConfiguration(BaseModel):
for m in models:
provider_models.append(
ModelWithProviderEntity(
**m.dict(),
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
)
......@@ -597,7 +609,13 @@ class ProviderConfiguration(BaseModel):
provider_models.append(
ModelWithProviderEntity(
**custom_model_schema.dict(),
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=custom_model_schema.fetch_from,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
......
......@@ -24,6 +24,9 @@ class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
def __init__(self) -> None:
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
......@@ -472,15 +475,16 @@ class ProviderManager:
provider_credentials = {}
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
self.decoding_rsa_key,
self.decoding_cipher_rsa
)
except ValueError:
pass
......@@ -524,15 +528,16 @@ class ProviderManager:
continue
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
self.decoding_rsa_key,
self.decoding_cipher_rsa
)
except ValueError:
pass
......@@ -641,15 +646,16 @@ class ProviderManager:
)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
self.decoding_rsa_key,
self.decoding_cipher_rsa
)
except ValueError:
pass
......
......@@ -14,7 +14,7 @@ from core.provider_manager import ProviderManager
from models.provider import ProviderType
from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \
SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \
DefaultModelResponse, ModelWithProviderEntityResponse
DefaultModelResponse, ModelWithProviderEntityResponse, SimpleProviderEntityResponse
logger = logging.getLogger(__name__)
......@@ -45,7 +45,17 @@ class ModelProviderService:
continue
provider_response = ProviderResponse(
**provider_configuration.provider.dict(),
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=provider_configuration.preferred_provider_type,
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
......@@ -53,7 +63,9 @@ class ModelProviderService:
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(
**provider_configuration.system_configuration.dict()
enabled=provider_configuration.system_configuration.enabled,
current_quota_type=provider_configuration.system_configuration.current_quota_type,
quota_configurations=provider_configuration.system_configuration.quota_configurations
)
)
......@@ -369,7 +381,15 @@ class ModelProviderService:
)
return DefaultModelResponse(
**result.dict()
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
)
) if result else None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment