Commit 2aa8847b authored by Joel's avatar Joel

mrege main

parents 049e858e fdd211e3
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
1. Start the docker-compose stack 1. Start the docker-compose stack
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`. The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
```bash ```bash
cd ../docker cd ../docker
docker-compose -f docker-compose.middleware.yaml -p dify up -d docker-compose -f docker-compose.middleware.yaml -p dify up -d
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
3. Generate a `SECRET_KEY` in the `.env` file. 3. Generate a `SECRET_KEY` in the `.env` file.
```bash ```bash
openssl rand -base64 42 sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
``` ```
3.5 If you use annaconda, create a new environment and activate it 3.5 If you use annaconda, create a new environment and activate it
```bash ```bash
...@@ -46,7 +46,7 @@ ...@@ -46,7 +46,7 @@
``` ```
pip install -r requirements.txt --upgrade --force-reinstall pip install -r requirements.txt --upgrade --force-reinstall
``` ```
6. Start backend: 6. Start backend:
```bash ```bash
flask run --host 0.0.0.0 --port=5001 --debug flask run --host 0.0.0.0 --port=5001 --debug
......
...@@ -27,7 +27,9 @@ from fields.app_fields import ( ...@@ -27,7 +27,9 @@ from fields.app_fields import (
from libs.login import login_required from libs.login import login_required
from models.model import App, AppModelConfig, Site from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.tool_manager import ToolManager
from core.entities.application_entities import AgentToolEntity
def _get_app(app_id, tenant_id): def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
...@@ -236,7 +238,39 @@ class AppApi(Resource): ...@@ -236,7 +238,39 @@ class AppApi(Resource):
def get(self, app_id): def get(self, app_id):
"""Get app detail""" """Get app detail"""
app_id = str(app_id) app_id = str(app_id)
app = _get_app(app_id, current_user.current_tenant_id) app: App = _get_app(app_id, current_user.current_tenant_id)
# get original app model config
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
return app return app
......
import json
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
...@@ -7,6 +8,9 @@ from controllers.console import api ...@@ -7,6 +8,9 @@ from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app import _get_app
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.entities.application_entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import login_required
...@@ -38,6 +42,82 @@ class ModelConfigResource(Resource): ...@@ -38,6 +42,82 @@ class ModelConfigResource(Resource):
) )
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()
# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
db.session.add(new_app_model_config) db.session.add(new_app_model_config)
db.session.flush() db.session.flush()
......
...@@ -82,6 +82,30 @@ class ToolBuiltinProviderIconApi(Resource): ...@@ -82,6 +82,30 @@ class ToolBuiltinProviderIconApi(Resource):
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider) icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=minetype) return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
class ToolModelProviderIconApi(Resource):
@setup_required
def get(self, provider):
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
class ToolModelProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
return ToolManageService.list_model_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
)
class ToolApiProviderAddApi(Resource): class ToolApiProviderAddApi(Resource):
@setup_required @setup_required
...@@ -283,6 +307,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide ...@@ -283,6 +307,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update') api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema') api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon') api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
......
...@@ -200,8 +200,8 @@ class DatasetSegmentApi(DatasetApiResource): ...@@ -200,8 +200,8 @@ class DatasetSegmentApi(DatasetApiResource):
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json') parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args['segments'], document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args['segments'], segment, document, dataset) segment = SegmentService.update_segment(args, segment, document, dataset)
return { return {
'data': marshal(segment, segment_fields), 'data': marshal(segment, segment_fields),
'doc_form': document.doc_form 'doc_form': document.doc_form
......
...@@ -195,6 +195,10 @@ class AssistantApplicationRunner(AppRunner): ...@@ -195,6 +195,10 @@ class AssistantApplicationRunner(AppRunner):
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
db.session.refresh(conversation)
db.session.refresh(message)
db.session.close()
# start agent runner # start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner( assistant_cot_runner = AssistantCotApplicationRunner(
......
...@@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner): ...@@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner):
model=app_orchestration_config.model_config.model model=app_orchestration_config.model_config.model
) )
db.session.close()
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=app_orchestration_config.model_config.parameters,
......
...@@ -89,6 +89,10 @@ class GenerateTaskPipeline: ...@@ -89,6 +89,10 @@ class GenerateTaskPipeline:
Process generate task pipeline. Process generate task pipeline.
:return: :return:
""" """
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()
if stream: if stream:
return self._process_stream_response() return self._process_stream_response()
else: else:
...@@ -303,6 +307,7 @@ class GenerateTaskPipeline: ...@@ -303,6 +307,7 @@ class GenerateTaskPipeline:
.first() .first()
) )
db.session.refresh(agent_thought) db.session.refresh(agent_thought)
db.session.close()
if agent_thought: if agent_thought:
response = { response = {
...@@ -330,6 +335,8 @@ class GenerateTaskPipeline: ...@@ -330,6 +335,8 @@ class GenerateTaskPipeline:
.filter(MessageFile.id == event.message_file_id) .filter(MessageFile.id == event.message_file_id)
.first() .first()
) )
db.session.close()
# get extension # get extension
if '.' in message_file.url: if '.' in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}' extension = f'.{message_file.url.split(".")[-1]}'
...@@ -413,6 +420,7 @@ class GenerateTaskPipeline: ...@@ -413,6 +420,7 @@ class GenerateTaskPipeline:
usage = llm_result.usage usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first() self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens self._message.message_tokens = usage.prompt_tokens
......
...@@ -201,7 +201,7 @@ class ApplicationManager: ...@@ -201,7 +201,7 @@ class ApplicationManager:
logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally: finally:
db.session.remove() db.session.close()
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager, queue_manager: ApplicationQueueManager,
...@@ -233,8 +233,6 @@ class ApplicationManager: ...@@ -233,8 +233,6 @@ class ApplicationManager:
else: else:
logger.exception(e) logger.exception(e)
raise e raise e
finally:
db.session.remove()
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity: -> AppOrchestrationConfigEntity:
...@@ -651,6 +649,7 @@ class ApplicationManager: ...@@ -651,6 +649,7 @@ class ApplicationManager:
db.session.add(conversation) db.session.add(conversation)
db.session.commit() db.session.commit()
db.session.refresh(conversation)
else: else:
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
...@@ -689,6 +688,7 @@ class ApplicationManager: ...@@ -689,6 +688,7 @@ class ApplicationManager:
db.session.add(message) db.session.add(message)
db.session.commit() db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files: for file in application_generate_entity.files:
message_file = MessageFile( message_file = MessageFile(
......
...@@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.agent_thought_count = db.session.query(MessageAgentThought).filter( self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id, MessageAgentThought.message_id == self.message.id,
).count() ).count()
db.session.close()
# check if model supports stream tool call # check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
...@@ -154,9 +155,9 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -154,9 +155,9 @@ class BaseAssistantApplicationRunner(AppRunner):
""" """
convert tool to prompt message tool convert tool to prompt message tool
""" """
tool_entity = ToolManager.get_tool_runtime( tool_entity = ToolManager.get_agent_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name, tenant_id=self.tenant_id,
tenant_id=self.application_generate_entity.tenant_id, agent_tool=tool,
agent_callback=self.agent_callback agent_callback=self.agent_callback
) )
tool_entity.load_variables(self.variables_pool) tool_entity.load_variables(self.variables_pool)
...@@ -171,33 +172,11 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -171,33 +172,11 @@ class BaseAssistantApplicationRunner(AppRunner):
} }
) )
runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters()
parameters = tool_entity.parameters or []
user_parameters = tool_entity.get_runtime_parameters() or []
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
for parameter in parameters: for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string' parameter_type = 'string'
enum = [] enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING: if parameter.type == ToolParameter.ToolParameterType.STRING:
...@@ -213,59 +192,16 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -213,59 +192,16 @@ class BaseAssistantApplicationRunner(AppRunner):
else: else:
raise ValueError(f"parameter type {parameter.type} is not supported") raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.FORM: message_tool.parameters['properties'][parameter.name] = {
# get tool parameter from form "type": parameter_type,
tool_parameter_config = tool.tool_parameters.get(parameter.name) "description": parameter.llm_description or '',
if not tool_parameter_config: }
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required: if len(enum) > 0:
message_tool.parameters['required'].append(parameter.name) message_tool.parameters['properties'][parameter.name]['enum'] = enum
tool_entity.runtime.runtime_parameters.update(runtime_parameters) if parameter.required:
message_tool.parameters['required'].append(parameter.name)
return message_tool, tool_entity return message_tool, tool_entity
...@@ -305,6 +241,9 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -305,6 +241,9 @@ class BaseAssistantApplicationRunner(AppRunner):
tool_runtime_parameters = tool.get_runtime_parameters() or [] tool_runtime_parameters = tool.get_runtime_parameters() or []
for parameter in tool_runtime_parameters: for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string' parameter_type = 'string'
enum = [] enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING: if parameter.type == ToolParameter.ToolParameterType.STRING:
...@@ -320,18 +259,17 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -320,18 +259,17 @@ class BaseAssistantApplicationRunner(AppRunner):
else: else:
raise ValueError(f"parameter type {parameter.type} is not supported") raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.LLM: prompt_tool.parameters['properties'][parameter.name] = {
prompt_tool.parameters['properties'][parameter.name] = { "type": parameter_type,
"type": parameter_type, "description": parameter.llm_description or '',
"description": parameter.llm_description or '', }
}
if len(enum) > 0: if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required: if parameter.required:
if parameter.name not in prompt_tool.parameters['required']: if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name) prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool return prompt_tool
...@@ -404,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -404,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner):
created_by=self.user_id, created_by=self.user_id,
) )
db.session.add(message_file) db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append(( result.append((
message_file, message_file,
message.save_as message.save_as
)) ))
db.session.commit()
db.session.close()
return result return result
def create_agent_thought(self, message_id: str, message: str, def create_agent_thought(self, message_id: str, message: str,
...@@ -447,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -447,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner):
db.session.add(thought) db.session.add(thought)
db.session.commit() db.session.commit()
db.session.refresh(thought)
db.session.close()
self.agent_thought_count += 1 self.agent_thought_count += 1
...@@ -464,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -464,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner):
""" """
Save agent thought Save agent thought
""" """
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None: if thought is not None:
agent_thought.thought = thought agent_thought.thought = thought
...@@ -514,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -514,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels) agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit() db.session.commit()
db.session.close()
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
""" """
...@@ -586,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -586,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner):
""" """
convert tool variables to db variables convert tool variables to db variables
""" """
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables.updated_at = datetime.utcnow() db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit() db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
...@@ -644,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -644,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner):
if message.answer: if message.answer:
result.append(AssistantPromptMessage(content=message.answer)) result.append(AssistantPromptMessage(content=message.answer))
db.session.close()
return result return result
\ No newline at end of file
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolParameterCacheType(Enum):
PARAMETER = "tool_parameter"
class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_tool_parameter = redis_client.get(self.cache_key)
if cached_tool_parameter:
try:
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
cached_tool_parameter = json.loads(cached_tool_parameter)
except JSONDecodeError:
return None
return cached_tool_parameter
else:
return None
def set(self, parameters: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)
\ No newline at end of file
...@@ -82,6 +82,8 @@ class HostingConfiguration: ...@@ -82,6 +82,8 @@ class HostingConfiguration:
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING),
] ]
) )
quotas.append(trial_quota) quotas.append(trial_quota)
......
...@@ -47,11 +47,14 @@ class TokenBufferMemory: ...@@ -47,11 +47,14 @@ class TokenBufferMemory:
files, message.app_model_config files, message.app_model_config
) )
prompt_message_contents = [TextPromptMessageContent(data=message.query)] if not file_objs:
for file_obj in file_objs: prompt_messages.append(UserPromptMessage(content=message.query))
prompt_message_contents.append(file_obj.prompt_message_content) else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:
prompt_messages.append(UserPromptMessage(content=message.query)) prompt_messages.append(UserPromptMessage(content=message.query))
......
...@@ -17,7 +17,7 @@ class ModelType(Enum): ...@@ -17,7 +17,7 @@ class ModelType(Enum):
SPEECH2TEXT = "speech2text" SPEECH2TEXT = "speech2text"
MODERATION = "moderation" MODERATION = "moderation"
TTS = "tts" TTS = "tts"
# TEXT2IMG = "text2img" TEXT2IMG = "text2img"
@classmethod @classmethod
def value_of(cls, origin_model_type: str) -> "ModelType": def value_of(cls, origin_model_type: str) -> "ModelType":
...@@ -36,6 +36,8 @@ class ModelType(Enum): ...@@ -36,6 +36,8 @@ class ModelType(Enum):
return cls.SPEECH2TEXT return cls.SPEECH2TEXT
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
return cls.TTS return cls.TTS
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value: elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION return cls.MODERATION
else: else:
...@@ -59,10 +61,11 @@ class ModelType(Enum): ...@@ -59,10 +61,11 @@ class ModelType(Enum):
return 'tts' return 'tts'
elif self == self.MODERATION: elif self == self.MODERATION:
return 'moderation' return 'moderation'
elif self == self.TEXT2IMG:
return 'text2img'
else: else:
raise ValueError(f'invalid model type {self}') raise ValueError(f'invalid model type {self}')
class FetchFrom(Enum): class FetchFrom(Enum):
""" """
Enum class for fetch from. Enum class for fetch from.
......
from abc import abstractmethod
from typing import IO, Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
class Text2ImageModel(AIModel):
"""
Model class for text2img model.
"""
model_type: ModelType = ModelType.TEXT2IMG
def invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
try:
return self._invoke(model, credentials, prompt, model_parameters, user)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
raise NotImplementedError
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
- togetherai - togetherai
- ollama - ollama
- mistralai - mistralai
- groq
- replicate - replicate
- huggingface_hub - huggingface_hub
- zhipuai - zhipuai
......
...@@ -424,8 +424,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ...@@ -424,8 +424,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}" message_text = f"{human_prompt} {content}"
if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{human_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{human_prompt} [IMAGE]"
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}" if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{ai_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{ai_prompt} [IMAGE]"
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message_text = content message_text = content
else: else:
......
...@@ -524,5 +524,62 @@ EMBEDDING_BASE_MODELS = [ ...@@ -524,5 +524,62 @@ EMBEDDING_BASE_MODELS = [
currency='USD', currency='USD',
) )
) )
),
AzureBaseModel(
base_model_name='text-embedding-3-small',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: 8191,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.00002,
unit=0.001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='text-embedding-3-large',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: 8191,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.00013,
unit=0.001,
currency='USD',
)
)
)
]
SPEECH2TEXT_BASE_MODELS = [
AzureBaseModel(
base_model_name='whisper-1',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={
ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm'
}
)
) )
] ]
...@@ -15,6 +15,7 @@ help: ...@@ -15,6 +15,7 @@ help:
supported_model_types: supported_model_types:
- llm - llm
- text-embedding - text-embedding
- speech2text
configurate_methods: configurate_methods:
- customizable-model - customizable-model
model_credential_schema: model_credential_schema:
...@@ -99,6 +100,24 @@ model_credential_schema: ...@@ -99,6 +100,24 @@ model_credential_schema:
show_on: show_on:
- variable: __model_type - variable: __model_type
value: text-embedding value: text-embedding
- label:
en_US: text-embedding-3-small
value: text-embedding-3-small
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-large
value: text-embedding-3-large
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: whisper-1
value: whisper-1
show_on:
- variable: __model_type
value: speech2text
placeholder: placeholder:
zh_Hans: 在此输入您的模型版本 zh_Hans: 在此输入您的模型版本
en_US: Enter your model version en_US: Enter your model version
import copy
from typing import IO, Optional
from openai import AzureOpenAI
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
"""
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
return self._speech2text_invoke(model, credentials, file)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, 'rb') as audio_file:
self._speech2text_invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:return: text for given audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# init model client
client = AzureOpenAI(**credentials_kwargs)
response = client.audio.transcriptions.create(model=model, file=file)
return response.text
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None
<svg width="112" height="24" viewBox="0 0 112 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M57.4336 17.092C56.4746 16.5453 55.7216 15.7924 55.1749 14.8244C54.6283 13.8564 54.3594 12.763 54.3594 11.544C54.3594 10.3251 54.6283 9.2137 55.1749 8.24571C55.7216 7.27772 56.4746 6.52485 57.4336 5.98708C58.3926 5.4493 59.4861 5.18042 60.6961 5.18042C61.6999 5.18042 62.623 5.3776 63.4476 5.77197C64.2722 6.16633 64.9445 6.73995 65.4554 7.49284L64.568 8.13816C64.1199 7.51076 63.5642 7.04469 62.9009 6.731C62.2377 6.41729 61.5027 6.26492 60.705 6.26492C59.7281 6.26492 58.8498 6.48899 58.0789 6.92818C57.2992 7.36736 56.6986 7.98579 56.2505 8.79244C55.8113 9.59014 55.5872 10.5133 55.5872 11.553C55.5872 12.5926 55.8113 13.5159 56.2505 14.3136C56.6896 15.1112 57.2992 15.7297 58.0789 16.1778C58.8587 16.617 59.7281 16.8411 60.705 16.8411C61.5027 16.8411 62.2377 16.6888 62.9009 16.375C63.5642 16.0613 64.1199 15.5953 64.568 14.9678L65.4554 15.6132C64.9445 16.366 64.2722 16.9396 63.4476 17.334C62.623 17.7284 61.7089 17.9255 60.6961 17.9255C59.4771 17.9255 58.3926 17.6568 57.4336 17.11V17.092Z" fill="#F55036"/>
<path d="M67.2754 0H68.4763V17.8181H67.2754V0Z" fill="#F55036"/>
<path d="M73.6754 17.092C72.7254 16.5454 71.9725 15.7924 71.4347 14.8244C70.888 13.8564 70.6191 12.763 70.6191 11.544C70.6191 10.3251 70.888 9.23163 71.4347 8.26364C71.9814 7.29566 72.7254 6.54277 73.6754 5.99604C74.6255 5.4493 75.6921 5.18042 76.8841 5.18042C78.0762 5.18042 79.1338 5.4493 80.0928 5.99604C81.0429 6.54277 81.7957 7.29566 82.3335 8.26364C82.8803 9.23163 83.1492 10.3251 83.1492 11.544C83.1492 12.763 82.8803 13.8564 82.3335 14.8244C81.7868 15.7924 81.0429 16.5454 80.0928 17.092C79.1427 17.6387 78.0673 17.9076 76.8841 17.9076C75.7011 17.9076 74.6344 17.6387 73.6754 17.092ZM79.4655 16.1599C80.2273 15.7118 80.8277 15.0843 81.2669 14.2867C81.7062 13.489 81.9302 12.5747 81.9302 11.553C81.9302 10.5312 81.7062 9.61703 81.2669 8.81933C80.8277 8.02164 80.2273 7.39425 79.4655 6.9461C78.7036 6.49796 77.8431 6.27389 76.8841 6.27389C75.9251 6.27389 75.0646 6.49796 74.3028 6.9461C73.5409 7.39425 72.9405 8.02164 72.5013 8.81933C72.0621 9.61703 71.838 10.5312 71.838 11.553C71.838 12.5747 72.0621 13.489 72.5013 14.2867C72.9405 15.0843 73.5409 15.7118 74.3028 16.1599C75.0646 16.608 75.9251 16.8322 76.8841 16.8322C77.8431 16.8322 78.7036 16.608 79.4655 16.1599Z" fill="#F55036"/>
<path d="M96.2799 5.27905V17.8091H95.1237V15.1203C94.7114 15.9986 94.0929 16.6887 93.2774 17.1728C92.4618 17.6567 91.5027 17.9077 90.4003 17.9077C88.769 17.9077 87.4873 17.4506 86.5553 16.5364C85.6231 15.6222 85.166 14.3136 85.166 12.6017V5.27905H86.367V12.5031C86.367 13.9102 86.7255 14.9858 87.4515 15.7207C88.1775 16.4557 89.1903 16.8232 90.4989 16.8232C91.9061 16.8232 93.0264 16.384 93.851 15.5057C94.6756 14.6272 95.0878 13.4442 95.0878 11.9563V5.27905H96.2889H96.2799Z" fill="#F55036"/>
<path d="M110.952 0V17.8181H109.777V14.8604C109.284 15.8374 108.585 16.5902 107.689 17.119C106.793 17.6479 105.78 17.9077 104.642 17.9077C103.503 17.9077 102.419 17.6389 101.469 17.0922C100.528 16.5454 99.7838 15.7925 99.246 14.8336C98.7083 13.8745 98.4395 12.781 98.4395 11.5441C98.4395 10.3073 98.7083 9.2138 99.246 8.24582C99.7838 7.27783 100.519 6.52496 101.469 5.98718C102.41 5.44941 103.468 5.18053 104.642 5.18053C105.816 5.18053 106.766 5.44044 107.653 5.96925C108.541 6.49807 109.24 7.23301 109.75 8.17411V0H110.952ZM107.295 16.16C108.057 15.7119 108.657 15.0844 109.096 14.2868C109.535 13.4891 109.759 12.5749 109.759 11.5531C109.759 10.5313 109.535 9.61713 109.096 8.81944C108.657 8.02174 108.057 7.39434 107.295 6.9462C106.533 6.49807 105.672 6.27399 104.713 6.27399C103.754 6.27399 102.894 6.49807 102.132 6.9462C101.37 7.39434 100.77 8.02174 100.331 8.81944C99.8914 9.61713 99.6673 10.5313 99.6673 11.5531C99.6673 12.5749 99.8914 13.4891 100.331 14.2868C100.77 15.0844 101.37 15.7119 102.132 16.16C102.894 16.6081 103.754 16.8322 104.713 16.8322C105.672 16.8322 106.533 16.6081 107.295 16.16Z" fill="#F55036"/>
<path d="M30.6085 5.27024C27.077 5.27024 24.209 8.13835 24.209 11.6697C24.209 15.201 27.077 18.0692 30.6085 18.0692C34.1399 18.0692 37.0079 15.201 37.0079 11.6697C37.0079 8.13835 34.1399 5.27921 30.6085 5.27024ZM30.6085 15.6672C28.4036 15.6672 26.611 13.8746 26.611 11.6697C26.611 9.46486 28.4036 7.67228 30.6085 7.67228C32.8133 7.67228 34.6059 9.46486 34.6059 11.6697C34.6059 13.8746 32.8133 15.6672 30.6085 15.6672Z" fill="black"/>
<path d="M6.45358 5.23422C2.92222 5.19837 0.036187 8.0396 0.000335591 11.571C-0.0355158 15.1023 2.80571 17.9974 6.33706 18.0242C6.37292 18.0242 6.41773 18.0242 6.45358 18.0242H8.55986V15.6311H6.45358C4.24873 15.658 2.43823 13.8923 2.41134 11.6785C2.38445 9.47365 4.15014 7.66315 6.36395 7.63626C6.39084 7.63626 6.4267 7.63626 6.45358 7.63626C8.65844 7.63626 10.46 9.42884 10.46 11.6337V17.5222C10.46 19.7092 8.67637 21.4929 6.48943 21.5197C5.44078 21.5197 4.44591 21.0895 3.71095 20.3455L2.01698 22.0395C3.1911 23.2227 4.7865 23.8949 6.45358 23.9128H6.54321C10.0298 23.859 12.8351 21.0357 12.853 17.5491V11.4724C12.7635 8.00374 9.93116 5.23422 6.46254 5.23422H6.45358Z" fill="black"/>
<path d="M51.2406 11.5082C51.151 8.03961 48.3187 5.27009 44.8501 5.27009C41.3187 5.23423 38.4237 8.07545 38.3968 11.6068C38.361 15.1382 41.2022 18.0331 44.7335 18.0601C44.7694 18.0601 44.8143 18.0601 44.8501 18.0601H46.9563V15.667H44.8501C42.6452 15.6939 40.8347 13.9282 40.8078 11.7144C40.7809 9.5095 42.5467 7.69902 44.7604 7.67213C44.7874 7.67213 44.8232 7.67213 44.8501 7.67213C47.055 7.67213 48.8565 9.46469 48.8565 11.6696V23.626L51.2406 23.6528V11.5082Z" fill="black"/>
<path d="M14.6808 18.0602H17.0649V11.6607C17.0649 9.45589 18.8575 7.66332 21.0623 7.66332C21.7883 7.66332 22.4695 7.8605 23.0611 8.2011L24.2621 6.12172C23.3209 5.57498 22.2276 5.27024 21.0713 5.27024C17.5399 5.27024 14.6719 8.13835 14.6719 11.6697V18.0692L14.6808 18.0602Z" fill="black"/>
</svg>
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="24" height="24" rx="12" fill="#F55036"/>
<path d="M12.146 6.00022C9.87734 5.97718 8.02325 7.80249 8.00022 10.0712C7.97718 12.3398 9.80249 14.1997 12.0712 14.217C12.0942 14.217 12.123 14.217 12.146 14.217H13.4992V12.6796H12.146C10.7295 12.6968 9.56641 11.5625 9.54913 10.1403C9.53186 8.72377 10.6662 7.56065 12.0884 7.54337C12.1057 7.54337 12.1287 7.54337 12.146 7.54337C13.5625 7.54337 14.7199 8.69498 14.7199 10.1115V13.8945C14.7199 15.2995 13.574 16.4453 12.169 16.4626C11.4953 16.4626 10.8562 16.1862 10.384 15.7083L9.29578 16.7965C10.0501 17.5566 11.075 17.9885 12.146 18H12.2036C14.4435 17.9654 16.2457 16.1516 16.2572 13.9117V10.0078C16.1997 7.77945 14.3801 6.00022 12.1518 6.00022H12.146Z" fill="white"/>
</svg>
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class GroqProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='llama2-70b-4096',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex
provider: groq
label:
zh_Hans: GroqCloud
en_US: GroqCloud
description:
en_US: GroqCloud provides access to the Groq Cloud API, which hosts models like LLama2 and Mixtral.
zh_Hans: GroqCloud 提供对 Groq Cloud API 的访问,其中托管了 LLama2 和 Mixtral 等模型。
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#F5F5F4"
help:
title:
en_US: Get your API Key from GroqCloud
zh_Hans: 从 GroqCloud 获取 API Key
url:
en_US: https://console.groq.com/
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model: llama2-70b-4096
label:
zh_Hans: Llama-2-70B-4096
en_US: Llama-2-70B-4096
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 4096
pricing:
input: '0.7'
output: '0.8'
unit: '0.000001'
currency: USD
from collections.abc import Generator
from typing import Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.groq.com/openai/v1'
model: mixtral-8x7b-32768
label:
zh_Hans: Mixtral-8x7b-Instruct-v0.1
en_US: Mixtral-8x7b-Instruct-v0.1
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 20480
pricing:
input: '0.27'
output: '0.27'
unit: '0.000001'
currency: USD
from os.path import abspath, dirname, join from os.path import abspath, dirname, join
from threading import Lock
from transformers import AutoTokenizer from transformers import AutoTokenizer
class JinaTokenizer: class JinaTokenizer:
@staticmethod _tokenizer = None
def _get_num_tokens_by_jina_base(text: str) -> int: _lock = Lock()
@classmethod
def _get_tokenizer(cls):
if cls._tokenizer is None:
with cls._lock:
if cls._tokenizer is None:
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
return cls._tokenizer
@classmethod
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
""" """
use jina tokenizer to get num tokens use jina tokenizer to get num tokens
""" """
base_path = abspath(__file__) tokenizer = cls._get_tokenizer()
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
tokens = tokenizer.encode(text) tokens = tokenizer.encode(text)
return len(tokens) return len(tokens)
@staticmethod @classmethod
def get_num_tokens(text: str) -> int: def get_num_tokens(cls, text: str) -> int:
return JinaTokenizer._get_num_tokens_by_jina_base(text) return cls._get_num_tokens_by_jina_base(text)
\ No newline at end of file \ No newline at end of file
...@@ -2,4 +2,4 @@ model: whisper-1 ...@@ -2,4 +2,4 @@ model: whisper-1
model_type: speech2text model_type: speech2text
model_properties: model_properties:
file_upload_limit: 25 file_upload_limit: 25
supported_file_extensions: mp3,mp4,mpeg,mpga,m4a,wav,webm supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm
...@@ -308,6 +308,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -308,6 +308,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
use_template='max_tokens', use_template='max_tokens',
min=1, min=1,
max=credentials.get('context_length', 2048),
default=512, default=512,
label=I18nObject( label=I18nObject(
zh_Hans='最大生成长度', zh_Hans='最大生成长度',
......
...@@ -44,6 +44,9 @@ class XinferenceRerankModel(RerankModel): ...@@ -44,6 +44,9 @@ class XinferenceRerankModel(RerankModel):
docs=[] docs=[]
) )
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
# initialize client # initialize client
client = Client( client = Client(
base_url=credentials['server_url'] base_url=credentials['server_url']
......
...@@ -10,7 +10,7 @@ from core.rag.models.document import Document ...@@ -10,7 +10,7 @@ from core.rag.models.document import Document
class WordExtractor(BaseExtractor): class WordExtractor(BaseExtractor):
"""Load pdf files. """Load docx files.
Args: Args:
...@@ -46,14 +46,16 @@ class WordExtractor(BaseExtractor): ...@@ -46,14 +46,16 @@ class WordExtractor(BaseExtractor):
def extract(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load given path as single page.""" """Load given path as single page."""
import docx2txt from docx import Document as docx_Document
return [ document = docx_Document(self.file_path)
Document( doc_texts = [paragraph.text for paragraph in document.paragraphs]
page_content=docx2txt.process(self.file_path), content = '\n'.join(doc_texts)
metadata={"source": self.file_path},
) return [Document(
] page_content=content,
metadata={"source": self.file_path},
)]
@staticmethod @staticmethod
def _is_valid_url(url: str) -> bool: def _is_valid_url(url: str) -> bool:
......
...@@ -52,7 +52,7 @@ class BaseIndexProcessor(ABC): ...@@ -52,7 +52,7 @@ class BaseIndexProcessor(ABC):
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"], chunk_size=segmentation["max_tokens"],
chunk_overlap=0, chunk_overlap=segmentation.get('chunk_overlap', 0),
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "。", ".", " ", ""], separators=["\n\n", "。", ".", " ", ""],
embedding_model_instance=embedding_model_instance embedding_model_instance=embedding_model_instance
...@@ -61,7 +61,7 @@ class BaseIndexProcessor(ABC): ...@@ -61,7 +61,7 @@ class BaseIndexProcessor(ABC):
# Automatic segmentation # Automatic segmentation
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=0, chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
separators=["\n\n", "。", ".", " ", ""], separators=["\n\n", "。", ".", " ", ""],
embedding_model_instance=embedding_model_instance embedding_model_instance=embedding_model_instance
) )
......
...@@ -30,7 +30,7 @@ def _split_text_with_regex( ...@@ -30,7 +30,7 @@ def _split_text_with_regex(
if separator: if separator:
if keep_separator: if keep_separator:
# The parentheses in the pattern keep the delimiters in the result. # The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text) _splits = re.split(f"({re.escape(separator)})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0: if len(_splits) % 2 == 0:
splits += _splits[-1:] splits += _splits[-1:]
...@@ -94,7 +94,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): ...@@ -94,7 +94,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
documents.append(new_doc) documents.append(new_doc)
return documents return documents
def split_documents(self, documents: Iterable[Document]) -> list[Document]: def split_documents(self, documents: Iterable[Document] ) -> list[Document]:
"""Split documents.""" """Split documents."""
texts, metadatas = [], [] texts, metadatas = [], []
for doc in documents: for doc in documents:
......
...@@ -119,7 +119,7 @@ parameters: # Parameter list ...@@ -119,7 +119,7 @@ parameters: # Parameter list
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc. - The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
- `parameters` Parameter list - `parameters` Parameter list
- `name` Parameter name, unique, no duplication with other parameters - `name` Parameter name, unique, no duplication with other parameters
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box - `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type
- `required` Required or not - `required` Required or not
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter - In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts - In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
......
...@@ -119,7 +119,7 @@ parameters: # 参数列表 ...@@ -119,7 +119,7 @@ parameters: # 参数列表
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等 - `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
- `parameters` 参数列表 - `parameters` 参数列表
- `name` 参数名称,唯一,不允许和其他参数重名 - `name` 参数名称,唯一,不允许和其他参数重名
- `type` 参数类型,目前支持`string``number``boolean``select` 四种类型,分别对应字符串、数字、布尔值、下拉框 - `type` 参数类型,目前支持`string``number``boolean``select``secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
- `required` 是否必填 - `required` 是否必填
-`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数 -`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数
-`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数 -`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
......
...@@ -8,15 +8,19 @@ class I18nObject(BaseModel): ...@@ -8,15 +8,19 @@ class I18nObject(BaseModel):
Model class for i18n object. Model class for i18n object.
""" """
zh_Hans: Optional[str] = None zh_Hans: Optional[str] = None
pt_BR: Optional[str] = None
en_US: str en_US: str
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
if not self.zh_Hans: if not self.zh_Hans:
self.zh_Hans = self.en_US self.zh_Hans = self.en_US
if not self.pt_BR:
self.pt_BR = self.en_US
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
'zh_Hans': self.zh_Hans, 'zh_Hans': self.zh_Hans,
'en_US': self.en_US, 'en_US': self.en_US,
} 'pt_BR': self.pt_BR
\ No newline at end of file }
...@@ -100,6 +100,7 @@ class ToolParameter(BaseModel): ...@@ -100,6 +100,7 @@ class ToolParameter(BaseModel):
NUMBER = "number" NUMBER = "number"
BOOLEAN = "boolean" BOOLEAN = "boolean"
SELECT = "select" SELECT = "select"
SECRET_INPUT = "secret-input"
class ToolParameterForm(Enum): class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool SCHEMA = "schema" # should be set while adding tool
...@@ -304,4 +305,24 @@ class ToolRuntimeVariablePool(BaseModel): ...@@ -304,4 +305,24 @@ class ToolRuntimeVariablePool(BaseModel):
value=value, value=value,
) )
self.pool.append(variable) self.pool.append(variable)
\ No newline at end of file
class ModelToolPropertyKey(Enum):
IMAGE_PARAMETER_NAME = "image_parameter_name"
class ModelToolConfiguration(BaseModel):
"""
Model tool configuration
"""
type: str = Field(..., description="The type of the model tool")
model: str = Field(..., description="The model")
label: I18nObject = Field(..., description="The label of the model tool")
properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")
class ModelToolProviderConfiguration(BaseModel):
"""
Model tool provider configuration
"""
provider: str = Field(..., description="The provider of the model tool")
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
\ No newline at end of file
...@@ -13,6 +13,7 @@ class UserToolProvider(BaseModel): ...@@ -13,6 +13,7 @@ class UserToolProvider(BaseModel):
BUILTIN = "builtin" BUILTIN = "builtin"
APP = "app" APP = "app"
API = "api" API = "api"
MODEL = "model"
id: str id: str
author: str author: str
......
provider: anthropic
label:
en_US: Anthropic Model Tools
zh_Hans: Anthropic 模型能力
pt_BR: Anthropic Model Tools
models:
- type: llm
model: claude-3-sonnet-20240229
label:
zh_Hans: Claude3 Sonnet 视觉
en_US: Claude3 Sonnet Vision
properties:
image_parameter_name: image_id
- type: llm
model: claude-3-opus-20240229
label:
zh_Hans: Claude3 Opus 视觉
en_US: Claude3 Opus Vision
properties:
image_parameter_name: image_id
provider: google
label:
en_US: Google Model Tools
zh_Hans: Google 模型能力
pt_BR: Google Model Tools
models:
- type: llm
model: gemini-pro-vision
label:
zh_Hans: Gemini Pro 视觉
en_US: Gemini Pro Vision
properties:
image_parameter_name: image_id
provider: openai
label:
en_US: OpenAI Model Tools
zh_Hans: OpenAI 模型能力
pt_BR: OpenAI Model Tools
models:
- type: llm
model: gpt-4-vision-preview
label:
zh_Hans: GPT-4 视觉
en_US: GPT-4 Vision
properties:
image_parameter_name: image_id
provider: zhipuai
label:
en_US: ZhipuAI Model Tools
zh_Hans: ZhipuAI 模型能力
pt_BR: ZhipuAI Model Tools
models:
- type: llm
model: glm-4v
label:
zh_Hans: GLM-4 视觉
en_US: GLM-4 Vision
properties:
image_parameter_name: image_id
- google - google
- bing - bing
- duckduckgo - duckduckgo
- yahoo - dalle
- azuredalle
- wikipedia - wikipedia
- model.openai
- model.google
- model.anthropic
- yahoo
- arxiv - arxiv
- pubmed - pubmed
- dalle
- azuredalle
- stablediffusion - stablediffusion
- webscraper - webscraper
- model.zhipuai
- aippt
- youtube - youtube
- wolframalpha - wolframalpha
- maths - maths
......
...@@ -4,24 +4,24 @@ from yaml import FullLoader, load ...@@ -4,24 +4,24 @@ from yaml import FullLoader, load
from core.tools.entities.user_entities import UserToolProvider from core.tools.entities.user_entities import UserToolProvider
position = {}
class BuiltinToolProviderSort: class BuiltinToolProviderSort:
@staticmethod _position = {}
def sort(providers: list[UserToolProvider]) -> list[UserToolProvider]:
global position @classmethod
if not position: def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
tmp_position = {} tmp_position = {}
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml') file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path) as f: with open(file_path) as f:
for pos, val in enumerate(load(f, Loader=FullLoader)): for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos tmp_position[val] = pos
position = tmp_position cls._position = tmp_position
def sort_compare(provider: UserToolProvider) -> int: def sort_compare(provider: UserToolProvider) -> int:
# if provider.type == UserToolProvider.ProviderType.MODEL: if provider.type == UserToolProvider.ProviderType.MODEL:
# return position.get(f'model_provider.{provider.name}', 10000) return cls._position.get(f'model.{provider.name}', 10000)
return position.get(provider.name, 10000) return cls._position.get(provider.name, 10000)
sorted_providers = sorted(providers, key=sort_compare) sorted_providers = sorted(providers, key=sort_compare)
......
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class AIPPTProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
identity:
author: Dify
name: aippt
label:
en_US: AIPPT
zh_Hans: AIPPT
description:
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
icon: icon.png
credentials_for_provider:
aippt_access_key:
type: secret-input
required: true
label:
en_US: AIPPT API key
zh_Hans: AIPPT API key
pt_BR: AIPPT API key
help:
en_US: Please input your AIPPT API key
zh_Hans: 请输入你的 AIPPT API key
pt_BR: Please input your AIPPT API key
placeholder:
en_US: Please input your AIPPT API key
zh_Hans: 请输入你的 AIPPT API key
pt_BR: Please input your AIPPT API key
url: https://www.aippt.cn
aippt_secret_key:
type: secret-input
required: true
label:
en_US: AIPPT Secret key
zh_Hans: AIPPT Secret key
pt_BR: AIPPT Secret key
help:
en_US: Please input your AIPPT Secret key
zh_Hans: 请输入你的 AIPPT Secret key
pt_BR: Please input your AIPPT Secret key
placeholder:
en_US: Please input your AIPPT Secret key
zh_Hans: 请输入你的 AIPPT Secret key
pt_BR: Please input your AIPPT Secret key
from base64 import b64encode
from hashlib import sha1
from hmac import new as hmac_new
from json import loads as json_loads
from threading import Lock
from time import sleep, time
from typing import Any
from httpx import get, post
from requests import get as requests_get
from yarl import URL
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from core.tools.tool.builtin_tool import BuiltinTool
class AIPPTGenerateTool(BuiltinTool):
"""
A tool for generating a ppt
"""
_api_base_url = URL('https://co.aippt.cn/api')
_api_token_cache = {}
_api_token_cache_lock = Lock()
_style_cache = {}
_style_cache_lock = Lock()
_task = {}
_task_type_map = {
'auto': 1,
'markdown': 7,
}
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Invokes the AIPPT generate tool with the given user ID and tool parameters.
Args:
user_id (str): The ID of the user invoking the tool.
tool_parameters (dict[str, Any]): The parameters for the tool
Returns:
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
"""
title = tool_parameters.get('title', '')
if not title:
return self.create_text_message('Please provide a title for the ppt')
model = tool_parameters.get('model', 'aippt')
if not model:
return self.create_text_message('Please provide a model for the ppt')
outline = tool_parameters.get('outline', '')
# create task
task_id = self._create_task(
type=self._task_type_map['auto' if not outline else 'markdown'],
title=title,
content=outline,
user_id=user_id
)
# get suit
color = tool_parameters.get('color')
style = tool_parameters.get('style')
if color == '__default__':
color_id = ''
else:
color_id = int(color.split('-')[1])
if style == '__default__':
style_id = ''
else:
style_id = int(style.split('-')[1])
suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
# generate outline
if not outline:
self._generate_outline(
task_id=task_id,
model=model,
user_id=user_id
)
# generate content
self._generate_content(
task_id=task_id,
model=model,
user_id=user_id
)
# generate ppt
_, ppt_url = self._generate_ppt(
task_id=task_id,
suit_id=suit_id,
user_id=user_id
)
return self.create_text_message('''the ppt has been created successfully,'''
f'''the ppt url is {ppt_url}'''
'''please give the ppt url to user and direct user to download it.''')
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
"""
Create a task
:param type: the task type
:param title: the task title
:param content: the task content
:return: the task ID
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = post(
str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'),
headers=headers,
files={
'type': ('', str(type)),
'title': ('', title),
'content': ('', content)
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to create task: {response.get("msg")}')
return response.get('data', {}).get('id')
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \
self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline'
api_url %= {'task_id': task_id}
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = requests_get(
url=api_url,
headers=headers,
stream=True,
timeout=(10, 60)
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
outline = ''
for chunk in response.iter_lines(delimiter=b'\n\n'):
if not chunk:
continue
event = ''
lines = chunk.decode('utf-8').split('\n')
for line in lines:
if line.startswith('event:'):
event = line[6:]
elif line.startswith('data:'):
data = line[5:]
if event == 'message':
try:
data = json_loads(data)
outline += data.get('content', '')
except Exception as e:
pass
elif event == 'close':
break
elif event == 'error' or event == 'filter':
raise Exception(f'Failed to generate outline: {data}')
return outline
def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \
self._api_base_url / 'ai' / 'chat' / 'wx' / 'content'
api_url %= {'task_id': task_id}
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = requests_get(
url=api_url,
headers=headers,
stream=True,
timeout=(10, 60)
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
if model == 'aippt':
content = ''
for chunk in response.iter_lines(delimiter=b'\n\n'):
if not chunk:
continue
event = ''
lines = chunk.decode('utf-8').split('\n')
for line in lines:
if line.startswith('event:'):
event = line[6:]
elif line.startswith('data:'):
data = line[5:]
if event == 'message':
try:
data = json_loads(data)
content += data.get('content', '')
except Exception as e:
pass
elif event == 'close':
break
elif event == 'error' or event == 'filter':
raise Exception(f'Failed to generate content: {data}')
return content
elif model == 'wenxin':
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to generate content: {response.get("msg")}')
return response.get('data', '')
return ''
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
"""
Generate a ppt
:param task_id: the task ID
:param suit_id: the suit ID
:return: the cover url of the ppt and the ppt url
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = post(
str(self._api_base_url / 'design' / 'v2' / 'save'),
headers=headers,
data={
'task_id': task_id,
'template_id': suit_id
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
id = response.get('data', {}).get('id')
cover_url = response.get('data', {}).get('cover_url')
response = post(
str(self._api_base_url / 'download' / 'export' / 'file'),
headers=headers,
data={
'id': id,
'format': 'ppt',
'files_to_zip': False,
'edit': True
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
export_code = response.get('data')
if not export_code:
raise Exception('Failed to generate ppt, the export code is empty')
current_iteration = 0
while current_iteration < 50:
# get ppt url
response = post(
str(self._api_base_url / 'download' / 'export' / 'file' / 'result'),
headers=headers,
data={
'task_key': export_code
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
if response.get('msg') == '导出中':
current_iteration += 1
sleep(2)
continue
ppt_url = response.get('data', [])
if len(ppt_url) == 0:
raise Exception('Failed to generate ppt, the ppt url is empty')
return cover_url, ppt_url[0]
raise Exception('Failed to generate ppt, the export is timeout')
@classmethod
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
"""
Get API token
:param credentials: the credentials
:return: the API token
"""
access_key = credentials['aippt_access_key']
secret_key = credentials['aippt_secret_key']
cache_key = f'{access_key}#@#{user_id}'
with cls._api_token_cache_lock:
# clear expired tokens
now = time()
for key in list(cls._api_token_cache.keys()):
if cls._api_token_cache[key]['expire'] < now:
del cls._api_token_cache[key]
if cache_key in cls._api_token_cache:
return cls._api_token_cache[cache_key]['token']
# get token
headers = {
'x-api-key': access_key,
'x-timestamp': str(int(now)),
'x-signature': cls._calculate_sign(access_key, secret_key, int(now))
}
param = {
'uid': user_id,
'channel': ''
}
response = get(
str(cls._api_base_url / 'grant' / 'token'),
params=param,
headers=headers
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
token = response.get('data', {}).get('token')
expire = response.get('data', {}).get('time_expire')
with cls._api_token_cache_lock:
cls._api_token_cache[cache_key] = {
'token': token,
'expire': now + expire
}
return token
@classmethod
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
return b64encode(
hmac_new(
key=secret_key.encode('utf-8'),
msg=f'GET@/api/grant/token/@{timestamp}'.encode(),
digestmod=sha1
).digest()
).decode('utf-8')
@classmethod
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
"""
Get styles
"""
# check cache
with cls._style_cache_lock:
# clear expired styles
now = time()
for key in list(cls._style_cache.keys()):
if cls._style_cache[key]['expire'] < now:
del cls._style_cache[key]
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
if key in cls._style_cache:
return cls._style_cache[key]['colors'], cls._style_cache[key]['styles']
headers = {
'x-channel': '',
'x-api-key': credentials['aippt_access_key'],
'x-token': cls._get_api_token(credentials=credentials, user_id=user_id)
}
response = get(
str(cls._api_base_url / 'template_component' / 'suit' / 'select'),
headers=headers
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
colors = [{
'id': f'id-{item.get("id")}',
'name': item.get('name'),
'en_name': item.get('en_name', item.get('name')),
} for item in response.get('data', {}).get('colour') or []]
styles = [{
'id': f'id-{item.get("id")}',
'name': item.get('title'),
} for item in response.get('data', {}).get('suit_style') or []]
with cls._style_cache_lock:
cls._style_cache[key] = {
'colors': colors,
'styles': styles,
'expire': now + 60 * 60
}
return colors, styles
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
"""
Get styles
:param credentials: the credentials
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
"""
if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'):
raise Exception('Please provide aippt credentials')
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
def _get_suit(self, style_id: int, colour_id: int) -> int:
"""
Get suit
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__')
}
response = get(
str(self._api_base_url / 'template_component' / 'suit' / 'search'),
headers=headers,
params={
'style_id': style_id,
'colour_id': colour_id,
'page': 1,
'page_size': 1
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
if len(response.get('data', {}).get('list') or []) > 0:
return response.get('data', {}).get('list')[0].get('id')
raise Exception('Failed to get suit, the suit does not exist, please check the style and color')
def get_runtime_parameters(self) -> list[ToolParameter]:
"""
Get runtime parameters
Override this method to add runtime parameters to the tool.
"""
try:
colors, styles = self.get_styles(user_id='__dify_system__')
except Exception as e:
colors, styles = [
{'id': -1, 'name': '__default__', 'en_name': '__default__'}
], [
{'id': -1, 'name': '__default__', 'en_name': '__default__'}
]
return [
ToolParameter(
name='color',
label=I18nObject(zh_Hans='颜色', en_US='Color'),
human_description=I18nObject(zh_Hans='颜色', en_US='Color'),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=False,
default=colors[0]['id'],
options=[
ToolParameterOption(
value=color['id'],
label=I18nObject(zh_Hans=color['name'], en_US=color['en_name'])
) for color in colors
]
),
ToolParameter(
name='style',
label=I18nObject(zh_Hans='风格', en_US='Style'),
human_description=I18nObject(zh_Hans='风格', en_US='Style'),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=False,
default=styles[0]['id'],
options=[
ToolParameterOption(
value=style['id'],
label=I18nObject(zh_Hans=style['name'], en_US=style['name'])
) for style in styles
]
),
]
\ No newline at end of file
identity:
name: aippt
author: Dify
label:
en_US: AIPPT
zh_Hans: AIPPT
description:
human:
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you.
parameters:
- name: title
type: string
required: true
label:
en_US: Title
zh_Hans: 标题
human_description:
en_US: The title of the PPT.
zh_Hans: PPT的标题。
llm_description: The title of the PPT, which will be used to generate the PPT outline.
form: llm
- name: outline
type: string
required: false
label:
en_US: Outline
zh_Hans: 大纲
human_description:
en_US: The outline of the PPT
zh_Hans: PPT的大纲
llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have.
form: llm
- name: llm
type: select
required: true
label:
en_US: LLM model
zh_Hans: 生成大纲的LLM
options:
- value: aippt
label:
en_US: AIPPT default model
zh_Hans: AIPPT默认模型
- value: wenxin
label:
en_US: Wenxin ErnieBot
zh_Hans: 文心一言
default: aippt
human_description:
en_US: The LLM model used for generating PPT outline.
zh_Hans: 用于生成PPT大纲的LLM模型。
form: form
<svg viewBox="-29.62167543756803 0.1 574.391675437568 799.8100000000002" xmlns="http://www.w3.org/2000/svg" width="1888"
height="2500">
<linearGradient id="a" gradientUnits="userSpaceOnUse" x1="286.383" x2="542.057" y1="284.169" y2="569.112">
<stop offset="0" stop-color="#37bdff"/>
<stop offset=".25" stop-color="#26c6f4"/>
<stop offset=".5" stop-color="#15d0e9"/>
<stop offset=".75" stop-color="#3bd6df"/>
<stop offset="1" stop-color="#62dcd4"/>
</linearGradient>
<linearGradient id="b" gradientUnits="userSpaceOnUse" x1="108.979" x2="100.756" y1="675.98" y2="43.669">
<stop offset="0" stop-color="#1b48ef"/>
<stop offset=".5" stop-color="#2080f1"/>
<stop offset="1" stop-color="#26b8f4"/>
</linearGradient>
<linearGradient id="c" gradientUnits="userSpaceOnUse" x1="256.823" x2="875.632" y1="649.719" y2="649.719">
<stop offset="0" stop-color="#39d2ff"/>
<stop offset=".5" stop-color="#248ffa"/>
<stop offset="1" stop-color="#104cf5"/>
</linearGradient>
<linearGradient id="d" gradientUnits="userSpaceOnUse" x1="256.823" x2="875.632" y1="649.719" y2="649.719">
<stop offset="0" stop-color="#fff"/>
<stop offset="1"/>
</linearGradient>
<path d="M249.97 277.48c-.12.96-.12 2.05-.12 3.12 0 4.16.83 8.16 2.33 11.84l1.34 2.76 5.3 13.56 27.53 70.23 24.01 61.33c6.85 12.38 17.82 22.1 31.05 27.28l4.11 1.51c.16.05.43.05.65.11l65.81 22.63v.05l25.16 8.64 1.72.58c.06 0 .16.06.22.06 4.96 1.25 9.82 2.93 14.46 4.98 10.73 4.63 20.46 11.23 28.77 19.28 3.35 3.2 6.43 6.65 9.28 10.33a88.64 88.64 0 0 1 6.64 9.72c8.78 14.58 13.82 31.72 13.82 49.97 0 3.26-.16 6.41-.49 9.61-.11 1.41-.28 2.77-.49 4.12v.11c-.22 1.43-.49 2.91-.76 4.36-.28 1.41-.54 2.81-.86 4.21-.05.16-.11.33-.17.49-.3 1.42-.68 2.82-1.07 4.23-.35 1.33-.79 2.7-1.28 3.99a42.96 42.96 0 0 1-1.51 4.16c-.49 1.4-1.07 2.82-1.72 4.16-1.78 4.11-3.9 8.06-6.28 11.83a97.889 97.889 0 0 1-10.47 13.95c30.88-33.2 51.41-76.07 56.52-123.51.86-7.78 1.3-15.67 1.3-23.61 0-5.07-.22-10.09-.55-15.13-3.89-56.89-29.79-107.77-69.32-144.08-10.9-10.09-22.81-19.07-35.62-26.69l-24.2-12.37-122.63-62.93a30.15 30.15 0 0 0-11.93-2.44c-15.88 0-28.99 12.11-30.55 27.56z"
fill="#7f7f7f"/>
<path d="M249.97 277.48c-.12.96-.12 2.05-.12 3.12 0 4.16.83 8.16 2.33 11.84l1.34 2.76 5.3 13.56 27.53 70.23 24.01 61.33c6.85 12.38 17.82 22.1 31.05 27.28l4.11 1.51c.16.05.43.05.65.11l65.81 22.63v.05l25.16 8.64 1.72.58c.06 0 .16.06.22.06 4.96 1.25 9.82 2.93 14.46 4.98 10.73 4.63 20.46 11.23 28.77 19.28 3.35 3.2 6.43 6.65 9.28 10.33a88.64 88.64 0 0 1 6.64 9.72c8.78 14.58 13.82 31.72 13.82 49.97 0 3.26-.16 6.41-.49 9.61-.11 1.41-.28 2.77-.49 4.12v.11c-.22 1.43-.49 2.91-.76 4.36-.28 1.41-.54 2.81-.86 4.21-.05.16-.11.33-.17.49-.3 1.42-.68 2.82-1.07 4.23-.35 1.33-.79 2.7-1.28 3.99a42.96 42.96 0 0 1-1.51 4.16c-.49 1.4-1.07 2.82-1.72 4.16-1.78 4.11-3.9 8.06-6.28 11.83a97.889 97.889 0 0 1-10.47 13.95c30.88-33.2 51.41-76.07 56.52-123.51.86-7.78 1.3-15.67 1.3-23.61 0-5.07-.22-10.09-.55-15.13-3.89-56.89-29.79-107.77-69.32-144.08-10.9-10.09-22.81-19.07-35.62-26.69l-24.2-12.37-122.63-62.93a30.15 30.15 0 0 0-11.93-2.44c-15.88 0-28.99 12.11-30.55 27.56z"
fill="url(#a)"/>
<path d="M31.62.1C14.17.41.16 14.69.16 32.15v559.06c.07 3.9.29 7.75.57 11.66.25 2.06.52 4.2.9 6.28 7.97 44.87 47.01 78.92 94.15 78.92 16.53 0 32.03-4.21 45.59-11.53.08-.06.22-.14.29-.14l4.88-2.95 19.78-11.64 25.16-14.93.06-496.73c0-33.01-16.52-62.11-41.81-79.4-.6-.36-1.18-.74-1.71-1.17L50.12 5.56C45.16 2.28 39.18.22 32.77.1z"
fill="#7f7f7f"/>
<path d="M31.62.1C14.17.41.16 14.69.16 32.15v559.06c.07 3.9.29 7.75.57 11.66.25 2.06.52 4.2.9 6.28 7.97 44.87 47.01 78.92 94.15 78.92 16.53 0 32.03-4.21 45.59-11.53.08-.06.22-.14.29-.14l4.88-2.95 19.78-11.64 25.16-14.93.06-496.73c0-33.01-16.52-62.11-41.81-79.4-.6-.36-1.18-.74-1.71-1.17L50.12 5.56C45.16 2.28 39.18.22 32.77.1z"
fill="url(#b)"/>
<path d="M419.81 510.84L194.72 644.26l-3.24 1.95v.71l-25.16 14.9-19.77 11.67-4.85 2.93-.33.16c-13.53 7.35-29.04 11.51-45.56 11.51-47.13 0-86.22-34.03-94.16-78.92 3.77 32.84 14.96 63.41 31.84 90.04 34.76 54.87 93.54 93.04 161.54 99.67h41.58c36.78-3.84 67.49-18.57 99.77-38.46l49.64-30.36c22.36-14.33 83.05-49.58 100.93-69.36 3.89-4.33 7.4-8.97 10.47-13.94 2.38-3.78 4.5-7.73 6.28-11.84.6-1.4 1.17-2.76 1.72-4.15.52-1.38 1.01-2.77 1.51-4.18.93-2.7 1.67-5.41 2.38-8.2.36-1.59.69-3.16 1.02-4.72 1.08-5.89 1.67-11.94 1.67-18.21 0-18.25-5.04-35.39-13.77-49.95-2-3.4-4.2-6.65-6.64-9.72-2.85-3.7-5.93-7.13-9.28-10.33-8.31-8.05-18.01-14.65-28.77-19.29-4.64-2.05-9.48-3.74-14.46-4.97-.06 0-.16-.06-.22-.06l-1.72-.58z"
fill="#7f7f7f"/>
<path d="M419.81 510.84L194.72 644.26l-3.24 1.95v.71l-25.16 14.9-19.77 11.67-4.85 2.93-.33.16c-13.53 7.35-29.04 11.51-45.56 11.51-47.13 0-86.22-34.03-94.16-78.92 3.77 32.84 14.96 63.41 31.84 90.04 34.76 54.87 93.54 93.04 161.54 99.67h41.58c36.78-3.84 67.49-18.57 99.77-38.46l49.64-30.36c22.36-14.33 83.05-49.58 100.93-69.36 3.89-4.33 7.4-8.97 10.47-13.94 2.38-3.78 4.5-7.73 6.28-11.84.6-1.4 1.17-2.76 1.72-4.15.52-1.38 1.01-2.77 1.51-4.18.93-2.7 1.67-5.41 2.38-8.2.36-1.59.69-3.16 1.02-4.72 1.08-5.89 1.67-11.94 1.67-18.21 0-18.25-5.04-35.39-13.77-49.95-2-3.4-4.2-6.65-6.64-9.72-2.85-3.7-5.93-7.13-9.28-10.33-8.31-8.05-18.01-14.65-28.77-19.29-4.64-2.05-9.48-3.74-14.46-4.97-.06 0-.16-.06-.22-.06l-1.72-.58z"
fill="url(#c)"/>
<path d="M512 595.46c0 6.27-.59 12.33-1.68 18.22-.32 1.56-.65 3.12-1.02 4.7-.7 2.8-1.44 5.51-2.37 8.22-.49 1.4-.99 2.8-1.51 4.16-.54 1.4-1.12 2.76-1.73 4.16a87.873 87.873 0 0 1-6.26 11.83 96.567 96.567 0 0 1-10.48 13.94c-17.88 19.79-78.57 55.04-100.93 69.37l-49.64 30.36c-36.39 22.42-70.77 38.29-114.13 39.38-2.05.06-4.06.11-6.05.11-2.8 0-5.56-.05-8.33-.16-73.42-2.8-137.45-42.25-174.38-100.54a213.368 213.368 0 0 1-31.84-90.04c7.94 44.89 47.03 78.92 94.16 78.92 16.52 0 32.03-4.17 45.56-11.51l.33-.17 4.85-2.92 19.77-11.67 25.16-14.9v-.71l3.24-1.95 225.09-133.43 17.33-10.27 1.72.58c.05 0 .16.06.22.06 4.98 1.23 9.83 2.92 14.46 4.97 10.76 4.64 20.45 11.24 28.77 19.29a92.13 92.13 0 0 1 9.28 10.33c2.44 3.07 4.64 6.32 6.64 9.72 8.73 14.56 13.77 31.7 13.77 49.95z"
fill="#7f7f7f" opacity=".15"/>
<path d="M512 595.46c0 6.27-.59 12.33-1.68 18.22-.32 1.56-.65 3.12-1.02 4.7-.7 2.8-1.44 5.51-2.37 8.22-.49 1.4-.99 2.8-1.51 4.16-.54 1.4-1.12 2.76-1.73 4.16a87.873 87.873 0 0 1-6.26 11.83 96.567 96.567 0 0 1-10.48 13.94c-17.88 19.79-78.57 55.04-100.93 69.37l-49.64 30.36c-36.39 22.42-70.77 38.29-114.13 39.38-2.05.06-4.06.11-6.05.11-2.8 0-5.56-.05-8.33-.16-73.42-2.8-137.45-42.25-174.38-100.54a213.368 213.368 0 0 1-31.84-90.04c7.94 44.89 47.03 78.92 94.16 78.92 16.52 0 32.03-4.17 45.56-11.51l.33-.17 4.85-2.92 19.77-11.67 25.16-14.9v-.71l3.24-1.95 225.09-133.43 17.33-10.27 1.72.58c.05 0 .16.06.22.06 4.98 1.23 9.83 2.92 14.46 4.97 10.76 4.64 20.45 11.24 28.77 19.29a92.13 92.13 0 0 1 9.28 10.33c2.44 3.07 4.64 6.32 6.64 9.72 8.73 14.56 13.77 31.7 13.77 49.95z"
fill="url(#d)" opacity=".15"/>
</svg>
\ No newline at end of file
...@@ -9,7 +9,7 @@ identity: ...@@ -9,7 +9,7 @@ identity:
en_US: Bing Search en_US: Bing Search
zh_Hans: Bing 搜索 zh_Hans: Bing 搜索
pt_BR: Bing Search pt_BR: Bing Search
icon: icon.png icon: icon.svg
credentials_for_provider: credentials_for_provider:
subscription_key: subscription_key:
type: secret-input type: secret-input
......
<svg version="1.1" xmlns="http://www.w3.org/2000/svg" height="1024" width="1024" viewBox="0 0 1024 1024"><path d="M699.052008 894.366428l-253.855434-159.289336-115.571175 114.768472 44.746184-157.983686L887.097839 163.880236 312.470884 651.791088l-205.584597-128.995835L887.097839 163.876212 699.056031 894.364417zM348.039293 321.886051h122.859882L348.039293 374.779976V321.886051z m675.960707 0v-75.373642C1024 109.706813 917.443646 0 782.927466 0H698.090373v224.076951l-80.471512 34.642986V0H242.63167C108.113477 0-0.002012 109.706813-0.002012 246.51442V321.886051h195.143419v80.471513H0v376.276746C0 915.439906 108.115489 1024 242.63167 1024h374.985179v-145.906923l80.471512 51.270412V1024h84.837093C917.445658 1024 1024 915.439906 1024 778.63431V402.357564h-172.255308l20.717391-80.471513H1024z" fill="#0093FD"></path></svg>
\ No newline at end of file
...@@ -9,7 +9,7 @@ identity: ...@@ -9,7 +9,7 @@ identity:
en_US: Autonavi Open Platform service toolkit. en_US: Autonavi Open Platform service toolkit.
zh_Hans: 高德开放平台服务工具包。 zh_Hans: 高德开放平台服务工具包。
pt_BR: Kit de ferramentas de serviço Autonavi Open Platform. pt_BR: Kit de ferramentas de serviço Autonavi Open Platform.
icon: icon.png icon: icon.svg
credentials_for_provider: credentials_for_provider:
api_key: api_key:
type: secret-input type: secret-input
......
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg width="800px" height="800px" viewBox="0 0 20 20" version="1.1" xmlns="http://www.w3.org/2000/svg"
xmlns:xlink="http://www.w3.org/1999/xlink">
<title>github [#142]</title>
<desc>Created with Sketch.</desc>
<defs>
</defs>
<g id="Page-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="Dribbble-Light-Preview" transform="translate(-140.000000, -7559.000000)" fill="#000000">
<g id="icons" transform="translate(56.000000, 160.000000)">
<path d="M94,7399 C99.523,7399 104,7403.59 104,7409.253 C104,7413.782 101.138,7417.624 97.167,7418.981 C96.66,7419.082 96.48,7418.762 96.48,7418.489 C96.48,7418.151 96.492,7417.047 96.492,7415.675 C96.492,7414.719 96.172,7414.095 95.813,7413.777 C98.04,7413.523 100.38,7412.656 100.38,7408.718 C100.38,7407.598 99.992,7406.684 99.35,7405.966 C99.454,7405.707 99.797,7404.664 99.252,7403.252 C99.252,7403.252 98.414,7402.977 96.505,7404.303 C95.706,7404.076 94.85,7403.962 94,7403.958 C93.15,7403.962 92.295,7404.076 91.497,7404.303 C89.586,7402.977 88.746,7403.252 88.746,7403.252 C88.203,7404.664 88.546,7405.707 88.649,7405.966 C88.01,7406.684 87.619,7407.598 87.619,7408.718 C87.619,7412.646 89.954,7413.526 92.175,7413.785 C91.889,7414.041 91.63,7414.493 91.54,7415.156 C90.97,7415.418 89.522,7415.871 88.63,7414.304 C88.63,7414.304 88.101,7413.319 87.097,7413.247 C87.097,7413.247 86.122,7413.234 87.029,7413.87 C87.029,7413.87 87.684,7414.185 88.139,7415.37 C88.139,7415.37 88.726,7417.2 91.508,7416.58 C91.513,7417.437 91.522,7418.245 91.522,7418.489 C91.522,7418.76 91.338,7419.077 90.839,7418.982 C86.865,7417.627 84,7413.783 84,7409.253 C84,7403.59 88.478,7399 94,7399"
id="github-[#142]">
</path>
</g>
</g>
</g>
</svg>
\ No newline at end of file
...@@ -9,7 +9,7 @@ identity: ...@@ -9,7 +9,7 @@ identity:
en_US: GitHub is an online software source code hosting service. en_US: GitHub is an online software source code hosting service.
zh_Hans: GitHub是一个在线软件源代码托管服务平台。 zh_Hans: GitHub是一个在线软件源代码托管服务平台。
pt_BR: GitHub é uma plataforma online para serviços de hospedagem de código fonte de software. pt_BR: GitHub é uma plataforma online para serviços de hospedagem de código fonte de software.
icon: icon.png icon: icon.svg
credentials_for_provider: credentials_for_provider:
access_tokens: access_tokens:
type: secret-input type: secret-input
......
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg width="800px" height="800px" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"> <svg width="800px" height="800px" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
<g> <g>
<path fill="none" d="M0 0h24v24H0z"/> <path fill="none" d="M0 0h24v24H0z"/>
......
...@@ -2,14 +2,23 @@ import io ...@@ -2,14 +2,23 @@ import io
import logging import logging
from typing import Any, Union from typing import Any, Union
import qrcode from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q
from qrcode.image.base import BaseImage
from qrcode.image.pure import PyPNGImage from qrcode.image.pure import PyPNGImage
from qrcode.main import QRCode
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
class QRCodeGeneratorTool(BuiltinTool): class QRCodeGeneratorTool(BuiltinTool):
error_correction_levels = {
'L': ERROR_CORRECT_L, # <=7%
'M': ERROR_CORRECT_M, # <=15%
'Q': ERROR_CORRECT_Q, # <=25%
'H': ERROR_CORRECT_H, # <=30%
}
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_parameters: dict[str, Any], tool_parameters: dict[str, Any],
...@@ -17,19 +26,44 @@ class QRCodeGeneratorTool(BuiltinTool): ...@@ -17,19 +26,44 @@ class QRCodeGeneratorTool(BuiltinTool):
""" """
invoke tools invoke tools
""" """
# get expression # get text content
content = tool_parameters.get('content', '') content = tool_parameters.get('content', '')
if not content: if not content:
return self.create_text_message('Invalid parameter content') return self.create_text_message('Invalid parameter content')
# get border size
border = tool_parameters.get('border', 0)
if border < 0 or border > 100:
return self.create_text_message('Invalid parameter border')
# get error_correction
error_correction = tool_parameters.get('error_correction', '')
if error_correction not in self.error_correction_levels.keys():
return self.create_text_message('Invalid parameter error_correction')
try: try:
img = qrcode.make(data=content, image_factory=PyPNGImage) image = self._generate_qrcode(content, border, error_correction)
byte_stream = io.BytesIO() image_bytes = self._image_to_byte_array(image)
img.save(byte_stream) return self.create_blob_message(blob=image_bytes,
byte_array = byte_stream.getvalue()
return self.create_blob_message(blob=byte_array,
meta={'mime_type': 'image/png'}, meta={'mime_type': 'image/png'},
save_as=self.VARIABLE_KEY.IMAGE.value) save_as=self.VARIABLE_KEY.IMAGE.value)
except Exception: except Exception:
logging.exception(f'Failed to generate QR code for content: {content}') logging.exception(f'Failed to generate QR code for content: {content}')
return self.create_text_message('Failed to generate QR code') return self.create_text_message('Failed to generate QR code')
def _generate_qrcode(self, content: str, border: int, error_correction: str) -> BaseImage:
qr = QRCode(
image_factory=PyPNGImage,
error_correction=self.error_correction_levels.get(error_correction),
border=border,
)
qr.add_data(data=content)
qr.make(fit=True)
img = qr.make_image()
return img
@staticmethod
def _image_to_byte_array(image: BaseImage) -> bytes:
byte_stream = io.BytesIO()
image.save(byte_stream)
return byte_stream.getvalue()
...@@ -2,9 +2,9 @@ identity: ...@@ -2,9 +2,9 @@ identity:
name: qrcode_generator name: qrcode_generator
author: Bowen Liang author: Bowen Liang
label: label:
en_US: QR Code Generator en_US: Generate QR Code
zh_Hans: 二维码生成器 zh_Hans: 生成二维码
pt_BR: QR Code Generator pt_BR: Generate QR Code
description: description:
human: human:
en_US: A tool for generating QR code image en_US: A tool for generating QR code image
...@@ -24,3 +24,53 @@ parameters: ...@@ -24,3 +24,53 @@ parameters:
zh_Hans: 二维码文本内容 zh_Hans: 二维码文本内容
pt_BR: 二维码文本内容 pt_BR: 二维码文本内容
form: llm form: llm
- name: error_correction
type: select
required: true
default: M
label:
en_US: Error Correction
zh_Hans: 容错等级
pt_BR: Error Correction
human_description:
en_US: Error Correction in L, M, Q or H, from low to high, the bigger size of generated QR code with the better error correction effect
zh_Hans: 容错等级,可设置为低、中、偏高或高,从低到高,生成的二维码越大且容错效果越好
pt_BR: Error Correction in L, M, Q or H, from low to high, the bigger size of generated QR code with the better error correction effect
options:
- value: L
label:
en_US: Low
zh_Hans:
pt_BR: Low
- value: M
label:
en_US: Medium
zh_Hans:
pt_BR: Medium
- value: Q
label:
en_US: Quartile
zh_Hans: 偏高
pt_BR: Quartile
- value: H
label:
en_US: High
zh_Hans:
pt_BR: High
form: form
- name: border
type: number
required: true
default: 2
min: 0
max: 100
label:
en_US: border size
zh_Hans: 边框粗细
pt_BR: border size
human_description:
en_US: border size(default to 2)
zh_Hans: 边框粗细的格数(默认为2)
pt_BR: border size(default to 2)
llm: border size, default to 2
form: form
...@@ -2,11 +2,11 @@ import io ...@@ -2,11 +2,11 @@ import io
import json import json
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from copy import deepcopy from copy import deepcopy
from os.path import join
from typing import Any, Union from typing import Any, Union
from httpx import get, post from httpx import get, post
from PIL import Image from PIL import Image
from yarl import URL
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
...@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
# set model # set model
try: try:
url = join(base_url, 'sdapi/v1/options') url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
response = post(url, data=json.dumps({ response = post(url, data=json.dumps({
'sd_model_checkpoint': model 'sd_model_checkpoint': model
})) }))
...@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool): ...@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool):
if not model: if not model:
raise ToolProviderCredentialValidationError('Please input model') raise ToolProviderCredentialValidationError('Please input model')
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120) api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
if response.status_code != 200: response = get(url=api_url, timeout=10)
if response.status_code == 404:
# try draw a picture
self._invoke(
user_id='test',
tool_parameters={
'prompt': 'a cat',
'width': 1024,
'height': 1024,
'steps': 1,
'lora': '',
}
)
elif response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to get models') raise ToolProviderCredentialValidationError('Failed to get models')
else: else:
models = [d['model_name'] for d in response.json()] models = [d['model_name'] for d in response.json()]
...@@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool): ...@@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e: except Exception as e:
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
def get_sd_models(self) -> list[str]:
"""
get sd models
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
if response.status_code != 200:
return []
else:
return [d['model_name'] for d in response.json()]
except Exception as e:
return []
def img2img(self, base_url: str, lora: str, image_binary: bytes, def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str, prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \ width: int, height: int, steps: int) \
...@@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['prompt'] = prompt draw_options['prompt'] = prompt
try: try:
url = join(base_url, 'sdapi/v1/img2img') url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
response = post(url, data=json.dumps(draw_options), timeout=120) response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200: if response.status_code != 200:
return self.create_text_message('Failed to generate image') return self.create_text_message('Failed to generate image')
...@@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['negative_prompt'] = negative_prompt draw_options['negative_prompt'] = negative_prompt
try: try:
url = join(base_url, 'sdapi/v1/txt2img') url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
response = post(url, data=json.dumps(draw_options), timeout=120) response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200: if response.status_code != 200:
return self.create_text_message('Failed to generate image') return self.create_text_message('Failed to generate image')
...@@ -269,5 +299,29 @@ class StableDiffusionTool(BuiltinTool): ...@@ -269,5 +299,29 @@ class StableDiffusionTool(BuiltinTool):
label=I18nObject(en_US=i.name, zh_Hans=i.name) label=I18nObject(en_US=i.name, zh_Hans=i.name)
) for i in self.list_default_image_variables()]) ) for i in self.list_default_image_variables()])
) )
if self.runtime.credentials:
try:
models = self.get_sd_models()
if len(models) != 0:
parameters.append(
ToolParameter(name='model',
label=I18nObject(en_US='Model', zh_Hans='Model'),
human_description=I18nObject(
en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
required=True,
default=models[0],
options=[ToolParameterOption(
value=i,
label=I18nObject(en_US=i, zh_Hans=i)
) for i in models])
)
except:
pass
return parameters return parameters
from typing import Any from typing import Any
from twilio.base.exceptions import TwilioRestException
from twilio.rest import Client
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
...@@ -7,19 +10,20 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl ...@@ -7,19 +10,20 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class TwilioProvider(BuiltinToolProviderController): class TwilioProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None: def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try: try:
""" # Extract credentials
SendMessageTool().fork_tool_runtime( account_sid = credentials["account_sid"]
meta={ auth_token = credentials["auth_token"]
"credentials": credentials, from_number = credentials["from_number"]
}
).invoke( # Initialize twilio client
user_id="", client = Client(account_sid, auth_token)
tool_parameters={
"message": "Credential validation message", # fetch account
"to_number": "+14846624384", client.api.accounts(account_sid).fetch()
},
) except TwilioRestException as e:
""" raise ToolProviderCredentialValidationError(f"Twilio API error: {e.msg}") from e
pass except KeyError as e:
raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e
except Exception as e: except Exception as e:
raise ToolProviderCredentialValidationError(str(e)) raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
...@@ -4,9 +4,10 @@ import httpx ...@@ -4,9 +4,10 @@ import httpx
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.utils.uuid_utils import is_valid_uuid
class WecomRepositoriesTool(BuiltinTool): class WecomGroupBotTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any] def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
""" """
...@@ -17,8 +18,9 @@ class WecomRepositoriesTool(BuiltinTool): ...@@ -17,8 +18,9 @@ class WecomRepositoriesTool(BuiltinTool):
return self.create_text_message('Invalid parameter content') return self.create_text_message('Invalid parameter content')
hook_key = tool_parameters.get('hook_key', '') hook_key = tool_parameters.get('hook_key', '')
if not hook_key: if not is_valid_uuid(hook_key):
return self.create_text_message('Invalid parameter hook_key') return self.create_text_message(
f'Invalid parameter hook_key ${hook_key}, not a valid UUID')
msgtype = 'text' msgtype = 'text'
api_url = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send' api_url = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send'
......
...@@ -14,7 +14,7 @@ description: ...@@ -14,7 +14,7 @@ description:
llm: A tool for sending messages to a chat group on Wecom(企业微信) . llm: A tool for sending messages to a chat group on Wecom(企业微信) .
parameters: parameters:
- name: hook_key - name: hook_key
type: string type: secret-input
required: true required: true
label: label:
en_US: Wecom Group bot webhook key en_US: Wecom Group bot webhook key
......
from core.tools.provider.builtin.wecom.tools.wecom_group_bot import WecomRepositoriesTool from core.tools.provider.builtin.wecom.tools.wecom_group_bot import WecomGroupBotTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class WecomProvider(BuiltinToolProviderController): class WecomProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None: def _validate_credentials(self, credentials: dict) -> None:
WecomRepositoriesTool() WecomGroupBotTool()
pass pass
<?xml version="1.0" encoding="UTF-8"?>
<svg width="800px" height="800px" viewBox="0 -38 256 256" version="1.1" xmlns="http://www.w3.org/2000/svg"
xmlns:xlink="http://www.w3.org/1999/xlink" preserveAspectRatio="xMidYMid">
<g>
<path d="M250.346231,28.0746923 C247.358133,17.0320558 238.732098,8.40602109 227.689461,5.41792308 C207.823743,0 127.868333,0 127.868333,0 C127.868333,0 47.9129229,0.164179487 28.0472049,5.58210256 C17.0045684,8.57020058 8.37853373,17.1962353 5.39043571,28.2388718 C-0.618533519,63.5374615 -2.94988224,117.322662 5.5546152,151.209308 C8.54271322,162.251944 17.1687479,170.877979 28.2113844,173.866077 C48.0771024,179.284 128.032513,179.284 128.032513,179.284 C128.032513,179.284 207.987923,179.284 227.853641,173.866077 C238.896277,170.877979 247.522312,162.251944 250.51041,151.209308 C256.847738,115.861464 258.801474,62.1091 250.346231,28.0746923 Z"
fill="#FF0000">
</path>
<polygon fill="#FFFFFF" points="102.420513 128.06 168.749025 89.642 102.420513 51.224">
</polygon>
</g>
</svg>
\ No newline at end of file
...@@ -9,7 +9,7 @@ identity: ...@@ -9,7 +9,7 @@ identity:
en_US: YouTube en_US: YouTube
zh_Hans: YouTube(油管)是全球最大的视频分享网站,用户可以在上面上传、观看和分享视频。 zh_Hans: YouTube(油管)是全球最大的视频分享网站,用户可以在上面上传、观看和分享视频。
pt_BR: YouTube é o maior site de compartilhamento de vídeos do mundo, onde os usuários podem fazer upload, assistir e compartilhar vídeos. pt_BR: YouTube é o maior site de compartilhamento de vídeos do mundo, onde os usuários podem fazer upload, assistir e compartilhar vídeos.
icon: icon.png icon: icon.svg
credentials_for_provider: credentials_for_provider:
google_api_key: google_api_key:
type: secret-input type: secret-input
......
from copy import deepcopy
from typing import Any
from core.entities.model_entities import ModelStatus
from core.errors.error import ProviderTokenNotInitError
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ModelToolPropertyKey,
ToolDescription,
ToolIdentity,
ToolParameter,
ToolProviderCredentials,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.model_tool import ModelTool
from core.tools.tool.tool import Tool
from core.tools.utils.configuration import ModelToolConfigurationManager
class ModelToolProviderController(ToolProviderController):
configuration: ProviderConfiguration = None
is_active: bool = False
def __init__(self, configuration: ProviderConfiguration = None, **kwargs):
"""
init the provider
:param data: the data of the provider
"""
super().__init__(**kwargs)
self.configuration = configuration
@staticmethod
def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController':
"""
init the provider from db
:param configuration: the configuration of the provider
"""
# check if all models are active
if configuration is None:
return None
is_active = True
models = configuration.get_provider_models()
for model in models:
if model.status != ModelStatus.ACTIVE:
is_active = False
break
# get the provider configuration
model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
if model_tool_configuration is None:
raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
# override the configuration
if model_tool_configuration.label:
label = deepcopy(model_tool_configuration.label)
if label.en_US:
label.en_US = model_tool_configuration.label.en_US
if label.zh_Hans:
label.zh_Hans = model_tool_configuration.label.zh_Hans
else:
label = I18nObject(
en_US=configuration.provider.label.en_US,
zh_Hans=configuration.provider.label.zh_Hans
)
return ModelToolProviderController(
is_active=is_active,
identity=ToolProviderIdentity(
author='Dify',
name=configuration.provider.provider,
description=I18nObject(
zh_Hans=f'{label.zh_Hans} 模型能力提供商',
en_US=f'{label.en_US} model capability provider'
),
label=I18nObject(
zh_Hans=label.zh_Hans,
en_US=label.en_US
),
icon=configuration.provider.icon_small.en_US,
),
configuration=configuration,
credentials_schema={},
)
@staticmethod
def is_configuration_valid(configuration: ProviderConfiguration) -> bool:
"""
check if the configuration has a model can be used as a tool
"""
models = configuration.get_provider_models()
for model in models:
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
return True
return False
def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
provider_manager = ProviderManager()
if self.configuration is None:
configurations = provider_manager.get_configurations(tenant_id=tenant_id).values()
self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
# get all tools
tools: list[ModelTool] = []
# get all models
if not self.configuration:
return tools
configuration = self.configuration
provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
if provider_configuration is None:
raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
for model in configuration.get_provider_models():
model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model)
if model_configuration is None:
continue
if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
provider_instance = configuration.get_provider_instance()
model_type_instance = provider_instance.get_model_instance(model.model_type)
provider_model_bundle = ProviderModelBundle(
configuration=configuration,
provider_instance=provider_instance,
model_type_instance=model_type_instance
)
try:
model_instance = ModelInstance(provider_model_bundle, model.model)
except ProviderTokenNotInitError:
model_instance = None
tools.append(ModelTool(
identity=ToolIdentity(
author='Dify',
name=model.model,
label=model_configuration.label,
),
parameters=[
ToolParameter(
name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value,
label=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
required=True,
default=Tool.VARIABLE_KEY.IMAGE.value
)
],
description=ToolDescription(
human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'),
llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.',
),
is_team_authorization=model.status == ModelStatus.ACTIVE,
tool_type=ModelTool.ModelToolType.VISION,
model_instance=model_instance,
model=model.model,
))
self.tools = tools
return tools
def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
return {}
def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
return self._get_model_tools(tenant_id=tenant_id)
def get_tool(self, tool_name: str) -> ModelTool:
"""
get tool by name
:param tool_name: the name of the tool
:return: the tool
"""
if self.tools is None:
self.get_tools(user_id='', tenant_id=self.configuration.tenant_id)
for tool in self.tools:
if tool.identity.name == tool_name:
return tool
raise ValueError(f'tool {tool_name} not found')
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
"""
returns the parameters of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
if tool is None:
raise ToolNotFoundError(f'tool {tool_name} not found')
return tool.parameters
@property
def app_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.MODEL
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
pass
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
pass
\ No newline at end of file
...@@ -12,6 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage ...@@ -12,6 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
API_TOOL_DEFAULT_TIMEOUT = (10, 60)
class ApiTool(Tool): class ApiTool(Tool):
api_bundle: ApiBasedToolBundle api_bundle: ApiBasedToolBundle
...@@ -211,19 +212,19 @@ class ApiTool(Tool): ...@@ -211,19 +212,19 @@ class ApiTool(Tool):
# do http request # do http request
if method == 'get': if method == 'get':
response = ssrf_proxy.get(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True) response = ssrf_proxy.get(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'post': elif method == 'post':
response = ssrf_proxy.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True) response = ssrf_proxy.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'put': elif method == 'put':
response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True) response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'delete': elif method == 'delete':
response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, allow_redirects=True) response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, allow_redirects=True)
elif method == 'patch': elif method == 'patch':
response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True) response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'head': elif method == 'head':
response = ssrf_proxy.head(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True) response = ssrf_proxy.head(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'options': elif method == 'options':
response = ssrf_proxy.options(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True) response = ssrf_proxy.options(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
else: else:
raise ValueError(f'Invalid http method {method}') raise ValueError(f'Invalid http method {method}')
......
from base64 import b64encode
from enum import Enum
from typing import Any, cast
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessageContent,
PromptMessageContentType,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage
from core.tools.tool.tool import Tool
VISION_PROMPT = """## Image Recognition Task
### Task Description
I require a powerful vision language model for an image recognition task. The model should be capable of extracting various details from the images, including but not limited to text content, layout distribution, color distribution, main subjects, and emotional expressions.
### Specific Requirements
1. **Text Content Extraction:** Ensure that the model accurately recognizes and extracts text content from the images, regardless of text size, font, or color.
2. **Layout Distribution Analysis:** The model should analyze the layout structure of the images, capturing the relationships between various elements and providing detailed information about the image layout.
3. **Color Distribution Analysis:** Extract information about color distribution in the images, including primary colors, color combinations, and other relevant details.
4. **Main Subject Recognition:** The model should accurately identify the main subjects in the images and provide detailed descriptions of these subjects.
5. **Emotional Expression Analysis:** Analyze and describe the emotions or expressions conveyed in the images based on facial expressions, postures, and other relevant features.
### Additional Considerations
- Ensure that the extracted information is as comprehensive and accurate as possible.
- For each task, provide confidence scores or relevance scores for the model outputs to assess the reliability of the results.
- If necessary, pose specific questions for different tasks to guide the model in better understanding the images and providing relevant information."""
class ModelTool(Tool):
class ModelToolType(Enum):
"""
the type of the model tool
"""
VISION = 'vision'
model_configuration: dict[str, Any] = None
tool_type: ModelToolType
def __init__(self, model_instance: ModelInstance = None, model: str = None,
tool_type: ModelToolType = ModelToolType.VISION,
properties: dict[ModelToolPropertyKey, Any] = None,
**kwargs):
"""
init the tool
"""
kwargs['model_configuration'] = {
'model_instance': model_instance,
'model': model,
'properties': properties
}
kwargs['tool_type'] = tool_type
super().__init__(**kwargs)
"""
Model tool
"""
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
identity=self.identity.copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
model_instance=self.model_configuration['model_instance'],
model=self.model_configuration['model'],
tool_type=self.tool_type,
runtime=Tool.Runtime(**meta)
)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> None:
"""
validate the credentials for Model tool
"""
pass
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
"""
model_instance = self.model_configuration['model_instance']
if not model_instance:
return self.create_text_message('the tool is not configured correctly')
if self.tool_type == ModelTool.ModelToolType.VISION:
return self._invoke_llm_vision(user_id, tool_parameters)
else:
return self.create_text_message('the tool is not configured correctly')
def _invoke_llm_vision(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
# get image
image_parameter_name = self.model_configuration['properties'].get(ModelToolPropertyKey.IMAGE_PARAMETER_NAME, 'image_id')
image_id = tool_parameters.pop(image_parameter_name, '')
if not image_id:
image = self.get_default_image_variable()
if not image:
return self.create_text_message('Please upload an image or input image_id')
else:
image = self.get_variable(image_id)
if not image:
image = self.get_default_image_variable()
if not image:
return self.create_text_message('Please upload an image or input image_id')
if not image:
return self.create_text_message('Please upload an image or input image_id')
# get image
image = self.get_variable_file(image.name)
if not image:
return self.create_text_message('Failed to get image')
# organize prompt messages
prompt_messages = [
SystemPromptMessage(
content=VISION_PROMPT
),
UserPromptMessage(
content=[
PromptMessageContent(
type=PromptMessageContentType.TEXT,
data='Recognize the image and extract the information from the image.'
),
PromptMessageContent(
type=PromptMessageContentType.IMAGE,
data=f'data:image/png;base64,{b64encode(image).decode("utf-8")}'
)
]
)
]
llm_instance = cast(LargeLanguageModel, self.model_configuration['model_instance'])
result: LLMResult = llm_instance.invoke(
model=self.model_configuration['model'],
credentials=self.runtime.credentials,
prompt_messages=prompt_messages,
model_parameters=tool_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
)
if not result:
return self.create_text_message('Failed to extract information from the image')
# get result
content = result.message.content
if not content:
return self.create_text_message('Failed to extract information from the image')
return self.create_text_message(content)
\ No newline at end of file
...@@ -266,6 +266,40 @@ class Tool(BaseModel, ABC): ...@@ -266,6 +266,40 @@ class Tool(BaseModel, ABC):
""" """
return self.parameters return self.parameters
def get_all_runtime_parameters(self) -> list[ToolParameter]:
"""
get all runtime parameters
:return: all runtime parameters
"""
parameters = self.parameters or []
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or []
user_parameters = user_parameters.copy()
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
return parameters
def is_tool_available(self) -> bool: def is_tool_available(self) -> bool:
""" """
check if the tool is available check if the tool is available
......
...@@ -6,20 +6,33 @@ from os import listdir, path ...@@ -6,20 +6,33 @@ from os import listdir, path
from typing import Any, Union from typing import Any, Union
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.entities.application_entities import AgentToolEntity
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.provider_manager import ProviderManager
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constant import DEFAULT_PROVIDERS from core.tools.entities.constant import DEFAULT_PROVIDERS
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeMessage,
ToolParameter,
ToolProviderCredentials,
)
from core.tools.entities.user_entities import UserToolProvider from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolProviderNotFoundError from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.model_tool_provider import ModelToolProviderController
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.utils.configuration import ToolConfiguration from core.tools.tool.tool import Tool
from core.tools.utils.configuration import (
ModelToolConfigurationManager,
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
from core.tools.utils.encoder import serialize_base_model_dict from core.tools.utils.encoder import serialize_base_model_dict
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider
...@@ -135,7 +148,7 @@ class ToolManager: ...@@ -135,7 +148,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found') raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod @staticmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id, def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
agent_callback: DifyAgentCallbackHandler = None) \ agent_callback: DifyAgentCallbackHandler = None) \
-> Union[BuiltinTool, ApiTool]: -> Union[BuiltinTool, ApiTool]:
""" """
...@@ -170,7 +183,7 @@ class ToolManager: ...@@ -170,7 +183,7 @@ class ToolManager:
# decrypt the credentials # decrypt the credentials
credentials = builtin_provider.credentials credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name) controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
...@@ -187,18 +200,96 @@ class ToolManager: ...@@ -187,18 +200,96 @@ class ToolManager:
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name) api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
# decrypt the credentials # decrypt the credentials
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={ return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
'tenant_id': tenant_id, 'tenant_id': tenant_id,
'credentials': decrypted_credentials, 'credentials': decrypted_credentials,
}) })
elif provider_type == 'model':
if tenant_id is None:
raise ValueError('tenant id is required for model provider')
# get model provider
model_provider = ToolManager.get_model_provider(tenant_id, provider_name)
# get tool
model_tool = model_provider.get_tool(tool_name)
return model_tool.fork_tool_runtime(meta={
'tenant_id': tenant_id,
'credentials': model_tool.model_configuration['model_instance'].credentials
})
elif provider_type == 'app': elif provider_type == 'app':
raise NotImplementedError('app provider not implemented') raise NotImplementedError('app provider not implemented')
else: else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found') raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
"""
get the agent tool runtime
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
agent_callback=agent_callback
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = agent_tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
tool_runtime=tool_entity,
provider_name=agent_tool.provider_id,
provider_type=agent_tool.provider_type,
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@staticmethod @staticmethod
def get_builtin_provider_icon(provider: str) -> tuple[str, str]: def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
""" """
...@@ -266,6 +357,49 @@ class ToolManager: ...@@ -266,6 +357,49 @@ class ToolManager:
return builtin_providers return builtin_providers
@staticmethod
def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
"""
list all the model providers
:return: the list of the model providers
"""
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get configurations
model_configurations = ModelToolConfigurationManager.get_all_configuration()
# get all providers
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id).values()
# get model providers
model_providers: list[ModelToolProviderController] = []
for configuration in configurations:
# all the model tool should be configurated
if configuration.provider.provider not in model_configurations:
continue
if not ModelToolProviderController.is_configuration_valid(configuration):
continue
model_providers.append(ModelToolProviderController.from_db(configuration))
return model_providers
@staticmethod
def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController:
"""
get the model provider
:param provider_name: the name of the provider
:return: the provider
"""
# get configurations
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id)
configuration = configurations.get(provider_name)
if configuration is None:
raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
return ModelToolProviderController.from_db(configuration)
@staticmethod @staticmethod
def get_tool_label(tool_name: str) -> Union[I18nObject, None]: def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
""" """
...@@ -338,13 +472,35 @@ class ToolManager: ...@@ -338,13 +472,35 @@ class ToolManager:
controller = ToolManager.get_builtin_provider(provider_name) controller = ToolManager.get_builtin_provider(provider_name)
# init tool configuration # init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
result_providers[provider_name].team_credentials = masked_credentials result_providers[provider_name].team_credentials = masked_credentials
# get model tool providers
model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
# append model providers
for provider in model_providers:
result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider(
id=provider.identity.name,
author=provider.identity.author,
name=provider.identity.name,
description=I18nObject(
en_US=provider.identity.description.en_US,
zh_Hans=provider.identity.description.zh_Hans,
),
icon=provider.identity.icon,
label=I18nObject(
en_US=provider.identity.label.en_US,
zh_Hans=provider.identity.label.zh_Hans,
),
type=UserToolProvider.ProviderType.MODEL,
team_credentials={},
is_team_authorization=provider.is_active,
)
# get db api providers # get db api providers
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all() filter(ApiToolProvider.tenant_id == tenant_id).all()
...@@ -383,7 +539,7 @@ class ToolManager: ...@@ -383,7 +539,7 @@ class ToolManager:
) )
# init tool configuration # init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
...@@ -443,7 +599,7 @@ class ToolManager: ...@@ -443,7 +599,7 @@ class ToolManager:
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
) )
# init tool configuration # init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
......
from typing import Any import os
from typing import Any, Union
from pydantic import BaseModel from pydantic import BaseModel
from yaml import FullLoader, load
from core.helper import encrypter from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import ToolProviderCredentials from core.tools.entities.tool_entities import (
ModelToolConfiguration,
ModelToolProviderConfiguration,
ToolParameter,
ToolProviderCredentials,
)
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
class ToolConfiguration(BaseModel): class ToolConfigurationManager(BaseModel):
tenant_id: str tenant_id: str
provider_controller: ToolProviderController provider_controller: ToolProviderController
...@@ -94,3 +103,187 @@ class ToolConfiguration(BaseModel): ...@@ -94,3 +103,187 @@ class ToolConfiguration(BaseModel):
cache_type=ToolProviderCredentialsCacheType.PROVIDER cache_type=ToolProviderCredentialsCacheType.PROVIDER
) )
cache.delete() cache.delete()
class ToolParameterConfigurationManager(BaseModel):
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
provider_type: str
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return {key: value for key, value in parameters.items()}
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = \
parameters[parameter.name][:2] + \
'*' * (len(parameters[parameter.name]) - 4) +\
parameters[parameter.name][-2:]
else:
parameters[parameter.name] = '*' * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
try:
has_secret_input = True
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except:
pass
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
)
cache.delete()
class ModelToolConfigurationManager:
"""
Model as tool configuration
"""
_configurations: dict[str, ModelToolProviderConfiguration] = {}
_model_configurations: dict[str, ModelToolConfiguration] = {}
_inited = False
@classmethod
def _init_configuration(cls):
"""
init configuration
"""
absolute_path = os.path.abspath(os.path.dirname(__file__))
model_tools_path = os.path.join(absolute_path, '..', 'model_tools')
# get all .yaml file
files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')]
for file in files:
provider = file.split('.')[0]
with open(os.path.join(model_tools_path, file), encoding='utf-8') as f:
configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader))
models = configurations.models or []
for model in models:
model_key = f'{provider}.{model.model}'
cls._model_configurations[model_key] = model
cls._configurations[provider] = configurations
cls._inited = True
@classmethod
def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]:
"""
get configuration by provider
"""
if not cls._inited:
cls._init_configuration()
return cls._configurations.get(provider, None)
@classmethod
def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]:
"""
get all configurations
"""
if not cls._inited:
cls._init_configuration()
return cls._configurations
@classmethod
def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]:
"""
get model configuration
"""
key = f'{provider}.{model}'
if not cls._inited:
cls._init_configuration()
return cls._model_configurations.get(key, None)
\ No newline at end of file
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
try:
uuid.UUID(uuid_str)
return True
except Exception:
return False
...@@ -32,8 +32,6 @@ class Mail: ...@@ -32,8 +32,6 @@ class Mail:
from libs.smtp import SMTPClient from libs.smtp import SMTPClient
if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'): if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'):
raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type') raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type')
if not app.config.get('SMTP_USERNAME') or not app.config.get('SMTP_PASSWORD'):
raise ValueError('SMTP_USERNAME and SMTP_PASSWORD are required for smtp mail type')
self._client = SMTPClient( self._client = SMTPClient(
server=app.config.get('SMTP_SERVER'), server=app.config.get('SMTP_SERVER'),
port=app.config.get('SMTP_PORT'), port=app.config.get('SMTP_PORT'),
......
...@@ -16,7 +16,8 @@ class SMTPClient: ...@@ -16,7 +16,8 @@ class SMTPClient:
smtp = smtplib.SMTP(self.server, self.port) smtp = smtplib.SMTP(self.server, self.port)
if self._use_tls: if self._use_tls:
smtp.starttls() smtp.starttls()
smtp.login(self.username, self.password) if (self.username):
smtp.login(self.username, self.password)
msg = MIMEMultipart() msg = MIMEMultipart()
msg['Subject'] = mail['subject'] msg['Subject'] = mail['subject']
msg['From'] = self._from msg['From'] = self._from
......
...@@ -32,7 +32,7 @@ celery==5.2.7 ...@@ -32,7 +32,7 @@ celery==5.2.7
redis~=4.5.4 redis~=4.5.4
openpyxl==3.1.2 openpyxl==3.1.2
chardet~=5.1.0 chardet~=5.1.0
docx2txt==0.8 python-docx~=1.1.0
pypdfium2==4.16.0 pypdfium2==4.16.0
resend~=0.7.0 resend~=0.7.0
pyjwt~=2.8.0 pyjwt~=2.8.0
......
...@@ -15,7 +15,7 @@ from events.tenant_event import tenant_was_created ...@@ -15,7 +15,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.helper import get_remote_ip from libs.helper import get_remote_ip
from libs.passport import PassportService from libs.passport import PassportService
from libs.password import compare_password, hash_password from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import * from models.account import *
from services.errors.account import ( from services.errors.account import (
...@@ -58,7 +58,7 @@ class AccountService: ...@@ -58,7 +58,7 @@ class AccountService:
account.current_tenant_id = available_ta.tenant_id account.current_tenant_id = available_ta.tenant_id
available_ta.current = True available_ta.current = True
db.session.commit() db.session.commit()
if datetime.utcnow() - account.last_active_at > timedelta(minutes=10): if datetime.utcnow() - account.last_active_at > timedelta(minutes=10):
account.last_active_at = datetime.utcnow() account.last_active_at = datetime.utcnow()
db.session.commit() db.session.commit()
...@@ -104,6 +104,9 @@ class AccountService: ...@@ -104,6 +104,9 @@ class AccountService:
if account.password and not compare_password(password, account.password, account.password_salt): if account.password and not compare_password(password, account.password, account.password_salt):
raise CurrentPasswordIncorrectError("Current password is incorrect.") raise CurrentPasswordIncorrectError("Current password is incorrect.")
# may be raised
valid_password(new_password)
# generate password salt # generate password salt
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode() base64_salt = base64.b64encode(salt).decode()
...@@ -140,9 +143,9 @@ class AccountService: ...@@ -140,9 +143,9 @@ class AccountService:
account.interface_language = interface_language account.interface_language = interface_language
account.interface_theme = interface_theme account.interface_theme = interface_theme
# Set timezone based on language # Set timezone based on language
account.timezone = language_timezone_mapping.get(interface_language, 'UTC') account.timezone = language_timezone_mapping.get(interface_language, 'UTC')
db.session.add(account) db.session.add(account)
db.session.commit() db.session.commit()
...@@ -279,7 +282,7 @@ class TenantService: ...@@ -279,7 +282,7 @@ class TenantService:
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
if not tenant_account_join: if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else: else:
TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False}) TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False})
tenant_account_join.current = True tenant_account_join.current = True
# Set the current tenant for the account # Set the current tenant for the account
...@@ -449,7 +452,7 @@ class RegisterService: ...@@ -449,7 +452,7 @@ class RegisterService:
return account return account
@classmethod @classmethod
def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str: def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str:
"""Invite new member""" """Invite new member"""
account = Account.query.filter_by(email=email).first() account = Account.query.filter_by(email=email).first()
......
...@@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import ( ...@@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ApiProviderSchemaType, ApiProviderSchemaType,
ToolCredentialsOption, ToolCredentialsOption,
ToolParameter,
ToolProviderCredentials, ToolProviderCredentials,
) )
from core.tools.entities.user_entities import UserTool, UserToolProvider from core.tools.entities.user_entities import UserTool, UserToolProvider
...@@ -16,11 +17,12 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio ...@@ -16,11 +17,12 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolConfiguration from core.tools.utils.configuration import ToolConfigurationManager
from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict
from core.tools.utils.parser import ApiBasedToolSchemaParser from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider
from services.model_provider_service import ModelProviderService
class ToolManageService: class ToolManageService:
...@@ -49,11 +51,13 @@ class ToolManageService: ...@@ -49,11 +51,13 @@ class ToolManageService:
:param provider: the provider dict :param provider: the provider dict
""" """
url_prefix = (current_app.config.get("CONSOLE_API_URL") url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ "/console/api/workspaces/current/tool-provider/builtin/") + "/console/api/workspaces/current/tool-provider/")
if 'icon' in provider: if 'icon' in provider:
if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value: if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value:
provider['icon'] = url_prefix + provider['name'] + '/icon' provider['icon'] = url_prefix + 'builtin/' + provider['name'] + '/icon'
elif provider['type'] == UserToolProvider.ProviderType.MODEL.value:
provider['icon'] = url_prefix + 'model/' + provider['name'] + '/icon'
elif provider['type'] == UserToolProvider.ProviderType.API.value: elif provider['type'] == UserToolProvider.ProviderType.API.value:
try: try:
provider['icon'] = json.loads(provider['icon']) provider['icon'] = json.loads(provider['icon'])
...@@ -73,15 +77,52 @@ class ToolManageService: ...@@ -73,15 +77,52 @@ class ToolManageService:
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools() tools = provider_controller.get_tools()
result = [ tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
UserTool( # check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
result = []
for tool in tools:
# fork tool runtime
tool = tool.fork_tool_runtime(meta={
'credentials': credentials,
'tenant_id': tenant_id,
})
# get tool parameters
parameters = tool.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
# override parameters
current_parameters = parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
user_tool = UserTool(
author=tool.identity.author, author=tool.identity.author,
name=tool.identity.name, name=tool.identity.name,
label=tool.identity.label, label=tool.identity.label,
description=tool.description.human, description=tool.description.human,
parameters=tool.parameters or [] parameters=current_parameters
) for tool in tools )
] result.append(user_tool)
return json.loads( return json.loads(
serialize_base_model_array(result) serialize_base_model_array(result)
...@@ -238,7 +279,7 @@ class ToolManageService: ...@@ -238,7 +279,7 @@ class ToolManageService:
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials # encrypt credentials
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials) db_provider.credentials_str = json.dumps(encrypted_credentials)
...@@ -325,7 +366,7 @@ class ToolManageService: ...@@ -325,7 +366,7 @@ class ToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider_name) provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
raise ValueError(f'provider {provider_name} does not need credentials') raise ValueError(f'provider {provider_name} does not need credentials')
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# get original credentials if exists # get original credentials if exists
if provider is not None: if provider is not None:
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
...@@ -409,7 +450,7 @@ class ToolManageService: ...@@ -409,7 +450,7 @@ class ToolManageService:
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists # get original credentials if exists
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
...@@ -449,7 +490,7 @@ class ToolManageService: ...@@ -449,7 +490,7 @@ class ToolManageService:
# delete cache # delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name) provider_controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache() tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' } return { 'result': 'success' }
...@@ -467,6 +508,46 @@ class ToolManageService: ...@@ -467,6 +508,46 @@ class ToolManageService:
return icon_bytes, mime_type return icon_bytes, mime_type
@staticmethod
def get_model_tool_provider_icon(
provider: str
):
"""
get tool provider icon and it's mimetype
"""
service = ModelProviderService()
icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US')
if icon_bytes is None:
raise ValueError(f'provider {provider} does not exists')
return icon_bytes, mime_type
@staticmethod
def list_model_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
):
"""
list model tool provider tools
"""
provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider)
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
result = [
UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
parameters=tool.parameters or []
) for tool in tools
]
return json.loads(
serialize_base_model_array(result)
)
@staticmethod @staticmethod
def delete_api_tool_provider( def delete_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str user_id: str, tenant_id: str, provider_name: str
...@@ -551,7 +632,7 @@ class ToolManageService: ...@@ -551,7 +632,7 @@ class ToolManageService:
# decrypt credentials # decrypt credentials
if db_provider.id: if db_provider.id:
tool_configuration = ToolConfiguration( tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_controller=provider_controller provider_controller=provider_controller
) )
......
...@@ -29,7 +29,17 @@ export default function ChartView({ appId }: IChartViewProps) { ...@@ -29,7 +29,17 @@ export default function ChartView({ appId }: IChartViewProps) {
const [period, setPeriod] = useState<PeriodParams>({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) const [period, setPeriod] = useState<PeriodParams>({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } })
const onSelect = (item: Item) => { const onSelect = (item: Item) => {
setPeriod({ name: item.name, query: item.value === 'all' ? undefined : { start: today.subtract(item.value as number, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) if (item.value === 'all') {
setPeriod({ name: item.name, query: undefined })
}
else if (item.value === 0) {
const startOfToday = today.startOf('day').format(queryDateFormat)
const endOfToday = today.endOf('day').format(queryDateFormat)
setPeriod({ name: item.name, query: { start: startOfToday, end: endOfToday } })
}
else {
setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } })
}
} }
if (!response) if (!response)
......
...@@ -62,8 +62,10 @@ const ActivateForm = () => { ...@@ -62,8 +62,10 @@ const ActivateForm = () => {
showErrorMessage(t('login.error.passwordEmpty')) showErrorMessage(t('login.error.passwordEmpty'))
return false return false
} }
if (!validPassword.test(password)) if (!validPassword.test(password)) {
showErrorMessage(t('login.error.passwordInvalid')) showErrorMessage(t('login.error.passwordInvalid'))
return false
}
return true return true
}, [name, password, showErrorMessage, t]) }, [name, password, showErrorMessage, t])
......
...@@ -24,7 +24,7 @@ const WarningMask: FC<IWarningMaskProps> = ({ ...@@ -24,7 +24,7 @@ const WarningMask: FC<IWarningMaskProps> = ({
return ( return (
<div className={`${s.mask} absolute z-10 inset-0 pt-16`} <div className={`${s.mask} absolute z-10 inset-0 pt-16`}
> >
<div className='mx-auto w-[535px]'> <div className='mx-auto px-10'>
<div className={`${s.icon} flex items-center justify-center w-11 h-11 rounded-xl bg-white`}>{warningIcon}</div> <div className={`${s.icon} flex items-center justify-center w-11 h-11 rounded-xl bg-white`}>{warningIcon}</div>
<div className='mt-4 text-[24px] leading-normal font-semibold text-gray-800'> <div className='mt-4 text-[24px] leading-normal font-semibold text-gray-800'>
{title} {title}
......
...@@ -25,6 +25,7 @@ import { useToastContext } from '@/app/components/base/toast' ...@@ -25,6 +25,7 @@ import { useToastContext } from '@/app/components/base/toast'
import { useEventEmitterContextContext } from '@/context/event-emitter' import { useEventEmitterContextContext } from '@/context/event-emitter'
import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/config-var' import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/config-var'
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block'
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block'
export type ISimplePromptInput = { export type ISimplePromptInput = {
mode: AppType mode: AppType
...@@ -122,6 +123,10 @@ const Prompt: FC<ISimplePromptInput> = ({ ...@@ -122,6 +123,10 @@ const Prompt: FC<ISimplePromptInput> = ({
if (mode === AppType.chat) if (mode === AppType.chat)
setIntroduction(res.opening_statement) setIntroduction(res.opening_statement)
showAutomaticFalse() showAutomaticFalse()
eventEmitter?.emit({
type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER,
payload: res.prompt,
} as any)
} }
const minHeight = 228 const minHeight = 228
const [editorHeight, setEditorHeight] = useState(minHeight) const [editorHeight, setEditorHeight] = useState(minHeight)
......
...@@ -34,7 +34,7 @@ const AgentTools: FC = () => { ...@@ -34,7 +34,7 @@ const AgentTools: FC = () => {
const [selectedProviderId, setSelectedProviderId] = useState<string | undefined>(undefined) const [selectedProviderId, setSelectedProviderId] = useState<string | undefined>(undefined)
const [isShowSettingTool, setIsShowSettingTool] = useState(false) const [isShowSettingTool, setIsShowSettingTool] = useState(false)
const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => { const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => {
const collection = collectionList.find(collection => collection.id === item.provider_id) const collection = collectionList.find(collection => collection.id === item.provider_id && collection.type === item.provider_type)
const icon = collection?.icon const icon = collection?.icon
return { return {
...item, ...item,
......
...@@ -8,7 +8,7 @@ import Drawer from '@/app/components/base/drawer-plus' ...@@ -8,7 +8,7 @@ import Drawer from '@/app/components/base/drawer-plus'
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
import type { Collection, Tool } from '@/app/components/tools/types' import type { Collection, Tool } from '@/app/components/tools/types'
import { fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools' import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList } from '@/service/tools'
import I18n from '@/context/i18n' import I18n from '@/context/i18n'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading' import Loading from '@/app/components/base/loading'
...@@ -19,6 +19,7 @@ import AppIcon from '@/app/components/base/app-icon' ...@@ -19,6 +19,7 @@ import AppIcon from '@/app/components/base/app-icon'
type Props = { type Props = {
collection: Collection collection: Collection
isBuiltIn?: boolean isBuiltIn?: boolean
isModel?: boolean
toolName: string toolName: string
setting?: Record<string, any> setting?: Record<string, any>
readonly?: boolean readonly?: boolean
...@@ -29,6 +30,7 @@ type Props = { ...@@ -29,6 +30,7 @@ type Props = {
const SettingBuiltInTool: FC<Props> = ({ const SettingBuiltInTool: FC<Props> = ({
collection, collection,
isBuiltIn = true, isBuiltIn = true,
isModel = true,
toolName, toolName,
setting = {}, setting = {},
readonly, readonly,
...@@ -56,7 +58,11 @@ const SettingBuiltInTool: FC<Props> = ({ ...@@ -56,7 +58,11 @@ const SettingBuiltInTool: FC<Props> = ({
(async () => { (async () => {
setIsLoading(true) setIsLoading(true)
try { try {
const list = isBuiltIn ? await fetchBuiltInToolList(collection.name) : await fetchCustomToolList(collection.name) const list = isBuiltIn
? await fetchBuiltInToolList(collection.name)
: isModel
? await fetchModelToolList(collection.name)
: await fetchCustomToolList(collection.name)
setTools(list) setTools(list)
const currTool = list.find(tool => tool.name === toolName) const currTool = list.find(tool => tool.name === toolName)
if (currTool) { if (currTool) {
......
...@@ -130,7 +130,7 @@ const Debug: FC<IDebug> = ({ ...@@ -130,7 +130,7 @@ const Debug: FC<IDebug> = ({
const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
const logError = useCallback((message: string) => { const logError = useCallback((message: string) => {
notify({ type: 'error', message }) notify({ type: 'error', message, duration: 3000 })
}, [notify]) }, [notify])
const [completionFiles, setCompletionFiles] = useState<VisionFile[]>([]) const [completionFiles, setCompletionFiles] = useState<VisionFile[]>([])
......
...@@ -12,6 +12,7 @@ import { SimpleSelect } from '@/app/components/base/select' ...@@ -12,6 +12,7 @@ import { SimpleSelect } from '@/app/components/base/select'
import type { AppDetailResponse } from '@/models/app' import type { AppDetailResponse } from '@/models/app'
import type { Language } from '@/types/app' import type { Language } from '@/types/app'
import EmojiPicker from '@/app/components/base/emoji-picker' import EmojiPicker from '@/app/components/base/emoji-picker'
import { useToastContext } from '@/app/components/base/toast'
import { languages } from '@/i18n/language' import { languages } from '@/i18n/language'
...@@ -42,6 +43,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({ ...@@ -42,6 +43,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
onClose, onClose,
onSave, onSave,
}) => { }) => {
const { notify } = useToastContext()
const [isShowMore, setIsShowMore] = useState(false) const [isShowMore, setIsShowMore] = useState(false)
const { icon, icon_background } = appInfo const { icon, icon_background } = appInfo
const { title, description, copyright, privacy_policy, default_language } = appInfo.site const { title, description, copyright, privacy_policy, default_language } = appInfo.site
...@@ -67,6 +69,10 @@ const SettingsModal: FC<ISettingsModalProps> = ({ ...@@ -67,6 +69,10 @@ const SettingsModal: FC<ISettingsModalProps> = ({
} }
const onClickSave = async () => { const onClickSave = async () => {
if (!inputInfo.title) {
notify({ type: 'error', message: t('app.newApp.nameNotEmpty') })
return
}
setSaveLoading(true) setSaveLoading(true)
const params = { const params = {
title: inputInfo.title, title: inputInfo.title,
......
...@@ -95,7 +95,10 @@ const ConfigPanel = () => { ...@@ -95,7 +95,10 @@ const ConfigPanel = () => {
<Button <Button
type='primary' type='primary'
className='mr-2 text-sm font-medium' className='mr-2 text-sm font-medium'
onClick={handleStartChat} onClick={() => {
setCollapsed(true)
handleStartChat()
}}
> >
{t('common.operation.save')} {t('common.operation.save')}
</Button> </Button>
......
import type { FC } from 'react' import type { FC } from 'react'
import { useState } from 'react' import { useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import Uploader from './uploader' import Uploader from './uploader'
import ImageLinkInput from './image-link-input' import ImageLinkInput from './image-link-input'
import { ImagePlus } from '@/app/components/base/icons/src/vender/line/images' import { ImagePlus } from '@/app/components/base/icons/src/vender/line/images'
...@@ -25,16 +26,16 @@ const UploadOnlyFromLocal: FC<UploadOnlyFromLocalProps> = ({ ...@@ -25,16 +26,16 @@ const UploadOnlyFromLocal: FC<UploadOnlyFromLocalProps> = ({
}) => { }) => {
return ( return (
<Uploader onUpload={onUpload} disabled={disabled} limit={limit}> <Uploader onUpload={onUpload} disabled={disabled} limit={limit}>
{ {hovering => (
hovering => ( <div
<div className={` className={`
relative flex items-center justify-center w-8 h-8 rounded-lg cursor-pointer relative flex items-center justify-center w-8 h-8 rounded-lg cursor-pointer
${hovering && 'bg-gray-100'} ${hovering && 'bg-gray-100'}
`}> `}
<ImagePlus className='w-4 h-4 text-gray-500' /> >
</div> <ImagePlus className="w-4 h-4 text-gray-500" />
) </div>
} )}
</Uploader> </Uploader>
) )
} }
...@@ -54,13 +55,16 @@ const UploaderButton: FC<UploaderButtonProps> = ({ ...@@ -54,13 +55,16 @@ const UploaderButton: FC<UploaderButtonProps> = ({
const { t } = useTranslation() const { t } = useTranslation()
const [open, setOpen] = useState(false) const [open, setOpen] = useState(false)
const hasUploadFromLocal = methods.find(method => method === TransferMethod.local_file) const hasUploadFromLocal = methods.find(
method => method === TransferMethod.local_file,
)
const handleUpload = (imageFile: ImageFile) => { const handleUpload = (imageFile: ImageFile) => {
setOpen(false)
onUpload(imageFile) onUpload(imageFile)
} }
const closePopover = () => setOpen(false)
const handleToggle = () => { const handleToggle = () => {
if (disabled) if (disabled)
return return
...@@ -72,43 +76,46 @@ const UploaderButton: FC<UploaderButtonProps> = ({ ...@@ -72,43 +76,46 @@ const UploaderButton: FC<UploaderButtonProps> = ({
<PortalToFollowElem <PortalToFollowElem
open={open} open={open}
onOpenChange={setOpen} onOpenChange={setOpen}
placement='top-start' placement="top-start"
> >
<PortalToFollowElemTrigger onClick={handleToggle}> <PortalToFollowElemTrigger onClick={handleToggle}>
<div className={` <button
relative flex items-center justify-center w-8 h-8 hover:bg-gray-100 rounded-lg type="button"
${disabled ? 'cursor-not-allowed' : 'cursor-pointer'} disabled={disabled}
`}> className="relative flex items-center justify-center w-8 h-8 enabled:hover:bg-gray-100 rounded-lg disabled:cursor-not-allowed"
<ImagePlus className='w-4 h-4 text-gray-500' /> >
</div> <ImagePlus className="w-4 h-4 text-gray-500" />
</button>
</PortalToFollowElemTrigger> </PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-50'> <PortalToFollowElemContent className="z-50">
<div className='p-2 w-[260px] bg-white rounded-lg border-[0.5px] border-gray-200 shadow-lg'> <div className="p-2 w-[260px] bg-white rounded-lg border-[0.5px] border-gray-200 shadow-lg">
<ImageLinkInput onUpload={handleUpload} /> <ImageLinkInput onUpload={handleUpload} />
{ {hasUploadFromLocal && (
hasUploadFromLocal && ( <>
<> <div className="flex items-center mt-2 px-2 text-xs font-medium text-gray-400">
<div className='flex items-center mt-2 px-2 text-xs font-medium text-gray-400'> <div className="mr-3 w-[93px] h-[1px] bg-gradient-to-l from-[#F3F4F6]" />
<div className='mr-3 w-[93px] h-[1px] bg-gradient-to-l from-[#F3F4F6]' /> OR
OR <div className="ml-3 w-[93px] h-[1px] bg-gradient-to-r from-[#F3F4F6]" />
<div className='ml-3 w-[93px] h-[1px] bg-gradient-to-r from-[#F3F4F6]' /> </div>
</div> <Uploader
<Uploader onUpload={handleUpload} limit={limit}> onUpload={handleUpload}
{ limit={limit}
hovering => ( closePopover={closePopover}
<div className={` >
flex items-center justify-center h-8 text-[13px] font-medium text-[#155EEF] rounded-lg cursor-pointer {hovering => (
${hovering && 'bg-primary-50'} <div
`}> className={cn(
<Upload03 className='mr-1 w-4 h-4' /> 'flex items-center justify-center h-8 text-[13px] font-medium text-[#155EEF] rounded-lg cursor-pointer',
{t('common.imageUploader.uploadFromComputer')} hovering && 'bg-primary-50',
</div> )}
) >
} <Upload03 className="mr-1 w-4 h-4" />
</Uploader> {t('common.imageUploader.uploadFromComputer')}
</> </div>
) )}
} </Uploader>
</>
)}
</div> </div>
</PortalToFollowElemContent> </PortalToFollowElemContent>
</PortalToFollowElem> </PortalToFollowElem>
...@@ -125,7 +132,9 @@ const ChatImageUploader: FC<ChatImageUploaderProps> = ({ ...@@ -125,7 +132,9 @@ const ChatImageUploader: FC<ChatImageUploaderProps> = ({
onUpload, onUpload,
disabled, disabled,
}) => { }) => {
const onlyUploadLocal = settings.transfer_methods.length === 1 && settings.transfer_methods[0] === TransferMethod.local_file const onlyUploadLocal
= settings.transfer_methods.length === 1
&& settings.transfer_methods[0] === TransferMethod.local_file
if (onlyUploadLocal) { if (onlyUploadLocal) {
return ( return (
......
...@@ -30,6 +30,7 @@ const ImageLinkInput: FC<ImageLinkInputProps> = ({ ...@@ -30,6 +30,7 @@ const ImageLinkInput: FC<ImageLinkInputProps> = ({
return ( return (
<div className='flex items-center pl-1.5 pr-1 h-8 border border-gray-200 bg-white shadow-xs rounded-lg'> <div className='flex items-center pl-1.5 pr-1 h-8 border border-gray-200 bg-white shadow-xs rounded-lg'>
<input <input
type="text"
className='grow mr-0.5 px-1 h-[18px] text-[13px] outline-none appearance-none' className='grow mr-0.5 px-1 h-[18px] text-[13px] outline-none appearance-none'
value={imageLink} value={imageLink}
onChange={e => setImageLink(e.target.value)} onChange={e => setImageLink(e.target.value)}
......
import type { FC } from 'react' import type { FC } from 'react'
import { useState } from 'react' import { useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { Loading02, XClose } from '@/app/components/base/icons/src/vender/line/general' import cn from 'classnames'
import {
Loading02,
XClose,
} from '@/app/components/base/icons/src/vender/line/general'
import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows' import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import TooltipPlus from '@/app/components/base/tooltip-plus' import TooltipPlus from '@/app/components/base/tooltip-plus'
...@@ -30,7 +34,11 @@ const ImageList: FC<ImageListProps> = ({ ...@@ -30,7 +34,11 @@ const ImageList: FC<ImageListProps> = ({
const [imagePreviewUrl, setImagePreviewUrl] = useState('') const [imagePreviewUrl, setImagePreviewUrl] = useState('')
const handleImageLinkLoadSuccess = (item: ImageFile) => { const handleImageLinkLoadSuccess = (item: ImageFile) => {
if (item.type === TransferMethod.remote_url && onImageLinkLoadSuccess && item.progress !== -1) if (
item.type === TransferMethod.remote_url
&& onImageLinkLoadSuccess
&& item.progress !== -1
)
onImageLinkLoadSuccess(item._id) onImageLinkLoadSuccess(item._id)
} }
const handleImageLinkLoadError = (item: ImageFile) => { const handleImageLinkLoadError = (item: ImageFile) => {
...@@ -39,89 +47,95 @@ const ImageList: FC<ImageListProps> = ({ ...@@ -39,89 +47,95 @@ const ImageList: FC<ImageListProps> = ({
} }
return ( return (
<div className='flex flex-wrap'> <div className="flex flex-wrap">
{ {list.map(item => (
list.map(item => ( <div
<div key={item._id}
key={item._id} className="group relative mr-1 border-[0.5px] border-black/5 rounded-lg"
className='group relative mr-1 border-[0.5px] border-black/5 rounded-lg' >
> {item.type === TransferMethod.local_file && item.progress !== 100 && (
{ <>
item.type === TransferMethod.local_file && item.progress !== 100 && ( <div
<> className="absolute inset-0 flex items-center justify-center z-[1] bg-black/30"
<div style={{ left: item.progress > -1 ? `${item.progress}%` : 0 }}
className='absolute inset-0 flex items-center justify-center z-[1] bg-black/30' >
style={{ left: item.progress > -1 ? `${item.progress}%` : 0 }} {item.progress === -1 && (
> <RefreshCcw01
{ className="w-5 h-5 text-white"
item.progress === -1 && ( onClick={() => onReUpload && onReUpload(item._id)}
<RefreshCcw01 className='w-5 h-5 text-white' onClick={() => onReUpload && onReUpload(item._id)} /> />
) )}
} </div>
</div> {item.progress > -1 && (
{ <span className="absolute top-[50%] left-[50%] translate-x-[-50%] translate-y-[-50%] text-sm text-white mix-blend-lighten z-[1]">
item.progress > -1 && ( {item.progress}%
<span className='absolute top-[50%] left-[50%] translate-x-[-50%] translate-y-[-50%] text-sm text-white mix-blend-lighten z-[1]'>{item.progress}%</span> </span>
) )}
} </>
</> )}
) {item.type === TransferMethod.remote_url && item.progress !== 100 && (
} <div
{ className={`
item.type === TransferMethod.remote_url && item.progress !== 100 && (
<div className={`
absolute inset-0 flex items-center justify-center rounded-lg z-[1] border absolute inset-0 flex items-center justify-center rounded-lg z-[1] border
${item.progress === -1 ? 'bg-[#FEF0C7] border-[#DC6803]' : 'bg-black/[0.16] border-transparent'} ${
`}> item.progress === -1
{ ? 'bg-[#FEF0C7] border-[#DC6803]'
item.progress > -1 && ( : 'bg-black/[0.16] border-transparent'
<Loading02 className='animate-spin w-5 h-5 text-white' />
)
}
{
item.progress === -1 && (
<TooltipPlus popupContent={t('common.imageUploader.pasteImageLinkInvalid')}>
<AlertTriangle className='w-4 h-4 text-[#DC6803]' />
</TooltipPlus>
)
}
</div>
)
} }
<img `}
className='w-16 h-16 rounded-lg object-cover cursor-pointer border-[0.5px] border-black/5' >
alt='' {item.progress > -1 && (
onLoad={() => handleImageLinkLoadSuccess(item)} <Loading02 className="animate-spin w-5 h-5 text-white" />
onError={() => handleImageLinkLoadError(item)} )}
src={item.type === TransferMethod.remote_url ? item.url : item.base64Url} {item.progress === -1 && (
onClick={() => item.progress === 100 && setImagePreviewUrl((item.type === TransferMethod.remote_url ? item.url : item.base64Url) as string)} <TooltipPlus
/> popupContent={t('common.imageUploader.pasteImageLinkInvalid')}
{
!readonly && (
<div
className={`
absolute z-10 -top-[9px] -right-[9px] items-center justify-center w-[18px] h-[18px]
bg-white hover:bg-gray-50 border-[0.5px] border-black/[0.02] rounded-2xl shadow-lg
cursor-pointer
${item.progress === -1 ? 'flex' : 'hidden group-hover:flex'}
`}
onClick={() => onRemove && onRemove(item._id)}
> >
<XClose className='w-3 h-3 text-gray-500' /> <AlertTriangle className="w-4 h-4 text-[#DC6803]" />
</div> </TooltipPlus>
)}
</div>
)}
<img
className="w-16 h-16 rounded-lg object-cover cursor-pointer border-[0.5px] border-black/5"
alt={item.file?.name}
onLoad={() => handleImageLinkLoadSuccess(item)}
onError={() => handleImageLinkLoadError(item)}
src={
item.type === TransferMethod.remote_url
? item.url
: item.base64Url
}
onClick={() =>
item.progress === 100
&& setImagePreviewUrl(
(item.type === TransferMethod.remote_url
? item.url
: item.base64Url) as string,
) )
} }
</div>
))
}
{
imagePreviewUrl && (
<ImagePreview
url={imagePreviewUrl}
onCancel={() => setImagePreviewUrl('')}
/> />
) {!readonly && (
} <button
type="button"
className={cn(
'absolute z-10 -top-[9px] -right-[9px] items-center justify-center w-[18px] h-[18px]',
'bg-white hover:bg-gray-50 border-[0.5px] border-black/[0.02] rounded-2xl shadow-lg',
item.progress === -1 ? 'flex' : 'hidden group-hover:flex',
)}
onClick={() => onRemove && onRemove(item._id)}
>
<XClose className="w-3 h-3 text-gray-500" />
</button>
)}
</div>
))}
{imagePreviewUrl && (
<ImagePreview
url={imagePreviewUrl}
onCancel={() => setImagePreviewUrl('')}
/>
)}
</div> </div>
) )
} }
......
...@@ -7,6 +7,7 @@ import { ALLOW_FILE_EXTENSIONS } from '@/types/app' ...@@ -7,6 +7,7 @@ import { ALLOW_FILE_EXTENSIONS } from '@/types/app'
type UploaderProps = { type UploaderProps = {
children: (hovering: boolean) => JSX.Element children: (hovering: boolean) => JSX.Element
onUpload: (imageFile: ImageFile) => void onUpload: (imageFile: ImageFile) => void
closePopover?: () => void
limit?: number limit?: number
disabled?: boolean disabled?: boolean
} }
...@@ -14,11 +15,16 @@ type UploaderProps = { ...@@ -14,11 +15,16 @@ type UploaderProps = {
const Uploader: FC<UploaderProps> = ({ const Uploader: FC<UploaderProps> = ({
children, children,
onUpload, onUpload,
closePopover,
limit, limit,
disabled, disabled,
}) => { }) => {
const [hovering, setHovering] = useState(false) const [hovering, setHovering] = useState(false)
const { handleLocalFileUpload } = useLocalFileUploader({ limit, onUpload, disabled }) const { handleLocalFileUpload } = useLocalFileUploader({
limit,
onUpload,
disabled,
})
const handleChange = (e: ChangeEvent<HTMLInputElement>) => { const handleChange = (e: ChangeEvent<HTMLInputElement>) => {
const file = e.target.files?.[0] const file = e.target.files?.[0]
...@@ -27,6 +33,7 @@ const Uploader: FC<UploaderProps> = ({ ...@@ -27,6 +33,7 @@ const Uploader: FC<UploaderProps> = ({
return return
handleLocalFileUpload(file) handleLocalFileUpload(file)
closePopover?.()
} }
return ( return (
...@@ -37,11 +44,8 @@ const Uploader: FC<UploaderProps> = ({ ...@@ -37,11 +44,8 @@ const Uploader: FC<UploaderProps> = ({
> >
{children(hovering)} {children(hovering)}
<input <input
className={` className='absolute block inset-0 opacity-0 text-[0] w-full disabled:cursor-not-allowed cursor-pointer'
absolute block inset-0 opacity-0 text-[0] w-full onClick={e => ((e.target as HTMLInputElement).value = '')}
${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}
`}
onClick={e => (e.target as HTMLInputElement).value = ''}
type='file' type='file'
accept={ALLOW_FILE_EXTENSIONS.map(ext => `.${ext}`).join(',')} accept={ALLOW_FILE_EXTENSIONS.map(ext => `.${ext}`).join(',')}
onChange={handleChange} onChange={handleChange}
......
...@@ -32,6 +32,7 @@ import VariableValueBlock from './plugins/variable-value-block' ...@@ -32,6 +32,7 @@ import VariableValueBlock from './plugins/variable-value-block'
import { VariableValueBlockNode } from './plugins/variable-value-block/node' import { VariableValueBlockNode } from './plugins/variable-value-block/node'
import { CustomTextNode } from './plugins/custom-text/node' import { CustomTextNode } from './plugins/custom-text/node'
import OnBlurBlock from './plugins/on-blur-or-focus-block' import OnBlurBlock from './plugins/on-blur-or-focus-block'
import UpdateBlock from './plugins/update-block'
import { textToEditorState } from './utils' import { textToEditorState } from './utils'
import type { Dataset } from './plugins/context-block' import type { Dataset } from './plugins/context-block'
import type { RoleName } from './plugins/history-block' import type { RoleName } from './plugins/history-block'
...@@ -226,6 +227,7 @@ const PromptEditor: FC<PromptEditorProps> = ({ ...@@ -226,6 +227,7 @@ const PromptEditor: FC<PromptEditorProps> = ({
<VariableValueBlock /> <VariableValueBlock />
<OnChangePlugin onChange={handleEditorChange} /> <OnChangePlugin onChange={handleEditorChange} />
<OnBlurBlock onBlur={onBlur} onFocus={onFocus} /> <OnBlurBlock onBlur={onBlur} onFocus={onFocus} />
<UpdateBlock />
{/* <TreeView /> */} {/* <TreeView /> */}
</div> </div>
</LexicalComposer> </LexicalComposer>
......
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { textToEditorState } from '../utils'
import { useEventEmitterContextContext } from '@/context/event-emitter'
export const PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER = 'PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER'
const UpdateBlock = () => {
const { eventEmitter } = useEventEmitterContextContext()
const [editor] = useLexicalComposerContext()
eventEmitter?.useSubscription((v: any) => {
if (v.type === PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER) {
const editorState = editor.parseEditorState(textToEditorState(v.payload))
editor.setEditorState(editorState)
}
})
return null
}
export default UpdateBlock
...@@ -838,7 +838,7 @@ const StepTwo = ({ ...@@ -838,7 +838,7 @@ const StepTwo = ({
{!isSetting {!isSetting
? ( ? (
<div className='flex items-center mt-8 py-2'> <div className='flex items-center mt-8 py-2'>
<Button onClick={() => onStepChange && onStepChange(-1)}>{t('datasetCreation.stepTwo.lastStep')}</Button> <Button onClick={() => onStepChange && onStepChange(-1)}>{t('datasetCreation.stepTwo.previousStep')}</Button>
<div className={s.divider} /> <div className={s.divider} />
<Button loading={isCreating} type='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.nextStep')}</Button> <Button loading={isCreating} type='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.nextStep')}</Button>
</div> </div>
......
...@@ -427,6 +427,53 @@ Chat applications support session persistence, allowing previous chat history to ...@@ -427,6 +427,53 @@ Chat applications support session persistence, allowing previous chat history to
--- ---
<Heading
url='/messages/{message_id}/suggested'
method='GET'
title='next suggested questions'
name='#suggested'
/>
<Row>
<Col>
Get next questions suggestions for the current message
### Path Params
<Properties>
<Property name='message_id' type='string' key='message_id'>
Message ID
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup title="Request" tag="GET" label="/messages/{message_id}/suggested" targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested \\\n--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \\\n--header 'Content-Type: application/json'`}>
```bash {{ title: 'cURL' }}
curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \
--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \
--header 'Content-Type: application/json' \
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"result": "success",
"data": [
"a",
"b",
"c"
]
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading <Heading
url='/messages' url='/messages'
method='GET' method='GET'
......
...@@ -442,6 +442,55 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ...@@ -442,6 +442,55 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
--- ---
<Heading
url='/messages/{message_id}/suggested'
method='GET'
title='获取下一轮建议问题列表'
name='#suggested'
/>
<Row>
<Col>
获取下一轮建议问题列表。
### Path Params
<Properties>
<Property name='message_id' type='string' key='message_id'>
Message ID
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup title="Request" tag="GET" label="/messages/{message_id}/suggested" targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested \\\n--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \\\n--header 'Content-Type: application/json'`}>
```bash {{ title: 'cURL' }}
curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \
--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \
--header 'Content-Type: application/json' \
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"result": "success",
"data": [
"a",
"b",
"c"
]
}
```
</CodeGroup>
</Col>
</Row>
---
---
<Heading <Heading
url='/messages' url='/messages'
method='GET' method='GET'
......
...@@ -71,10 +71,14 @@ export default function AccountPage() { ...@@ -71,10 +71,14 @@ export default function AccountPage() {
showErrorMessage(t('login.error.passwordEmpty')) showErrorMessage(t('login.error.passwordEmpty'))
return false return false
} }
if (!validPassword.test(password)) if (!validPassword.test(password)) {
showErrorMessage(t('login.error.passwordInvalid')) showErrorMessage(t('login.error.passwordInvalid'))
if (password !== confirmPassword) return false
}
if (password !== confirmPassword) {
showErrorMessage(t('common.account.notEqual')) showErrorMessage(t('common.account.notEqual'))
return false
}
return true return true
} }
......
...@@ -18,7 +18,7 @@ import NoSearchRes from './info/no-search-res' ...@@ -18,7 +18,7 @@ import NoSearchRes from './info/no-search-res'
import NoCustomToolPlaceholder from './no-custom-tool-placeholder' import NoCustomToolPlaceholder from './no-custom-tool-placeholder'
import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
import TabSlider from '@/app/components/base/tab-slider' import TabSlider from '@/app/components/base/tab-slider'
import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools' import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList } from '@/service/tools'
import type { AgentTool } from '@/types/app' import type { AgentTool } from '@/types/app'
type Props = { type Props = {
...@@ -89,9 +89,11 @@ const Tools: FC<Props> = ({ ...@@ -89,9 +89,11 @@ const Tools: FC<Props> = ({
const showCollectionList = (() => { const showCollectionList = (() => {
let typeFilteredList: Collection[] = [] let typeFilteredList: Collection[] = []
if (collectionType === CollectionType.all) if (collectionType === CollectionType.all)
typeFilteredList = collectionList typeFilteredList = collectionList.filter(item => item.type !== CollectionType.model)
else else if (collectionType === CollectionType.builtIn)
typeFilteredList = collectionList.filter(item => item.type === collectionType) typeFilteredList = collectionList.filter(item => item.type === CollectionType.builtIn)
else if (collectionType === CollectionType.custom)
typeFilteredList = collectionList.filter(item => item.type === CollectionType.custom)
if (query) if (query)
return typeFilteredList.filter(item => item.name.includes(query)) return typeFilteredList.filter(item => item.name.includes(query))
...@@ -122,6 +124,10 @@ const Tools: FC<Props> = ({ ...@@ -122,6 +124,10 @@ const Tools: FC<Props> = ({
const list = await fetchBuiltInToolList(currCollection.name) const list = await fetchBuiltInToolList(currCollection.name)
setCurrentTools(list) setCurrentTools(list)
} }
else if (currCollection.type === CollectionType.model) {
const list = await fetchModelToolList(currCollection.name)
setCurrentTools(list)
}
else { else {
const list = await fetchCustomToolList(currCollection.name) const list = await fetchCustomToolList(currCollection.name)
setCurrentTools(list) setCurrentTools(list)
...@@ -130,7 +136,7 @@ const Tools: FC<Props> = ({ ...@@ -130,7 +136,7 @@ const Tools: FC<Props> = ({
catch (e) { } catch (e) { }
setIsDetailLoading(false) setIsDetailLoading(false)
})() })()
}, [currCollection?.name]) }, [currCollection?.name, currCollection?.type])
const [isShowEditCollectionToolModal, setIsShowEditCollectionToolModal] = useState(false) const [isShowEditCollectionToolModal, setIsShowEditCollectionToolModal] = useState(false)
const handleCreateToolCollection = () => { const handleCreateToolCollection = () => {
...@@ -197,7 +203,7 @@ const Tools: FC<Props> = ({ ...@@ -197,7 +203,7 @@ const Tools: FC<Props> = ({
(showCollectionList.length > 0 || !query) (showCollectionList.length > 0 || !query)
? <ToolNavList ? <ToolNavList
className='mt-2 grow height-0 overflow-y-auto' className='mt-2 grow height-0 overflow-y-auto'
currentName={currCollection?.name || ''} currentIndex={currCollectionIndex || 0}
list={showCollectionList} list={showCollectionList}
onChosen={setCurrCollectionIndex} onChosen={setCurrCollectionIndex}
/> />
......
...@@ -29,9 +29,8 @@ const Header: FC<Props> = ({ ...@@ -29,9 +29,8 @@ const Header: FC<Props> = ({
const { t } = useTranslation() const { t } = useTranslation()
const isInToolsPage = loc === LOC.tools const isInToolsPage = loc === LOC.tools
const isInDebugPage = !isInToolsPage const isInDebugPage = !isInToolsPage
const needAuth = collection?.allow_delete
// const isBuiltIn = collection.type === CollectionType.builtIn const needAuth = collection?.allow_delete || collection?.type === CollectionType.model
const isAuthed = collection.is_team_authorization const isAuthed = collection.is_team_authorization
return ( return (
<div className={cn(isInToolsPage ? 'py-4 px-6' : 'py-[11px] pl-4 pr-3', 'flex justify-between items-start border-b border-gray-200')}> <div className={cn(isInToolsPage ? 'py-4 px-6' : 'py-[11px] pl-4 pr-3', 'flex justify-between items-start border-b border-gray-200')}>
...@@ -50,10 +49,13 @@ const Header: FC<Props> = ({ ...@@ -50,10 +49,13 @@ const Header: FC<Props> = ({
)} )}
</div> </div>
</div> </div>
{collection.type === CollectionType.builtIn && needAuth && ( {(collection.type === CollectionType.builtIn || collection.type === CollectionType.model) && needAuth && (
<div <div
className={cn('cursor-pointer', 'ml-1 shrink-0 flex items-center h-8 border border-gray-200 rounded-lg px-3 space-x-2 shadow-xs')} className={cn('cursor-pointer', 'ml-1 shrink-0 flex items-center h-8 border border-gray-200 rounded-lg px-3 space-x-2 shadow-xs')}
onClick={() => onShowAuth()} onClick={() => {
if (collection.type === CollectionType.builtIn || collection.type === CollectionType.model)
onShowAuth()
}}
> >
<div className={cn(isAuthed ? 'border-[#12B76A] bg-[#32D583]' : 'border-gray-400 bg-gray-300', 'rounded h-2 w-2 border')}></div> <div className={cn(isAuthed ? 'border-[#12B76A] bg-[#32D583]' : 'border-gray-400 bg-gray-300', 'rounded h-2 w-2 border')}></div>
<div className='leading-5 text-sm font-medium text-gray-700'>{t(`tools.auth.${isAuthed ? 'authorized' : 'unauthorized'}`)}</div> <div className='leading-5 text-sm font-medium text-gray-700'>{t(`tools.auth.${isAuthed ? 'authorized' : 'unauthorized'}`)}</div>
......
...@@ -8,6 +8,7 @@ import type { Collection, CustomCollectionBackend, Tool } from '../types' ...@@ -8,6 +8,7 @@ import type { Collection, CustomCollectionBackend, Tool } from '../types'
import Loading from '../../base/loading' import Loading from '../../base/loading'
import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows' import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows'
import Toast from '../../base/toast' import Toast from '../../base/toast'
import { ConfigurateMethodEnum } from '../../header/account-setting/model-provider-page/declarations'
import Header from './header' import Header from './header'
import Item from './item' import Item from './item'
import AppIcon from '@/app/components/base/app-icon' import AppIcon from '@/app/components/base/app-icon'
...@@ -16,6 +17,8 @@ import { fetchCustomCollection, removeBuiltInToolCredential, removeCustomCollect ...@@ -16,6 +17,8 @@ import { fetchCustomCollection, removeBuiltInToolCredential, removeCustomCollect
import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal' import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal'
import type { AgentTool } from '@/types/app' import type { AgentTool } from '@/types/app'
import { MAX_TOOLS_NUM } from '@/config' import { MAX_TOOLS_NUM } from '@/config'
import { useModalContext } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
type Props = { type Props = {
collection: Collection | null collection: Collection | null
...@@ -42,9 +45,32 @@ const ToolList: FC<Props> = ({ ...@@ -42,9 +45,32 @@ const ToolList: FC<Props> = ({
const { t } = useTranslation() const { t } = useTranslation()
const isInToolsPage = loc === LOC.tools const isInToolsPage = loc === LOC.tools
const isBuiltIn = collection?.type === CollectionType.builtIn const isBuiltIn = collection?.type === CollectionType.builtIn
const isModel = collection?.type === CollectionType.model
const needAuth = collection?.allow_delete const needAuth = collection?.allow_delete
const { setShowModelModal } = useModalContext()
const [showSettingAuth, setShowSettingAuth] = useState(false) const [showSettingAuth, setShowSettingAuth] = useState(false)
const { modelProviders: providers } = useProviderContext()
const showSettingAuthModal = () => {
if (isModel) {
const provider = providers.find(item => item.provider === collection?.id)
if (provider) {
setShowModelModal({
payload: {
currentProvider: provider,
currentConfigurateMethod: ConfigurateMethodEnum.predefinedModel,
currentCustomConfigrationModelFixedFields: undefined,
},
onSaveCallback: () => {
onRefreshData()
},
})
}
}
else {
setShowSettingAuth(true)
}
}
const [customCollection, setCustomCollection] = useState<CustomCollectionBackend | null>(null) const [customCollection, setCustomCollection] = useState<CustomCollectionBackend | null>(null)
useEffect(() => { useEffect(() => {
...@@ -116,7 +142,7 @@ const ToolList: FC<Props> = ({ ...@@ -116,7 +142,7 @@ const ToolList: FC<Props> = ({
icon={icon} icon={icon}
collection={collection} collection={collection}
loc={loc} loc={loc}
onShowAuth={() => setShowSettingAuth(true)} onShowAuth={() => showSettingAuthModal()}
onShowEditCustomCollection={() => setIsShowEditCustomCollectionModal(true)} onShowEditCustomCollection={() => setIsShowEditCustomCollectionModal(true)}
/> />
<div className={cn(isInToolsPage ? 'px-6 pt-4' : 'px-4 pt-3')}> <div className={cn(isInToolsPage ? 'px-6 pt-4' : 'px-4 pt-3')}>
...@@ -124,12 +150,12 @@ const ToolList: FC<Props> = ({ ...@@ -124,12 +150,12 @@ const ToolList: FC<Props> = ({
<div className=''>{t('tools.includeToolNum', { <div className=''>{t('tools.includeToolNum', {
num: list.length, num: list.length,
})}</div> })}</div>
{needAuth && isBuiltIn && !collection.is_team_authorization && ( {needAuth && (isBuiltIn || isModel) && !collection.is_team_authorization && (
<> <>
<div>·</div> <div>·</div>
<div <div
className='flex items-center text-[#155EEF] cursor-pointer' className='flex items-center text-[#155EEF] cursor-pointer'
onClick={() => setShowSettingAuth(true)} onClick={() => showSettingAuthModal()}
> >
<div>{t('tools.auth.setup')}</div> <div>{t('tools.auth.setup')}</div>
<ArrowNarrowRight className='ml-0.5 w-3 h-3' /> <ArrowNarrowRight className='ml-0.5 w-3 h-3' />
...@@ -149,7 +175,7 @@ const ToolList: FC<Props> = ({ ...@@ -149,7 +175,7 @@ const ToolList: FC<Props> = ({
collection={collection} collection={collection}
isInToolsPage={isInToolsPage} isInToolsPage={isInToolsPage}
isToolNumMax={(addedTools?.length || 0) >= MAX_TOOLS_NUM} isToolNumMax={(addedTools?.length || 0) >= MAX_TOOLS_NUM}
added={!!addedTools?.find(v => v.provider_id === collection.id && v.tool_name === item.name)} added={!!addedTools?.find(v => v.provider_id === collection.id && v.provider_type === collection.type && v.tool_name === item.name)}
onAdd={!isInToolsPage ? tool => onAddTool?.(collection as Collection, tool) : undefined} onAdd={!isInToolsPage ? tool => onAddTool?.(collection as Collection, tool) : undefined}
/> />
))} ))}
......
...@@ -35,6 +35,7 @@ const Item: FC<Props> = ({ ...@@ -35,6 +35,7 @@ const Item: FC<Props> = ({
const language = getLanguage(locale) const language = getLanguage(locale)
const isBuiltIn = collection.type === CollectionType.builtIn const isBuiltIn = collection.type === CollectionType.builtIn
const isModel = collection.type === CollectionType.model
const canShowDetail = isInToolsPage const canShowDetail = isInToolsPage
const [showDetail, setShowDetail] = useState(false) const [showDetail, setShowDetail] = useState(false)
const addBtn = <Button className='shrink-0 flex items-center h-7 !px-3 !text-xs !font-medium !text-gray-700' disabled={added || !collection.is_team_authorization} onClick={() => onAdd?.(payload)}>{t(`common.operation.${added ? 'added' : 'add'}`)}</Button> const addBtn = <Button className='shrink-0 flex items-center h-7 !px-3 !text-xs !font-medium !text-gray-700' disabled={added || !collection.is_team_authorization} onClick={() => onAdd?.(payload)}>{t(`common.operation.${added ? 'added' : 'add'}`)}</Button>
...@@ -73,6 +74,7 @@ const Item: FC<Props> = ({ ...@@ -73,6 +74,7 @@ const Item: FC<Props> = ({
setShowDetail(false) setShowDetail(false)
}} }}
isBuiltIn={isBuiltIn} isBuiltIn={isBuiltIn}
isModel={isModel}
/> />
)} )}
</> </>
......
...@@ -6,21 +6,21 @@ import Item from './item' ...@@ -6,21 +6,21 @@ import Item from './item'
import type { Collection } from '@/app/components/tools/types' import type { Collection } from '@/app/components/tools/types'
type Props = { type Props = {
className?: string className?: string
currentName: string currentIndex: number
list: Collection[] list: Collection[]
onChosen: (index: number) => void onChosen: (index: number) => void
} }
const ToolNavList: FC<Props> = ({ const ToolNavList: FC<Props> = ({
className, className,
currentName, currentIndex,
list, list,
onChosen, onChosen,
}) => { }) => {
return ( return (
<div className={cn(className)}> <div className={cn(className)}>
{list.map((item, index) => ( {list.map((item, index) => (
<Item isCurrent={item.name === currentName} key={item.name} payload={item} onClick={() => onChosen(index)}></Item> <Item isCurrent={index === currentIndex} key={index} payload={item} onClick={() => onChosen(index)}></Item>
))} ))}
</div> </div>
) )
......
...@@ -26,6 +26,7 @@ export enum CollectionType { ...@@ -26,6 +26,7 @@ export enum CollectionType {
all = 'all', all = 'all',
builtIn = 'builtin', builtIn = 'builtin',
custom = 'api', custom = 'api',
model = 'model',
} }
export type Emoji = { export type Emoji = {
......
...@@ -89,7 +89,7 @@ const translation = { ...@@ -89,7 +89,7 @@ const translation = {
other: 'and other ', other: 'and other ',
fileUnit: ' files', fileUnit: ' files',
notionUnit: ' pages', notionUnit: ' pages',
lastStep: 'Last step', previousStep: 'Previous step',
nextStep: 'Save & Process', nextStep: 'Save & Process',
save: 'Save & Process', save: 'Save & Process',
cancel: 'Cancel', cancel: 'Cancel',
......
...@@ -89,7 +89,7 @@ const translation = { ...@@ -89,7 +89,7 @@ const translation = {
other: 'その他', other: 'その他',
fileUnit: 'ファイル', fileUnit: 'ファイル',
notionUnit: 'ページ', notionUnit: 'ページ',
lastStep: '最後のステップ', previousStep: '前のステップ',
nextStep: '保存して処理', nextStep: '保存して処理',
save: '保存して処理', save: '保存して処理',
cancel: 'キャンセル', cancel: 'キャンセル',
......
...@@ -89,7 +89,7 @@ const translation = { ...@@ -89,7 +89,7 @@ const translation = {
other: 'e outros ', other: 'e outros ',
fileUnit: ' arquivos', fileUnit: ' arquivos',
notionUnit: ' páginas', notionUnit: ' páginas',
lastStep: 'Última etapa', previousStep: 'Passo anterior',
nextStep: 'Salvar e Processar', nextStep: 'Salvar e Processar',
save: 'Salvar e Processar', save: 'Salvar e Processar',
cancel: 'Cancelar', cancel: 'Cancelar',
......
...@@ -89,7 +89,7 @@ const translation = { ...@@ -89,7 +89,7 @@ const translation = {
other: ' та інші ', other: ' та інші ',
fileUnit: ' файли', fileUnit: ' файли',
notionUnit: ' сторінки', notionUnit: ' сторінки',
lastStep: 'Попередній крок', previousStep: 'Попередній крок',
nextStep: 'Зберегти та обробити', nextStep: 'Зберегти та обробити',
save: 'Зберегти та обробити', save: 'Зберегти та обробити',
cancel: 'Скасувати', cancel: 'Скасувати',
......
...@@ -89,7 +89,7 @@ const translation = { ...@@ -89,7 +89,7 @@ const translation = {
other: '和其他 ', other: '和其他 ',
fileUnit: ' 个文件', fileUnit: ' 个文件',
notionUnit: ' 个页面', notionUnit: ' 个页面',
lastStep: '上一步', previousStep: '上一步',
nextStep: '保存并处理', nextStep: '保存并处理',
save: '保存并处理', save: '保存并处理',
cancel: '取消', cancel: '取消',
......
...@@ -12,6 +12,11 @@ export const fetchBuiltInToolList = (collectionName: string) => { ...@@ -12,6 +12,11 @@ export const fetchBuiltInToolList = (collectionName: string) => {
export const fetchCustomToolList = (collectionName: string) => { export const fetchCustomToolList = (collectionName: string) => {
return get<Tool[]>(`/workspaces/current/tool-provider/api/tools?provider=${collectionName}`) return get<Tool[]>(`/workspaces/current/tool-provider/api/tools?provider=${collectionName}`)
} }
export const fetchModelToolList = (collectionName: string) => {
return get<Tool[]>(`/workspaces/current/tool-provider/model/tools?provider=${collectionName}`)
}
export const fetchBuiltInToolCredentialSchema = (collectionName: string) => { export const fetchBuiltInToolCredentialSchema = (collectionName: string) => {
return get<ToolCredential[]>(`/workspaces/current/tool-provider/builtin/${collectionName}/credentials_schema`) return get<ToolCredential[]>(`/workspaces/current/tool-provider/builtin/${collectionName}/credentials_schema`)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment