Unverified Commit 1779cea6 authored by takatost's avatar takatost Committed by GitHub

fix: model provider credentials null value validate failed (#2009)

parent 26eff330
...@@ -165,7 +165,7 @@ class ProviderConfiguration(BaseModel): ...@@ -165,7 +165,7 @@ class ProviderConfiguration(BaseModel):
if value == '[__HIDDEN__]' and key in original_credentials: if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory.provider_credentials_validate( credentials = model_provider_factory.provider_credentials_validate(
self.provider.provider, self.provider.provider,
credentials credentials
) )
...@@ -308,24 +308,13 @@ class ProviderConfiguration(BaseModel): ...@@ -308,24 +308,13 @@ class ProviderConfiguration(BaseModel):
if value == '[__HIDDEN__]' and key in original_credentials: if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory.model_credentials_validate( credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, provider=self.provider.provider,
model_type=model_type, model_type=model_type,
model=model, model=model,
credentials=credentials credentials=credentials
) )
model_schema = (
model_provider_factory.get_provider_instance(self.provider.provider)
.get_model_instance(model_type)._get_customizable_model_schema(
model=model,
credentials=credentials
)
)
if model_schema:
credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
for key, value in credentials.items(): for key, value in credentials.items():
if key in provider_credential_secret_variables: if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(self.tenant_id, value) credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
......
...@@ -61,7 +61,7 @@ class ModelProviderFactory: ...@@ -61,7 +61,7 @@ class ModelProviderFactory:
# return providers # return providers
return providers return providers
def provider_credentials_validate(self, provider: str, credentials: dict) -> None: def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
""" """
Validate provider credentials Validate provider credentials
...@@ -80,13 +80,15 @@ class ModelProviderFactory: ...@@ -80,13 +80,15 @@ class ModelProviderFactory:
# validate provider credential schema # validate provider credential schema
validator = ProviderCredentialSchemaValidator(provider_credential_schema) validator = ProviderCredentialSchemaValidator(provider_credential_schema)
validator.validate_and_filter(credentials) filtered_credentials = validator.validate_and_filter(credentials)
# validate the credentials, raise exception if validation failed # validate the credentials, raise exception if validation failed
model_provider_instance.validate_provider_credentials(credentials) model_provider_instance.validate_provider_credentials(filtered_credentials)
return filtered_credentials
def model_credentials_validate(self, provider: str, model_type: ModelType, def model_credentials_validate(self, provider: str, model_type: ModelType,
model: str, credentials: dict) -> None: model: str, credentials: dict) -> dict:
""" """
Validate model credentials Validate model credentials
...@@ -107,13 +109,15 @@ class ModelProviderFactory: ...@@ -107,13 +109,15 @@ class ModelProviderFactory:
# validate model credential schema # validate model credential schema
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
validator.validate_and_filter(credentials) filtered_credentials = validator.validate_and_filter(credentials)
# get model instance of the model type # get model instance of the model type
model_instance = model_provider_instance.get_model_instance(model_type) model_instance = model_provider_instance.get_model_instance(model_type)
# call validate_credentials method of model type to validate credentials, raise exception if validation failed # call validate_credentials method of model type to validate credentials, raise exception if validation failed
model_instance.validate_credentials(model, credentials) model_instance.validate_credentials(model, filtered_credentials)
return filtered_credentials
def get_models(self, def get_models(self,
provider: Optional[str] = None, provider: Optional[str] = None,
......
...@@ -46,7 +46,7 @@ class CommonValidator: ...@@ -46,7 +46,7 @@ class CommonValidator:
:return: validated credential form schema value :return: validated credential form schema value
""" """
# If the variable does not exist in credentials # If the variable does not exist in credentials
if credential_form_schema.variable not in credentials: if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]:
# If required is True, an exception is thrown # If required is True, an exception is thrown
if credential_form_schema.required: if credential_form_schema.required:
raise ValueError(f'Variable {credential_form_schema.variable} is required') raise ValueError(f'Variable {credential_form_schema.variable} is required')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment