Commit 44ba9011 authored by takatost's avatar takatost

use callback to filter workflow stream output

parent 2d351c62
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
from typing import cast from typing import cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
...@@ -10,7 +11,6 @@ from core.app.entities.app_invoke_entities import ( ...@@ -10,7 +11,6 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom, InvokeFrom,
) )
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.node_entities import SystemVariable
from core.workflow.workflow_engine_manager import WorkflowEngineManager from core.workflow.workflow_engine_manager import WorkflowEngineManager
...@@ -93,7 +93,10 @@ class AdvancedChatAppRunner(AppRunner): ...@@ -93,7 +93,10 @@ class AdvancedChatAppRunner(AppRunner):
SystemVariable.FILES: files, SystemVariable.FILES: files,
SystemVariable.CONVERSATION: conversation.id, SystemVariable.CONVERSATION: conversation.id,
}, },
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] callbacks=[WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
) )
def handle_input_moderation(self, queue_manager: AppQueueManager, def handle_input_moderation(self, queue_manager: AppQueueManager,
......
...@@ -7,13 +7,15 @@ from core.app.entities.queue_entities import ( ...@@ -7,13 +7,15 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
) )
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from models.workflow import WorkflowNodeExecution, WorkflowRun from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
class WorkflowEventTriggerCallback(BaseWorkflowCallback): class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
""" """
...@@ -51,13 +53,34 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ...@@ -51,13 +53,34 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
PublishFrom.APPLICATION_MANAGER PublishFrom.APPLICATION_MANAGER
) )
def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_text_chunk(self, text: str) -> None:
""" """
Publish text chunk Publish text chunk
""" """
self._queue_manager.publish( if node_id in self._streamable_node_ids:
QueueTextChunkEvent( self._queue_manager.publish(
text=text QueueTextChunkEvent(
), PublishFrom.APPLICATION_MANAGER text=text
) ), PublishFrom.APPLICATION_MANAGER
)
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param graph: workflow graph
:return:
"""
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('type') == NodeType.END.value:
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
...@@ -4,13 +4,13 @@ from typing import cast ...@@ -4,13 +4,13 @@ from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AppGenerateEntity, AppGenerateEntity,
InvokeFrom, InvokeFrom,
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.moderation.input_moderation import InputModeration from core.moderation.input_moderation import InputModeration
from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.node_entities import SystemVariable
...@@ -76,7 +76,10 @@ class WorkflowAppRunner: ...@@ -76,7 +76,10 @@ class WorkflowAppRunner:
system_inputs={ system_inputs={
SystemVariable.FILES: files SystemVariable.FILES: files
}, },
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] callbacks=[WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
) )
def handle_input_moderation(self, queue_manager: AppQueueManager, def handle_input_moderation(self, queue_manager: AppQueueManager,
......
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
QueueNodeFinishedEvent,
QueueNodeStartedEvent,
QueueTextChunkEvent,
QueueWorkflowFinishedEvent,
QueueWorkflowStartedEvent,
)
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
"""
Workflow run finished
"""
self._queue_manager.publish(
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute finished
"""
self._queue_manager.publish(
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str) -> None:
"""
Publish text chunk
"""
if node_id in self._streamable_node_ids:
self._queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param graph: workflow graph
:return:
"""
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('type') == NodeType.END.value:
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
from abc import abstractmethod from abc import ABC, abstractmethod
from models.workflow import WorkflowNodeExecution, WorkflowRun from models.workflow import WorkflowNodeExecution, WorkflowRun
class BaseWorkflowCallback: class BaseWorkflowCallback(ABC):
@abstractmethod @abstractmethod
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
""" """
...@@ -33,7 +33,7 @@ class BaseWorkflowCallback: ...@@ -33,7 +33,7 @@ class BaseWorkflowCallback:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def on_text_chunk(self, text: str) -> None: def on_node_text_chunk(self, node_id: str, text: str) -> None:
""" """
Publish text chunk Publish text chunk
""" """
......
...@@ -16,7 +16,6 @@ class BaseNode: ...@@ -16,7 +16,6 @@ class BaseNode:
node_data: BaseNodeData node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None node_run_result: Optional[NodeRunResult] = None
stream_output_supported: bool = False
callbacks: list[BaseWorkflowCallback] callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict, def __init__(self, config: dict,
...@@ -71,10 +70,12 @@ class BaseNode: ...@@ -71,10 +70,12 @@ class BaseNode:
:param text: chunk text :param text: chunk text
:return: :return:
""" """
if self.stream_output_supported: if self.callbacks:
if self.callbacks: for callback in self.callbacks:
for callback in self.callbacks: callback.on_node_text_chunk(
callback.on_text_chunk(text) node_id=self.node_id,
text=text
)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
......
...@@ -32,7 +32,6 @@ from models.workflow import ( ...@@ -32,7 +32,6 @@ from models.workflow import (
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
WorkflowRunTriggeredFrom, WorkflowRunTriggeredFrom,
WorkflowType,
) )
node_classes = { node_classes = {
...@@ -171,9 +170,6 @@ class WorkflowEngineManager: ...@@ -171,9 +170,6 @@ class WorkflowEngineManager:
) )
) )
# fetch predecessor node ids before end node (include: llm, direct answer)
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
try: try:
predecessor_node = None predecessor_node = None
while True: while True:
...@@ -187,10 +183,6 @@ class WorkflowEngineManager: ...@@ -187,10 +183,6 @@ class WorkflowEngineManager:
if not next_node: if not next_node:
break break
# check if node is streamable
if next_node.node_id in streamable_node_ids:
next_node.stream_output_supported = True
# max steps 30 reached # max steps 30 reached
if len(workflow_run_state.workflow_node_executions) > 30: if len(workflow_run_state.workflow_node_executions) > 30:
raise ValueError('Max steps 30 reached.') raise ValueError('Max steps 30 reached.')
...@@ -233,34 +225,6 @@ class WorkflowEngineManager: ...@@ -233,34 +225,6 @@ class WorkflowEngineManager:
callbacks=callbacks callbacks=callbacks
) )
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param workflow: Workflow instance
:param graph: workflow graph
:return:
"""
workflow_type = WorkflowType.value_of(workflow.type)
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('type') == NodeType.END.value:
if workflow_type == WorkflowType.WORKFLOW:
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
end_node_ids.append(node_config.get('id'))
else:
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
def _init_workflow_run(self, workflow: Workflow, def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom, triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser], user: Union[Account, EndUser],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment