Unverified Commit 94f3cf1a authored by Yeuoly's avatar Yeuoly

feat: tool entity

parent 8e491ace
...@@ -315,7 +315,7 @@ class ToolManager: ...@@ -315,7 +315,7 @@ class ToolManager:
for parameter in parameters: for parameter in parameters:
# save tool parameter to tool entity memory # save tool parameter to tool entity memory
value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_parameters) value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
runtime_parameters[parameter.name] = value runtime_parameters[parameter.name] = value
# decrypt runtime parameters # decrypt runtime parameters
......
from typing import Literal, Union from typing import Literal, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
...@@ -13,11 +13,20 @@ class ToolEntity(BaseModel): ...@@ -13,11 +13,20 @@ class ToolEntity(BaseModel):
provider_name: str # redundancy provider_name: str # redundancy
tool_name: str tool_name: str
tool_label: str # redundancy tool_label: str # redundancy
tool_parameters: dict[str, ToolParameterValue] tool_configurations: dict[str, ToolParameterValue]
class ToolNodeData(BaseNodeData, ToolEntity): class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(VariableSelector):
variable_type: Literal['selector', 'static']
value: Optional[str]
@validator('value')
def check_value(cls, value, values, **kwargs):
if values['variable_type'] == 'static' and value is None:
raise ValueError('value is required for static variable')
return value
""" """
Tool Node Schema Tool Node Schema
""" """
tool_inputs: list[VariableSelector] tool_parameters: list[ToolInput]
...@@ -27,14 +27,8 @@ class ToolNode(BaseNode): ...@@ -27,14 +27,8 @@ class ToolNode(BaseNode):
node_data = cast(ToolNodeData, self.node_data) node_data = cast(ToolNodeData, self.node_data)
# extract tool parameters # get parameters
parameters = { parameters = self._generate_parameters(variable_pool, node_data)
k.variable: variable_pool.get_variable_value(k.value_selector)
for k in node_data.tool_inputs
}
if len(parameters) != len(node_data.tool_inputs):
raise ValueError('Invalid tool parameters')
# get tool runtime # get tool runtime
try: try:
...@@ -47,6 +41,7 @@ class ToolNode(BaseNode): ...@@ -47,6 +41,7 @@ class ToolNode(BaseNode):
) )
try: try:
# TODO: user_id
messages = tool_runtime.invoke(None, parameters) messages = tool_runtime.invoke(None, parameters)
except Exception as e: except Exception as e:
return NodeRunResult( return NodeRunResult(
...@@ -59,12 +54,23 @@ class ToolNode(BaseNode): ...@@ -59,12 +54,23 @@ class ToolNode(BaseNode):
plain_text, files = self._convert_tool_messages(messages) plain_text, files = self._convert_tool_messages(messages)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCESS, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={ outputs={
'text': plain_text, 'text': plain_text,
'files': files 'files': files
}, },
) )
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
"""
Generate parameters
"""
return {
k.variable:
k.value if k.variable_type == 'static' else
variable_pool.get_variable_value(k.value) if k.variable_type == 'selector' else ''
for k in node_data.tool_parameters
}
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]: def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]:
""" """
...@@ -125,11 +131,6 @@ class ToolNode(BaseNode): ...@@ -125,11 +131,6 @@ class ToolNode(BaseNode):
for message in tool_response for message in tool_response
]) ])
def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict:
"""
Convert ToolInvokeMessage into file
"""
pass
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment