Commit 295a2485 authored by takatost's avatar takatost

add tenant_id / app_id / workflow_id for nodes

parent 4630f9c7
......@@ -3,7 +3,7 @@ from typing import Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowType
class WorkflowNodeAndResult:
......@@ -16,7 +16,11 @@ class WorkflowNodeAndResult:
class WorkflowRunState:
workflow: Workflow
tenant_id: str
app_id: str
workflow_id: str
workflow_type: WorkflowType
start_at: float
variable_pool: VariablePool
......@@ -25,6 +29,10 @@ class WorkflowRunState:
workflow_nodes_and_results: list[WorkflowNodeAndResult] = []
def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool):
self.workflow = workflow
self.workflow_id = workflow.id
self.tenant_id = workflow.tenant_id
self.app_id = workflow.app_id
self.workflow_type = WorkflowType.value_of(workflow.type)
self.start_at = start_at
self.variable_pool = variable_pool
......@@ -12,14 +12,25 @@ class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
tenant_id: str
app_id: str
workflow_id: str
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict,
def __init__(self, tenant_id: str,
app_id: str,
workflow_id: str,
config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.node_id = config.get("id")
if not self.node_id:
raise ValueError("Node ID is required.")
......
......@@ -122,6 +122,7 @@ class WorkflowEngineManager:
while True:
# get next node, multiple target nodes in the future
next_node = self._get_next_node(
workflow_run_state=workflow_run_state,
graph=graph,
predecessor_node=predecessor_node,
callbacks=callbacks
......@@ -198,7 +199,8 @@ class WorkflowEngineManager:
error=error
)
def _get_next_node(self, graph: dict,
def _get_next_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
"""
......@@ -216,7 +218,13 @@ class WorkflowEngineManager:
if not predecessor_node:
for node_config in nodes:
if node_config.get('data', {}).get('type', '') == NodeType.START.value:
return StartNode(config=node_config)
return StartNode(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
config=node_config,
callbacks=callbacks
)
else:
edges = graph.get('edges')
source_node_id = predecessor_node.node_id
......@@ -256,6 +264,9 @@ class WorkflowEngineManager:
target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
return target_node(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
config=target_node_config,
callbacks=callbacks
)
......@@ -354,7 +365,7 @@ class WorkflowEngineManager:
:param node_run_result: node run result
:return:
"""
if workflow_run_state.workflow.type == WorkflowType.CHAT.value and node.node_type == NodeType.END:
if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END:
workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2]
if workflow_nodes_and_result_before_end:
if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment