Commit 5963e7d1 authored by takatost's avatar takatost

completed workflow engine main logic

parent c7618fc3
......@@ -83,7 +83,6 @@ class AdvancedChatAppRunner(AppRunner):
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
app_model=app_record,
workflow=workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
......@@ -94,7 +93,7 @@ class AdvancedChatAppRunner(AppRunner):
SystemVariable.FILES: files,
SystemVariable.CONVERSATION: conversation.id,
},
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)],
)
def handle_input_moderation(self, queue_manager: AppQueueManager,
......
......@@ -253,8 +253,6 @@ class AdvancedChatAppGenerateTaskPipeline:
'error': workflow_run.error,
'elapsed_time': workflow_run.elapsed_time,
'total_tokens': workflow_run.total_tokens,
'total_price': workflow_run.total_price,
'currency': workflow_run.currency,
'total_steps': workflow_run.total_steps,
'created_at': int(workflow_run.created_at.timestamp()),
'finished_at': int(workflow_run.finished_at.timestamp())
......
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from models.workflow import WorkflowNodeExecution, WorkflowRun
......@@ -43,3 +43,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
workflow_node_execution_id=workflow_node_execution.id,
pub_from=PublishFrom.TASK_PIPELINE
)
def on_text_chunk(self, text: str) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish_text_chunk(
text=text,
pub_from=PublishFrom.TASK_PIPELINE
)
......@@ -31,3 +31,11 @@ class BaseWorkflowCallback:
Workflow node execute finished
"""
raise NotImplementedError
@abstractmethod
def on_text_chunk(self, text: str) -> None:
"""
Publish text chunk
"""
raise NotImplementedError
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from models.workflow import WorkflowNodeExecutionStatus
class NodeType(Enum):
......@@ -39,3 +44,19 @@ class SystemVariable(Enum):
QUERY = 'query'
FILES = 'files'
CONVERSATION = 'conversation'
class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs
metadata: Optional[dict] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
from decimal import Decimal
from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecution, WorkflowRun
......@@ -10,7 +8,10 @@ class WorkflowRunState:
variable_pool: VariablePool
total_tokens: int = 0
total_price: Decimal = Decimal(0)
currency: str = "USD"
workflow_node_executions: list[WorkflowNodeExecution] = []
def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None:
self.workflow_run = workflow_run
self.start_at = start_at
self.variable_pool = variable_pool
from abc import abstractmethod
from typing import Optional
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus
class BaseNode:
......@@ -13,17 +14,23 @@ class BaseNode:
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
def __init__(self, config: dict) -> None:
stream_output_supported: bool = False
callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
self.node_id = config.get("id")
if not self.node_id:
raise ValueError("Node ID is required.")
self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
......@@ -33,22 +40,41 @@ class BaseNode:
raise NotImplementedError
def run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> dict:
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node entry
:param variable_pool: variable pool
:param run_args: run args
:param callbacks: callbacks
:return:
"""
if variable_pool is None and run_args is None:
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
return self._run(
try:
result = self._run(
variable_pool=variable_pool,
run_args=run_args
)
except Exception as e:
# process unhandled exception
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str) -> None:
"""
Publish text chunk
:param text: chunk text
:return:
"""
if self.stream_output_supported:
if self.callbacks:
for callback in self.callbacks:
callback.on_text_chunk(text)
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
......
import json
import time
from datetime import datetime
from typing import Optional, Union
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowRunState
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.code.code_node import CodeNode
......@@ -31,6 +32,7 @@ from models.workflow import (
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
WorkflowType,
)
node_classes = {
......@@ -120,8 +122,7 @@ class WorkflowEngineManager:
return default_config
def run_workflow(self, app_model: App,
workflow: Workflow,
def run_workflow(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
......@@ -129,7 +130,6 @@ class WorkflowEngineManager:
callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Run workflow
:param app_model: App instance
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
......@@ -143,13 +143,23 @@ class WorkflowEngineManager:
if not graph:
raise ValueError('workflow graph not found')
if 'nodes' not in graph or 'edges' not in graph:
raise ValueError('nodes or edges not found in workflow graph')
if isinstance(graph.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init workflow run
workflow_run = self._init_workflow_run(
workflow=workflow,
triggered_from=triggered_from,
user=user,
user_inputs=user_inputs,
system_inputs=system_inputs
system_inputs=system_inputs,
callbacks=callbacks
)
# init workflow run state
......@@ -161,44 +171,54 @@ class WorkflowEngineManager:
)
)
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started(workflow_run)
# fetch start node
start_node = self._get_entry_node(graph)
if not start_node:
self._workflow_run_failed(
workflow_run_state=workflow_run_state,
error='Start node not found in workflow graph',
callbacks=callbacks
)
return
# fetch predecessor node ids before end node (include: llm, direct answer)
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
try:
predecessor_node = None
current_node = start_node
while True:
# run workflow
self._run_workflow_node(
workflow_run_state=workflow_run_state,
node=current_node,
# get next node, multiple target nodes in the future
next_node = self._get_next_node(
graph=graph,
predecessor_node=predecessor_node,
callbacks=callbacks
)
if current_node.node_type == NodeType.END:
if not next_node:
break
# todo fetch next node until end node finished or no next node
current_node = None
# check if node is streamable
if next_node.node_id in streamable_node_ids:
next_node.stream_output_supported = True
if not current_node:
break
# max steps 30 reached
if len(workflow_run_state.workflow_node_executions) > 30:
raise ValueError('Max steps 30 reached.')
predecessor_node = current_node
# or max steps 30 reached
# or max execution time 10min reached
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600):
raise ValueError('Max execution time 10min reached.')
# run workflow, run multiple target nodes in the future
self._run_workflow_node(
workflow_run_state=workflow_run_state,
node=next_node,
predecessor_node=predecessor_node,
callbacks=callbacks
)
if next_node.node_type == NodeType.END:
break
predecessor_node = next_node
if not predecessor_node and not next_node:
self._workflow_run_failed(
workflow_run_state=workflow_run_state,
error='Start node not found in workflow graph.',
callbacks=callbacks
)
return
except Exception as e:
self._workflow_run_failed(
workflow_run_state=workflow_run_state,
......@@ -213,11 +233,40 @@ class WorkflowEngineManager:
callbacks=callbacks
)
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param workflow: Workflow instance
:param graph: workflow graph
:return:
"""
workflow_type = WorkflowType.value_of(workflow.type)
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('type') == NodeType.END.value:
if workflow_type == WorkflowType.WORKFLOW:
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
end_node_ids.append(node_config.get('id'))
else:
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> WorkflowRun:
system_inputs: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
"""
Init workflow run
:param workflow: Workflow instance
......@@ -225,6 +274,7 @@ class WorkflowEngineManager:
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
:return:
"""
try:
......@@ -260,6 +310,39 @@ class WorkflowEngineManager:
db.session.rollback()
raise
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started(workflow_run)
return workflow_run
def _workflow_run_success(self, workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
"""
Workflow run success
:param workflow_run_state: workflow run state
:param callbacks: workflow callbacks
:return:
"""
workflow_run = workflow_run_state.workflow_run
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
# fetch last workflow_node_executions
last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1]
if last_workflow_node_execution:
workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs)
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
workflow_run.total_tokens = workflow_run_state.total_tokens
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_run_finished(workflow_run)
return workflow_run
def _workflow_run_failed(self, workflow_run_state: WorkflowRunState,
......@@ -277,9 +360,8 @@ class WorkflowEngineManager:
workflow_run.error = error
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
workflow_run.total_tokens = workflow_run_state.total_tokens
workflow_run.total_price = workflow_run_state.total_price
workflow_run.currency = workflow_run_state.currency
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
......@@ -289,22 +371,78 @@ class WorkflowEngineManager:
return workflow_run
def _get_entry_node(self, graph: dict) -> Optional[StartNode]:
def _get_next_node(self, graph: dict,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
"""
Get entry node
Get next node
multiple target nodes in the future.
:param graph: workflow graph
:param predecessor_node: predecessor node
:param callbacks: workflow callbacks
:return:
"""
nodes = graph.get('nodes')
if not nodes:
return None
for node_config in nodes.items():
if not predecessor_node:
for node_config in nodes:
if node_config.get('type') == NodeType.START.value:
return StartNode(config=node_config)
else:
edges = graph.get('edges')
source_node_id = predecessor_node.node_id
# fetch all outgoing edges from source node
outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id]
if not outgoing_edges:
return None
# fetch target node id from outgoing edges
outgoing_edge = None
source_handle = predecessor_node.node_run_result.edge_source_handle
if source_handle:
for edge in outgoing_edges:
if edge.get('source_handle') and edge.get('source_handle') == source_handle:
outgoing_edge = edge
break
else:
outgoing_edge = outgoing_edges[0]
if not outgoing_edge:
return None
target_node_id = outgoing_edge.get('target')
# fetch target node from target node id
target_node_config = None
for node in nodes:
if node.get('id') == target_node_id:
target_node_config = node
break
if not target_node_config:
return None
# get next node
target_node = node_classes.get(NodeType.value_of(target_node_config.get('type')))
return target_node(
config=target_node_config,
callbacks=callbacks
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
:param start_at: start time
:param max_execution_time: max execution time
:return:
"""
# TODO check queue is stopped
return time.perf_counter() - start_at > max_execution_time
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
predecessor_node: Optional[BaseNode] = None,
......@@ -320,28 +458,41 @@ class WorkflowEngineManager:
# add to workflow node executions
workflow_run_state.workflow_node_executions.append(workflow_node_execution)
try:
# run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result = node.run(
variable_pool=workflow_run_state.variable_pool,
callbacks=callbacks
variable_pool=workflow_run_state.variable_pool
)
except Exception as e:
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
# node run failed
self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
error=str(e),
start_at=start_at,
error=node_run_result.error,
callbacks=callbacks
)
raise
raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
# node run success
self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
start_at=start_at,
result=node_run_result,
callbacks=callbacks
)
for variable_key, variable_value in node_run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
variable_pool=workflow_run_state.variable_pool,
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
if node_run_result.metadata.get('total_tokens'):
workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens'))
return workflow_node_execution
def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState,
......@@ -384,3 +535,86 @@ class WorkflowEngineManager:
callback.on_workflow_node_execute_started(workflow_node_execution)
return workflow_node_execution
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
result: NodeRunResult,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
"""
Workflow node execution success
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param result: node run result
:param callbacks: workflow callbacks
:return:
"""
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(result.inputs)
workflow_node_execution.process_data = json.dumps(result.process_data)
workflow_node_execution.outputs = json.dumps(result.outputs)
workflow_node_execution.execution_metadata = json.dumps(result.metadata)
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_finished(workflow_node_execution)
return workflow_node_execution
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
error: str,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param error: error message
:param callbacks: workflow callbacks
:return:
"""
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_finished(workflow_node_execution)
return workflow_node_execution
def _append_variables_recursively(self, variable_pool: VariablePool,
node_id: str,
variable_key_list: list[str],
variable_value: VariableValue):
"""
Append variables recursively
:param variable_pool: variable pool
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value
)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
variable_pool=variable_pool,
node_id=node_id,
variable_key_list=new_key_list,
variable_value=value
)
......@@ -11,8 +11,6 @@ workflow_run_for_log_fields = {
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField
......@@ -29,8 +27,6 @@ workflow_run_for_list_fields = {
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
"created_at": TimestampField,
......@@ -56,8 +52,6 @@ workflow_run_detail_fields = {
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer,
"created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
......
......@@ -216,8 +216,6 @@ class WorkflowRun(db.Model):
- error (string) `optional` Error reason
- elapsed_time (float) `optional` Time consumption (s)
- total_tokens (int) `optional` Total tokens used
- total_price (decimal) `optional` Total cost
- currency (string) `optional` Currency, such as USD / RMB
- total_steps (int) Total steps (redundant), default 0
- created_by_role (string) Creator role
......@@ -251,8 +249,6 @@ class WorkflowRun(db.Model):
error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0'))
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
total_price = db.Column(db.Numeric(10, 7))
currency = db.Column(db.String(255))
total_steps = db.Column(db.Integer, server_default=db.text('0'))
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment