Unverified Commit 23e95fd7 authored by Yeuoly's avatar Yeuoly Committed by GitHub

Fix tool provider credential caching issue (#2433)

parent e1045f01
...@@ -85,4 +85,12 @@ class ToolConfiguration(BaseModel): ...@@ -85,4 +85,12 @@ class ToolConfiguration(BaseModel):
pass pass
cache.set(credentials) cache.set(credentials)
return credentials return credentials
\ No newline at end of file
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cache.delete()
...@@ -355,10 +355,12 @@ class ToolManageService: ...@@ -355,10 +355,12 @@ class ToolManageService:
else: else:
provider.encrypted_credentials = json.dumps(credentials) provider.encrypted_credentials = json.dumps(credentials)
db.session.add(provider) db.session.add(provider)
db.session.commit() db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' } return { 'result': 'success' }
@staticmethod @staticmethod
...@@ -393,7 +395,6 @@ class ToolManageService: ...@@ -393,7 +395,6 @@ class ToolManageService:
provider.description = extra_info.get('description', '') provider.description = extra_info.get('description', '')
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
provider.tools_str = serialize_base_model_array(tool_bundles) provider.tools_str = serialize_base_model_array(tool_bundles)
provider.credentials_str = json.dumps(credentials)
provider.privacy_policy = privacy_policy provider.privacy_policy = privacy_policy
if 'auth_type' not in credentials: if 'auth_type' not in credentials:
...@@ -403,33 +404,54 @@ class ToolManageService: ...@@ -403,33 +404,54 @@ class ToolManageService:
auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
# create provider entity # create provider entity
provider_entity = ApiBasedToolProviderController.from_db(provider, auth_type) provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type)
# load tools into provider entity # load tools into provider entity
provider_entity.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = tool_configuration.encrypt_tool_credentials(credentials)
provider.credentials_str = json.dumps(credentials)
db.session.add(provider) db.session.add(provider)
db.session.commit() db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' } return { 'result': 'success' }
@staticmethod @staticmethod
def delete_builtin_tool_provider( def delete_builtin_tool_provider(
user_id: str, tenant_id: str, provider: str user_id: str, tenant_id: str, provider_name: str
): ):
""" """
delete tool provider delete tool provider
""" """
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider, BuiltinToolProvider.provider == provider_name,
).first() ).first()
if provider is None: if provider is None:
raise ValueError(f'you have not added provider {provider}') raise ValueError(f'you have not added provider {provider_name}')
db.session.delete(provider) db.session.delete(provider)
db.session.commit() db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' } return { 'result': 'success' }
@staticmethod @staticmethod
...@@ -437,7 +459,7 @@ class ToolManageService: ...@@ -437,7 +459,7 @@ class ToolManageService:
provider: str provider: str
): ):
""" """
get tool provider icon and it's minetype get tool provider icon and it's mimetype
""" """
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
with open(icon_path, 'rb') as f: with open(icon_path, 'rb') as f:
...@@ -447,18 +469,18 @@ class ToolManageService: ...@@ -447,18 +469,18 @@ class ToolManageService:
@staticmethod @staticmethod
def delete_api_tool_provider( def delete_api_tool_provider(
user_id: str, tenant_id: str, provider: str user_id: str, tenant_id: str, provider_name: str
): ):
""" """
delete tool provider delete tool provider
""" """
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider, ApiToolProvider.name == provider_name,
).first() ).first()
if provider is None: if provider is None:
raise ValueError(f'you have not added provider {provider}') raise ValueError(f'you have not added provider {provider_name}')
db.session.delete(provider) db.session.delete(provider)
db.session.commit() db.session.commit()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment