Unverified Commit 2f9cb8c4 authored by Yeuoly's avatar Yeuoly

Merge branch 'main' into feat/agent-image

parents 0debd75b 625b0afa
This diff is collapsed.
...@@ -107,20 +107,33 @@ class AppListApi(Resource): ...@@ -107,20 +107,33 @@ class AppListApi(Resource):
# validate config # validate config
model_config_dict = args['model_config'] model_config_dict = args['model_config']
# get model provider # Get provider configurations
model_manager = ModelManager() provider_manager = ProviderManager()
model_instance = model_manager.get_default_model_instance( provider_configurations = provider_manager.get_configurations(current_user.current_tenant_id)
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM # get available models from provider_configurations
available_models = provider_configurations.get_models(
model_type=ModelType.LLM,
only_active=True
) )
if not model_instance: # check if model is available
raise ProviderNotInitializeError( available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
f"No Default System Reasoning Model available. Please configure " provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
f"in the Settings -> Model Provider.") if provider_model not in available_models_names:
else: model_manager = ModelManager()
model_config_dict["model"]["provider"] = model_instance.provider model_instance = model_manager.get_default_model_instance(
model_config_dict["model"]["name"] = model_instance.model tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["name"] = model_instance.model
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
......
...@@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom ...@@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool from core.tools.entities.tool_entities import ToolRuntimeVariablePool
...@@ -169,7 +170,7 @@ class AssistantApplicationRunner(AppRunner): ...@@ -169,7 +170,7 @@ class AssistantApplicationRunner(AppRunner):
# load tool variables # load tool variables
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
tanent_id=application_generate_entity.tenant_id) tenant_id=application_generate_entity.tenant_id)
# convert db variables to tool variables # convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
...@@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner): ...@@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner):
memory=memory, memory=memory,
) )
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
# start agent runner # start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner( assistant_cot_runner = AssistantCotApplicationRunner(
...@@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner): ...@@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner):
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables, variables_pool=tool_variables,
db_variables=tool_conversation_variables, db_variables=tool_conversation_variables,
model_instance=model_instance
) )
invoke_result = assistant_cot_runner.run( invoke_result = assistant_cot_runner.run(
model_instance=model_instance,
conversation=conversation, conversation=conversation,
message=message, message=message,
query=query, query=query,
...@@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner): ...@@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner):
memory=memory, memory=memory,
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables, variables_pool=tool_variables,
db_variables=tool_conversation_variables db_variables=tool_conversation_variables,
model_instance=model_instance
) )
invoke_result = assistant_fc_runner.run( invoke_result = assistant_fc_runner.run(
model_instance=model_instance,
conversation=conversation, conversation=conversation,
message=message, message=message,
query=query, query=query,
...@@ -246,13 +254,13 @@ class AssistantApplicationRunner(AppRunner): ...@@ -246,13 +254,13 @@ class AssistantApplicationRunner(AppRunner):
agent=True agent=True
) )
def _load_tool_variables(self, conversation_id: str, user_id: str, tanent_id: str) -> ToolConversationVariables: def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
""" """
load tool variables from database load tool variables from database
""" """
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter( tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == conversation_id, ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tanent_id ToolConversationVariables.tenant_id == tenant_id
).first() ).first()
if tool_variables: if tool_variables:
...@@ -263,7 +271,7 @@ class AssistantApplicationRunner(AppRunner): ...@@ -263,7 +271,7 @@ class AssistantApplicationRunner(AppRunner):
tool_variables = ToolConversationVariables( tool_variables = ToolConversationVariables(
conversation_id=conversation_id, conversation_id=conversation_id,
user_id=user_id, user_id=user_id,
tenant_id=tanent_id, tenant_id=tenant_id,
variables_str='[]', variables_str='[]',
) )
db.session.add(tool_variables) db.session.add(tool_variables)
......
import logging import logging
import json import json
from typing import Optional, List, Tuple, Union from typing import Optional, List, Tuple, Union, cast
from datetime import datetime from datetime import datetime
from mimetypes import guess_extension from mimetypes import guess_extension
...@@ -12,7 +12,7 @@ from models.model import MessageAgentThought, Message, MessageFile ...@@ -12,7 +12,7 @@ from models.model import MessageAgentThought, Message, MessageFile
from models.tools import ToolConversationVariables from models.tools import ToolConversationVariables
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, \ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, \
ToolRuntimeVariablePool, ToolParamter ToolRuntimeVariablePool, ToolParameter
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
...@@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \ ...@@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_manager import ModelInstance
from core.file.message_file_parser import FileTransferMethod from core.file.message_file_parser import FileTransferMethod
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
prompt_messages: Optional[List[PromptMessage]] = None, prompt_messages: Optional[List[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None, db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None: ) -> None:
""" """
Agent runner Agent runner
...@@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.history_prompt_messages = prompt_messages self.history_prompt_messages = prompt_messages
self.variables_pool = variables_pool self.variables_pool = variables_pool
self.db_variables_pool = db_variables self.db_variables_pool = db_variables
self.model_instance = model_instance
# init callback # init callback
self.agent_callback = DifyAgentCallbackHandler() self.agent_callback = DifyAgentCallbackHandler()
...@@ -95,9 +100,17 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -95,9 +100,17 @@ class BaseAssistantApplicationRunner(AppRunner):
MessageAgentThought.message_id == self.message.id, MessageAgentThought.message_id == self.message.id,
).count() ).count()
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: # check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
self.stream_tool_call = True
else:
self.stream_tool_call = False
def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
""" """
Repacket app orchestration config Repack app orchestration config
""" """
if app_orchestration_config.prompt_template.simple_prompt_template is None: if app_orchestration_config.prompt_template.simple_prompt_template is None:
app_orchestration_config.prompt_template.simple_prompt_template = '' app_orchestration_config.prompt_template.simple_prompt_template = ''
...@@ -113,7 +126,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -113,7 +126,7 @@ class BaseAssistantApplicationRunner(AppRunner):
if response.type == ToolInvokeMessage.MessageType.TEXT: if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK: elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please dirct user to check it." result += f"result link: {response.message}. please tell user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE: response.type == ToolInvokeMessage.MessageType.IMAGE:
result += f"image has been created and sent to user already, you should tell user to check it now." result += f"image has been created and sent to user already, you should tell user to check it now."
...@@ -172,20 +185,20 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -172,20 +185,20 @@ class BaseAssistantApplicationRunner(AppRunner):
for parameter in parameters: for parameter in parameters:
parameter_type = 'string' parameter_type = 'string'
enum = [] enum = []
if parameter.type == ToolParamter.ToolParameterType.STRING: if parameter.type == ToolParameter.ToolParameterType.STRING:
parameter_type = 'string' parameter_type = 'string'
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN: elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
parameter_type = 'boolean' parameter_type = 'boolean'
elif parameter.type == ToolParamter.ToolParameterType.NUMBER: elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
parameter_type = 'number' parameter_type = 'number'
elif parameter.type == ToolParamter.ToolParameterType.SELECT: elif parameter.type == ToolParameter.ToolParameterType.SELECT:
for option in parameter.options: for option in parameter.options:
enum.append(option.value) enum.append(option.value)
parameter_type = 'string' parameter_type = 'string'
else: else:
raise ValueError(f"parameter type {parameter.type} is not supported") raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParamter.ToolParameterForm.FORM: if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form # get tool parameter from form
tool_parameter_config = tool.tool_parameters.get(parameter.name) tool_parameter_config = tool.tool_parameters.get(parameter.name)
if not tool_parameter_config: if not tool_parameter_config:
...@@ -194,7 +207,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -194,7 +207,7 @@ class BaseAssistantApplicationRunner(AppRunner):
if not tool_parameter_config and parameter.required: if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config") raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParamter.ToolParameterType.SELECT: if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options # check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options)) options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options: if tool_parameter_config not in options:
...@@ -202,7 +215,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -202,7 +215,7 @@ class BaseAssistantApplicationRunner(AppRunner):
# convert tool parameter config to correct type # convert tool parameter config to correct type
try: try:
if parameter.type == ToolParamter.ToolParameterType.NUMBER: if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer # check if tool parameter is integer
if isinstance(tool_parameter_config, int): if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config tool_parameter_config = tool_parameter_config
...@@ -213,11 +226,11 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -213,11 +226,11 @@ class BaseAssistantApplicationRunner(AppRunner):
tool_parameter_config = float(tool_parameter_config) tool_parameter_config = float(tool_parameter_config)
else: else:
tool_parameter_config = int(tool_parameter_config) tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN: elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config) tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParamter.ToolParameterType.SELECT, ToolParamter.ToolParameterType.STRING]: elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config) tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParamter.ToolParameterType: elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config) tool_parameter_config = str(tool_parameter_config)
except Exception as e: except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
...@@ -225,7 +238,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -225,7 +238,7 @@ class BaseAssistantApplicationRunner(AppRunner):
# save tool parameter to tool entity memory # save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config runtime_parameters[parameter.name] = tool_parameter_config
elif parameter.form == ToolParamter.ToolParameterForm.LLM: elif parameter.form == ToolParameter.ToolParameterForm.LLM:
message_tool.parameters['properties'][parameter.name] = { message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or '', "description": parameter.llm_description or '',
...@@ -279,20 +292,20 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -279,20 +292,20 @@ class BaseAssistantApplicationRunner(AppRunner):
for parameter in tool_runtime_parameters: for parameter in tool_runtime_parameters:
parameter_type = 'string' parameter_type = 'string'
enum = [] enum = []
if parameter.type == ToolParamter.ToolParameterType.STRING: if parameter.type == ToolParameter.ToolParameterType.STRING:
parameter_type = 'string' parameter_type = 'string'
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN: elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
parameter_type = 'boolean' parameter_type = 'boolean'
elif parameter.type == ToolParamter.ToolParameterType.NUMBER: elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
parameter_type = 'number' parameter_type = 'number'
elif parameter.type == ToolParamter.ToolParameterType.SELECT: elif parameter.type == ToolParameter.ToolParameterType.SELECT:
for option in parameter.options: for option in parameter.options:
enum.append(option.value) enum.append(option.value)
parameter_type = 'string' parameter_type = 'string'
else: else:
raise ValueError(f"parameter type {parameter.type} is not supported") raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParamter.ToolParameterForm.LLM: if parameter.form == ToolParameter.ToolParameterForm.LLM:
prompt_tool.parameters['properties'][parameter.name] = { prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or '', "description": parameter.llm_description or '',
......
...@@ -12,7 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMRes ...@@ -12,7 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMRes
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.tools.errors import ToolInvokeError, ToolNotFoundError, \ from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \ ToolNotSupportedError, ToolProviderNotFoundError, ToolParameterValidationError, \
ToolProviderCredentialValidationError ToolProviderCredentialValidationError
from core.features.assistant_base_runner import BaseAssistantApplicationRunner from core.features.assistant_base_runner import BaseAssistantApplicationRunner
...@@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner ...@@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
from models.model import Conversation, Message from models.model import Conversation, Message
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance, def run(self, conversation: Conversation,
conversation: Conversation,
message: Message, message: Message,
query: str, query: str,
) -> Union[Generator, LLMResult]: ) -> Union[Generator, LLMResult]:
...@@ -29,7 +28,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -29,7 +28,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
Run Cot agent application Run Cot agent application
""" """
app_orchestration_config = self.app_orchestration_config app_orchestration_config = self.app_orchestration_config
self._repacket_app_orchestration_config(app_orchestration_config) self._repack_app_orchestration_config(app_orchestration_config)
agent_scratchpad: List[AgentScratchpadUnit] = [] agent_scratchpad: List[AgentScratchpadUnit] = []
...@@ -72,7 +71,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -72,7 +71,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
} }
final_answer = '' final_answer = ''
def increse_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']: if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage final_llm_usage_dict['usage'] = usage
else: else:
...@@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price llm_usage.completion_price += usage.completion_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps: while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = False function_call_state = False
...@@ -104,7 +105,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -104,7 +105,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# update prompt messages # update prompt messages
prompt_messages = self._originze_cot_prompt_messages( prompt_messages = self._organize_cot_prompt_messages(
mode=app_orchestration_config.model_config.mode, mode=app_orchestration_config.model_config.mode,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=prompt_messages_tools, tools=prompt_messages_tools,
...@@ -137,7 +138,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -137,7 +138,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# get llm usage # get llm usage
if llm_result.usage: if llm_result.usage:
increse_usage(llm_usage, llm_result.usage) increase_usage(llm_usage, llm_result.usage)
# publish agent thought if it's first iteration # publish agent thought if it's first iteration
if iteration_step == 1: if iteration_step == 1:
...@@ -207,7 +208,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -207,7 +208,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
try: try:
tool_response = tool_instance.invoke( tool_response = tool_instance.invoke(
user_id=self.user_id, user_id=self.user_id,
tool_paramters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args) tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
) )
# transform tool response to llm friendly response # transform tool response to llm friendly response
tool_response = self.transform_tool_invoke_messages(tool_response) tool_response = self.transform_tool_invoke_messages(tool_response)
...@@ -225,15 +226,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -225,15 +226,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
message_file_ids = [message_file.id for message_file, _ in message_files] message_file_ids = [message_file.id for message_file, _ in message_files]
except ToolProviderCredentialValidationError as e: except ToolProviderCredentialValidationError as e:
error_response = f"Plese check your tool provider credentials" error_response = f"Please check your tool provider credentials"
except ( except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e: ) as e:
error_response = f"there is not a tool named {tool_call_name}" error_response = f"there is not a tool named {tool_call_name}"
except ( except (
ToolParamterValidationError ToolParameterValidationError
) as e: ) as e:
error_response = f"tool paramters validation error: {e}, please check your tool paramters" error_response = f"tool parameters validation error: {e}, please check your tool parameters"
except ToolInvokeError as e: except ToolInvokeError as e:
error_response = f"tool invoke error: {e}" error_response = f"tool invoke error: {e}"
except Exception as e: except Exception as e:
...@@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# remove Action: xxx from agent thought # remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input: if action_name and action_input is not None:
return AgentScratchpadUnit( return AgentScratchpadUnit(
agent_response=content, agent_response=content,
thought=agent_thought, thought=agent_thought,
...@@ -468,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -468,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
if not next_iteration.find("{{observation}}") >= 0: if not next_iteration.find("{{observation}}") >= 0:
raise ValueError("{{observation}} is required in next_iteration") raise ValueError("{{observation}} is required in next_iteration")
def _convert_strachpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str: def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
""" """
convert agent scratchpad list to str convert agent scratchpad list to str
""" """
...@@ -480,7 +481,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -480,7 +481,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
return result return result
def _originze_cot_prompt_messages(self, mode: Literal["completion", "chat"], def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
prompt_messages: List[PromptMessage], prompt_messages: List[PromptMessage],
tools: List[PromptMessageTool], tools: List[PromptMessageTool],
agent_scratchpad: List[AgentScratchpadUnit], agent_scratchpad: List[AgentScratchpadUnit],
...@@ -489,7 +490,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -489,7 +490,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
input: str, input: str,
) -> List[PromptMessage]: ) -> List[PromptMessage]:
""" """
originze chain of thought prompt messages, a standard prompt message is like: organize chain of thought prompt messages, a standard prompt message is like:
Respond to the human as helpfully and accurately as possible. Respond to the human as helpfully and accurately as possible.
{{instruction}} {{instruction}}
...@@ -527,7 +528,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -527,7 +528,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
.replace("{{tools}}", tools_str) \ .replace("{{tools}}", tools_str) \
.replace("{{tool_names}}", tool_names) .replace("{{tool_names}}", tool_names)
# originze prompt messages # organize prompt messages
if mode == "chat": if mode == "chat":
# override system message # override system message
overrided = False overrided = False
...@@ -558,7 +559,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -558,7 +559,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
return prompt_messages return prompt_messages
elif mode == "completion": elif mode == "completion":
# parse agent scratchpad # parse agent scratchpad
agent_scratchpad_str = self._convert_strachpad_list_to_str(agent_scratchpad) agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
# parse prompt messages # parse prompt messages
return [UserPromptMessage( return [UserPromptMessage(
content=first_prompt.replace("{{instruction}}", instruction) content=first_prompt.replace("{{instruction}}", instruction)
......
This diff is collapsed.
...@@ -78,6 +78,7 @@ class ModelFeature(Enum): ...@@ -78,6 +78,7 @@ class ModelFeature(Enum):
MULTI_TOOL_CALL = "multi-tool-call" MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought" AGENT_THOUGHT = "agent-thought"
VISION = "vision" VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
class DefaultParameterName(Enum): class DefaultParameterName(Enum):
......
...@@ -36,6 +36,7 @@ LLM_BASE_MODELS = [ ...@@ -36,6 +36,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
...@@ -80,6 +81,7 @@ LLM_BASE_MODELS = [ ...@@ -80,6 +81,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
...@@ -124,6 +126,7 @@ LLM_BASE_MODELS = [ ...@@ -124,6 +126,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
...@@ -198,6 +201,7 @@ LLM_BASE_MODELS = [ ...@@ -198,6 +201,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
...@@ -272,6 +276,7 @@ LLM_BASE_MODELS = [ ...@@ -272,6 +276,7 @@ LLM_BASE_MODELS = [
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
......
...@@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
tools: Optional[list[PromptMessageTool]] = None) -> Generator: tools: Optional[list[PromptMessageTool]] = None) -> Generator:
index = 0 index = 0
full_assistant_content = '' full_assistant_content = ''
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
real_model = model real_model = model
system_fingerprint = None system_fingerprint = None
completion = '' completion = ''
...@@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
delta = chunk.choices[0] delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \
delta.delta.function_call is None:
continue continue
# assistant_message_tool_calls = delta.delta.tool_calls # assistant_message_tool_calls = delta.delta.tool_calls
assistant_message_function_call = delta.delta.function_call assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if delta_assistant_message_function_call_storage is not None:
# handle process of stream function call
if assistant_message_function_call:
# message has not ended ever
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
continue
else:
# message has ended
assistant_message_function_call = delta_assistant_message_function_call_storage
delta_assistant_message_function_call_storage = None
else:
if assistant_message_function_call:
# start of stream function call
delta_assistant_message_function_call_storage = assistant_message_function_call
if delta_assistant_message_function_call_storage.arguments is None:
delta_assistant_message_function_call_storage.arguments = ''
continue
# extract tool calls from response # extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call) function_call = self._extract_response_function_call(assistant_message_function_call)
...@@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
if message.name is not None: if message.name:
message_dict["name"] = message.name message_dict["name"] = message.name
return message_dict return message_dict
...@@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
num_tokens = 0 num_tokens = 0
for tool in tools: for tool in tools:
num_tokens += len(encoding.encode('type')) num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(tool.get("type")))
num_tokens += len(encoding.encode('function')) num_tokens += len(encoding.encode('function'))
# calculate num tokens for function object # calculate num tokens for function object
......
...@@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast ...@@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction,
PromptMessageTool, SystemPromptMessage, UserPromptMessage) PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
...@@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ...@@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")
......
...@@ -4,6 +4,8 @@ label: ...@@ -4,6 +4,8 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16384 context_size: 16384
......
...@@ -4,6 +4,8 @@ label: ...@@ -4,6 +4,8 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 32768 context_size: 32768
......
...@@ -16,7 +16,7 @@ class MinimaxChatCompletion(object): ...@@ -16,7 +16,7 @@ class MinimaxChatCompletion(object):
""" """
def generate(self, model: str, api_key: str, group_id: str, def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: List[MinimaxMessage], model_parameters: dict, prompt_messages: List[MinimaxMessage], model_parameters: dict,
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
""" """
generate chat completion generate chat completion
...@@ -162,7 +162,6 @@ class MinimaxChatCompletion(object): ...@@ -162,7 +162,6 @@ class MinimaxChatCompletion(object):
continue continue
for choice in choices: for choice in choices:
print(choice)
message = choice['delta'] message = choice['delta']
yield MinimaxMessage( yield MinimaxMessage(
content=message, content=message,
......
...@@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object): ...@@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object):
""" """
def generate(self, model: str, api_key: str, group_id: str, def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: List[MinimaxMessage], model_parameters: dict, prompt_messages: List[MinimaxMessage], model_parameters: dict,
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
""" """
generate chat completion generate chat completion
...@@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object): ...@@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object):
**extra_kwargs **extra_kwargs
} }
if tools:
body['functions'] = tools
body['function_call'] = { 'type': 'auto' }
try: try:
response = post( response = post(
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
...@@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object): ...@@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object):
""" """
handle stream chat generate response handle stream chat generate response
""" """
function_call_storage = None
for line in response.iter_lines(): for line in response.iter_lines():
if not line: if not line:
continue continue
...@@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object): ...@@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object):
msg = data['base_resp']['status_msg'] msg = data['base_resp']['status_msg']
self._handle_error(code, msg) self._handle_error(code, msg)
if data['reply']: if data['reply'] or 'usage' in data and data['usage']:
total_tokens = data['usage']['total_tokens'] total_tokens = data['usage']['total_tokens']
message = MinimaxMessage( message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value, role=MinimaxMessage.Role.ASSISTANT.value,
...@@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object): ...@@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object):
'total_tokens': total_tokens 'total_tokens': total_tokens
} }
message.stop_reason = data['choices'][0]['finish_reason'] message.stop_reason = data['choices'][0]['finish_reason']
if function_call_storage:
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
function_call_message.function_call = function_call_storage
yield function_call_message
yield message yield message
return return
...@@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object): ...@@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object):
continue continue
for choice in choices: for choice in choices:
message = choice['messages'][0]['text'] message = choice['messages'][0]
if not message:
continue if 'function_call' in message:
if not function_call_storage:
function_call_storage = message['function_call']
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
function_call_storage['arguments'] = ''
continue
else:
function_call_storage['arguments'] += message['function_call']['arguments']
continue
else:
if function_call_storage:
message['function_call'] = function_call_storage
function_call_storage = None
yield MinimaxMessage( minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
content=message,
role=MinimaxMessage.Role.ASSISTANT.value if 'function_call' in message:
) minimax_message.function_call = message['function_call']
\ No newline at end of file
if 'text' in message:
minimax_message.content = message['text']
yield minimax_message
\ No newline at end of file
...@@ -2,7 +2,7 @@ from typing import Generator, List ...@@ -2,7 +2,7 @@ from typing import Generator, List
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage) SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
...@@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): ...@@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
""" """
client: MinimaxChatCompletionPro = self.model_apis[model]() client: MinimaxChatCompletionPro = self.model_apis[model]()
if tools:
tools = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
response = client.generate( response = client.generate(
model=model, model=model,
api_key=credentials['minimax_api_key'], api_key=credentials['minimax_api_key'],
...@@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): ...@@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
elif isinstance(prompt_message, UserPromptMessage): elif isinstance(prompt_message, UserPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
elif isinstance(prompt_message, AssistantPromptMessage): elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
content=''
)
message.function_call={
'name': prompt_message.tool_calls[0].function.name,
'arguments': prompt_message.tool_calls[0].function.arguments
}
return message
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
elif isinstance(prompt_message, ToolPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
else: else:
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
...@@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): ...@@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
finish_reason=message.stop_reason if message.stop_reason else None, finish_reason=message.stop_reason if message.stop_reason else None,
), ),
) )
elif message.function_call:
if 'name' not in message.function_call or 'arguments' not in message.function_call:
continue
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content='',
tool_calls=[AssistantPromptMessage.ToolCall(
id='',
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.function_call['name'],
arguments=message.function_call['arguments']
)
)]
),
),
)
else: else:
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
......
...@@ -7,13 +7,23 @@ class MinimaxMessage: ...@@ -7,13 +7,23 @@ class MinimaxMessage:
USER = 'USER' USER = 'USER'
ASSISTANT = 'BOT' ASSISTANT = 'BOT'
SYSTEM = 'SYSTEM' SYSTEM = 'SYSTEM'
FUNCTION = 'FUNCTION'
role: str = Role.USER.value role: str = Role.USER.value
content: str content: str
usage: Dict[str, int] = None usage: Dict[str, int] = None
stop_reason: str = '' stop_reason: str = ''
function_call: Dict[str, Any] = None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
return {
'sender_type': 'BOT',
'sender_name': '专家',
'text': '',
'function_call': self.function_call
}
return { return {
'sender_type': self.role, 'sender_type': self.role,
'sender_name': '我' if self.role == 'USER' else '专家', 'sender_name': '我' if self.role == 'USER' else '专家',
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 4096 context_size: 4096
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16385 context_size: 16385
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16385 context_size: 16385
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16385 context_size: 16385
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 4096 context_size: 4096
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 32768 context_size: 32768
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000
......
...@@ -6,6 +6,7 @@ model_type: llm ...@@ -6,6 +6,7 @@ model_type: llm
features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 8192 context_size: 8192
......
...@@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ...@@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
if message.name is not None: if message.name:
message_dict["name"] = message.name message_dict["name"] = message.name
return message_dict return message_dict
......
...@@ -41,7 +41,7 @@ class OpenLLMGenerate(object): ...@@ -41,7 +41,7 @@ class OpenLLMGenerate(object):
if not server_url: if not server_url:
raise InvalidAuthenticationError('Invalid server URL') raise InvalidAuthenticationError('Invalid server URL')
defautl_llm_config = { default_llm_config = {
"max_new_tokens": 128, "max_new_tokens": 128,
"min_length": 0, "min_length": 0,
"early_stopping": False, "early_stopping": False,
...@@ -75,19 +75,19 @@ class OpenLLMGenerate(object): ...@@ -75,19 +75,19 @@ class OpenLLMGenerate(object):
} }
if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int:
defautl_llm_config['max_new_tokens'] = model_parameters['max_tokens'] default_llm_config['max_new_tokens'] = model_parameters['max_tokens']
if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: if 'temperature' in model_parameters and type(model_parameters['temperature']) == float:
defautl_llm_config['temperature'] = model_parameters['temperature'] default_llm_config['temperature'] = model_parameters['temperature']
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
defautl_llm_config['top_p'] = model_parameters['top_p'] default_llm_config['top_p'] = model_parameters['top_p']
if 'top_k' in model_parameters and type(model_parameters['top_k']) == int: if 'top_k' in model_parameters and type(model_parameters['top_k']) == int:
defautl_llm_config['top_k'] = model_parameters['top_k'] default_llm_config['top_k'] = model_parameters['top_k']
if 'use_cache' in model_parameters and type(model_parameters['use_cache']) == bool: if 'use_cache' in model_parameters and type(model_parameters['use_cache']) == bool:
defautl_llm_config['use_cache'] = model_parameters['use_cache'] default_llm_config['use_cache'] = model_parameters['use_cache']
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
...@@ -104,7 +104,7 @@ class OpenLLMGenerate(object): ...@@ -104,7 +104,7 @@ class OpenLLMGenerate(object):
data = { data = {
'stop': stop if stop else [], 'stop': stop if stop else [],
'prompt': '\n'.join([message.content for message in prompt_messages]), 'prompt': '\n'.join([message.content for message in prompt_messages]),
'llm_config': defautl_llm_config, 'llm_config': default_llm_config,
} }
try: try:
......
...@@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast ...@@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage) SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
ParameterRule, ParameterType) ParameterRule, ParameterType, ModelFeature)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper, from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper,
XinferenceModelExtraParameter) XinferenceModelExtraParameter)
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError,
...@@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
""" """
if 'temperature' in model_parameters:
if model_parameters['temperature'] < 0.01:
model_parameters['temperature'] = 0.01
elif model_parameters['temperature'] > 1.0:
model_parameters['temperature'] = 0.99
return self._generate( return self._generate(
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
tools=tools, stop=stop, stream=stream, user=user, tools=tools, stop=stop, stream=stream, user=user,
...@@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
credentials['completion_type'] = 'completion' credentials['completion_type'] = 'completion'
else: else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported') raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
if extra_param.support_function_call:
credentials['support_function_call'] = True
except RuntimeError as e: except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
...@@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")
...@@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
label=I18nObject( label=I18nObject(
zh_Hans='温度', zh_Hans='温度',
en_US='Temperature' en_US='Temperature'
) ),
), ),
ParameterRule( ParameterRule(
name='top_p', name='top_p',
...@@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
completion_type = LLMMode.COMPLETION.value completion_type = LLMMode.COMPLETION.value
else: else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
support_function_call = credentials.get('support_function_call', False)
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
...@@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
features=[
ModelFeature.TOOL_CALL
] if support_function_call else [],
model_properties={ model_properties={
ModelPropertyKey.MODE: completion_type, ModelPropertyKey.MODE: completion_type,
}, },
...@@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ...@@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
""" """
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
client = OpenAI( client = OpenAI(
base_url=f'{credentials["server_url"]}/v1', base_url=f'{credentials["server_url"]}/v1',
api_key='abc', api_key='abc',
......
...@@ -2,7 +2,7 @@ import time ...@@ -2,7 +2,7 @@ import time
from typing import Optional from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
...@@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError ...@@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
class XinferenceTextEmbeddingModel(TextEmbeddingModel): class XinferenceTextEmbeddingModel(TextEmbeddingModel):
""" """
...@@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): ...@@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
""" """
server_url = credentials['server_url'] server_url = credentials['server_url']
model_uid = credentials['model_uid'] model_uid = credentials['model_uid']
if server_url.endswith('/'):
server_url = server_url[:-1]
client = Client(base_url=server_url) client = Client(base_url=server_url)
try: try:
...@@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): ...@@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
:return: :return:
""" """
try: try:
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens
self._invoke(model=model, credentials=credentials, texts=['ping']) self._invoke(model=model, credentials=credentials, texts=['ping'])
except InvokeAuthorizationError: except (InvokeAuthorizationError, RuntimeError):
raise CredentialsValidateFailedError('Invalid api key') raise CredentialsValidateFailedError('Invalid api key')
@property @property
...@@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): ...@@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
""" """
used to define customizable model schema used to define customizable model schema
""" """
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
label=I18nObject( label=I18nObject(
...@@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): ...@@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model_properties={}, model_properties={
ModelPropertyKey.MAX_CHUNKS: 1,
ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
},
parameter_rules=[] parameter_rules=[]
) )
......
from threading import Lock from threading import Lock
from time import time from time import time
from typing import List from typing import List
from os import path
from requests import get from requests import get
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
...@@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object): ...@@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object):
model_format: str model_format: str
model_handle_type: str model_handle_type: str
model_ability: List[str] model_ability: List[str]
max_tokens: int = 512
support_function_call: bool = False
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None: def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
support_function_call: bool, max_tokens: int) -> None:
self.model_format = model_format self.model_format = model_format
self.model_handle_type = model_handle_type self.model_handle_type = model_handle_type
self.model_ability = model_ability self.model_ability = model_ability
self.support_function_call = support_function_call
self.max_tokens = max_tokens
cache = {} cache = {}
cache_lock = Lock() cache_lock = Lock()
...@@ -49,7 +55,7 @@ class XinferenceHelper: ...@@ -49,7 +55,7 @@ class XinferenceHelper:
get xinference model extra parameter like model_format and model_handle_type get xinference model extra parameter like model_format and model_handle_type
""" """
url = f'{server_url}/v1/models/{model_uid}' url = path.join(server_url, 'v1/models', model_uid)
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session() session = Session()
...@@ -66,10 +72,12 @@ class XinferenceHelper: ...@@ -66,10 +72,12 @@ class XinferenceHelper:
response_json = response.json() response_json = response.json()
model_format = response_json['model_format'] model_format = response_json.get('model_format', 'ggmlv3')
model_ability = response_json['model_ability'] model_ability = response_json.get('model_ability', [])
if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: if response_json.get('model_type') == 'embedding':
model_handle_type = 'embedding'
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
model_handle_type = 'chatglm' model_handle_type = 'chatglm'
elif 'generate' in model_ability: elif 'generate' in model_ability:
model_handle_type = 'generate' model_handle_type = 'generate'
...@@ -78,8 +86,13 @@ class XinferenceHelper: ...@@ -78,8 +86,13 @@ class XinferenceHelper:
else: else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
support_function_call = 'tools' in model_ability
max_tokens = response_json.get('max_tokens', 512)
return XinferenceModelExtraParameter( return XinferenceModelExtraParameter(
model_format=model_format, model_format=model_format,
model_handle_type=model_handle_type, model_handle_type=model_handle_type,
model_ability=model_ability model_ability=model_ability,
support_function_call=support_function_call,
max_tokens=max_tokens
) )
\ No newline at end of file
...@@ -2,6 +2,10 @@ model: glm-3-turbo ...@@ -2,6 +2,10 @@ model: glm-3-turbo
label: label:
en_US: glm-3-turbo en_US: glm-3-turbo
model_type: llm model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
parameter_rules: parameter_rules:
......
...@@ -2,6 +2,10 @@ model: glm-4 ...@@ -2,6 +2,10 @@ model: glm-4
label: label:
en_US: glm-4 en_US: glm-4
model_type: llm model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
parameter_rules: parameter_rules:
......
...@@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
'content': prompt_message.content, 'content': prompt_message.content,
'tool_call_id': prompt_message.tool_call_id 'tool_call_id': prompt_message.tool_call_id
}) })
elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
params['messages'].append({
'role': 'assistant',
'content': prompt_message.content,
'tool_calls': [
{
'id': tool_call.id,
'type': tool_call.type,
'function': {
'name': tool_call.function.name,
'arguments': tool_call.function.arguments
}
} for tool_call in prompt_message.tool_calls
]
})
else:
params['messages'].append({
'role': 'assistant',
'content': prompt_message.content
})
else: else:
params['messages'].append({ params['messages'].append({
'role': prompt_message.role.value, 'role': prompt_message.role.value,
......
...@@ -218,15 +218,30 @@ class ProviderManager: ...@@ -218,15 +218,30 @@ class ProviderManager:
) )
if available_models: if available_models:
available_model = available_models[0] found = False
default_model = TenantDefaultModel( for available_model in available_models:
tenant_id=tenant_id, if available_model.model == "gpt-3.5-turbo-1106":
model_type=model_type.to_origin_model_type(), default_model = TenantDefaultModel(
provider_name=available_model.provider.provider, tenant_id=tenant_id,
model_name=available_model.model model_type=model_type.to_origin_model_type(),
) provider_name=available_model.provider.provider,
db.session.add(default_model) model_name=available_model.model
db.session.commit() )
db.session.add(default_model)
db.session.commit()
found = True
break
if not found:
available_model = available_models[0]
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(),
provider_name=available_model.provider.provider,
model_name=available_model.model
)
db.session.add(default_model)
db.session.commit()
if not default_model: if not default_model:
return None return None
......
...@@ -125,7 +125,7 @@ from openai import OpenAI ...@@ -125,7 +125,7 @@ from openai import OpenAI
class DallE3Tool(BuiltinTool): class DallE3Tool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
...@@ -135,7 +135,7 @@ class DallE3Tool(BuiltinTool): ...@@ -135,7 +135,7 @@ class DallE3Tool(BuiltinTool):
) )
# prompt # prompt
prompt = tool_paramters.get('prompt', '') prompt = tool_parameters.get('prompt', '')
if not prompt: if not prompt:
return self.create_text_message('Please input prompt') return self.create_text_message('Please input prompt')
...@@ -163,7 +163,7 @@ Next, we use Vectorizer.AI to convert the PNG icon generated by DallE3 into a ve ...@@ -163,7 +163,7 @@ Next, we use Vectorizer.AI to convert the PNG icon generated by DallE3 into a ve
```python ```python
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
...@@ -171,20 +171,20 @@ from httpx import post ...@@ -171,20 +171,20 @@ from httpx import post
from base64 import b64decode from base64 import b64decode
class VectorizerTool(BuiltinTool): class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool
""" """
def get_runtime_parameters(self) -> List[ToolParamter]: def get_runtime_parameters(self) -> List[ToolParameter]:
""" """
Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list
""" """
def is_tool_avaliable(self) -> bool: def is_tool_available(self) -> bool:
""" """
Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here
""" """
...@@ -194,7 +194,7 @@ Next, let's implement these three functions ...@@ -194,7 +194,7 @@ Next, let's implement these three functions
```python ```python
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
...@@ -202,7 +202,7 @@ from httpx import post ...@@ -202,7 +202,7 @@ from httpx import post
from base64 import b64decode from base64 import b64decode
class VectorizerTool(BuiltinTool): class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
...@@ -214,7 +214,7 @@ class VectorizerTool(BuiltinTool): ...@@ -214,7 +214,7 @@ class VectorizerTool(BuiltinTool):
raise ToolProviderCredentialValidationError('Please input api key name and value') raise ToolProviderCredentialValidationError('Please input api key name and value')
# Get image_id, the definition of image_id can be found in get_runtime_parameters # Get image_id, the definition of image_id can be found in get_runtime_parameters
image_id = tool_paramters.get('image_id', '') image_id = tool_parameters.get('image_id', '')
if not image_id: if not image_id:
return self.create_text_message('Please input image id') return self.create_text_message('Please input image id')
...@@ -241,24 +241,24 @@ class VectorizerTool(BuiltinTool): ...@@ -241,24 +241,24 @@ class VectorizerTool(BuiltinTool):
meta={'mime_type': 'image/svg+xml'}) meta={'mime_type': 'image/svg+xml'})
] ]
def get_runtime_parameters(self) -> List[ToolParamter]: def get_runtime_parameters(self) -> List[ToolParameter]:
""" """
override the runtime parameters override the runtime parameters
""" """
# Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml. # Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml.
return [ return [
ToolParamter.get_simple_instance( ToolParameter.get_simple_instance(
name='image_id', name='image_id',
llm_description=f'the image id that you want to vectorize, \ llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \ and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}', {[i.name for i in self.list_default_image_variables()]}',
type=ToolParamter.ToolParameterType.SELECT, type=ToolParameter.ToolParameterType.SELECT,
required=True, required=True,
options=[i.name for i in self.list_default_image_variables()] options=[i.name for i in self.list_default_image_variables()]
) )
] ]
def is_tool_avaliable(self) -> bool: def is_tool_available(self) -> bool:
# Only when there are images in the variable pool, the LLM needs to use this tool # Only when there are images in the variable pool, the LLM needs to use this tool
return len(self.list_default_image_variables()) > 0 return len(self.list_default_image_variables()) > 0
``` ```
......
...@@ -146,13 +146,13 @@ from typing import Any, Dict, List, Union ...@@ -146,13 +146,13 @@ from typing import Any, Dict, List, Union
class GoogleSearchTool(BuiltinTool): class GoogleSearchTool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
query = tool_paramters['query'] query = tool_parameters['query']
result_type = tool_paramters['result_type'] result_type = tool_parameters['result_type']
api_key = self.runtime.credentials['serpapi_api_key'] api_key = self.runtime.credentials['serpapi_api_key']
# TODO: search with serpapi # TODO: search with serpapi
result = SerpAPI(api_key).run(query, result_type=result_type) result = SerpAPI(api_key).run(query, result_type=result_type)
...@@ -163,7 +163,7 @@ class GoogleSearchTool(BuiltinTool): ...@@ -163,7 +163,7 @@ class GoogleSearchTool(BuiltinTool):
``` ```
### Parameters ### Parameters
The overall logic of the tool is in the `_invoke` method, this method accepts two parameters: `user_id` and `tool_paramters`, which represent the user ID and tool parameters respectively The overall logic of the tool is in the `_invoke` method, this method accepts two parameters: `user_id` and `tool_parameters`, which represent the user ID and tool parameters respectively
### Return Data ### Return Data
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message. When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message.
...@@ -195,7 +195,7 @@ class GoogleProvider(BuiltinToolProviderController): ...@@ -195,7 +195,7 @@ class GoogleProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"query": "test", "query": "test",
"result_type": "link" "result_type": "link"
}, },
......
...@@ -125,7 +125,7 @@ from openai import OpenAI ...@@ -125,7 +125,7 @@ from openai import OpenAI
class DallE3Tool(BuiltinTool): class DallE3Tool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
...@@ -135,7 +135,7 @@ class DallE3Tool(BuiltinTool): ...@@ -135,7 +135,7 @@ class DallE3Tool(BuiltinTool):
) )
# prompt # prompt
prompt = tool_paramters.get('prompt', '') prompt = tool_parameters.get('prompt', '')
if not prompt: if not prompt:
return self.create_text_message('Please input prompt') return self.create_text_message('Please input prompt')
...@@ -163,7 +163,7 @@ class DallE3Tool(BuiltinTool): ...@@ -163,7 +163,7 @@ class DallE3Tool(BuiltinTool):
```python ```python
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
...@@ -171,20 +171,20 @@ from httpx import post ...@@ -171,20 +171,20 @@ from httpx import post
from base64 import b64decode from base64 import b64decode
class VectorizerTool(BuiltinTool): class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
工具调用,图片变量名需要从这里传递进来,从而我们就可以从变量池中获取到图片 工具调用,图片变量名需要从这里传递进来,从而我们就可以从变量池中获取到图片
""" """
def get_runtime_parameters(self) -> List[ToolParamter]: def get_runtime_parameters(self) -> List[ToolParameter]:
""" """
重写工具参数列表,我们可以根据当前变量池里的实际情况来动态生成参数列表,从而LLM可以根据参数列表来生成表单 重写工具参数列表,我们可以根据当前变量池里的实际情况来动态生成参数列表,从而LLM可以根据参数列表来生成表单
""" """
def is_tool_avaliable(self) -> bool: def is_tool_available(self) -> bool:
""" """
当前工具是否可用,如果当前变量池中没有图片,那么我们就不需要展示这个工具,这里返回False即可 当前工具是否可用,如果当前变量池中没有图片,那么我们就不需要展示这个工具,这里返回False即可
""" """
...@@ -194,7 +194,7 @@ class VectorizerTool(BuiltinTool): ...@@ -194,7 +194,7 @@ class VectorizerTool(BuiltinTool):
```python ```python
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
...@@ -202,7 +202,7 @@ from httpx import post ...@@ -202,7 +202,7 @@ from httpx import post
from base64 import b64decode from base64 import b64decode
class VectorizerTool(BuiltinTool): class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
...@@ -214,7 +214,7 @@ class VectorizerTool(BuiltinTool): ...@@ -214,7 +214,7 @@ class VectorizerTool(BuiltinTool):
raise ToolProviderCredentialValidationError('Please input api key name and value') raise ToolProviderCredentialValidationError('Please input api key name and value')
# 获取image_id,image_id的定义可以在get_runtime_parameters中找到 # 获取image_id,image_id的定义可以在get_runtime_parameters中找到
image_id = tool_paramters.get('image_id', '') image_id = tool_parameters.get('image_id', '')
if not image_id: if not image_id:
return self.create_text_message('Please input image id') return self.create_text_message('Please input image id')
...@@ -241,24 +241,24 @@ class VectorizerTool(BuiltinTool): ...@@ -241,24 +241,24 @@ class VectorizerTool(BuiltinTool):
meta={'mime_type': 'image/svg+xml'}) meta={'mime_type': 'image/svg+xml'})
] ]
def get_runtime_parameters(self) -> List[ToolParamter]: def get_runtime_parameters(self) -> List[ToolParameter]:
""" """
override the runtime parameters override the runtime parameters
""" """
# 这里,我们重写了工具参数列表,定义了image_id,并设置了它的选项列表为当前变量池中的所有图片,这里的配置与yaml中的配置是一致的 # 这里,我们重写了工具参数列表,定义了image_id,并设置了它的选项列表为当前变量池中的所有图片,这里的配置与yaml中的配置是一致的
return [ return [
ToolParamter.get_simple_instance( ToolParameter.get_simple_instance(
name='image_id', name='image_id',
llm_description=f'the image id that you want to vectorize, \ llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \ and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}', {[i.name for i in self.list_default_image_variables()]}',
type=ToolParamter.ToolParameterType.SELECT, type=ToolParameter.ToolParameterType.SELECT,
required=True, required=True,
options=[i.name for i in self.list_default_image_variables()] options=[i.name for i in self.list_default_image_variables()]
) )
] ]
def is_tool_avaliable(self) -> bool: def is_tool_available(self) -> bool:
# 只有当变量池中有图片时,LLM才需要使用这个工具 # 只有当变量池中有图片时,LLM才需要使用这个工具
return len(self.list_default_image_variables()) > 0 return len(self.list_default_image_variables()) > 0
``` ```
......
...@@ -146,13 +146,13 @@ from typing import Any, Dict, List, Union ...@@ -146,13 +146,13 @@ from typing import Any, Dict, List, Union
class GoogleSearchTool(BuiltinTool): class GoogleSearchTool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
query = tool_paramters['query'] query = tool_parameters['query']
result_type = tool_paramters['result_type'] result_type = tool_parameters['result_type']
api_key = self.runtime.credentials['serpapi_api_key'] api_key = self.runtime.credentials['serpapi_api_key']
# TODO: search with serpapi # TODO: search with serpapi
result = SerpAPI(api_key).run(query, result_type=result_type) result = SerpAPI(api_key).run(query, result_type=result_type)
...@@ -163,7 +163,7 @@ class GoogleSearchTool(BuiltinTool): ...@@ -163,7 +163,7 @@ class GoogleSearchTool(BuiltinTool):
``` ```
### 参数 ### 参数
工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id``tool_paramters`,分别表示用户ID和工具参数 工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id``tool_parameters`,分别表示用户ID和工具参数
### 返回数据 ### 返回数据
在工具返回时,你可以选择返回一个消息或者多个消息,这里我们返回一个消息,使用`create_text_message``create_link_message`可以创建一个文本消息或者一个链接消息。 在工具返回时,你可以选择返回一个消息或者多个消息,这里我们返回一个消息,使用`create_text_message``create_link_message`可以创建一个文本消息或者一个链接消息。
...@@ -195,7 +195,7 @@ class GoogleProvider(BuiltinToolProviderController): ...@@ -195,7 +195,7 @@ class GoogleProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"query": "test", "query": "test",
"result_type": "link" "result_type": "link"
}, },
......
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, Optional, Any, List from typing import Dict, Optional, Any, List
from core.tools.entities.tool_entities import ToolProviderType, ToolParamter from core.tools.entities.tool_entities import ToolProviderType, ToolParameter
class ApiBasedToolBundle(BaseModel): class ApiBasedToolBundle(BaseModel):
""" """
...@@ -16,7 +16,7 @@ class ApiBasedToolBundle(BaseModel): ...@@ -16,7 +16,7 @@ class ApiBasedToolBundle(BaseModel):
# operation_id # operation_id
operation_id: str = None operation_id: str = None
# parameters # parameters
parameters: Optional[List[ToolParamter]] = None parameters: Optional[List[ToolParameter]] = None
# author # author
author: str author: str
# icon # icon
......
...@@ -89,11 +89,11 @@ class ToolInvokeMessageBinary(BaseModel): ...@@ -89,11 +89,11 @@ class ToolInvokeMessageBinary(BaseModel):
url: str = Field(..., description="The url of the binary") url: str = Field(..., description="The url of the binary")
save_as: str = '' save_as: str = ''
class ToolParamterOption(BaseModel): class ToolParameterOption(BaseModel):
value: str = Field(..., description="The value of the option") value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option") label: I18nObject = Field(..., description="The label of the option")
class ToolParamter(BaseModel): class ToolParameter(BaseModel):
class ToolParameterType(Enum): class ToolParameterType(Enum):
STRING = "string" STRING = "string"
NUMBER = "number" NUMBER = "number"
...@@ -115,12 +115,12 @@ class ToolParamter(BaseModel): ...@@ -115,12 +115,12 @@ class ToolParamter(BaseModel):
default: Optional[str] = None default: Optional[str] = None
min: Optional[Union[float, int]] = None min: Optional[Union[float, int]] = None
max: Optional[Union[float, int]] = None max: Optional[Union[float, int]] = None
options: Optional[List[ToolParamterOption]] = None options: Optional[List[ToolParameterOption]] = None
@classmethod @classmethod
def get_simple_instance(cls, def get_simple_instance(cls,
name: str, llm_description: str, type: ToolParameterType, name: str, llm_description: str, type: ToolParameterType,
required: bool, options: Optional[List[str]] = None) -> 'ToolParamter': required: bool, options: Optional[List[str]] = None) -> 'ToolParameter':
""" """
get a simple tool parameter get a simple tool parameter
...@@ -130,9 +130,9 @@ class ToolParamter(BaseModel): ...@@ -130,9 +130,9 @@ class ToolParamter(BaseModel):
:param required: if the parameter is required :param required: if the parameter is required
:param options: the options of the parameter :param options: the options of the parameter
""" """
# convert options to ToolParamterOption # convert options to ToolParameterOption
if options: if options:
options = [ToolParamterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options] options = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
return cls( return cls(
name=name, name=name,
label=I18nObject(en_US='', zh_Hans=''), label=I18nObject(en_US='', zh_Hans=''),
...@@ -184,7 +184,7 @@ class ToolProviderCredentials(BaseModel): ...@@ -184,7 +184,7 @@ class ToolProviderCredentials(BaseModel):
raise ValueError(f'invalid mode value {value}') raise ValueError(f'invalid mode value {value}')
@staticmethod @staticmethod
def defaut(value: str) -> str: def default(value: str) -> str:
return "" return ""
name: str = Field(..., description="The name of the credentials") name: str = Field(..., description="The name of the credentials")
......
...@@ -4,7 +4,7 @@ from typing import List, Dict, Optional ...@@ -4,7 +4,7 @@ from typing import List, Dict, Optional
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderCredentials from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.tool.tool import ToolParamter from core.tools.tool.tool import ToolParameter
class UserToolProvider(BaseModel): class UserToolProvider(BaseModel):
class ProviderType(Enum): class ProviderType(Enum):
...@@ -46,4 +46,4 @@ class UserTool(BaseModel): ...@@ -46,4 +46,4 @@ class UserTool(BaseModel):
name: str # identifier name: str # identifier
label: I18nObject # label label: I18nObject # label
description: I18nObject description: I18nObject
parameters: Optional[List[ToolParamter]] parameters: Optional[List[ToolParameter]]
\ No newline at end of file \ No newline at end of file
...@@ -4,7 +4,7 @@ class ToolProviderNotFoundError(ValueError): ...@@ -4,7 +4,7 @@ class ToolProviderNotFoundError(ValueError):
class ToolNotFoundError(ValueError): class ToolNotFoundError(ValueError):
pass pass
class ToolParamterValidationError(ValueError): class ToolParameterValidationError(ValueError):
pass pass
class ToolProviderCredentialValidationError(ValueError): class ToolProviderCredentialValidationError(ValueError):
......
...@@ -123,12 +123,12 @@ class ApiBasedToolProviderController(ToolProviderController): ...@@ -123,12 +123,12 @@ class ApiBasedToolProviderController(ToolProviderController):
return self.tools return self.tools
def get_tools(self, user_id: str, tanent_id: str) -> List[ApiTool]: def get_tools(self, user_id: str, tenant_id: str) -> List[ApiTool]:
""" """
fetch tools from database fetch tools from database
:param user_id: the user id :param user_id: the user id
:param tanent_id: the tanent id :param tenant_id: the tenant id
:return: the tools :return: the tools
""" """
if self.tools is not None: if self.tools is not None:
...@@ -136,9 +136,9 @@ class ApiBasedToolProviderController(ToolProviderController): ...@@ -136,9 +136,9 @@ class ApiBasedToolProviderController(ToolProviderController):
tools: List[Tool] = [] tools: List[Tool] = []
# get tanent api providers # get tenant api providers
db_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider).filter( db_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tanent_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == self.identity.name ApiToolProvider.name == self.identity.name
).all() ).all()
......
from typing import Any, Dict, List from typing import Any, Dict, List
from core.tools.entities.tool_entities import ToolProviderType, ToolParamter, ToolParamterOption from core.tools.entities.tool_entities import ToolProviderType, ToolParameter, ToolParameterOption
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
...@@ -71,7 +71,7 @@ class AppBasedToolProviderEntity(ToolProviderController): ...@@ -71,7 +71,7 @@ class AppBasedToolProviderEntity(ToolProviderController):
variable_name = input_form[form_type]['variable_name'] variable_name = input_form[form_type]['variable_name']
options = input_form[form_type].get('options', []) options = input_form[form_type].get('options', [])
if form_type == 'paragraph' or form_type == 'text-input': if form_type == 'paragraph' or form_type == 'text-input':
tool['parameters'].append(ToolParamter( tool['parameters'].append(ToolParameter(
name=variable_name, name=variable_name,
label=I18nObject( label=I18nObject(
en_US=label, en_US=label,
...@@ -82,13 +82,13 @@ class AppBasedToolProviderEntity(ToolProviderController): ...@@ -82,13 +82,13 @@ class AppBasedToolProviderEntity(ToolProviderController):
zh_Hans=label zh_Hans=label
), ),
llm_description=label, llm_description=label,
form=ToolParamter.ToolParameterForm.FORM, form=ToolParameter.ToolParameterForm.FORM,
type=ToolParamter.ToolParameterType.STRING, type=ToolParameter.ToolParameterType.STRING,
required=required, required=required,
default=default default=default
)) ))
elif form_type == 'select': elif form_type == 'select':
tool['parameters'].append(ToolParamter( tool['parameters'].append(ToolParameter(
name=variable_name, name=variable_name,
label=I18nObject( label=I18nObject(
en_US=label, en_US=label,
...@@ -99,11 +99,11 @@ class AppBasedToolProviderEntity(ToolProviderController): ...@@ -99,11 +99,11 @@ class AppBasedToolProviderEntity(ToolProviderController):
zh_Hans=label zh_Hans=label
), ),
llm_description=label, llm_description=label,
form=ToolParamter.ToolParameterForm.FORM, form=ToolParameter.ToolParameterForm.FORM,
type=ToolParamter.ToolParameterType.SELECT, type=ToolParameter.ToolParameterType.SELECT,
required=required, required=required,
default=default, default=default,
options=[ToolParamterOption( options=[ToolParameterOption(
value=option, value=option,
label=I18nObject( label=I18nObject(
en_US=option, en_US=option,
......
...@@ -13,7 +13,7 @@ class AzureDALLEProvider(BuiltinToolProviderController): ...@@ -13,7 +13,7 @@ class AzureDALLEProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"prompt": "cute girl, blue eyes, white hair, anime style", "prompt": "cute girl, blue eyes, white hair, anime style",
"size": "square", "size": "square",
"n": 1 "n": 1
......
...@@ -2,13 +2,13 @@ identity: ...@@ -2,13 +2,13 @@ identity:
author: Leslie author: Leslie
name: azuredalle name: azuredalle
label: label:
en_US: AzureDALL-E en_US: Azure DALL-E
zh_Hans: AzureDALL-E 绘画 zh_Hans: Azure DALL-E 绘画
pt_BR: AzureDALL-E pt_BR: Azure DALL-E
description: description:
en_US: AZURE DALL-E art en_US: Azure DALL-E art
zh_Hans: AZURE DALL-E 绘画 zh_Hans: Azure DALL-E 绘画
pt_BR: AZURE DALL-E art pt_BR: Azure DALL-E art
icon: icon.png icon: icon.png
credentials_for_provider: credentials_for_provider:
azure_openai_api_key: azure_openai_api_key:
...@@ -21,26 +21,26 @@ credentials_for_provider: ...@@ -21,26 +21,26 @@ credentials_for_provider:
help: help:
en_US: Please input your Azure OpenAI API key en_US: Please input your Azure OpenAI API key
zh_Hans: 请输入你的 Azure OpenAI API key zh_Hans: 请输入你的 Azure OpenAI API key
pt_BR: Please input your Azure OpenAI API key pt_BR: Introduza a sua chave de API OpenAI do Azure
placeholder: placeholder:
en_US: Please input your Azure OpenAI API key en_US: Please input your Azure OpenAI API key
zh_Hans: 请输入你的 Azure OpenAI API key zh_Hans: 请输入你的 Azure OpenAI API key
pt_BR: Please input your Azure OpenAI API key pt_BR: Introduza a sua chave de API OpenAI do Azure
azure_openai_api_model_name: azure_openai_api_model_name:
type: text-input type: text-input
required: true required: true
label: label:
en_US: Deployment Name en_US: Deployment Name
zh_Hans: 部署名称 zh_Hans: 部署名称
pt_BR: Deployment Name pt_BR: Nome da Implantação
help: help:
en_US: Please input the name of your Azure Openai DALL-E API deployment en_US: Please input the name of your Azure Openai DALL-E API deployment
zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称 zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
pt_BR: Please input the name of your Azure Openai DALL-E API deployment pt_BR: Insira o nome da implantação da API DALL-E do Azure Openai
placeholder: placeholder:
en_US: Please input the name of your Azure Openai DALL-E API deployment en_US: Please input the name of your Azure Openai DALL-E API deployment
zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称 zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
pt_BR: Please input the name of your Azure Openai DALL-E API deployment pt_BR: Insira o nome da implantação da API DALL-E do Azure Openai
azure_openai_base_url: azure_openai_base_url:
type: text-input type: text-input
required: true required: true
...@@ -49,13 +49,13 @@ credentials_for_provider: ...@@ -49,13 +49,13 @@ credentials_for_provider:
zh_Hans: API 域名 zh_Hans: API 域名
pt_BR: API Endpoint URL pt_BR: API Endpoint URL
help: help:
en_US: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/ en_US: Please input your Azure OpenAI Endpoint URL, e.g. https://xxx.openai.azure.com/
zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/ zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
pt_BR: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/ pt_BR: Introduza a URL do Azure OpenAI Endpoint, e.g. https://xxx.openai.azure.com/
placeholder: placeholder:
en_US: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/ en_US: Please input your Azure OpenAI Endpoint URL, e.g. https://xxx.openai.azure.com/
zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/ zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
pt_BR: Please input your Azure OpenAI Endpoint URL,eg:https://xxx.openai.azure.com/ pt_BR: Introduza a URL do Azure OpenAI Endpoint, e.g. https://xxx.openai.azure.com/
azure_openai_api_version: azure_openai_api_version:
type: text-input type: text-input
required: true required: true
...@@ -64,10 +64,10 @@ credentials_for_provider: ...@@ -64,10 +64,10 @@ credentials_for_provider:
zh_Hans: API 版本 zh_Hans: API 版本
pt_BR: API Version pt_BR: API Version
help: help:
en_US: Please input your Azure OpenAI API Version,eg:2023-12-01-preview en_US: Please input your Azure OpenAI API Version,e.g. 2023-12-01-preview
zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
pt_BR: Please input your Azure OpenAI API Version,eg:2023-12-01-preview pt_BR: Introduza a versão da API OpenAI do Azure,e.g. 2023-12-01-preview
placeholder: placeholder:
en_US: Please input your Azure OpenAI API Version,eg:2023-12-01-preview en_US: Please input your Azure OpenAI API Version,e.g. 2023-12-01-preview
zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
pt_BR: Please input your Azure OpenAI API Version,eg:2023-12-01-preview pt_BR: Introduza a versão da API OpenAI do Azure,e.g. 2023-12-01-preview
...@@ -10,7 +10,7 @@ from openai import AzureOpenAI ...@@ -10,7 +10,7 @@ from openai import AzureOpenAI
class DallE3Tool(BuiltinTool): class DallE3Tool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
...@@ -28,19 +28,19 @@ class DallE3Tool(BuiltinTool): ...@@ -28,19 +28,19 @@ class DallE3Tool(BuiltinTool):
} }
# prompt # prompt
prompt = tool_paramters.get('prompt', '') prompt = tool_parameters.get('prompt', '')
if not prompt: if not prompt:
return self.create_text_message('Please input prompt') return self.create_text_message('Please input prompt')
# get size # get size
size = SIZE_MAPPING[tool_paramters.get('size', 'square')] size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
# get n # get n
n = tool_paramters.get('n', 1) n = tool_parameters.get('n', 1)
# get quality # get quality
quality = tool_paramters.get('quality', 'standard') quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']: if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality') return self.create_text_message('Invalid quality')
# get style # get style
style = tool_paramters.get('style', 'vivid') style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']: if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style') return self.create_text_message('Invalid style')
......
identity: identity:
name: dalle3 name: azure_dalle3
author: Leslie author: Leslie
label: label:
en_US: DALL-E 3 en_US: Azure DALL-E 3
zh_Hans: DALL-E 3 绘画 zh_Hans: Azure DALL-E 3 绘画
pt_BR: DALL-E 3 pt_BR: Azure DALL-E 3
description: description:
en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
zh_Hans: DALL-E 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像,相比于DallE 2, DallE 3拥有更强的绘画能力,但会消耗更多的资源 zh_Hans: DALL-E 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像,相比于DallE 2, DallE 3拥有更强的绘画能力,但会消耗更多的资源
pt_BR: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources pt_BR: DALL-E 3 é uma poderosa ferramenta de desenho que pode desenhar a imagem que você deseja com base em seu prompt, em comparação com DallE 2, DallE 3 tem uma capacidade de desenho mais forte, mas consumirá mais recursos
description: description:
human: human:
en_US: DALL-E is a text to image tool en_US: DALL-E is a text to image tool
zh_Hans: DALL-E 是一个文本到图像的工具 zh_Hans: DALL-E 是一个文本到图像的工具
pt_BR: DALL-E is a text to image tool pt_BR: DALL-E é uma ferramenta de texto para imagem
llm: DALL-E is a tool used to generate images from text llm: DALL-E is a tool used to generate images from text
parameters: parameters:
- name: prompt - name: prompt
...@@ -26,7 +26,7 @@ parameters: ...@@ -26,7 +26,7 @@ parameters:
human_description: human_description:
en_US: Image prompt, you can check the official documentation of DallE 3 en_US: Image prompt, you can check the official documentation of DallE 3
zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档 zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档
pt_BR: Image prompt, you can check the official documentation of DallE 3 pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
form: llm form: llm
- name: size - name: size
...@@ -35,18 +35,18 @@ parameters: ...@@ -35,18 +35,18 @@ parameters:
human_description: human_description:
en_US: selecting the image size en_US: selecting the image size
zh_Hans: 选择图像大小 zh_Hans: 选择图像大小
pt_BR: selecting the image size pt_BR: seleccionar o tamanho da imagem
label: label:
en_US: Image size en_US: Image size
zh_Hans: 图像大小 zh_Hans: 图像大小
pt_BR: Image size pt_BR: Tamanho da imagem
form: form form: form
options: options:
- value: square - value: square
label: label:
en_US: Squre(1024x1024) en_US: Squre(1024x1024)
zh_Hans: 方(1024x1024) zh_Hans: 方(1024x1024)
pt_BR: Squre(1024x1024) pt_BR: Squire(1024x1024)
- value: vertical - value: vertical
label: label:
en_US: Vertical(1024x1792) en_US: Vertical(1024x1792)
...@@ -64,11 +64,11 @@ parameters: ...@@ -64,11 +64,11 @@ parameters:
human_description: human_description:
en_US: selecting the number of images en_US: selecting the number of images
zh_Hans: 选择图像数量 zh_Hans: 选择图像数量
pt_BR: selecting the number of images pt_BR: seleccionar o número de imagens
label: label:
en_US: Number of images en_US: Number of images
zh_Hans: 图像数量 zh_Hans: 图像数量
pt_BR: Number of images pt_BR: Número de imagens
form: form form: form
min: 1 min: 1
max: 1 max: 1
...@@ -79,18 +79,18 @@ parameters: ...@@ -79,18 +79,18 @@ parameters:
human_description: human_description:
en_US: selecting the image quality en_US: selecting the image quality
zh_Hans: 选择图像质量 zh_Hans: 选择图像质量
pt_BR: selecting the image quality pt_BR: seleccionar a qualidade da imagem
label: label:
en_US: Image quality en_US: Image quality
zh_Hans: 图像质量 zh_Hans: 图像质量
pt_BR: Image quality pt_BR: Qualidade da imagem
form: form form: form
options: options:
- value: standard - value: standard
label: label:
en_US: Standard en_US: Standard
zh_Hans: 标准 zh_Hans: 标准
pt_BR: Standard pt_BR: Normal
- value: hd - value: hd
label: label:
en_US: HD en_US: HD
...@@ -103,18 +103,18 @@ parameters: ...@@ -103,18 +103,18 @@ parameters:
human_description: human_description:
en_US: selecting the image style en_US: selecting the image style
zh_Hans: 选择图像风格 zh_Hans: 选择图像风格
pt_BR: selecting the image style pt_BR: seleccionar o estilo da imagem
label: label:
en_US: Image style en_US: Image style
zh_Hans: 图像风格 zh_Hans: 图像风格
pt_BR: Image style pt_BR: Estilo da imagem
form: form form: form
options: options:
- value: vivid - value: vivid
label: label:
en_US: Vivid en_US: Vivid
zh_Hans: 生动 zh_Hans: 生动
pt_BR: Vivid pt_BR: Vívido
- value: natural - value: natural
label: label:
en_US: Natural en_US: Natural
......
...@@ -16,7 +16,7 @@ class ChartProvider(BuiltinToolProviderController): ...@@ -16,7 +16,7 @@ class ChartProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"data": "1,3,5,7,9,2,4,6,8,10", "data": "1,3,5,7,9,2,4,6,8,10",
}, },
) )
......
...@@ -6,9 +6,9 @@ import io ...@@ -6,9 +6,9 @@ import io
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
class BarChartTool(BuiltinTool): class BarChartTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
data = tool_paramters.get('data', '') data = tool_parameters.get('data', '')
if not data: if not data:
return self.create_text_message('Please input data') return self.create_text_message('Please input data')
data = data.split(';') data = data.split(';')
...@@ -19,7 +19,7 @@ class BarChartTool(BuiltinTool): ...@@ -19,7 +19,7 @@ class BarChartTool(BuiltinTool):
else: else:
data = [float(i) for i in data] data = [float(i) for i in data]
axis = tool_paramters.get('x_axis', None) or None axis = tool_parameters.get('x_axis', None) or None
if axis: if axis:
axis = axis.split(';') axis = axis.split(';')
if len(axis) != len(data): if len(axis) != len(data):
......
...@@ -8,14 +8,14 @@ from typing import Any, Dict, List, Union ...@@ -8,14 +8,14 @@ from typing import Any, Dict, List, Union
class LinearChartTool(BuiltinTool): class LinearChartTool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
data = tool_paramters.get('data', '') data = tool_parameters.get('data', '')
if not data: if not data:
return self.create_text_message('Please input data') return self.create_text_message('Please input data')
data = data.split(';') data = data.split(';')
axis = tool_paramters.get('x_axis', None) or None axis = tool_parameters.get('x_axis', None) or None
if axis: if axis:
axis = axis.split(';') axis = axis.split(';')
if len(axis) != len(data): if len(axis) != len(data):
......
...@@ -8,13 +8,13 @@ from typing import Any, Dict, List, Union ...@@ -8,13 +8,13 @@ from typing import Any, Dict, List, Union
class PieChartTool(BuiltinTool): class PieChartTool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
data = tool_paramters.get('data', '') data = tool_parameters.get('data', '')
if not data: if not data:
return self.create_text_message('Please input data') return self.create_text_message('Please input data')
data = data.split(';') data = data.split(';')
categories = tool_paramters.get('categories', None) or None categories = tool_parameters.get('categories', None) or None
# if all data is int, convert to int # if all data is int, convert to int
if all([i.isdigit() for i in data]): if all([i.isdigit() for i in data]):
......
...@@ -13,7 +13,7 @@ class DALLEProvider(BuiltinToolProviderController): ...@@ -13,7 +13,7 @@ class DALLEProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"prompt": "cute girl, blue eyes, white hair, anime style", "prompt": "cute girl, blue eyes, white hair, anime style",
"size": "small", "size": "small",
"n": 1 "n": 1
......
...@@ -10,7 +10,7 @@ from openai import OpenAI ...@@ -10,7 +10,7 @@ from openai import OpenAI
class DallE2Tool(BuiltinTool): class DallE2Tool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
...@@ -37,15 +37,15 @@ class DallE2Tool(BuiltinTool): ...@@ -37,15 +37,15 @@ class DallE2Tool(BuiltinTool):
} }
# prompt # prompt
prompt = tool_paramters.get('prompt', '') prompt = tool_parameters.get('prompt', '')
if not prompt: if not prompt:
return self.create_text_message('Please input prompt') return self.create_text_message('Please input prompt')
# get size # get size
size = SIZE_MAPPING[tool_paramters.get('size', 'large')] size = SIZE_MAPPING[tool_parameters.get('size', 'large')]
# get n # get n
n = tool_paramters.get('n', 1) n = tool_parameters.get('n', 1)
# call openapi dalle2 # call openapi dalle2
response = client.images.generate( response = client.images.generate(
......
...@@ -10,7 +10,7 @@ from openai import OpenAI ...@@ -10,7 +10,7 @@ from openai import OpenAI
class DallE3Tool(BuiltinTool): class DallE3Tool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
...@@ -37,19 +37,19 @@ class DallE3Tool(BuiltinTool): ...@@ -37,19 +37,19 @@ class DallE3Tool(BuiltinTool):
} }
# prompt # prompt
prompt = tool_paramters.get('prompt', '') prompt = tool_parameters.get('prompt', '')
if not prompt: if not prompt:
return self.create_text_message('Please input prompt') return self.create_text_message('Please input prompt')
# get size # get size
size = SIZE_MAPPING[tool_paramters.get('size', 'square')] size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
# get n # get n
n = tool_paramters.get('n', 1) n = tool_parameters.get('n', 1)
# get quality # get quality
quality = tool_paramters.get('quality', 'standard') quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']: if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality') return self.create_text_message('Invalid quality')
# get style # get style
style = tool_paramters.get('style', 'vivid') style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']: if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style') return self.create_text_message('Invalid style')
......
identity: identity:
author: CharlirWei author: CharlieWei
name: gaode name: gaode
label: label:
en_US: GaoDe en_US: Autonavi
zh_Hans: 高德 zh_Hans: 高德
pt_BR: GaoDe pt_BR: Autonavi
description: description:
en_US: Autonavi Open Platform service toolkit. en_US: Autonavi Open Platform service toolkit.
zh_Hans: 高德开放平台服务工具包。 zh_Hans: 高德开放平台服务工具包。
...@@ -19,11 +19,11 @@ credentials_for_provider: ...@@ -19,11 +19,11 @@ credentials_for_provider:
zh_Hans: API Key zh_Hans: API Key
pt_BR: Fogo a chave pt_BR: Fogo a chave
placeholder: placeholder:
en_US: Please enter your GaoDe API Key en_US: Please enter your Autonavi API Key
zh_Hans: 请输入你的高德开放平台 API Key zh_Hans: 请输入你的高德开放平台 API Key
pt_BR: Insira sua chave de API GaoDe pt_BR: Insira sua chave de API Autonavi
help: help:
en_US: Get your API Key from GaoDe en_US: Get your API Key from Autonavi
zh_Hans: 从高德获取您的 API Key zh_Hans: 从高德获取您的 API Key
pt_BR: Obtenha sua chave de API do GaoDe pt_BR: Obtenha sua chave de API do Autonavi
url: https://console.amap.com/dev/key/app url: https://console.amap.com/dev/key/app
...@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Union ...@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Union
class GaodeRepositoriesTool(BuiltinTool): class GaodeRepositoriesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
city = tool_paramters.get('city', '') city = tool_parameters.get('city', '')
if not city: if not city:
return self.create_text_message('Please tell me your city') return self.create_text_message('Please tell me your city')
......
identity: identity:
author: CharlirWei author: CharlieWei
name: github name: github
label: label:
en_US: Github en_US: Github
......
...@@ -9,12 +9,12 @@ from typing import Any, Dict, List, Union ...@@ -9,12 +9,12 @@ from typing import Any, Dict, List, Union
class GihubRepositoriesTool(BuiltinTool): class GihubRepositoriesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
top_n = tool_paramters.get('top_n', 5) top_n = tool_parameters.get('top_n', 5)
query = tool_paramters.get('query', '') query = tool_parameters.get('query', '')
if not query: if not query:
return self.create_text_message('Please input symbol') return self.create_text_message('Please input symbol')
......
identity: identity:
name: repositories name: github_repositories
author: CharlieWei author: CharlieWei
label: label:
en_US: Search Repositories en_US: Search Repositories
...@@ -24,7 +24,7 @@ parameters: ...@@ -24,7 +24,7 @@ parameters:
en_US: You want to find the project development language, keywords, For example. Find 10 Python developed PDF document parsing projects. en_US: You want to find the project development language, keywords, For example. Find 10 Python developed PDF document parsing projects.
zh_Hans: 你想要找的项目开发语言、关键字,如:找10个Python开发的PDF文档解析项目。 zh_Hans: 你想要找的项目开发语言、关键字,如:找10个Python开发的PDF文档解析项目。
pt_BR: Você deseja encontrar a linguagem de desenvolvimento do projeto, palavras-chave, Por exemplo. Encontre 10 projetos de análise de documentos PDF desenvolvidos em Python. pt_BR: Você deseja encontrar a linguagem de desenvolvimento do projeto, palavras-chave, Por exemplo. Encontre 10 projetos de análise de documentos PDF desenvolvidos em Python.
llm_description: The query of you want to search, format query condition like "keywords+language:js", language can be other dev languages, por exemplo. Procuro um projeto de análise de documentos PDF desenvolvido em Python. llm_description: The query of you want to search, format query condition like "keywords+language:js", language can be other dev languages.
form: llm form: llm
- name: top_n - name: top_n
type: number type: number
......
...@@ -14,7 +14,7 @@ class GoogleProvider(BuiltinToolProviderController): ...@@ -14,7 +14,7 @@ class GoogleProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"query": "test", "query": "test",
"result_type": "link" "result_type": "link"
}, },
......
...@@ -148,13 +148,13 @@ class SerpAPI: ...@@ -148,13 +148,13 @@ class SerpAPI:
class GoogleSearchTool(BuiltinTool): class GoogleSearchTool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
query = tool_paramters['query'] query = tool_parameters['query']
result_type = tool_paramters['result_type'] result_type = tool_parameters['result_type']
api_key = self.runtime.credentials['serpapi_api_key'] api_key = self.runtime.credentials['serpapi_api_key']
result = SerpAPI(api_key).run(query, result_type=result_type) result = SerpAPI(api_key).run(query, result_type=result_type)
if result_type == 'text': if result_type == 'text':
......
...@@ -14,7 +14,7 @@ class StableDiffusionProvider(BuiltinToolProviderController): ...@@ -14,7 +14,7 @@ class StableDiffusionProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"prompt": "cat", "prompt": "cat",
"lora": "", "lora": "",
"steps": 1, "steps": 1,
......
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter, ToolParamterOption from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
...@@ -60,7 +60,7 @@ DRAW_TEXT_OPTIONS = { ...@@ -60,7 +60,7 @@ DRAW_TEXT_OPTIONS = {
} }
class StableDiffusionTool(BuiltinTool): class StableDiffusionTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
...@@ -86,25 +86,25 @@ class StableDiffusionTool(BuiltinTool): ...@@ -86,25 +86,25 @@ class StableDiffusionTool(BuiltinTool):
# prompt # prompt
prompt = tool_paramters.get('prompt', '') prompt = tool_parameters.get('prompt', '')
if not prompt: if not prompt:
return self.create_text_message('Please input prompt') return self.create_text_message('Please input prompt')
# get negative prompt # get negative prompt
negative_prompt = tool_paramters.get('negative_prompt', '') negative_prompt = tool_parameters.get('negative_prompt', '')
# get size # get size
width = tool_paramters.get('width', 1024) width = tool_parameters.get('width', 1024)
height = tool_paramters.get('height', 1024) height = tool_parameters.get('height', 1024)
# get steps # get steps
steps = tool_paramters.get('steps', 1) steps = tool_parameters.get('steps', 1)
# get lora # get lora
lora = tool_paramters.get('lora', '') lora = tool_parameters.get('lora', '')
# get image id # get image id
image_id = tool_paramters.get('image_id', '') image_id = tool_parameters.get('image_id', '')
if image_id.strip(): if image_id.strip():
image_variable = self.get_default_image_variable() image_variable = self.get_default_image_variable()
if image_variable: if image_variable:
...@@ -188,6 +188,8 @@ class StableDiffusionTool(BuiltinTool): ...@@ -188,6 +188,8 @@ class StableDiffusionTool(BuiltinTool):
if lora: if lora:
draw_options['prompt'] = f'{lora},{prompt}' draw_options['prompt'] = f'{lora},{prompt}'
else:
draw_options['prompt'] = prompt
draw_options['width'] = width draw_options['width'] = width
draw_options['height'] = height draw_options['height'] = height
...@@ -210,32 +212,32 @@ class StableDiffusionTool(BuiltinTool): ...@@ -210,32 +212,32 @@ class StableDiffusionTool(BuiltinTool):
return self.create_text_message('Failed to generate image') return self.create_text_message('Failed to generate image')
def get_runtime_parameters(self) -> List[ToolParamter]: def get_runtime_parameters(self) -> List[ToolParameter]:
parameters = [ parameters = [
ToolParamter(name='prompt', ToolParameter(name='prompt',
label=I18nObject(en_US='Prompt', zh_Hans='Prompt'), label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
human_description=I18nObject( human_description=I18nObject(
en_US='Image prompt, you can check the official documentation of Stable Diffusion', en_US='Image prompt, you can check the official documentation of Stable Diffusion',
zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档', zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档',
), ),
type=ToolParamter.ToolParameterType.STRING, type=ToolParameter.ToolParameterType.STRING,
form=ToolParamter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.', llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.',
required=True), required=True),
] ]
if len(self.list_default_image_variables()) != 0: if len(self.list_default_image_variables()) != 0:
parameters.append( parameters.append(
ToolParamter(name='image_id', ToolParameter(name='image_id',
label=I18nObject(en_US='image_id', zh_Hans='image_id'), label=I18nObject(en_US='image_id', zh_Hans='image_id'),
human_description=I18nObject( human_description=I18nObject(
en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.', en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',
zh_Hans='您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。', zh_Hans='您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。',
), ),
type=ToolParamter.ToolParameterType.STRING, type=ToolParameter.ToolParameterType.STRING,
form=ToolParamter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.', llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.',
required=True, required=True,
options=[ToolParamterOption( options=[ToolParameterOption(
value=i.name, value=i.name,
label=I18nObject(en_US=i.name, zh_Hans=i.name) label=I18nObject(en_US=i.name, zh_Hans=i.name)
) for i in self.list_default_image_variables()]) ) for i in self.list_default_image_variables()])
......
...@@ -10,7 +10,7 @@ class WikiPediaProvider(BuiltinToolProviderController): ...@@ -10,7 +10,7 @@ class WikiPediaProvider(BuiltinToolProviderController):
try: try:
CurrentTimeTool().invoke( CurrentTimeTool().invoke(
user_id='', user_id='',
tool_paramters={}, tool_parameters={},
) )
except Exception as e: except Exception as e:
raise ToolProviderCredentialValidationError(str(e)) raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
...@@ -8,7 +8,7 @@ from datetime import datetime, timezone ...@@ -8,7 +8,7 @@ from datetime import datetime, timezone
class CurrentTimeTool(BuiltinTool): class CurrentTimeTool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
......
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
...@@ -8,21 +8,21 @@ from httpx import post ...@@ -8,21 +8,21 @@ from httpx import post
from base64 import b64decode from base64 import b64decode
class VectorizerTool(BuiltinTool): class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
api_key_name = self.runtime.credentials.get('api_key_name', None) api_key_name = self.runtime.credentials.get('api_key_name', None)
api_key_value = self.runtime.credentials.get('api_key_value', None) api_key_value = self.runtime.credentials.get('api_key_value', None)
mode = tool_paramters.get('mode', 'test') mode = tool_parameters.get('mode', 'test')
if mode == 'production': if mode == 'production':
mode = 'preview' mode = 'preview'
if not api_key_name or not api_key_value: if not api_key_name or not api_key_value:
raise ToolProviderCredentialValidationError('Please input api key name and value') raise ToolProviderCredentialValidationError('Please input api key name and value')
image_id = tool_paramters.get('image_id', '') image_id = tool_parameters.get('image_id', '')
if not image_id: if not image_id:
return self.create_text_message('Please input image id') return self.create_text_message('Please input image id')
...@@ -54,21 +54,21 @@ class VectorizerTool(BuiltinTool): ...@@ -54,21 +54,21 @@ class VectorizerTool(BuiltinTool):
meta={'mime_type': 'image/svg+xml'}) meta={'mime_type': 'image/svg+xml'})
] ]
def get_runtime_parameters(self) -> List[ToolParamter]: def get_runtime_parameters(self) -> List[ToolParameter]:
""" """
override the runtime parameters override the runtime parameters
""" """
return [ return [
ToolParamter.get_simple_instance( ToolParameter.get_simple_instance(
name='image_id', name='image_id',
llm_description=f'the image id that you want to vectorize, \ llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \ and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}', {[i.name for i in self.list_default_image_variables()]}',
type=ToolParamter.ToolParameterType.SELECT, type=ToolParameter.ToolParameterType.SELECT,
required=True, required=True,
options=[i.name for i in self.list_default_image_variables()] options=[i.name for i in self.list_default_image_variables()]
) )
] ]
def is_tool_avaliable(self) -> bool: def is_tool_available(self) -> bool:
return len(self.list_default_image_variables()) > 0 return len(self.list_default_image_variables()) > 0
\ No newline at end of file
...@@ -14,7 +14,7 @@ class VectorizerProvider(BuiltinToolProviderController): ...@@ -14,7 +14,7 @@ class VectorizerProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"mode": "test", "mode": "test",
"image_id": "__test_123" "image_id": "__test_123"
}, },
......
...@@ -7,14 +7,14 @@ from typing import Any, Dict, List, Union ...@@ -7,14 +7,14 @@ from typing import Any, Dict, List, Union
class WebscraperTool(BuiltinTool): class WebscraperTool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
try: try:
url = tool_paramters.get('url', '') url = tool_parameters.get('url', '')
user_agent = tool_paramters.get('user_agent', '') user_agent = tool_parameters.get('user_agent', '')
if not url: if not url:
return self.create_text_message('Please input url') return self.create_text_message('Please input url')
......
...@@ -14,7 +14,7 @@ class WebscraperProvider(BuiltinToolProviderController): ...@@ -14,7 +14,7 @@ class WebscraperProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
'url': 'https://www.google.com', 'url': 'https://www.google.com',
'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' 'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
}, },
......
...@@ -14,12 +14,12 @@ class WikipediaInput(BaseModel): ...@@ -14,12 +14,12 @@ class WikipediaInput(BaseModel):
class WikiPediaSearchTool(BuiltinTool): class WikiPediaSearchTool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
query = tool_paramters.get('query', '') query = tool_parameters.get('query', '')
if not query: if not query:
return self.create_text_message('Please input query') return self.create_text_message('Please input query')
......
...@@ -12,7 +12,7 @@ class WikiPediaProvider(BuiltinToolProviderController): ...@@ -12,7 +12,7 @@ class WikiPediaProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"query": "misaka mikoto", "query": "misaka mikoto",
}, },
) )
......
...@@ -11,12 +11,12 @@ class WolframAlphaTool(BuiltinTool): ...@@ -11,12 +11,12 @@ class WolframAlphaTool(BuiltinTool):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_paramters: Dict[str, Any], tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
query = tool_paramters.get('query', '') query = tool_parameters.get('query', '')
if not query: if not query:
return self.create_text_message('Please input query') return self.create_text_message('Please input query')
appid = self.runtime.credentials.get('appid', '') appid = self.runtime.credentials.get('appid', '')
......
...@@ -16,7 +16,7 @@ class GoogleProvider(BuiltinToolProviderController): ...@@ -16,7 +16,7 @@ class GoogleProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"query": "1+2+....+111", "query": "1+2+....+111",
}, },
) )
......
...@@ -9,23 +9,23 @@ from yfinance import download ...@@ -9,23 +9,23 @@ from yfinance import download
import pandas as pd import pandas as pd
class YahooFinanceAnalyticsTool(BuiltinTool): class YahooFinanceAnalyticsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
symbol = tool_paramters.get('symbol', '') symbol = tool_parameters.get('symbol', '')
if not symbol: if not symbol:
return self.create_text_message('Please input symbol') return self.create_text_message('Please input symbol')
time_range = [None, None] time_range = [None, None]
start_date = tool_paramters.get('start_date', '') start_date = tool_parameters.get('start_date', '')
if start_date: if start_date:
time_range[0] = start_date time_range[0] = start_date
else: else:
time_range[0] = '1800-01-01' time_range[0] = '1800-01-01'
end_date = tool_paramters.get('end_date', '') end_date = tool_parameters.get('end_date', '')
if end_date: if end_date:
time_range[1] = end_date time_range[1] = end_date
else: else:
......
...@@ -7,13 +7,13 @@ from requests.exceptions import HTTPError, ReadTimeout ...@@ -7,13 +7,13 @@ from requests.exceptions import HTTPError, ReadTimeout
import yfinance import yfinance
class YahooFinanceSearchTickerTool(BuiltinTool): class YahooFinanceSearchTickerTool(BuiltinTool):
def _invoke(self,user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self,user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
''' '''
invoke tools invoke tools
''' '''
query = tool_paramters.get('symbol', '') query = tool_parameters.get('symbol', '')
if not query: if not query:
return self.create_text_message('Please input symbol') return self.create_text_message('Please input symbol')
......
...@@ -7,12 +7,12 @@ from requests.exceptions import HTTPError, ReadTimeout ...@@ -7,12 +7,12 @@ from requests.exceptions import HTTPError, ReadTimeout
from yfinance import Ticker from yfinance import Ticker
class YahooFinanceSearchTickerTool(BuiltinTool): class YahooFinanceSearchTickerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
query = tool_paramters.get('symbol', '') query = tool_parameters.get('symbol', '')
if not query: if not query:
return self.create_text_message('Please input symbol') return self.create_text_message('Please input symbol')
......
...@@ -12,7 +12,7 @@ class YahooFinanceProvider(BuiltinToolProviderController): ...@@ -12,7 +12,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"ticker": "MSFT", "ticker": "MSFT",
}, },
) )
......
...@@ -7,23 +7,23 @@ from datetime import datetime ...@@ -7,23 +7,23 @@ from datetime import datetime
from googleapiclient.discovery import build from googleapiclient.discovery import build
class YoutubeVideosAnalyticsTool(BuiltinTool): class YoutubeVideosAnalyticsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
""" """
invoke tools invoke tools
""" """
channel = tool_paramters.get('channel', '') channel = tool_parameters.get('channel', '')
if not channel: if not channel:
return self.create_text_message('Please input symbol') return self.create_text_message('Please input symbol')
time_range = [None, None] time_range = [None, None]
start_date = tool_paramters.get('start_date', '') start_date = tool_parameters.get('start_date', '')
if start_date: if start_date:
time_range[0] = start_date time_range[0] = start_date
else: else:
time_range[0] = '1800-01-01' time_range[0] = '1800-01-01'
end_date = tool_paramters.get('end_date', '') end_date = tool_parameters.get('end_date', '')
if end_date: if end_date:
time_range[1] = end_date time_range[1] = end_date
else: else:
......
...@@ -12,7 +12,7 @@ class YahooFinanceProvider(BuiltinToolProviderController): ...@@ -12,7 +12,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
} }
).invoke( ).invoke(
user_id='', user_id='',
tool_paramters={ tool_parameters={
"channel": "TOKYO GIRLS COLLECTION", "channel": "TOKYO GIRLS COLLECTION",
"start_date": "2020-01-01", "start_date": "2020-01-01",
"end_date": "2024-12-31", "end_date": "2024-12-31",
......
...@@ -5,13 +5,13 @@ from os import path, listdir ...@@ -5,13 +5,13 @@ from os import path, listdir
from yaml import load, FullLoader from yaml import load, FullLoader
from core.tools.entities.tool_entities import ToolProviderType, \ from core.tools.entities.tool_entities import ToolProviderType, \
ToolParamter, ToolProviderCredentials ToolParameter, ToolProviderCredentials
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.entities.user_entities import UserToolProviderCredentials from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError, \ from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError, \
ToolParamterValidationError, ToolProviderCredentialValidationError ToolParameterValidationError, ToolProviderCredentialValidationError
import importlib import importlib
...@@ -109,7 +109,7 @@ class BuiltinToolProviderController(ToolProviderController): ...@@ -109,7 +109,7 @@ class BuiltinToolProviderController(ToolProviderController):
""" """
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
def get_parameters(self, tool_name: str) -> List[ToolParamter]: def get_parameters(self, tool_name: str) -> List[ToolParameter]:
""" """
returns the parameters of the tool returns the parameters of the tool
...@@ -148,62 +148,62 @@ class BuiltinToolProviderController(ToolProviderController): ...@@ -148,62 +148,62 @@ class BuiltinToolProviderController(ToolProviderController):
""" """
tool_parameters_schema = self.get_parameters(tool_name) tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: Dict[str, ToolParamter] = {} tool_parameters_need_to_validate: Dict[str, ToolParameter] = {}
for parameter in tool_parameters_schema: for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters: for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate: if parameter not in tool_parameters_need_to_validate:
raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}') raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type # check type
parameter_schema = tool_parameters_need_to_validate[parameter] parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParamter.ToolParameterType.STRING: if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str): if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string') raise ToolParameterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER: elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], (int, float)): if not isinstance(tool_parameters[parameter], (int, float)):
raise ToolParamterValidationError(f'parameter {parameter} should be number') raise ToolParameterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN: elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool): if not isinstance(tool_parameters[parameter], bool):
raise ToolParamterValidationError(f'parameter {parameter} should be boolean') raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT: elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str): if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string') raise ToolParameterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options options = parameter_schema.options
if not isinstance(options, list): if not isinstance(options, list):
raise ToolParamterValidationError(f'parameter {parameter} options should be list') raise ToolParameterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]: if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}') raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter) tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate: for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter] parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required: if parameter_schema.required:
raise ToolParamterValidationError(f'parameter {parameter} is required') raise ToolParameterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed # the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None: if parameter_schema.default is not None:
default_value = parameter_schema.default default_value = parameter_schema.default
# parse default value into the correct type # parse default value into the correct type
if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \ if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
parameter_schema.type == ToolParamter.ToolParameterType.SELECT: parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
default_value = str(default_value) default_value = str(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER: elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
default_value = float(default_value) default_value = float(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN: elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
default_value = bool(default_value) default_value = bool(default_value)
tool_parameters[parameter] = default_value tool_parameters[parameter] = default_value
......
...@@ -4,11 +4,11 @@ from typing import List, Dict, Any, Optional ...@@ -4,11 +4,11 @@ from typing import List, Dict, Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolProviderType, \ from core.tools.entities.tool_entities import ToolProviderType, \
ToolProviderIdentity, ToolParamter, ToolProviderCredentials ToolProviderIdentity, ToolParameter, ToolProviderCredentials
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.entities.user_entities import UserToolProviderCredentials from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ToolNotFoundError, \ from core.tools.errors import ToolNotFoundError, \
ToolParamterValidationError, ToolProviderCredentialValidationError ToolParameterValidationError, ToolProviderCredentialValidationError
class ToolProviderController(BaseModel, ABC): class ToolProviderController(BaseModel, ABC):
identity: Optional[ToolProviderIdentity] = None identity: Optional[ToolProviderIdentity] = None
...@@ -50,7 +50,7 @@ class ToolProviderController(BaseModel, ABC): ...@@ -50,7 +50,7 @@ class ToolProviderController(BaseModel, ABC):
""" """
pass pass
def get_parameters(self, tool_name: str) -> List[ToolParamter]: def get_parameters(self, tool_name: str) -> List[ToolParameter]:
""" """
returns the parameters of the tool returns the parameters of the tool
...@@ -80,62 +80,62 @@ class ToolProviderController(BaseModel, ABC): ...@@ -80,62 +80,62 @@ class ToolProviderController(BaseModel, ABC):
""" """
tool_parameters_schema = self.get_parameters(tool_name) tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: Dict[str, ToolParamter] = {} tool_parameters_need_to_validate: Dict[str, ToolParameter] = {}
for parameter in tool_parameters_schema: for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters: for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate: if parameter not in tool_parameters_need_to_validate:
raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}') raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type # check type
parameter_schema = tool_parameters_need_to_validate[parameter] parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParamter.ToolParameterType.STRING: if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str): if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string') raise ToolParameterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER: elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], (int, float)): if not isinstance(tool_parameters[parameter], (int, float)):
raise ToolParamterValidationError(f'parameter {parameter} should be number') raise ToolParameterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN: elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool): if not isinstance(tool_parameters[parameter], bool):
raise ToolParamterValidationError(f'parameter {parameter} should be boolean') raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT: elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str): if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string') raise ToolParameterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options options = parameter_schema.options
if not isinstance(options, list): if not isinstance(options, list):
raise ToolParamterValidationError(f'parameter {parameter} options should be list') raise ToolParameterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]: if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}') raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter) tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate: for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter] parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required: if parameter_schema.required:
raise ToolParamterValidationError(f'parameter {parameter} is required') raise ToolParameterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed # the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None: if parameter_schema.default is not None:
default_value = parameter_schema.default default_value = parameter_schema.default
# parse default value into the correct type # parse default value into the correct type
if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \ if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
parameter_schema.type == ToolParamter.ToolParameterType.SELECT: parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
default_value = str(default_value) default_value = str(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER: elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
default_value = float(default_value) default_value = float(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN: elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
default_value = bool(default_value) default_value = bool(default_value)
tool_parameters[parameter] = default_value tool_parameters[parameter] = default_value
......
...@@ -8,6 +8,7 @@ from core.tools.errors import ToolProviderCredentialValidationError ...@@ -8,6 +8,7 @@ from core.tools.errors import ToolProviderCredentialValidationError
import httpx import httpx
import requests import requests
import json
class ApiTool(Tool): class ApiTool(Tool):
api_bundle: ApiBasedToolBundle api_bundle: ApiBasedToolBundle
...@@ -79,11 +80,29 @@ class ApiTool(Tool): ...@@ -79,11 +80,29 @@ class ApiTool(Tool):
if isinstance(response, httpx.Response): if isinstance(response, httpx.Response):
if response.status_code >= 400: if response.status_code >= 400:
raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}") raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}")
return response.text if not response.content:
return 'Empty response from the tool, please check your parameters and try again.'
try:
response = response.json()
try:
return json.dumps(response, ensure_ascii=False)
except Exception as e:
return json.dumps(response)
except Exception as e:
return response.text
elif isinstance(response, requests.Response): elif isinstance(response, requests.Response):
if not response.ok: if not response.ok:
raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}") raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}")
return response.text if not response.content:
return 'Empty response from the tool, please check your parameters and try again.'
try:
response = response.json()
try:
return json.dumps(response, ensure_ascii=False)
except Exception as e:
return json.dumps(response)
except Exception as e:
return response.text
else: else:
raise ValueError(f'Invalid response type {type(response)}') raise ValueError(f'Invalid response type {type(response)}')
...@@ -204,15 +223,15 @@ class ApiTool(Tool): ...@@ -204,15 +223,15 @@ class ApiTool(Tool):
return response return response
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]: def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]:
""" """
invoke http request invoke http request
""" """
# assemble request # assemble request
headers = self.assembling_request(tool_paramters) headers = self.assembling_request(tool_parameters)
# do http request # do http request
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_paramters) response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
# validate response # validate response
response = self.validate_and_parse_response(response) response = self.validate_and_parse_response(response)
......
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from core.features.dataset_retrieval import DatasetRetrievalFeature from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter, ToolIdentity, ToolDescription from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolIdentity, ToolDescription
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
...@@ -63,23 +63,23 @@ class DatasetRetrieverTool(Tool): ...@@ -63,23 +63,23 @@ class DatasetRetrieverTool(Tool):
return tools return tools
def get_runtime_parameters(self) -> List[ToolParamter]: def get_runtime_parameters(self) -> List[ToolParameter]:
return [ return [
ToolParamter(name='query', ToolParameter(name='query',
label=I18nObject(en_US='', zh_Hans=''), label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''), human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParamter.ToolParameterType.STRING, type=ToolParameter.ToolParameterType.STRING,
form=ToolParamter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.', llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True, required=True,
default=''), default=''),
] ]
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]: def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]:
""" """
invoke dataset retriever tool invoke dataset retriever tool
""" """
query = tool_paramters.get('query', None) query = tool_parameters.get('query', None)
if not query: if not query:
return self.create_text_message(text='please input query') return self.create_text_message(text='please input query')
......
...@@ -5,13 +5,13 @@ from abc import abstractmethod, ABC ...@@ -5,13 +5,13 @@ from abc import abstractmethod, ABC
from enum import Enum from enum import Enum
from core.tools.entities.tool_entities import ToolIdentity, ToolInvokeMessage,\ from core.tools.entities.tool_entities import ToolIdentity, ToolInvokeMessage,\
ToolParamter, ToolDescription, ToolRuntimeVariablePool, ToolRuntimeVariable, ToolRuntimeImageVariable ToolParameter, ToolDescription, ToolRuntimeVariablePool, ToolRuntimeVariable, ToolRuntimeImageVariable
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
class Tool(BaseModel, ABC): class Tool(BaseModel, ABC):
identity: ToolIdentity = None identity: ToolIdentity = None
parameters: Optional[List[ToolParamter]] = None parameters: Optional[List[ToolParameter]] = None
description: ToolDescription = None description: ToolDescription = None
is_team_authorization: bool = False is_team_authorization: bool = False
agent_callback: Optional[DifyAgentCallbackHandler] = None agent_callback: Optional[DifyAgentCallbackHandler] = None
...@@ -166,22 +166,22 @@ class Tool(BaseModel, ABC): ...@@ -166,22 +166,22 @@ class Tool(BaseModel, ABC):
return result return result
def invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> List[ToolInvokeMessage]: def invoke(self, user_id: str, tool_parameters: Dict[str, Any]) -> List[ToolInvokeMessage]:
# update tool_paramters # update tool_parameters
if self.runtime.runtime_parameters: if self.runtime.runtime_parameters:
tool_paramters.update(self.runtime.runtime_parameters) tool_parameters.update(self.runtime.runtime_parameters)
# hit callback # hit callback
if self.use_callback: if self.use_callback:
self.agent_callback.on_tool_start( self.agent_callback.on_tool_start(
tool_name=self.identity.name, tool_name=self.identity.name,
tool_inputs=tool_paramters tool_inputs=tool_parameters
) )
try: try:
result = self._invoke( result = self._invoke(
user_id=user_id, user_id=user_id,
tool_paramters=tool_paramters, tool_parameters=tool_parameters,
) )
except Exception as e: except Exception as e:
if self.use_callback: if self.use_callback:
...@@ -195,7 +195,7 @@ class Tool(BaseModel, ABC): ...@@ -195,7 +195,7 @@ class Tool(BaseModel, ABC):
if self.use_callback: if self.use_callback:
self.agent_callback.on_tool_end( self.agent_callback.on_tool_end(
tool_name=self.identity.name, tool_name=self.identity.name,
tool_inputs=tool_paramters, tool_inputs=tool_parameters,
tool_outputs=self._convert_tool_response_to_str(result) tool_outputs=self._convert_tool_response_to_str(result)
) )
...@@ -210,7 +210,7 @@ class Tool(BaseModel, ABC): ...@@ -210,7 +210,7 @@ class Tool(BaseModel, ABC):
if response.type == ToolInvokeMessage.MessageType.TEXT: if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK: elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please dirct user to check it." result += f"result link: {response.message}. please tell user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE: response.type == ToolInvokeMessage.MessageType.IMAGE:
result += f"image has been created and sent to user already, you should tell user to check it now." result += f"image has been created and sent to user already, you should tell user to check it now."
...@@ -225,7 +225,7 @@ class Tool(BaseModel, ABC): ...@@ -225,7 +225,7 @@ class Tool(BaseModel, ABC):
return result return result
@abstractmethod @abstractmethod
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
pass pass
def validate_credentials(self, credentials: Dict[str, Any], parameters: Dict[str, Any]) -> None: def validate_credentials(self, credentials: Dict[str, Any], parameters: Dict[str, Any]) -> None:
...@@ -237,7 +237,7 @@ class Tool(BaseModel, ABC): ...@@ -237,7 +237,7 @@ class Tool(BaseModel, ABC):
""" """
pass pass
def get_runtime_parameters(self) -> List[ToolParamter]: def get_runtime_parameters(self) -> List[ToolParameter]:
""" """
get the runtime parameters get the runtime parameters
...@@ -247,11 +247,11 @@ class Tool(BaseModel, ABC): ...@@ -247,11 +247,11 @@ class Tool(BaseModel, ABC):
""" """
return self.parameters return self.parameters
def is_tool_avaliable(self) -> bool: def is_tool_available(self) -> bool:
""" """
check if the tool is avaliable check if the tool is available
:return: if the tool is avaliable :return: if the tool is available
""" """
return True return True
......
...@@ -13,7 +13,7 @@ from core.tools.errors import ToolProviderNotFoundError ...@@ -13,7 +13,7 @@ from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
from core.tools.entities.user_entities import UserToolProvider from core.tools.entities.user_entities import UserToolProvider
from core.tools.utils.configration import ToolConfiguration from core.tools.utils.configuration import ToolConfiguration
from core.tools.utils.encoder import serialize_base_model_dict from core.tools.utils.encoder import serialize_base_model_dict
from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
...@@ -117,7 +117,7 @@ class ToolManager: ...@@ -117,7 +117,7 @@ class ToolManager:
return tool return tool
@staticmethod @staticmethod
def get_tool(provider_type: str, provider_id: str, tool_name: str, tanent_id: str = None) \ def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
-> Union[BuiltinTool, ApiTool]: -> Union[BuiltinTool, ApiTool]:
""" """
get the tool get the tool
...@@ -131,9 +131,9 @@ class ToolManager: ...@@ -131,9 +131,9 @@ class ToolManager:
if provider_type == 'builtin': if provider_type == 'builtin':
return ToolManager.get_builtin_tool(provider_id, tool_name) return ToolManager.get_builtin_tool(provider_id, tool_name)
elif provider_type == 'api': elif provider_type == 'api':
if tanent_id is None: if tenant_id is None:
raise ValueError('tanent id is required for api provider') raise ValueError('tenant id is required for api provider')
api_provider, _ = ToolManager.get_api_provider_controller(tanent_id, provider_id) api_provider, _ = ToolManager.get_api_provider_controller(tenant_id, provider_id)
return api_provider.get_tool(tool_name) return api_provider.get_tool(tool_name)
elif provider_type == 'app': elif provider_type == 'app':
raise NotImplementedError('app provider not implemented') raise NotImplementedError('app provider not implemented')
...@@ -188,7 +188,7 @@ class ToolManager: ...@@ -188,7 +188,7 @@ class ToolManager:
elif provider_type == 'api': elif provider_type == 'api':
if tenant_id is None: if tenant_id is None:
raise ValueError('tanent id is required for api provider') raise ValueError('tenant id is required for api provider')
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name) api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
...@@ -202,7 +202,7 @@ class ToolManager: ...@@ -202,7 +202,7 @@ class ToolManager:
}) })
elif provider_type == 'model': elif provider_type == 'model':
if tenant_id is None: if tenant_id is None:
raise ValueError('tanent id is required for model provider') raise ValueError('tenant id is required for model provider')
# get model provider # get model provider
model_provider = ToolManager.get_model_provider(tenant_id, provider_name) model_provider = ToolManager.get_model_provider(tenant_id, provider_name)
...@@ -374,7 +374,7 @@ class ToolManager: ...@@ -374,7 +374,7 @@ class ToolManager:
schema = provider.get_credentials_schema() schema = provider.get_credentials_schema()
for name, value in schema.items(): for name, value in schema.items():
result_providers[provider.identity.name].team_credentials[name] = \ result_providers[provider.identity.name].team_credentials[name] = \
ToolProviderCredentials.CredentialsType.defaut(value.type) ToolProviderCredentials.CredentialsType.default(value.type)
# check if the provider need credentials # check if the provider need credentials
if not provider.need_credentials: if not provider.need_credentials:
...@@ -476,7 +476,7 @@ class ToolManager: ...@@ -476,7 +476,7 @@ class ToolManager:
return BuiltinToolProviderSort.sort(list(result_providers.values())) return BuiltinToolProviderSort.sort(list(result_providers.values()))
@staticmethod @staticmethod
def get_api_provider_controller(tanent_id: str, provider_id: str) -> Tuple[ApiBasedToolProviderController, Dict[str, Any]]: def get_api_provider_controller(tenant_id: str, provider_id: str) -> Tuple[ApiBasedToolProviderController, Dict[str, Any]]:
""" """
get the api provider get the api provider
...@@ -486,7 +486,7 @@ class ToolManager: ...@@ -486,7 +486,7 @@ class ToolManager:
""" """
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.id == provider_id, ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tanent_id, ApiToolProvider.tenant_id == tenant_id,
).first() ).first()
if provider is None: if provider is None:
...@@ -513,7 +513,7 @@ class ToolManager: ...@@ -513,7 +513,7 @@ class ToolManager:
).first() ).first()
if provider is None: if provider is None:
raise ValueError(f'yout have not added provider {provider}') raise ValueError(f'you have not added provider {provider}')
try: try:
credentials = json.loads(provider.credentials_str) or {} credentials = json.loads(provider.credentials_str) or {}
......
...@@ -18,7 +18,7 @@ class ToolConfiguration(BaseModel): ...@@ -18,7 +18,7 @@ class ToolConfiguration(BaseModel):
def encrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]: def encrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]:
""" """
encrypt tool credentials with tanent id encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values return a deep copy of credentials with encrypted values
""" """
...@@ -59,7 +59,7 @@ class ToolConfiguration(BaseModel): ...@@ -59,7 +59,7 @@ class ToolConfiguration(BaseModel):
def decrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]: def decrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]:
""" """
decrypt tool credentials with tanent id decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values return a deep copy of credentials with decrypted values
""" """
......
from core.tools.entities.tool_bundle import ApiBasedToolBundle from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import ToolParamter, ToolParamterOption, ApiProviderSchemaType from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ApiProviderSchemaType
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.errors import ToolProviderNotFoundError, ToolNotSupportedError, \ from core.tools.errors import ToolProviderNotFoundError, ToolNotSupportedError, \
ToolApiSchemaError ToolApiSchemaError
...@@ -47,7 +47,7 @@ class ApiBasedToolSchemaParser: ...@@ -47,7 +47,7 @@ class ApiBasedToolSchemaParser:
parameters = [] parameters = []
if 'parameters' in interface['operation']: if 'parameters' in interface['operation']:
for parameter in interface['operation']['parameters']: for parameter in interface['operation']['parameters']:
parameters.append(ToolParamter( parameters.append(ToolParameter(
name=parameter['name'], name=parameter['name'],
label=I18nObject( label=I18nObject(
en_US=parameter['name'], en_US=parameter['name'],
...@@ -57,9 +57,9 @@ class ApiBasedToolSchemaParser: ...@@ -57,9 +57,9 @@ class ApiBasedToolSchemaParser:
en_US=parameter.get('description', ''), en_US=parameter.get('description', ''),
zh_Hans=parameter.get('description', '') zh_Hans=parameter.get('description', '')
), ),
type=ToolParamter.ToolParameterType.STRING, type=ToolParameter.ToolParameterType.STRING,
required=parameter.get('required', False), required=parameter.get('required', False),
form=ToolParamter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get('description'), llm_description=parameter.get('description'),
default=parameter['default'] if 'default' in parameter else None, default=parameter['default'] if 'default' in parameter else None,
)) ))
...@@ -87,7 +87,7 @@ class ApiBasedToolSchemaParser: ...@@ -87,7 +87,7 @@ class ApiBasedToolSchemaParser:
required = body_schema['required'] if 'required' in body_schema else [] required = body_schema['required'] if 'required' in body_schema else []
properties = body_schema['properties'] if 'properties' in body_schema else {} properties = body_schema['properties'] if 'properties' in body_schema else {}
for name, property in properties.items(): for name, property in properties.items():
parameters.append(ToolParamter( parameters.append(ToolParameter(
name=name, name=name,
label=I18nObject( label=I18nObject(
en_US=name, en_US=name,
...@@ -97,9 +97,9 @@ class ApiBasedToolSchemaParser: ...@@ -97,9 +97,9 @@ class ApiBasedToolSchemaParser:
en_US=property['description'] if 'description' in property else '', en_US=property['description'] if 'description' in property else '',
zh_Hans=property['description'] if 'description' in property else '' zh_Hans=property['description'] if 'description' in property else ''
), ),
type=ToolParamter.ToolParameterType.STRING, type=ToolParameter.ToolParameterType.STRING,
required=name in required, required=name in required,
form=ToolParamter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description=property['description'] if 'description' in property else '', llm_description=property['description'] if 'description' in property else '',
default=property['default'] if 'default' in property else None, default=property['default'] if 'default' in property else None,
)) ))
...@@ -114,6 +114,10 @@ class ApiBasedToolSchemaParser: ...@@ -114,6 +114,10 @@ class ApiBasedToolSchemaParser:
if count > 1: if count > 1:
warning['duplicated_parameter'] = f'Parameter {name} is duplicated.' warning['duplicated_parameter'] = f'Parameter {name} is duplicated.'
# check if there is a operation id, use $path_$method as operation id if not
if 'operationId' not in interface['operation']:
interface['operation']['operationId'] = f'{interface["path"]}_{interface["method"]}'
bundles.append(ApiBasedToolBundle( bundles.append(ApiBasedToolBundle(
server_url=server_url + interface['path'], server_url=server_url + interface['path'],
method=interface['method'], method=interface['method'],
......
...@@ -100,7 +100,7 @@ class ApiToolProvider(db.Model): ...@@ -100,7 +100,7 @@ class ApiToolProvider(db.Model):
schema_type_str = db.Column(db.String(40), nullable=False) schema_type_str = db.Column(db.String(40), nullable=False)
# who created this tool # who created this tool
user_id = db.Column(UUID, nullable=False) user_id = db.Column(UUID, nullable=False)
# tanent id # tenant id
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(UUID, nullable=False)
# description of the provider # description of the provider
description = db.Column(db.Text, nullable=False) description = db.Column(db.Text, nullable=False)
...@@ -135,7 +135,7 @@ class ApiToolProvider(db.Model): ...@@ -135,7 +135,7 @@ class ApiToolProvider(db.Model):
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).filter(Account.id == self.user_id).first()
@property @property
def tanent(self) -> Tenant: def tenant(self) -> Tenant:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
class ToolModelInvoke(db.Model): class ToolModelInvoke(db.Model):
...@@ -150,7 +150,7 @@ class ToolModelInvoke(db.Model): ...@@ -150,7 +150,7 @@ class ToolModelInvoke(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
# who invoke this tool # who invoke this tool
user_id = db.Column(UUID, nullable=False) user_id = db.Column(UUID, nullable=False)
# tanent id # tenant id
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(UUID, nullable=False)
# provider # provider
provider = db.Column(db.String(40), nullable=False) provider = db.Column(db.String(40), nullable=False)
...@@ -190,7 +190,7 @@ class ToolConversationVariables(db.Model): ...@@ -190,7 +190,7 @@ class ToolConversationVariables(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
# conversation user id # conversation user id
user_id = db.Column(UUID, nullable=False) user_id = db.Column(UUID, nullable=False)
# tanent id # tenant id
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(UUID, nullable=False)
# conversation id # conversation id
conversation_id = db.Column(UUID, nullable=False) conversation_id = db.Column(UUID, nullable=False)
...@@ -218,7 +218,7 @@ class ToolFile(db.Model): ...@@ -218,7 +218,7 @@ class ToolFile(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
# conversation user id # conversation user id
user_id = db.Column(UUID, nullable=False) user_id = db.Column(UUID, nullable=False)
# tanent id # tenant id
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(UUID, nullable=False)
# conversation id # conversation id
conversation_id = db.Column(UUID, nullable=False) conversation_id = db.Column(UUID, nullable=False)
......
coverage~=7.2.4 coverage~=7.2.4
beautifulsoup4==4.12.2 beautifulsoup4==4.12.2
flask~=2.3.2 flask~=3.0.1
Flask-SQLAlchemy~=3.0.3 Flask-SQLAlchemy~=3.0.5
SQLAlchemy~=1.4.28 SQLAlchemy~=1.4.28
flask-login==0.6.2 flask-login~=0.6.3
flask-migrate~=4.0.4 flask-migrate~=4.0.5
flask-restful==0.3.9 flask-restful~=0.3.10
flask-session2==1.3.1 flask-cors~=4.0.0
flask-cors==3.0.10
gunicorn~=21.2.0 gunicorn~=21.2.0
gevent~=23.9.1 gevent~=23.9.1
langchain==0.0.250 langchain==0.0.250
...@@ -25,7 +24,7 @@ cachetools~=5.3.0 ...@@ -25,7 +24,7 @@ cachetools~=5.3.0
weaviate-client~=3.21.0 weaviate-client~=3.21.0
mailchimp-transactional~=1.0.50 mailchimp-transactional~=1.0.50
scikit-learn==1.2.2 scikit-learn==1.2.2
sentry-sdk[flask]~=1.21.1 sentry-sdk[flask]~=1.39.2
sympy==1.12 sympy==1.12
jieba==0.42.1 jieba==0.42.1
celery==5.2.7 celery==5.2.7
...@@ -48,10 +47,10 @@ dashscope[tokenizer]~=1.14.0 ...@@ -48,10 +47,10 @@ dashscope[tokenizer]~=1.14.0
huggingface_hub~=0.16.4 huggingface_hub~=0.16.4
transformers~=4.31.0 transformers~=4.31.0
pandas==1.5.3 pandas==1.5.3
xinference-client~=0.6.4 xinference-client~=0.8.1
safetensors==0.3.2 safetensors==0.3.2
zhipuai==1.0.7 zhipuai==1.0.7
werkzeug==2.3.8 werkzeug~=3.0.1
pymilvus==2.3.0 pymilvus==2.3.0
qdrant-client==1.6.4 qdrant-client==1.6.4
cohere~=4.44 cohere~=4.44
......
...@@ -12,7 +12,7 @@ from core.tools.provider.tool_provider import ToolProviderController ...@@ -12,7 +12,7 @@ from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.utils.parser import ApiBasedToolSchemaParser from core.tools.utils.parser import ApiBasedToolSchemaParser
from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict
from core.tools.utils.configration import ToolConfiguration from core.tools.utils.configuration import ToolConfiguration
from core.tools.errors import ToolProviderCredentialValidationError, ToolProviderNotFoundError, ToolNotFoundError from core.tools.errors import ToolProviderCredentialValidationError, ToolProviderNotFoundError, ToolNotFoundError
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
...@@ -26,26 +26,26 @@ import json ...@@ -26,26 +26,26 @@ import json
class ToolManageService: class ToolManageService:
@staticmethod @staticmethod
def list_tool_providers(user_id: str, tanent_id: str): def list_tool_providers(user_id: str, tenant_id: str):
""" """
list tool providers list tool providers
:return: the list of tool providers :return: the list of tool providers
""" """
result = [provider.to_dict() for provider in ToolManager.user_list_providers( result = [provider.to_dict() for provider in ToolManager.user_list_providers(
user_id, tanent_id user_id, tenant_id
)] )]
# add icon url prefix # add icon url prefix
for provider in result: for provider in result:
ToolManageService.repacket_provider(provider) ToolManageService.repack_provider(provider)
return result return result
@staticmethod @staticmethod
def repacket_provider(provider: dict): def repack_provider(provider: dict):
""" """
repacket provider repack provider
:param provider: the provider dict :param provider: the provider dict
""" """
...@@ -290,7 +290,7 @@ class ToolManageService: ...@@ -290,7 +290,7 @@ class ToolManageService:
).first() ).first()
if provider is None: if provider is None:
raise ValueError(f'yout have not added provider {provider}') raise ValueError(f'you have not added provider {provider}')
return json.loads( return json.loads(
serialize_base_model_array([ serialize_base_model_array([
...@@ -341,25 +341,33 @@ class ToolManageService: ...@@ -341,25 +341,33 @@ class ToolManageService:
""" """
update builtin tool provider update builtin tool provider
""" """
# get if the provider exists
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
try: try:
# get provider # get provider
provider_controller = ToolManager.get_builtin_provider(provider_name) provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
raise ValueError(f'provider {provider_name} does not need credentials') raise ValueError(f'provider {provider_name} does not need credentials')
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
# validate credentials # validate credentials
provider_controller.validate_credentials(credentials) provider_controller.validate_credentials(credentials)
# encrypt credentials # encrypt credentials
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
credentials = tool_configuration.encrypt_tool_credentials(credentials) credentials = tool_configuration.encrypt_tool_credentials(credentials)
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e: except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
raise ValueError(str(e)) raise ValueError(str(e))
# get if the provider exists
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
if provider is None: if provider is None:
# create provider # create provider
provider = BuiltinToolProvider( provider = BuiltinToolProvider(
...@@ -444,7 +452,7 @@ class ToolManageService: ...@@ -444,7 +452,7 @@ class ToolManageService:
).first() ).first()
if provider is None: if provider is None:
raise ValueError(f'yout have not added provider {provider}') raise ValueError(f'you have not added provider {provider}')
db.session.delete(provider) db.session.delete(provider)
db.session.commit() db.session.commit()
...@@ -493,7 +501,7 @@ class ToolManageService: ...@@ -493,7 +501,7 @@ class ToolManageService:
).first() ).first()
if provider is None: if provider is None:
raise ValueError(f'yout have not added provider {provider}') raise ValueError(f'you have not added provider {provider}')
db.session.delete(provider) db.session.delete(provider)
db.session.commit() db.session.commit()
...@@ -521,10 +529,10 @@ class ToolManageService: ...@@ -521,10 +529,10 @@ class ToolManageService:
if schema_type not in [member.value for member in ApiProviderSchemaType]: if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f'invalid schema type {schema_type}') raise ValueError(f'invalid schema type {schema_type}')
if schema_type == ApiProviderSchemaType.OPENAPI.value: try:
tool_bundles = ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(schema) tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
else: except Exception as e:
raise ValueError(f'invalid schema type {schema_type}') raise ValueError(f'invalid schema')
# get tool bundle # get tool bundle
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
......
...@@ -19,58 +19,86 @@ class MockXinferenceClass(object): ...@@ -19,58 +19,86 @@ class MockXinferenceClass(object):
raise RuntimeError('404 Not Found') raise RuntimeError('404 Not Found')
if 'generate' == model_uid: if 'generate' == model_uid:
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url) return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'chat' == model_uid: if 'chat' == model_uid:
return RESTfulChatModelHandle(model_uid, base_url=self.base_url) return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'embedding' == model_uid: if 'embedding' == model_uid:
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url) return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'rerank' == model_uid: if 'rerank' == model_uid:
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url) return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError('404 Not Found') raise RuntimeError('404 Not Found')
def get(self: Session, url: str, **kwargs): def get(self: Session, url: str, **kwargs):
if '/v1/models/' in url: response = Response()
response = Response() if 'v1/models/' in url:
# get model uid # get model uid
model_uid = url.split('/')[-1] model_uid = url.split('/')[-1]
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']: model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404 response.status_code = 404
raise ConnectionError('404 Not Found') return response
# check if url is valid # check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404 response.status_code = 404
raise ConnectionError('404 Not Found') return response
if model_uid in ['generate', 'chat']:
response.status_code = 200
response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return response
elif model_uid == 'embedding':
response.status_code = 200
response._content = b'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return response
elif 'v1/cluster/auth' in url:
response.status_code = 200 response.status_code = 200
response._content = b'''{ response._content = b'''{
"model_type": "LLM", "auth": true
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}''' }'''
return response return response
def _check_cluster_authenticated(self):
self._cluster_authed = True
def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict: def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
# check if self._model_uid is a valid uuid # check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
...@@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' ...@@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
def setup_xinference_mock(request, monkeypatch: MonkeyPatch): def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK: if MOCK:
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
......
""" """
LocalAI Embedding Interface is temporarily unavaliable due to LocalAI Embedding Interface is temporarily unavailable due to
we could not find a way to test it for now. we could not find a way to test it for now.
""" """
\ No newline at end of file
...@@ -153,7 +153,7 @@ const NewAppDialog = ({ show, onSuccess, onClose }: NewAppDialogProps) => { ...@@ -153,7 +153,7 @@ const NewAppDialog = ({ show, onSuccess, onClose }: NewAppDialogProps) => {
<div className={style.listItemHeading}> <div className={style.listItemHeading}>
<div className={style.listItemHeadingContent}>{t('app.newApp.chatApp')}</div> <div className={style.listItemHeadingContent}>{t('app.newApp.chatApp')}</div>
</div> </div>
<div className='shrink-0 flex items-center h-[18px] border border-indigo-300 px-1 rounded-[5px] text-xs font-medium text-indigo-600 uppercase'>{t('app.newApp.agentAssistant')}</div> <div className='flex items-center h-[18px] border border-indigo-300 px-1 rounded-[5px] text-xs font-medium text-indigo-600 uppercase truncate'>{t('app.newApp.agentAssistant')}</div>
</div> </div>
<div className={`${style.listItemDescription} ${style.noClip}`}>{t('app.newApp.chatAppIntro')}</div> <div className={`${style.listItemDescription} ${style.noClip}`}>{t('app.newApp.chatAppIntro')}</div>
{/* <div className={classNames(style.listItemFooter, 'justify-end')}> {/* <div className={classNames(style.listItemFooter, 'justify-end')}>
......
...@@ -132,6 +132,7 @@ const EditAnnotationModal: FC<Props> = ({ ...@@ -132,6 +132,7 @@ const EditAnnotationModal: FC<Props> = ({
onRemove={() => { onRemove={() => {
onRemove() onRemove()
setShowModal(false) setShowModal(false)
onHide()
}} }}
text={t('appDebug.feature.annotation.removeConfirm') as string} text={t('appDebug.feature.annotation.removeConfirm') as string}
/> />
......
...@@ -48,7 +48,7 @@ const Popup: FC<PopupProps> = ({ ...@@ -48,7 +48,7 @@ const Popup: FC<PopupProps> = ({
> >
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}> <PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
<div className='flex items-center px-2 max-w-[240px] h-7 bg-white rounded-lg'> <div className='flex items-center px-2 max-w-[240px] h-7 bg-white rounded-lg'>
<FileIcon type={fileType} className='mr-1 w-4 h-4' /> <FileIcon type={fileType} className='shrink-0 mr-1 w-4 h-4' />
<div className='text-xs text-gray-600 truncate'>{data.documentName}</div> <div className='text-xs text-gray-600 truncate'>{data.documentName}</div>
</div> </div>
</PortalToFollowElemTrigger> </PortalToFollowElemTrigger>
...@@ -56,7 +56,7 @@ const Popup: FC<PopupProps> = ({ ...@@ -56,7 +56,7 @@ const Popup: FC<PopupProps> = ({
<div className='w-[360px] bg-gray-50 rounded-xl shadow-lg'> <div className='w-[360px] bg-gray-50 rounded-xl shadow-lg'>
<div className='px-4 pt-3 pb-2'> <div className='px-4 pt-3 pb-2'>
<div className='flex items-center h-[18px]'> <div className='flex items-center h-[18px]'>
<FileIcon type={fileType} className='mr-1 w-4 h-4' /> <FileIcon type={fileType} className='shrink-0 mr-1 w-4 h-4' />
<div className='text-xs font-medium text-gray-600 truncate'>{data.documentName}</div> <div className='text-xs font-medium text-gray-600 truncate'>{data.documentName}</div>
</div> </div>
</div> </div>
......
...@@ -43,7 +43,7 @@ const ConfigSelect: FC<IConfigSelectProps> = ({ ...@@ -43,7 +43,7 @@ const ConfigSelect: FC<IConfigSelectProps> = ({
<div className={`${s.inputWrap} relative`} key={index}> <div className={`${s.inputWrap} relative`} key={index}>
<div className='handle flex items-center justify-center w-4 h-4 cursor-grab'> <div className='handle flex items-center justify-center w-4 h-4 cursor-grab'>
<svg width="6" height="10" viewBox="0 0 6 10" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="6" height="10" viewBox="0 0 6 10" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fillRule="evenodd" clipRule="evenodd" d="M1 2C1.55228 2 2 1.55228 2 1C2 0.447715 1.55228 0 1 0C0.447715 0 0 0.447715 0 1C0 1.55228 0.447715 2 1 2ZM1 6C1.55228 6 2 5.55228 2 5C2 4.44772 1.55228 4 1 4C0.447715 4 0 4.44772 0 5C0 5.55228 0.447715 6 1 6ZM6 1C6 1.55228 5.55228 2 5 2C4.44772 2 4 1.55228 4 1C4 0.447715 4.44772 0 5 0C5.55228 0 6 0.447715 6 1ZM5 6C5.55228 6 6 5.55228 6 5C6 4.44772 5.55228 4 5 4C4.44772 4 4 4.44772 4 5C4 5.55228 4.44772 6 5 6ZM2 9C2 9.55229 1.55228 10 1 10C0.447715 10 0 9.55229 0 9C0 8.44771 0.447715 8 1 8C1.55228 8 2 8.44771 2 9ZM5 10C5.55228 10 6 9.55229 6 9C6 8.44771 5.55228 8 5 8C4.44772 8 4 8.44771 4 9C4 9.55229 4.44772 10 5 10Z" fill="#98A2B3"/> <path fillRule="evenodd" clipRule="evenodd" d="M1 2C1.55228 2 2 1.55228 2 1C2 0.447715 1.55228 0 1 0C0.447715 0 0 0.447715 0 1C0 1.55228 0.447715 2 1 2ZM1 6C1.55228 6 2 5.55228 2 5C2 4.44772 1.55228 4 1 4C0.447715 4 0 4.44772 0 5C0 5.55228 0.447715 6 1 6ZM6 1C6 1.55228 5.55228 2 5 2C4.44772 2 4 1.55228 4 1C4 0.447715 4.44772 0 5 0C5.55228 0 6 0.447715 6 1ZM5 6C5.55228 6 6 5.55228 6 5C6 4.44772 5.55228 4 5 4C4.44772 4 4 4.44772 4 5C4 5.55228 4.44772 6 5 6ZM2 9C2 9.55229 1.55228 10 1 10C0.447715 10 0 9.55229 0 9C0 8.44771 0.447715 8 1 8C1.55228 8 2 8.44771 2 9ZM5 10C5.55228 10 6 9.55229 6 9C6 8.44771 5.55228 8 5 8C4.44772 8 4 8.44771 4 9C4 9.55229 4.44772 10 5 10Z" fill="#98A2B3" />
</svg> </svg>
</div> </div>
<input <input
...@@ -59,7 +59,7 @@ const ConfigSelect: FC<IConfigSelectProps> = ({ ...@@ -59,7 +59,7 @@ const ConfigSelect: FC<IConfigSelectProps> = ({
return item return item
})) }))
}} }}
className={`${s.input} w-full px-1.5 text-sm leading-9 text-gray-900 border-0 grow h-9 bg-transparent focus:outline-none cursor-pointer`} className={'w-full pl-1.5 pr-8 text-sm leading-9 text-gray-900 border-0 grow h-9 bg-transparent focus:outline-none cursor-pointer'}
/> />
<RemoveIcon <RemoveIcon
className={`${s.deleteBtn} absolute top-1/2 translate-y-[-50%] right-1.5 items-center justify-center w-6 h-6 rounded-md cursor-pointer hover:bg-[#FEE4E2]`} className={`${s.deleteBtn} absolute top-1/2 translate-y-[-50%] right-1.5 items-center justify-center w-6 h-6 rounded-md cursor-pointer hover:bg-[#FEE4E2]`}
......
...@@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next' ...@@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next'
import cn from 'classnames' import cn from 'classnames'
import { useContext } from 'use-context-selector' import { useContext } from 'use-context-selector'
import produce from 'immer' import produce from 'immer'
import { useFormattingChangedDispatcher } from '../../../debug/hooks'
import ChooseTool from './choose-tool' import ChooseTool from './choose-tool'
import SettingBuiltInTool from './setting-built-in-tool' import SettingBuiltInTool from './setting-built-in-tool'
import Panel from '@/app/components/app/configuration/base/feature-panel' import Panel from '@/app/components/app/configuration/base/feature-panel'
...@@ -27,6 +28,7 @@ const AgentTools: FC = () => { ...@@ -27,6 +28,7 @@ const AgentTools: FC = () => {
const { t } = useTranslation() const { t } = useTranslation()
const [isShowChooseTool, setIsShowChooseTool] = useState(false) const [isShowChooseTool, setIsShowChooseTool] = useState(false)
const { modelConfig, setModelConfig, collectionList } = useContext(ConfigContext) const { modelConfig, setModelConfig, collectionList } = useContext(ConfigContext)
const formattingChangedDispatcher = useFormattingChangedDispatcher()
const [currentTool, setCurrentTool] = useState<AgentToolWithMoreInfo>(null) const [currentTool, setCurrentTool] = useState<AgentToolWithMoreInfo>(null)
const [selectedProviderId, setSelectedProviderId] = useState<string | undefined>(undefined) const [selectedProviderId, setSelectedProviderId] = useState<string | undefined>(undefined)
...@@ -49,6 +51,7 @@ const AgentTools: FC = () => { ...@@ -49,6 +51,7 @@ const AgentTools: FC = () => {
}) })
setModelConfig(newModelConfig) setModelConfig(newModelConfig)
setIsShowSettingTool(false) setIsShowSettingTool(false)
formattingChangedDispatcher()
} }
return ( return (
...@@ -141,6 +144,7 @@ const AgentTools: FC = () => { ...@@ -141,6 +144,7 @@ const AgentTools: FC = () => {
draft.agentConfig.tools.splice(index, 1) draft.agentConfig.tools.splice(index, 1)
}) })
setModelConfig(newModelConfig) setModelConfig(newModelConfig)
formattingChangedDispatcher()
}}> }}>
<Trash03 className='w-4 h-4 text-gray-500' /> <Trash03 className='w-4 h-4 text-gray-500' />
</div> </div>
...@@ -167,6 +171,7 @@ const AgentTools: FC = () => { ...@@ -167,6 +171,7 @@ const AgentTools: FC = () => {
draft.agentConfig.tools.splice(index, 1) draft.agentConfig.tools.splice(index, 1)
}) })
setModelConfig(newModelConfig) setModelConfig(newModelConfig)
formattingChangedDispatcher()
}}> }}>
<Trash03 className='w-4 h-4 text-gray-500' /> <Trash03 className='w-4 h-4 text-gray-500' />
</div> </div>
...@@ -183,6 +188,7 @@ const AgentTools: FC = () => { ...@@ -183,6 +188,7 @@ const AgentTools: FC = () => {
(draft.agentConfig.tools[index] as any).enabled = enabled (draft.agentConfig.tools[index] as any).enabled = enabled
}) })
setModelConfig(newModelConfig) setModelConfig(newModelConfig)
formattingChangedDispatcher()
}} /> }} />
</div> </div>
</div> </div>
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment