Commit 295a2485 authored by takatost's avatar takatost

add tenant_id / app_id / workflow_id for nodes

parent 4630f9c7
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from models.workflow import Workflow from models.workflow import Workflow, WorkflowType
class WorkflowNodeAndResult: class WorkflowNodeAndResult:
...@@ -16,7 +16,11 @@ class WorkflowNodeAndResult: ...@@ -16,7 +16,11 @@ class WorkflowNodeAndResult:
class WorkflowRunState: class WorkflowRunState:
workflow: Workflow tenant_id: str
app_id: str
workflow_id: str
workflow_type: WorkflowType
start_at: float start_at: float
variable_pool: VariablePool variable_pool: VariablePool
...@@ -25,6 +29,10 @@ class WorkflowRunState: ...@@ -25,6 +29,10 @@ class WorkflowRunState:
workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] workflow_nodes_and_results: list[WorkflowNodeAndResult] = []
def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool):
self.workflow = workflow self.workflow_id = workflow.id
self.tenant_id = workflow.tenant_id
self.app_id = workflow.app_id
self.workflow_type = WorkflowType.value_of(workflow.type)
self.start_at = start_at self.start_at = start_at
self.variable_pool = variable_pool self.variable_pool = variable_pool
...@@ -12,14 +12,25 @@ class BaseNode(ABC): ...@@ -12,14 +12,25 @@ class BaseNode(ABC):
_node_data_cls: type[BaseNodeData] _node_data_cls: type[BaseNodeData]
_node_type: NodeType _node_type: NodeType
tenant_id: str
app_id: str
workflow_id: str
node_id: str node_id: str
node_data: BaseNodeData node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None node_run_result: Optional[NodeRunResult] = None
callbacks: list[BaseWorkflowCallback] callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict, def __init__(self, tenant_id: str,
app_id: str,
workflow_id: str,
config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None: callbacks: list[BaseWorkflowCallback] = None) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.node_id = config.get("id") self.node_id = config.get("id")
if not self.node_id: if not self.node_id:
raise ValueError("Node ID is required.") raise ValueError("Node ID is required.")
......
...@@ -122,6 +122,7 @@ class WorkflowEngineManager: ...@@ -122,6 +122,7 @@ class WorkflowEngineManager:
while True: while True:
# get next node, multiple target nodes in the future # get next node, multiple target nodes in the future
next_node = self._get_next_node( next_node = self._get_next_node(
workflow_run_state=workflow_run_state,
graph=graph, graph=graph,
predecessor_node=predecessor_node, predecessor_node=predecessor_node,
callbacks=callbacks callbacks=callbacks
...@@ -198,7 +199,8 @@ class WorkflowEngineManager: ...@@ -198,7 +199,8 @@ class WorkflowEngineManager:
error=error error=error
) )
def _get_next_node(self, graph: dict, def _get_next_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
predecessor_node: Optional[BaseNode] = None, predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]: callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
""" """
...@@ -216,7 +218,13 @@ class WorkflowEngineManager: ...@@ -216,7 +218,13 @@ class WorkflowEngineManager:
if not predecessor_node: if not predecessor_node:
for node_config in nodes: for node_config in nodes:
if node_config.get('data', {}).get('type', '') == NodeType.START.value: if node_config.get('data', {}).get('type', '') == NodeType.START.value:
return StartNode(config=node_config) return StartNode(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
config=node_config,
callbacks=callbacks
)
else: else:
edges = graph.get('edges') edges = graph.get('edges')
source_node_id = predecessor_node.node_id source_node_id = predecessor_node.node_id
...@@ -256,6 +264,9 @@ class WorkflowEngineManager: ...@@ -256,6 +264,9 @@ class WorkflowEngineManager:
target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
return target_node( return target_node(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
config=target_node_config, config=target_node_config,
callbacks=callbacks callbacks=callbacks
) )
...@@ -354,7 +365,7 @@ class WorkflowEngineManager: ...@@ -354,7 +365,7 @@ class WorkflowEngineManager:
:param node_run_result: node run result :param node_run_result: node run result
:return: :return:
""" """
if workflow_run_state.workflow.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END:
workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2]
if workflow_nodes_and_result_before_end: if workflow_nodes_and_result_before_end:
if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM: if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment