Commit 823d423f authored by takatost's avatar takatost

Merge branch 'feat/workflow-backend' into deploy/dev

parents 82daa14c 760c7f04
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
from typing import cast from typing import cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
...@@ -10,7 +11,6 @@ from core.app.entities.app_invoke_entities import ( ...@@ -10,7 +11,6 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom, InvokeFrom,
) )
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.node_entities import SystemVariable
from core.workflow.workflow_engine_manager import WorkflowEngineManager from core.workflow.workflow_engine_manager import WorkflowEngineManager
...@@ -93,7 +93,10 @@ class AdvancedChatAppRunner(AppRunner): ...@@ -93,7 +93,10 @@ class AdvancedChatAppRunner(AppRunner):
SystemVariable.FILES: files, SystemVariable.FILES: files,
SystemVariable.CONVERSATION: conversation.id, SystemVariable.CONVERSATION: conversation.id,
}, },
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] callbacks=[WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
) )
def handle_input_moderation(self, queue_manager: AppQueueManager, def handle_input_moderation(self, queue_manager: AppQueueManager,
......
...@@ -7,13 +7,15 @@ from core.app.entities.queue_entities import ( ...@@ -7,13 +7,15 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
) )
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from models.workflow import WorkflowNodeExecution, WorkflowRun from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
class WorkflowEventTriggerCallback(BaseWorkflowCallback): class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
""" """
...@@ -51,13 +53,34 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ...@@ -51,13 +53,34 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
PublishFrom.APPLICATION_MANAGER PublishFrom.APPLICATION_MANAGER
) )
def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_text_chunk(self, text: str) -> None:
""" """
Publish text chunk Publish text chunk
""" """
self._queue_manager.publish( if node_id in self._streamable_node_ids:
QueueTextChunkEvent( self._queue_manager.publish(
text=text QueueTextChunkEvent(
), PublishFrom.APPLICATION_MANAGER text=text
) ), PublishFrom.APPLICATION_MANAGER
)
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param graph: workflow graph
:return:
"""
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('type') == NodeType.END.value:
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
...@@ -4,13 +4,13 @@ from typing import cast ...@@ -4,13 +4,13 @@ from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AppGenerateEntity, AppGenerateEntity,
InvokeFrom, InvokeFrom,
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.moderation.input_moderation import InputModeration from core.moderation.input_moderation import InputModeration
from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.node_entities import SystemVariable
...@@ -76,7 +76,10 @@ class WorkflowAppRunner: ...@@ -76,7 +76,10 @@ class WorkflowAppRunner:
system_inputs={ system_inputs={
SystemVariable.FILES: files SystemVariable.FILES: files
}, },
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] callbacks=[WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
) )
def handle_input_moderation(self, queue_manager: AppQueueManager, def handle_input_moderation(self, queue_manager: AppQueueManager,
......
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
QueueNodeFinishedEvent,
QueueNodeStartedEvent,
QueueTextChunkEvent,
QueueWorkflowFinishedEvent,
QueueWorkflowStartedEvent,
)
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
"""
Workflow run finished
"""
self._queue_manager.publish(
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute finished
"""
self._queue_manager.publish(
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str) -> None:
"""
Publish text chunk
"""
if node_id in self._streamable_node_ids:
self._queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param graph: workflow graph
:return:
"""
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('type') == NodeType.END.value:
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
from abc import abstractmethod from abc import ABC, abstractmethod
from models.workflow import WorkflowNodeExecution, WorkflowRun from models.workflow import WorkflowNodeExecution, WorkflowRun
class BaseWorkflowCallback: class BaseWorkflowCallback(ABC):
@abstractmethod @abstractmethod
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
""" """
...@@ -33,7 +33,7 @@ class BaseWorkflowCallback: ...@@ -33,7 +33,7 @@ class BaseWorkflowCallback:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def on_text_chunk(self, text: str) -> None: def on_node_text_chunk(self, node_id: str, text: str) -> None:
""" """
Publish text chunk Publish text chunk
""" """
......
...@@ -5,7 +5,5 @@ from pydantic import BaseModel ...@@ -5,7 +5,5 @@ from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel): class BaseNodeData(ABC, BaseModel):
type: str
title: str title: str
desc: Optional[str] = None desc: Optional[str] = None
from enum import Enum from enum import Enum
from typing import Optional from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
...@@ -46,6 +46,15 @@ class SystemVariable(Enum): ...@@ -46,6 +46,15 @@ class SystemVariable(Enum):
CONVERSATION = 'conversation' CONVERSATION = 'conversation'
class NodeRunMetadataKey(Enum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = 'total_tokens'
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
class NodeRunResult(BaseModel): class NodeRunResult(BaseModel):
""" """
Node Run Result. Node Run Result.
...@@ -55,7 +64,7 @@ class NodeRunResult(BaseModel): ...@@ -55,7 +64,7 @@ class NodeRunResult(BaseModel):
inputs: Optional[dict] = None # node inputs inputs: Optional[dict] = None # node inputs
process_data: Optional[dict] = None # process data process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs outputs: Optional[dict] = None # node outputs
metadata: Optional[dict] = None # node metadata metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
......
from pydantic import BaseModel
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: list[str]
...@@ -5,13 +5,18 @@ from models.workflow import WorkflowNodeExecution, WorkflowRun ...@@ -5,13 +5,18 @@ from models.workflow import WorkflowNodeExecution, WorkflowRun
class WorkflowRunState: class WorkflowRunState:
workflow_run: WorkflowRun workflow_run: WorkflowRun
start_at: float start_at: float
user_inputs: dict
variable_pool: VariablePool variable_pool: VariablePool
total_tokens: int = 0 total_tokens: int = 0
workflow_node_executions: list[WorkflowNodeExecution] = [] workflow_node_executions: list[WorkflowNodeExecution] = []
def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None: def __init__(self, workflow_run: WorkflowRun,
start_at: float,
user_inputs: dict,
variable_pool: VariablePool) -> None:
self.workflow_run = workflow_run self.workflow_run = workflow_run
self.start_at = start_at self.start_at = start_at
self.user_inputs = user_inputs
self.variable_pool = variable_pool self.variable_pool = variable_pool
from abc import abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
...@@ -8,7 +8,7 @@ from core.workflow.entities.variable_pool import VariablePool ...@@ -8,7 +8,7 @@ from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
class BaseNode: class BaseNode(ABC):
_node_data_cls: type[BaseNodeData] _node_data_cls: type[BaseNodeData]
_node_type: NodeType _node_type: NodeType
...@@ -16,7 +16,6 @@ class BaseNode: ...@@ -16,7 +16,6 @@ class BaseNode:
node_data: BaseNodeData node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None node_run_result: Optional[NodeRunResult] = None
stream_output_supported: bool = False
callbacks: list[BaseWorkflowCallback] callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict, def __init__(self, config: dict,
...@@ -71,10 +70,12 @@ class BaseNode: ...@@ -71,10 +70,12 @@ class BaseNode:
:param text: chunk text :param text: chunk text
:return: :return:
""" """
if self.stream_output_supported: if self.callbacks:
if self.callbacks: for callback in self.callbacks:
for callback in self.callbacks: callback.on_node_text_chunk(
callback.on_text_chunk(text) node_id=self.node_id,
text=text
)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
......
import time
from typing import Optional, cast
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class DirectAnswerNode(BaseNode): class DirectAnswerNode(BaseNode):
pass _node_data_cls = DirectAnswerNodeData
node_type = NodeType.DIRECT_ANSWER
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
if variable_pool is None and run_args:
raise ValueError("Not support single step debug.")
variable_values = {}
for variable_selector in node_data.variables:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector,
target_value_type=ValueType.STRING
)
variable_values[variable_selector.variable] = value
# format answer template
template_parser = PromptTemplateParser(node_data.answer)
answer = template_parser.format(variable_values)
# publish answer as stream
for word in answer:
self.publish_text_chunk(word)
time.sleep(0.01)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values,
output={
"answer": answer
}
)
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class DirectAnswerNodeData(BaseNodeData):
"""
DirectAnswer Node Data.
"""
variables: list[VariableSelector] = []
answer: str
from typing import Optional, cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs
from models.workflow import WorkflowNodeExecutionStatus
class EndNode(BaseNode): class EndNode(BaseNode):
pass _node_data_cls = EndNodeData
node_type = NodeType.END
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
outputs_config = node_data.outputs
if variable_pool is not None:
outputs = None
if outputs_config:
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
plain_text_selector = outputs_config.plain_text_selector
if plain_text_selector:
outputs = {
'text': variable_pool.get_variable_value(
variable_selector=plain_text_selector,
target_value_type=ValueType.STRING
)
}
else:
outputs = {
'text': ''
}
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
structured_variables = outputs_config.structured_variables
if structured_variables:
outputs = {}
for variable_selector in structured_variables:
variable_value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = variable_value
else:
outputs = {}
else:
raise ValueError("Not support single step debug.")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs
)
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class EndNodeOutputType(Enum): class EndNodeOutputType(Enum):
...@@ -23,3 +29,40 @@ class EndNodeOutputType(Enum): ...@@ -23,3 +29,40 @@ class EndNodeOutputType(Enum):
if output_type.value == value: if output_type.value == value:
return output_type return output_type
raise ValueError(f'invalid output type value {value}') raise ValueError(f'invalid output type value {value}')
class EndNodeDataOutputs(BaseModel):
"""
END Node Data Outputs.
"""
class OutputType(Enum):
"""
Output Types.
"""
NONE = 'none'
PLAIN_TEXT = 'plain-text'
STRUCTURED = 'structured'
@classmethod
def value_of(cls, value: str) -> 'OutputType':
"""
Get value of given output type.
:param value: output type value
:return: output type
"""
for output_type in cls:
if output_type.value == value:
return output_type
raise ValueError(f'invalid output type value {value}')
type: OutputType = OutputType.NONE
plain_text_selector: Optional[list[str]] = None
structured_variables: Optional[list[VariableSelector]] = None
class EndNodeData(BaseNodeData):
"""
END Node Data.
"""
outputs: Optional[EndNodeDataOutputs] = None
from core.workflow.entities.base_node_data_entities import BaseNodeData
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
pass
from typing import Optional from typing import Optional, cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData
class LLMNode(BaseNode): class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
node_type = NodeType.LLM
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
pass
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
""" """
......
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class StartNodeData(BaseNodeData): class StartNodeData(BaseNodeData):
""" """
- title (string) 节点标题 Start Node Data
- desc (string) optional 节点描述
- type (string) 节点类型,固定为 start
- variables (array[object]) 表单变量列表
- type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义)
- label (string) 控件展示标签名
- variable (string) 变量 key
- max_length (int) 最大长度,适用于 text-input 和 paragraph
- default (string) optional 默认值
- required (bool) optional是否必填,默认 false
- hint (string) optional 提示信息
- options (array[string]) 选项值(仅 select 可用)
""" """
type: str = NodeType.START.value
variables: list[VariableEntity] = [] variables: list[VariableEntity] = []
from typing import Optional from typing import Optional, cast
from core.workflow.entities.node_entities import NodeType from core.app.app_config.entities import VariableEntity
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class StartNode(BaseNode): class StartNode(BaseNode):
...@@ -11,12 +13,58 @@ class StartNode(BaseNode): ...@@ -11,12 +13,58 @@ class StartNode(BaseNode):
node_type = NodeType.START node_type = NodeType.START
def _run(self, variable_pool: Optional[VariablePool] = None, def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict: run_args: Optional[dict] = None) -> NodeRunResult:
""" """
Run node Run node
:param variable_pool: variable pool :param variable_pool: variable pool
:param run_args: run args :param run_args: run args
:return: :return:
""" """
pass node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
variables = node_data.variables
# Get cleaned inputs
cleaned_inputs = self._get_cleaned_inputs(variables, run_args)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
)
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"Input form variable {variable} is required")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value:
if not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
from datetime import datetime from datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
...@@ -32,7 +33,6 @@ from models.workflow import ( ...@@ -32,7 +33,6 @@ from models.workflow import (
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
WorkflowRunTriggeredFrom, WorkflowRunTriggeredFrom,
WorkflowType,
) )
node_classes = { node_classes = {
...@@ -52,30 +52,6 @@ node_classes = { ...@@ -52,30 +52,6 @@ node_classes = {
class WorkflowEngineManager: class WorkflowEngineManager:
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
"""
Get draft workflow
"""
# fetch draft workflow by app_model
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == 'draft'
).first()
# return draft workflow
return workflow
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
"""
Get published workflow
"""
if not app_model.workflow_id:
return None
# fetch published workflow by workflow_id
return self.get_workflow(app_model, app_model.workflow_id)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
""" """
Get workflow Get workflow
...@@ -166,14 +142,12 @@ class WorkflowEngineManager: ...@@ -166,14 +142,12 @@ class WorkflowEngineManager:
workflow_run_state = WorkflowRunState( workflow_run_state = WorkflowRunState(
workflow_run=workflow_run, workflow_run=workflow_run,
start_at=time.perf_counter(), start_at=time.perf_counter(),
user_inputs=user_inputs,
variable_pool=VariablePool( variable_pool=VariablePool(
system_variables=system_inputs, system_variables=system_inputs,
) )
) )
# fetch predecessor node ids before end node (include: llm, direct answer)
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
try: try:
predecessor_node = None predecessor_node = None
while True: while True:
...@@ -187,10 +161,6 @@ class WorkflowEngineManager: ...@@ -187,10 +161,6 @@ class WorkflowEngineManager:
if not next_node: if not next_node:
break break
# check if node is streamable
if next_node.node_id in streamable_node_ids:
next_node.stream_output_supported = True
# max steps 30 reached # max steps 30 reached
if len(workflow_run_state.workflow_node_executions) > 30: if len(workflow_run_state.workflow_node_executions) > 30:
raise ValueError('Max steps 30 reached.') raise ValueError('Max steps 30 reached.')
...@@ -233,34 +203,6 @@ class WorkflowEngineManager: ...@@ -233,34 +203,6 @@ class WorkflowEngineManager:
callbacks=callbacks callbacks=callbacks
) )
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param workflow: Workflow instance
:param graph: workflow graph
:return:
"""
workflow_type = WorkflowType.value_of(workflow.type)
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('type') == NodeType.END.value:
if workflow_type == WorkflowType.WORKFLOW:
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
end_node_ids.append(node_config.get('id'))
else:
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
def _init_workflow_run(self, workflow: Workflow, def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom, triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser], user: Union[Account, EndUser],
...@@ -440,7 +382,6 @@ class WorkflowEngineManager: ...@@ -440,7 +382,6 @@ class WorkflowEngineManager:
:param max_execution_time: max execution time :param max_execution_time: max execution time
:return: :return:
""" """
# TODO check queue is stopped
return time.perf_counter() - start_at > max_execution_time return time.perf_counter() - start_at > max_execution_time
def _run_workflow_node(self, workflow_run_state: WorkflowRunState, def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
...@@ -460,7 +401,9 @@ class WorkflowEngineManager: ...@@ -460,7 +401,9 @@ class WorkflowEngineManager:
# run node, result must have inputs, process_data, outputs, execution_metadata # run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result = node.run( node_run_result = node.run(
variable_pool=workflow_run_state.variable_pool variable_pool=workflow_run_state.variable_pool,
run_args=workflow_run_state.user_inputs
if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node
) )
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
...@@ -553,7 +496,7 @@ class WorkflowEngineManager: ...@@ -553,7 +496,7 @@ class WorkflowEngineManager:
workflow_node_execution.inputs = json.dumps(result.inputs) workflow_node_execution.inputs = json.dumps(result.inputs)
workflow_node_execution.process_data = json.dumps(result.process_data) workflow_node_execution.process_data = json.dumps(result.process_data)
workflow_node_execution.outputs = json.dumps(result.outputs) workflow_node_execution.outputs = json.dumps(result.outputs)
workflow_node_execution.execution_metadata = json.dumps(result.metadata) workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata))
workflow_node_execution.finished_at = datetime.utcnow() workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit() db.session.commit()
......
"""conversation columns set nullable
Revision ID: 42e85ed5564d
Revises: f9107f83abab
Create Date: 2024-03-07 08:30:29.133614
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '42e85ed5564d'
down_revision = 'f9107f83abab'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('app_model_config_id',
existing_type=postgresql.UUID(),
nullable=True)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True)
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('app_model_config_id',
existing_type=postgresql.UUID(),
nullable=False)
# ### end Alembic commands ###
...@@ -78,8 +78,6 @@ def upgrade(): ...@@ -78,8 +78,6 @@ def upgrade():
sa.Column('error', sa.Text(), nullable=True), sa.Column('error', sa.Text(), nullable=True),
sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), sa.Column('elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False),
sa.Column('total_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), sa.Column('total_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False),
sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True),
sa.Column('currency', sa.String(length=255), nullable=True),
sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), sa.Column('total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True),
sa.Column('created_by_role', sa.String(length=255), nullable=False), sa.Column('created_by_role', sa.String(length=255), nullable=False),
sa.Column('created_by', postgresql.UUID(), nullable=False), sa.Column('created_by', postgresql.UUID(), nullable=False),
......
...@@ -26,22 +26,28 @@ class WorkflowService: ...@@ -26,22 +26,28 @@ class WorkflowService:
""" """
Get draft workflow Get draft workflow
""" """
workflow_engine_manager = WorkflowEngineManager() # fetch draft workflow by app_model
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == 'draft'
).first()
# return draft workflow # return draft workflow
return workflow_engine_manager.get_draft_workflow(app_model=app_model) return workflow
def get_published_workflow(self, app_model: App) -> Optional[Workflow]: def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
""" """
Get published workflow Get published workflow
""" """
if not app_model.workflow_id: if not app_model.workflow_id:
return None return None
workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager = WorkflowEngineManager()
# return published workflow # fetch published workflow by workflow_id
return workflow_engine_manager.get_published_workflow(app_model=app_model) return workflow_engine_manager.get_workflow(app_model, app_model.workflow_id)
def sync_draft_workflow(self, app_model: App, def sync_draft_workflow(self, app_model: App,
graph: dict, graph: dict,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment