Unverified Commit 94f3cf1a authored by Yeuoly's avatar Yeuoly

feat: tool entity

parent 8e491ace
......@@ -315,7 +315,7 @@ class ToolManager:
for parameter in parameters:
# save tool parameter to tool entity memory
value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_parameters)
value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
......
from typing import Literal, Union
from typing import Literal, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
......@@ -13,11 +13,20 @@ class ToolEntity(BaseModel):
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_parameters: dict[str, ToolParameterValue]
tool_configurations: dict[str, ToolParameterValue]
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(VariableSelector):
variable_type: Literal['selector', 'static']
value: Optional[str]
@validator('value')
def check_value(cls, value, values, **kwargs):
if values['variable_type'] == 'static' and value is None:
raise ValueError('value is required for static variable')
return value
"""
Tool Node Schema
"""
tool_inputs: list[VariableSelector]
tool_parameters: list[ToolInput]
......@@ -27,14 +27,8 @@ class ToolNode(BaseNode):
node_data = cast(ToolNodeData, self.node_data)
# extract tool parameters
parameters = {
k.variable: variable_pool.get_variable_value(k.value_selector)
for k in node_data.tool_inputs
}
if len(parameters) != len(node_data.tool_inputs):
raise ValueError('Invalid tool parameters')
# get parameters
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
......@@ -47,6 +41,7 @@ class ToolNode(BaseNode):
)
try:
# TODO: user_id
messages = tool_runtime.invoke(None, parameters)
except Exception as e:
return NodeRunResult(
......@@ -59,13 +54,24 @@ class ToolNode(BaseNode):
plain_text, files = self._convert_tool_messages(messages)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCESS,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'text': plain_text,
'files': files
},
)
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
"""
Generate parameters
"""
return {
k.variable:
k.value if k.variable_type == 'static' else
variable_pool.get_variable_value(k.value) if k.variable_type == 'selector' else ''
for k in node_data.tool_parameters
}
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
......@@ -125,11 +131,6 @@ class ToolNode(BaseNode):
for message in tool_response
])
def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict:
"""
Convert ToolInvokeMessage into file
"""
pass
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment