Unverified Commit 36686d74 authored by Yeuoly's avatar Yeuoly Committed by GitHub

fix: test custom tool already exists without decrypting credentials (#2668)

parent 34387ec0
...@@ -259,6 +259,7 @@ class ToolApiProviderPreviousTestApi(Resource): ...@@ -259,6 +259,7 @@ class ToolApiProviderPreviousTestApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json') parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json') parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
...@@ -268,6 +269,7 @@ class ToolApiProviderPreviousTestApi(Resource): ...@@ -268,6 +269,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ToolManageService.test_api_tool_preview( return ToolManageService.test_api_tool_preview(
current_user.current_tenant_id, current_user.current_tenant_id,
args['provider_name'] if args['provider_name'] else '',
args['tool_name'], args['tool_name'],
args['credentials'], args['credentials'],
args['parameters'], args['parameters'],
......
import json import json
from json import dumps from json import dumps
from typing import Any, Union from typing import Any, Union
from urllib.parse import urlencode
import httpx import httpx
import requests import requests
...@@ -203,6 +204,8 @@ class ApiTool(Tool): ...@@ -203,6 +204,8 @@ class ApiTool(Tool):
if 'Content-Type' in headers: if 'Content-Type' in headers:
if headers['Content-Type'] == 'application/json': if headers['Content-Type'] == 'application/json':
body = dumps(body) body = dumps(body)
elif headers['Content-Type'] == 'application/x-www-form-urlencoded':
body = urlencode(body)
else: else:
body = body body = body
......
...@@ -498,12 +498,16 @@ class ToolManageService: ...@@ -498,12 +498,16 @@ class ToolManageService:
@staticmethod @staticmethod
def test_api_tool_preview( def test_api_tool_preview(
tenant_id: str, tool_name: str, credentials: dict, parameters: dict, schema_type: str, schema: str tenant_id: str,
provider_name: str,
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema: str
): ):
""" """
test api tool before adding api tool provider test api tool before adding api tool provider
1. parse schema into tool bundle
""" """
if schema_type not in [member.value for member in ApiProviderSchemaType]: if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f'invalid schema type {schema_type}') raise ValueError(f'invalid schema type {schema_type}')
...@@ -518,15 +522,21 @@ class ToolManageService: ...@@ -518,15 +522,21 @@ class ToolManageService:
if tool_bundle is None: if tool_bundle is None:
raise ValueError(f'invalid tool name {tool_name}') raise ValueError(f'invalid tool name {tool_name}')
# create a fake db provider db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
db_provider = ApiToolProvider( ApiToolProvider.tenant_id == tenant_id,
tenant_id='', user_id='', name='', icon='', ApiToolProvider.name == provider_name,
schema=schema, ).first()
description='',
schema_type_str=ApiProviderSchemaType.OPENAPI.value, if not db_provider:
tools_str=serialize_base_model_array(tool_bundles), # create a fake db provider
credentials_str=json.dumps(credentials), db_provider = ApiToolProvider(
) tenant_id='', user_id='', name='', icon='',
schema=schema,
description='',
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
tools_str=serialize_base_model_array(tool_bundles),
credentials_str=json.dumps(credentials),
)
if 'auth_type' not in credentials: if 'auth_type' not in credentials:
raise ValueError('auth_type is required') raise ValueError('auth_type is required')
...@@ -539,6 +549,19 @@ class ToolManageService: ...@@ -539,6 +549,19 @@ class ToolManageService:
# load tools into provider entity # load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# decrypt credentials
if db_provider.id:
tool_configuration = ToolConfiguration(
tenant_id=tenant_id,
provider_controller=provider_controller
)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]
try: try:
provider_controller.validate_credentials_format(credentials) provider_controller.validate_credentials_format(credentials)
# get tool # get tool
......
...@@ -42,6 +42,7 @@ const TestApi: FC<Props> = ({ ...@@ -42,6 +42,7 @@ const TestApi: FC<Props> = ({
delete credentials.api_key_value delete credentials.api_key_value
} }
const data = { const data = {
provider_name: customCollection.provider,
tool_name: toolName, tool_name: toolName,
credentials, credentials,
schema_type: customCollection.schema_type, schema_type: customCollection.schema_type,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment