Unverified Commit 296bf443 authored by takatost's avatar takatost Committed by GitHub

feat: reuse decoding_rsa_key & decoding_cipher_rsa & optimize construct (#1937)

parent af7be9bd
...@@ -520,7 +520,13 @@ class ProviderConfiguration(BaseModel): ...@@ -520,7 +520,13 @@ class ProviderConfiguration(BaseModel):
provider_models.extend( provider_models.extend(
[ [
ModelWithProviderEntity( ModelWithProviderEntity(
**m.dict(), model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE status=ModelStatus.ACTIVE
) )
...@@ -569,7 +575,13 @@ class ProviderConfiguration(BaseModel): ...@@ -569,7 +575,13 @@ class ProviderConfiguration(BaseModel):
for m in models: for m in models:
provider_models.append( provider_models.append(
ModelWithProviderEntity( ModelWithProviderEntity(
**m.dict(), model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
) )
...@@ -597,7 +609,13 @@ class ProviderConfiguration(BaseModel): ...@@ -597,7 +609,13 @@ class ProviderConfiguration(BaseModel):
provider_models.append( provider_models.append(
ModelWithProviderEntity( ModelWithProviderEntity(
**custom_model_schema.dict(), model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=custom_model_schema.fetch_from,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE status=ModelStatus.ACTIVE
) )
......
...@@ -24,6 +24,9 @@ class ProviderManager: ...@@ -24,6 +24,9 @@ class ProviderManager:
""" """
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
""" """
def __init__(self) -> None:
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
def get_configurations(self, tenant_id: str) -> ProviderConfigurations: def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
""" """
...@@ -472,15 +475,16 @@ class ProviderManager: ...@@ -472,15 +475,16 @@ class ProviderManager:
provider_credentials = {} provider_credentials = {}
# Get decoding rsa key and cipher for decrypting credentials # Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in provider_credential_secret_variables: for variable in provider_credential_secret_variables:
if variable in provider_credentials: if variable in provider_credentials:
try: try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable), provider_credentials.get(variable),
decoding_rsa_key, self.decoding_rsa_key,
decoding_cipher_rsa self.decoding_cipher_rsa
) )
except ValueError: except ValueError:
pass pass
...@@ -524,15 +528,16 @@ class ProviderManager: ...@@ -524,15 +528,16 @@ class ProviderManager:
continue continue
# Get decoding rsa key and cipher for decrypting credentials # Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in model_credential_secret_variables: for variable in model_credential_secret_variables:
if variable in provider_model_credentials: if variable in provider_model_credentials:
try: try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable), provider_model_credentials.get(variable),
decoding_rsa_key, self.decoding_rsa_key,
decoding_cipher_rsa self.decoding_cipher_rsa
) )
except ValueError: except ValueError:
pass pass
...@@ -641,15 +646,16 @@ class ProviderManager: ...@@ -641,15 +646,16 @@ class ProviderManager:
) )
# Get decoding rsa key and cipher for decrypting credentials # Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in provider_credential_secret_variables: for variable in provider_credential_secret_variables:
if variable in provider_credentials: if variable in provider_credentials:
try: try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable), provider_credentials.get(variable),
decoding_rsa_key, self.decoding_rsa_key,
decoding_cipher_rsa self.decoding_cipher_rsa
) )
except ValueError: except ValueError:
pass pass
......
...@@ -14,7 +14,7 @@ from core.provider_manager import ProviderManager ...@@ -14,7 +14,7 @@ from core.provider_manager import ProviderManager
from models.provider import ProviderType from models.provider import ProviderType
from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \ from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \
SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \ SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \
DefaultModelResponse, ModelWithProviderEntityResponse DefaultModelResponse, ModelWithProviderEntityResponse, SimpleProviderEntityResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,7 +45,17 @@ class ModelProviderService: ...@@ -45,7 +45,17 @@ class ModelProviderService:
continue continue
provider_response = ProviderResponse( provider_response = ProviderResponse(
**provider_configuration.provider.dict(), provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=provider_configuration.preferred_provider_type, preferred_provider_type=provider_configuration.preferred_provider_type,
custom_configuration=CustomConfigurationResponse( custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE status=CustomConfigurationStatus.ACTIVE
...@@ -53,7 +63,9 @@ class ModelProviderService: ...@@ -53,7 +63,9 @@ class ModelProviderService:
else CustomConfigurationStatus.NO_CONFIGURE else CustomConfigurationStatus.NO_CONFIGURE
), ),
system_configuration=SystemConfigurationResponse( system_configuration=SystemConfigurationResponse(
**provider_configuration.system_configuration.dict() enabled=provider_configuration.system_configuration.enabled,
current_quota_type=provider_configuration.system_configuration.current_quota_type,
quota_configurations=provider_configuration.system_configuration.quota_configurations
) )
) )
...@@ -369,7 +381,15 @@ class ModelProviderService: ...@@ -369,7 +381,15 @@ class ModelProviderService:
) )
return DefaultModelResponse( return DefaultModelResponse(
**result.dict() model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
)
) if result else None ) if result else None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment