Commit 8a322d49 authored by takatost's avatar takatost

completed workflow engine main logic

parent c7e5ba1e
...@@ -83,7 +83,6 @@ class AdvancedChatAppRunner(AppRunner): ...@@ -83,7 +83,6 @@ class AdvancedChatAppRunner(AppRunner):
# RUN WORKFLOW # RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow( workflow_engine_manager.run_workflow(
app_model=app_record,
workflow=workflow, workflow=workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN, if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
...@@ -94,7 +93,7 @@ class AdvancedChatAppRunner(AppRunner): ...@@ -94,7 +93,7 @@ class AdvancedChatAppRunner(AppRunner):
SystemVariable.FILES: files, SystemVariable.FILES: files,
SystemVariable.CONVERSATION: conversation.id, SystemVariable.CONVERSATION: conversation.id,
}, },
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)],
) )
def handle_input_moderation(self, queue_manager: AppQueueManager, def handle_input_moderation(self, queue_manager: AppQueueManager,
......
...@@ -253,8 +253,6 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -253,8 +253,6 @@ class AdvancedChatAppGenerateTaskPipeline:
'error': workflow_run.error, 'error': workflow_run.error,
'elapsed_time': workflow_run.elapsed_time, 'elapsed_time': workflow_run.elapsed_time,
'total_tokens': workflow_run.total_tokens, 'total_tokens': workflow_run.total_tokens,
'total_price': workflow_run.total_price,
'currency': workflow_run.currency,
'total_steps': workflow_run.total_steps, 'total_steps': workflow_run.total_steps,
'created_at': int(workflow_run.created_at.timestamp()), 'created_at': int(workflow_run.created_at.timestamp()),
'finished_at': int(workflow_run.finished_at.timestamp()) 'finished_at': int(workflow_run.finished_at.timestamp())
......
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.workflow.callbacks.base_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from models.workflow import WorkflowNodeExecution, WorkflowRun from models.workflow import WorkflowNodeExecution, WorkflowRun
...@@ -43,3 +43,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ...@@ -43,3 +43,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
workflow_node_execution_id=workflow_node_execution.id, workflow_node_execution_id=workflow_node_execution.id,
pub_from=PublishFrom.TASK_PIPELINE pub_from=PublishFrom.TASK_PIPELINE
) )
def on_text_chunk(self, text: str) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish_text_chunk(
text=text,
pub_from=PublishFrom.TASK_PIPELINE
)
...@@ -31,3 +31,11 @@ class BaseWorkflowCallback: ...@@ -31,3 +31,11 @@ class BaseWorkflowCallback:
Workflow node execute finished Workflow node execute finished
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def on_text_chunk(self, text: str) -> None:
"""
Publish text chunk
"""
raise NotImplementedError
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel
from models.workflow import WorkflowNodeExecutionStatus
class NodeType(Enum): class NodeType(Enum):
...@@ -39,3 +44,19 @@ class SystemVariable(Enum): ...@@ -39,3 +44,19 @@ class SystemVariable(Enum):
QUERY = 'query' QUERY = 'query'
FILES = 'files' FILES = 'files'
CONVERSATION = 'conversation' CONVERSATION = 'conversation'
class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs
metadata: Optional[dict] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
from decimal import Decimal
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecution, WorkflowRun from models.workflow import WorkflowNodeExecution, WorkflowRun
...@@ -10,7 +8,10 @@ class WorkflowRunState: ...@@ -10,7 +8,10 @@ class WorkflowRunState:
variable_pool: VariablePool variable_pool: VariablePool
total_tokens: int = 0 total_tokens: int = 0
total_price: Decimal = Decimal(0)
currency: str = "USD"
workflow_node_executions: list[WorkflowNodeExecution] = [] workflow_node_executions: list[WorkflowNodeExecution] = []
def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None:
self.workflow_run = workflow_run
self.start_at = start_at
self.variable_pool = variable_pool
from abc import abstractmethod from abc import abstractmethod
from typing import Optional from typing import Optional
from core.workflow.callbacks.base_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus
class BaseNode: class BaseNode:
...@@ -13,17 +14,23 @@ class BaseNode: ...@@ -13,17 +14,23 @@ class BaseNode:
node_id: str node_id: str
node_data: BaseNodeData node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
def __init__(self, config: dict) -> None: stream_output_supported: bool = False
callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
self.node_id = config.get("id") self.node_id = config.get("id")
if not self.node_id: if not self.node_id:
raise ValueError("Node ID is required.") raise ValueError("Node ID is required.")
self.node_data = self._node_data_cls(**config.get("data", {})) self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod @abstractmethod
def _run(self, variable_pool: Optional[VariablePool] = None, def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict: run_args: Optional[dict] = None) -> NodeRunResult:
""" """
Run node Run node
:param variable_pool: variable pool :param variable_pool: variable pool
...@@ -33,22 +40,41 @@ class BaseNode: ...@@ -33,22 +40,41 @@ class BaseNode:
raise NotImplementedError raise NotImplementedError
def run(self, variable_pool: Optional[VariablePool] = None, def run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None, run_args: Optional[dict] = None) -> NodeRunResult:
callbacks: list[BaseWorkflowCallback] = None) -> dict:
""" """
Run node entry Run node entry
:param variable_pool: variable pool :param variable_pool: variable pool
:param run_args: run args :param run_args: run args
:param callbacks: callbacks
:return: :return:
""" """
if variable_pool is None and run_args is None: if variable_pool is None and run_args is None:
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.") raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
return self._run( try:
variable_pool=variable_pool, result = self._run(
run_args=run_args variable_pool=variable_pool,
) run_args=run_args
)
except Exception as e:
# process unhandled exception
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str) -> None:
"""
Publish text chunk
:param text: chunk text
:return:
"""
if self.stream_output_supported:
if self.callbacks:
for callback in self.callbacks:
callback.on_text_chunk(text)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
......
...@@ -11,8 +11,6 @@ workflow_run_for_log_fields = { ...@@ -11,8 +11,6 @@ workflow_run_for_log_fields = {
"error": fields.String, "error": fields.String,
"elapsed_time": fields.Float, "elapsed_time": fields.Float,
"total_tokens": fields.Integer, "total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer, "total_steps": fields.Integer,
"created_at": TimestampField, "created_at": TimestampField,
"finished_at": TimestampField "finished_at": TimestampField
...@@ -29,8 +27,6 @@ workflow_run_for_list_fields = { ...@@ -29,8 +27,6 @@ workflow_run_for_list_fields = {
"error": fields.String, "error": fields.String,
"elapsed_time": fields.Float, "elapsed_time": fields.Float,
"total_tokens": fields.Integer, "total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer, "total_steps": fields.Integer,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
"created_at": TimestampField, "created_at": TimestampField,
...@@ -56,8 +52,6 @@ workflow_run_detail_fields = { ...@@ -56,8 +52,6 @@ workflow_run_detail_fields = {
"error": fields.String, "error": fields.String,
"elapsed_time": fields.Float, "elapsed_time": fields.Float,
"total_tokens": fields.Integer, "total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer, "total_steps": fields.Integer,
"created_by_role": fields.String, "created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
......
...@@ -216,8 +216,6 @@ class WorkflowRun(db.Model): ...@@ -216,8 +216,6 @@ class WorkflowRun(db.Model):
- error (string) `optional` Error reason - error (string) `optional` Error reason
- elapsed_time (float) `optional` Time consumption (s) - elapsed_time (float) `optional` Time consumption (s)
- total_tokens (int) `optional` Total tokens used - total_tokens (int) `optional` Total tokens used
- total_price (decimal) `optional` Total cost
- currency (string) `optional` Currency, such as USD / RMB
- total_steps (int) Total steps (redundant), default 0 - total_steps (int) Total steps (redundant), default 0
- created_by_role (string) Creator role - created_by_role (string) Creator role
...@@ -251,8 +249,6 @@ class WorkflowRun(db.Model): ...@@ -251,8 +249,6 @@ class WorkflowRun(db.Model):
error = db.Column(db.Text) error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0'))
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
total_price = db.Column(db.Numeric(10, 7))
currency = db.Column(db.String(255))
total_steps = db.Column(db.Integer, server_default=db.text('0')) total_steps = db.Column(db.Integer, server_default=db.text('0'))
created_by_role = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(UUID, nullable=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment