Commit 5963e7d1 authored by takatost's avatar takatost

completed workflow engine main logic

parent c7618fc3
......@@ -83,7 +83,6 @@ class AdvancedChatAppRunner(AppRunner):
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
app_model=app_record,
workflow=workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
......@@ -94,7 +93,7 @@ class AdvancedChatAppRunner(AppRunner):
SystemVariable.FILES: files,
SystemVariable.CONVERSATION: conversation.id,
},
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)],
)
def handle_input_moderation(self, queue_manager: AppQueueManager,
......
......@@ -253,8 +253,6 @@ class AdvancedChatAppGenerateTaskPipeline:
'error': workflow_run.error,
'elapsed_time': workflow_run.elapsed_time,
'total_tokens': workflow_run.total_tokens,
'total_price': workflow_run.total_price,
'currency': workflow_run.currency,
'total_steps': workflow_run.total_steps,
'created_at': int(workflow_run.created_at.timestamp()),
'finished_at': int(workflow_run.finished_at.timestamp())
......
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from models.workflow import WorkflowNodeExecution, WorkflowRun
......@@ -43,3 +43,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
workflow_node_execution_id=workflow_node_execution.id,
pub_from=PublishFrom.TASK_PIPELINE
)
def on_text_chunk(self, text: str) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish_text_chunk(
text=text,
pub_from=PublishFrom.TASK_PIPELINE
)
......@@ -31,3 +31,11 @@ class BaseWorkflowCallback:
Workflow node execute finished
"""
raise NotImplementedError
@abstractmethod
def on_text_chunk(self, text: str) -> None:
"""
Publish text chunk
"""
raise NotImplementedError
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from models.workflow import WorkflowNodeExecutionStatus
class NodeType(Enum):
......@@ -39,3 +44,19 @@ class SystemVariable(Enum):
QUERY = 'query'
FILES = 'files'
CONVERSATION = 'conversation'
class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs
metadata: Optional[dict] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
from decimal import Decimal
from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecution, WorkflowRun
......@@ -10,7 +8,10 @@ class WorkflowRunState:
variable_pool: VariablePool
total_tokens: int = 0
total_price: Decimal = Decimal(0)
currency: str = "USD"
workflow_node_executions: list[WorkflowNodeExecution] = []
def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None:
self.workflow_run = workflow_run
self.start_at = start_at
self.variable_pool = variable_pool
from abc import abstractmethod
from typing import Optional
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus
class BaseNode:
......@@ -13,17 +14,23 @@ class BaseNode:
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
def __init__(self, config: dict) -> None:
stream_output_supported: bool = False
callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
self.node_id = config.get("id")
if not self.node_id:
raise ValueError("Node ID is required.")
self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
......@@ -33,22 +40,41 @@ class BaseNode:
raise NotImplementedError
def run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> dict:
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node entry
:param variable_pool: variable pool
:param run_args: run args
:param callbacks: callbacks
:return:
"""
if variable_pool is None and run_args is None:
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
return self._run(
variable_pool=variable_pool,
run_args=run_args
)
try:
result = self._run(
variable_pool=variable_pool,
run_args=run_args
)
except Exception as e:
# process unhandled exception
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str) -> None:
"""
Publish text chunk
:param text: chunk text
:return:
"""
if self.stream_output_supported:
if self.callbacks:
for callback in self.callbacks:
callback.on_text_chunk(text)
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
......
......@@ -11,8 +11,6 @@ workflow_run_for_log_fields = {
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField
......@@ -29,8 +27,6 @@ workflow_run_for_list_fields = {
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
"created_at": TimestampField,
......@@ -56,8 +52,6 @@ workflow_run_detail_fields = {
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer,
"created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
......
......@@ -216,8 +216,6 @@ class WorkflowRun(db.Model):
- error (string) `optional` Error reason
- elapsed_time (float) `optional` Time consumption (s)
- total_tokens (int) `optional` Total tokens used
- total_price (decimal) `optional` Total cost
- currency (string) `optional` Currency, such as USD / RMB
- total_steps (int) Total steps (redundant), default 0
- created_by_role (string) Creator role
......@@ -251,8 +249,6 @@ class WorkflowRun(db.Model):
error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0'))
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
total_price = db.Column(db.Numeric(10, 7))
currency = db.Column(db.String(255))
total_steps = db.Column(db.Integer, server_default=db.text('0'))
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment