Commit a0b8c224 authored by John Wang's avatar John Wang

feat: completed tool plugins api

parent ee2857de
......@@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import workspace controllers
from .workspace import workspace, members, providers, account
from .workspace import workspace, members, model_providers, account, tool_providers
# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
import json
from flask_login import login_required, current_user
from flask_restful import Resource, abort, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.tool.provider.errors import ToolValidateFailedError
from core.tool.provider.tool_provider_service import ToolProviderService
from extensions.ext_database import db
from models.tool import ToolProvider, ToolProviderName
class ToolProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
provider_list = [
{
'tool_name': p.tool_name,
'is_enabled': p.is_enabled,
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
}
for p in tool_providers
]
return provider_list
class ToolProviderCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ToolProviderName]:
abort(404)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
f'current_role is {current_user.current_tenant.current_role}')
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
tool_provider_service = ToolProviderService(tenant_id, provider)
try:
tool_provider_service.credentials_validate(args['credentials'])
except ToolValidateFailedError as ex:
raise ValueError(str(ex))
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
tenant = current_user.current_tenant
tool_provider_model = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == tenant.id,
ToolProvider.tool_name == provider,
).first()
# Only allow updating token for CUSTOM provider type
if tool_provider_model:
tool_provider_model.encrypted_credentials = encrypted_credentials
tool_provider_model.is_enabled = True
else:
tool_provider_model = ToolProvider(
tenant_id=tenant.id,
tool_name=provider,
encrypted_credentials=encrypted_credentials,
is_enabled=True
)
db.session.add(tool_provider_model)
db.session.commit()
return {'result': 'success'}, 201
class ToolProviderCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ToolProviderName]:
abort(404)
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
result = True
error = None
tenant_id = current_user.current_tenant_id
tool_provider_service = ToolProviderService(tenant_id, provider)
try:
tool_provider_service.credentials_validate(args['credentials'])
except ToolValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
api.add_resource(ToolProviderCredentialsValidateApi,
'/workspaces/current/tool-providers/<provider>/credentials-validate')
......@@ -16,6 +16,10 @@ class BaseToolProvider(ABC):
def get_provider_name(self) -> ToolProviderName:
raise NotImplementedError
@abstractmethod
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
raise NotImplementedError
......@@ -30,11 +34,11 @@ class BaseToolProvider(ABC):
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
Returns the Provider instance for the given tenant_id and tool_name.
"""
query = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == self.tenant_id,
ToolProvider.provider_name == self.get_provider_name()
ToolProvider.tool_name == self.get_provider_name().value
)
if must_enabled:
......
......@@ -2,6 +2,7 @@ from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.errors import ToolValidateFailedError
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
from models.tool import ToolProviderName
......@@ -25,14 +26,14 @@ class SerpAPIToolProvider(BaseToolProvider):
if not tool_provider:
return None
config = tool_provider.config
if not config:
credentials = tool_provider.credentials
if not credentials:
return None
if config.get('api_key'):
config['api_key'] = self.decrypt_token(config.get('api_key'), obfuscated)
if credentials.get('api_key'):
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
return config
return credentials
def credentials_to_func_kwargs(self) -> Optional[dict]:
"""
......@@ -57,3 +58,20 @@ class SerpAPIToolProvider(BaseToolProvider):
"""
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ToolValidateFailedError("SerpAPI api_key is required.")
api_key = credentials.get('api_key')
try:
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
except Exception as e:
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
return credentials
......@@ -16,6 +16,12 @@ class ToolProviderService:
raise Exception('tool provider {} not found'.format(provider_name))
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for Tool as a dictionary.
:param obfuscated:
:return:
"""
return self.provider.get_credentials(obfuscated)
def credentials_validate(self, credentials: dict):
......@@ -26,3 +32,12 @@ class ToolProviderService:
:raises: ValidateFailedError
"""
return self.provider.credentials_validate(credentials)
def encrypt_credentials(self, credentials: dict):
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
return self.provider.encrypt_credentials(credentials)
......@@ -22,7 +22,7 @@ def upgrade():
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
sa.Column('tool_name', sa.String(length=40), nullable=False),
sa.Column('encrypted_config', sa.Text(), nullable=True),
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
sa.Column('is_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
......
......@@ -5,6 +5,7 @@ from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db
class ToolProviderName(Enum):
SERPAPI = 'serpapi'
......@@ -26,21 +27,21 @@ class ToolProvider(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
tool_name = db.Column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
encrypted_credentials = db.Column(db.Text, nullable=True)
is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def config_is_set(self):
def credentials_is_set(self):
"""
Returns True if the encrypted_config is not None, indicating that the token is set.
"""
return self.encrypted_config is not None
return self.encrypted_credentials is not None
@property
def config(self):
def credentials(self):
"""
Returns the decrypted config.
"""
return json.loads(self.decrypt_config()) if self.encrypted_config is not None else None
return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None
......@@ -37,4 +37,5 @@ pyjwt~=2.6.0
newspaper3k==0.2.8
google-api-python-client==2.90.0
wikipedia==1.4.0
readabilipy==0.2.0
\ No newline at end of file
readabilipy==0.2.0
google-search-results==2.4.2
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment