Commit 2db67c41 authored by takatost's avatar takatost

refactor pipeline and remove node run run_args

parent 80b4db08
...@@ -55,6 +55,19 @@ class TaskState(BaseModel): ...@@ -55,6 +55,19 @@ class TaskState(BaseModel):
""" """
TaskState entity TaskState entity
""" """
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution: WorkflowNodeExecution
start_at: float
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
answer: str = "" answer: str = ""
metadata: dict = {} metadata: dict = {}
usage: LLMUsage usage: LLMUsage
...@@ -64,8 +77,8 @@ class TaskState(BaseModel): ...@@ -64,8 +77,8 @@ class TaskState(BaseModel):
total_tokens: int = 0 total_tokens: int = 0
total_steps: int = 0 total_steps: int = 0
current_node_execution: Optional[WorkflowNodeExecution] = None running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
current_node_execution_start_at: Optional[float] = None latest_node_execution_info: Optional[NodeExecutionInfo] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
...@@ -218,7 +231,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -218,7 +231,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
self._on_node_start(event) self._on_node_start(event)
workflow_node_execution = self._task_state.current_node_execution workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
response = { response = {
'event': 'node_started', 'event': 'node_started',
...@@ -237,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -237,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event) self._on_node_finished(event)
workflow_node_execution = self._task_state.current_node_execution workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
if workflow_node_execution.node_type == NodeType.LLM.value: if workflow_node_execution.node_type == NodeType.LLM.value:
...@@ -447,15 +460,21 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -447,15 +460,21 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
predecessor_node_id=event.predecessor_node_id predecessor_node_id=event.predecessor_node_id
) )
self._task_state.current_node_execution = workflow_node_execution latest_node_execution_info = TaskState.NodeExecutionInfo(
self._task_state.current_node_execution_start_at = time.perf_counter() workflow_node_execution=workflow_node_execution,
start_at=time.perf_counter()
)
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1 self._task_state.total_steps += 1
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
if isinstance(event, QueueNodeSucceededEvent): if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success( workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=self._task_state.current_node_execution, workflow_node_execution=current_node_execution.workflow_node_execution,
start_at=self._task_state.current_node_execution_start_at, start_at=current_node_execution.start_at,
inputs=event.inputs, inputs=event.inputs,
process_data=event.process_data, process_data=event.process_data,
outputs=event.outputs, outputs=event.outputs,
...@@ -472,12 +491,14 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -472,12 +491,14 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
self._task_state.metadata['usage'] = usage_dict self._task_state.metadata['usage'] = usage_dict
else: else:
workflow_node_execution = self._workflow_node_execution_failed( workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=self._task_state.current_node_execution, workflow_node_execution=current_node_execution.workflow_node_execution,
start_at=self._task_state.current_node_execution_start_at, start_at=current_node_execution.start_at,
error=event.error error=event.error
) )
self._task_state.current_node_execution = workflow_node_execution # remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None:
if isinstance(event, QueueStopEvent): if isinstance(event, QueueStopEvent):
...@@ -504,8 +525,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -504,8 +525,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
start_at=self._task_state.start_at, start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens, total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps, total_steps=self._task_state.total_steps,
outputs=self._task_state.current_node_execution.outputs outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs
if self._task_state.current_node_execution else None if self._task_state.latest_node_execution_info else None
) )
self._task_state.workflow_run = workflow_run self._task_state.workflow_run = workflow_run
......
...@@ -41,6 +41,19 @@ class TaskState(BaseModel): ...@@ -41,6 +41,19 @@ class TaskState(BaseModel):
""" """
TaskState entity TaskState entity
""" """
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution: WorkflowNodeExecution
start_at: float
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
answer: str = "" answer: str = ""
metadata: dict = {} metadata: dict = {}
...@@ -49,8 +62,8 @@ class TaskState(BaseModel): ...@@ -49,8 +62,8 @@ class TaskState(BaseModel):
total_tokens: int = 0 total_tokens: int = 0
total_steps: int = 0 total_steps: int = 0
current_node_execution: Optional[WorkflowNodeExecution] = None running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
current_node_execution_start_at: Optional[float] = None latest_node_execution_info: Optional[NodeExecutionInfo] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
...@@ -179,7 +192,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -179,7 +192,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
self._on_node_start(event) self._on_node_start(event)
workflow_node_execution = self._task_state.current_node_execution workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
response = { response = {
'event': 'node_started', 'event': 'node_started',
...@@ -198,7 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -198,7 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event) self._on_node_finished(event)
workflow_node_execution = self._task_state.current_node_execution workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
response = { response = {
'event': 'node_finished', 'event': 'node_finished',
...@@ -339,15 +352,22 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -339,15 +352,22 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
predecessor_node_id=event.predecessor_node_id predecessor_node_id=event.predecessor_node_id
) )
self._task_state.current_node_execution = workflow_node_execution latest_node_execution_info = TaskState.NodeExecutionInfo(
self._task_state.current_node_execution_start_at = time.perf_counter() workflow_node_execution=workflow_node_execution,
start_at=time.perf_counter()
)
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1 self._task_state.total_steps += 1
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
if isinstance(event, QueueNodeSucceededEvent): if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success( workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=self._task_state.current_node_execution, workflow_node_execution=current_node_execution.workflow_node_execution,
start_at=self._task_state.current_node_execution_start_at, start_at=current_node_execution.start_at,
inputs=event.inputs, inputs=event.inputs,
process_data=event.process_data, process_data=event.process_data,
outputs=event.outputs, outputs=event.outputs,
...@@ -359,12 +379,14 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -359,12 +379,14 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
else: else:
workflow_node_execution = self._workflow_node_execution_failed( workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=self._task_state.current_node_execution, workflow_node_execution=current_node_execution.workflow_node_execution,
start_at=self._task_state.current_node_execution_start_at, start_at=current_node_execution.start_at,
error=event.error error=event.error
) )
self._task_state.current_node_execution = workflow_node_execution # remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None:
if isinstance(event, QueueStopEvent): if isinstance(event, QueueStopEvent):
...@@ -391,8 +413,8 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -391,8 +413,8 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
start_at=self._task_state.start_at, start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens, total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps, total_steps=self._task_state.total_steps,
outputs=self._task_state.current_node_execution.outputs outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs
if self._task_state.current_node_execution else None if self._task_state.latest_node_execution_info else None
) )
self._task_state.workflow_run = workflow_run self._task_state.workflow_run = workflow_run
......
...@@ -19,14 +19,17 @@ class ValueType(Enum): ...@@ -19,14 +19,17 @@ class ValueType(Enum):
class VariablePool: class VariablePool:
variables_mapping = {} variables_mapping = {}
user_inputs: dict
def __init__(self, system_variables: dict[SystemVariable, Any]) -> None: def __init__(self, system_variables: dict[SystemVariable, Any],
user_inputs: dict) -> None:
# system variables # system variables
# for example: # for example:
# { # {
# 'query': 'abc', # 'query': 'abc',
# 'files': [] # 'files': []
# } # }
self.user_inputs = user_inputs
for system_variable, value in system_variables.items(): for system_variable, value in system_variables.items():
self.append_variable('sys', [system_variable.value], value) self.append_variable('sys', [system_variable.value], value)
......
...@@ -18,15 +18,13 @@ class WorkflowNodeAndResult: ...@@ -18,15 +18,13 @@ class WorkflowNodeAndResult:
class WorkflowRunState: class WorkflowRunState:
workflow: Workflow workflow: Workflow
start_at: float start_at: float
user_inputs: dict
variable_pool: VariablePool variable_pool: VariablePool
total_tokens: int = 0 total_tokens: int = 0
workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] workflow_nodes_and_results: list[WorkflowNodeAndResult] = []
def __init__(self, workflow: Workflow, start_at: float, user_inputs: dict, variable_pool: VariablePool): def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool):
self.workflow = workflow self.workflow = workflow
self.start_at = start_at self.start_at = start_at
self.user_inputs = user_inputs
self.variable_pool = variable_pool self.variable_pool = variable_pool
...@@ -28,31 +28,23 @@ class BaseNode(ABC): ...@@ -28,31 +28,23 @@ class BaseNode(ABC):
self.callbacks = callbacks or [] self.callbacks = callbacks or []
@abstractmethod @abstractmethod
def _run(self, variable_pool: Optional[VariablePool] = None, def _run(self, variable_pool: VariablePool) -> NodeRunResult:
run_args: Optional[dict] = None) -> NodeRunResult:
""" """
Run node Run node
:param variable_pool: variable pool :param variable_pool: variable pool
:param run_args: run args
:return: :return:
""" """
raise NotImplementedError raise NotImplementedError
def run(self, variable_pool: Optional[VariablePool] = None, def run(self, variable_pool: VariablePool) -> NodeRunResult:
run_args: Optional[dict] = None) -> NodeRunResult:
""" """
Run node entry Run node entry
:param variable_pool: variable pool :param variable_pool: variable pool
:param run_args: run args
:return: :return:
""" """
if variable_pool is None and run_args is None:
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
try: try:
result = self._run( result = self._run(
variable_pool=variable_pool, variable_pool=variable_pool
run_args=run_args
) )
except Exception as e: except Exception as e:
# process unhandled exception # process unhandled exception
...@@ -77,6 +69,26 @@ class BaseNode(ABC): ...@@ -77,6 +69,26 @@ class BaseNode(ABC):
text=text text=text
) )
@classmethod
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict:
"""
Extract variable selector to variable mapping
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(node_data)
@classmethod
@abstractmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
raise NotImplementedError
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
""" """
......
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
...@@ -15,6 +16,7 @@ MAX_STRING_LENGTH = 1000 ...@@ -15,6 +16,7 @@ MAX_STRING_LENGTH = 1000
MAX_STRING_ARRAY_LENGTH = 30 MAX_STRING_ARRAY_LENGTH = 30
MAX_NUMBER_ARRAY_LENGTH = 1000 MAX_NUMBER_ARRAY_LENGTH = 1000
class CodeNode(BaseNode): class CodeNode(BaseNode):
_node_data_cls = CodeNodeData _node_data_cls = CodeNodeData
node_type = NodeType.CODE node_type = NodeType.CODE
...@@ -78,21 +80,15 @@ class CodeNode(BaseNode): ...@@ -78,21 +80,15 @@ class CodeNode(BaseNode):
} }
} }
def _run(self, variable_pool: Optional[VariablePool] = None, def _run(self, variable_pool: VariablePool) -> NodeRunResult:
run_args: Optional[dict] = None) -> NodeRunResult:
""" """
Run code Run code
:param variable_pool: variable pool :param variable_pool: variable pool
:param run_args: run args
:return: :return:
""" """
node_data = self.node_data node_data = self.node_data
node_data: CodeNodeData = cast(self._node_data_cls, node_data) node_data = cast(self._node_data_cls, node_data)
# SINGLE DEBUG NOT IMPLEMENTED YET
if variable_pool is None and run_args:
raise ValueError("Not support single step debug.")
# Get code language # Get code language
code_language = node_data.code_language code_language = node_data.code_language
code = node_data.code code = node_data.code
...@@ -134,7 +130,6 @@ class CodeNode(BaseNode): ...@@ -134,7 +130,6 @@ class CodeNode(BaseNode):
Check string Check string
:param value: value :param value: value
:param variable: variable :param variable: variable
:param max_length: max length
:return: :return:
""" """
if not isinstance(value, str): if not isinstance(value, str):
...@@ -142,9 +137,9 @@ class CodeNode(BaseNode): ...@@ -142,9 +137,9 @@ class CodeNode(BaseNode):
if len(value) > MAX_STRING_LENGTH: if len(value) > MAX_STRING_LENGTH:
raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters') raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters')
return value.replace('\x00', '') return value.replace('\x00', '')
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
""" """
Check number Check number
...@@ -157,13 +152,13 @@ class CodeNode(BaseNode): ...@@ -157,13 +152,13 @@ class CodeNode(BaseNode):
if value > MAX_NUMBER or value < MIN_NUMBER: if value > MAX_NUMBER or value < MIN_NUMBER:
raise ValueError(f'{variable} in input form is out of range.') raise ValueError(f'{variable} in input form is out of range.')
if isinstance(value, float): if isinstance(value, float):
value = round(value, MAX_PRECISION) value = round(value, MAX_PRECISION)
return value return value
def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output], def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output],
prefix: str = '', prefix: str = '',
depth: int = 1) -> dict: depth: int = 1) -> dict:
""" """
...@@ -174,7 +169,7 @@ class CodeNode(BaseNode): ...@@ -174,7 +169,7 @@ class CodeNode(BaseNode):
""" """
if depth > MAX_DEPTH: if depth > MAX_DEPTH:
raise ValueError("Depth limit reached, object too deep.") raise ValueError("Depth limit reached, object too deep.")
transformed_result = {} transformed_result = {}
for output_name, output_config in output_schema.items(): for output_name, output_config in output_schema.items():
if output_config.type == 'object': if output_config.type == 'object':
...@@ -183,7 +178,7 @@ class CodeNode(BaseNode): ...@@ -183,7 +178,7 @@ class CodeNode(BaseNode):
raise ValueError( raise ValueError(
f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.' f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.'
) )
transformed_result[output_name] = self._transform_result( transformed_result[output_name] = self._transform_result(
result=result[output_name], result=result[output_name],
output_schema=output_config.children, output_schema=output_config.children,
...@@ -208,7 +203,7 @@ class CodeNode(BaseNode): ...@@ -208,7 +203,7 @@ class CodeNode(BaseNode):
raise ValueError( raise ValueError(
f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.'
) )
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
raise ValueError( raise ValueError(
f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters' f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters'
...@@ -227,12 +222,12 @@ class CodeNode(BaseNode): ...@@ -227,12 +222,12 @@ class CodeNode(BaseNode):
raise ValueError( raise ValueError(
f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.' f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.'
) )
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
raise ValueError( raise ValueError(
f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters' f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters'
) )
transformed_result[output_name] = [ transformed_result[output_name] = [
self._check_string( self._check_string(
value=value, value=value,
...@@ -242,5 +237,15 @@ class CodeNode(BaseNode): ...@@ -242,5 +237,15 @@ class CodeNode(BaseNode):
] ]
else: else:
raise ValueError(f'Output type {output_config.type} is not supported.') raise ValueError(f'Output type {output_config.type} is not supported.')
return transformed_result return transformed_result
\ No newline at end of file
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
# TODO extract variable selector to variable mapping for single step debugging
return {}
import time import time
from typing import Optional, cast from typing import cast
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
...@@ -13,20 +14,15 @@ class DirectAnswerNode(BaseNode): ...@@ -13,20 +14,15 @@ class DirectAnswerNode(BaseNode):
_node_data_cls = DirectAnswerNodeData _node_data_cls = DirectAnswerNodeData
node_type = NodeType.DIRECT_ANSWER node_type = NodeType.DIRECT_ANSWER
def _run(self, variable_pool: Optional[VariablePool] = None, def _run(self, variable_pool: VariablePool) -> NodeRunResult:
run_args: Optional[dict] = None) -> NodeRunResult:
""" """
Run node Run node
:param variable_pool: variable pool :param variable_pool: variable pool
:param run_args: run args
:return: :return:
""" """
node_data = self.node_data node_data = self.node_data
node_data = cast(self._node_data_cls, node_data) node_data = cast(self._node_data_cls, node_data)
if variable_pool is None and run_args:
raise ValueError("Not support single step debug.")
variable_values = {} variable_values = {}
for variable_selector in node_data.variables: for variable_selector in node_data.variables:
value = variable_pool.get_variable_value( value = variable_pool.get_variable_value(
...@@ -43,7 +39,7 @@ class DirectAnswerNode(BaseNode): ...@@ -43,7 +39,7 @@ class DirectAnswerNode(BaseNode):
# publish answer as stream # publish answer as stream
for word in answer: for word in answer:
self.publish_text_chunk(word) self.publish_text_chunk(word)
time.sleep(0.01) # todo sleep 0.01 time.sleep(0.01)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
...@@ -52,3 +48,12 @@ class DirectAnswerNode(BaseNode): ...@@ -52,3 +48,12 @@ class DirectAnswerNode(BaseNode):
"answer": answer "answer": answer
} }
) )
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}
from typing import Optional, cast from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
...@@ -11,50 +12,54 @@ class EndNode(BaseNode): ...@@ -11,50 +12,54 @@ class EndNode(BaseNode):
_node_data_cls = EndNodeData _node_data_cls = EndNodeData
node_type = NodeType.END node_type = NodeType.END
def _run(self, variable_pool: Optional[VariablePool] = None, def _run(self, variable_pool: VariablePool) -> NodeRunResult:
run_args: Optional[dict] = None) -> NodeRunResult:
""" """
Run node Run node
:param variable_pool: variable pool :param variable_pool: variable pool
:param run_args: run args
:return: :return:
""" """
node_data = self.node_data node_data = self.node_data
node_data = cast(self._node_data_cls, node_data) node_data = cast(self._node_data_cls, node_data)
outputs_config = node_data.outputs outputs_config = node_data.outputs
if variable_pool is not None: outputs = None
outputs = None if outputs_config:
if outputs_config: if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: plain_text_selector = outputs_config.plain_text_selector
plain_text_selector = outputs_config.plain_text_selector if plain_text_selector:
if plain_text_selector: outputs = {
outputs = { 'text': variable_pool.get_variable_value(
'text': variable_pool.get_variable_value( variable_selector=plain_text_selector,
variable_selector=plain_text_selector, target_value_type=ValueType.STRING
target_value_type=ValueType.STRING )
) }
} else:
else: outputs = {
outputs = { 'text': ''
'text': '' }
} elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED: structured_variables = outputs_config.structured_variables
structured_variables = outputs_config.structured_variables if structured_variables:
if structured_variables: outputs = {}
outputs = {} for variable_selector in structured_variables:
for variable_selector in structured_variables: variable_value = variable_pool.get_variable_value(
variable_value = variable_pool.get_variable_value( variable_selector=variable_selector.value_selector
variable_selector=variable_selector.value_selector )
) outputs[variable_selector.variable] = variable_value
outputs[variable_selector.variable] = variable_value else:
else: outputs = {}
outputs = {}
else:
raise ValueError("Not support single step debug.")
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs, inputs=outputs,
outputs=outputs outputs=outputs
) )
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}
from typing import Optional, cast from typing import Optional, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
...@@ -10,12 +11,10 @@ class LLMNode(BaseNode): ...@@ -10,12 +11,10 @@ class LLMNode(BaseNode):
_node_data_cls = LLMNodeData _node_data_cls = LLMNodeData
node_type = NodeType.LLM node_type = NodeType.LLM
def _run(self, variable_pool: Optional[VariablePool] = None, def _run(self, variable_pool: VariablePool) -> NodeRunResult:
run_args: Optional[dict] = None) -> NodeRunResult:
""" """
Run node Run node
:param variable_pool: variable pool :param variable_pool: variable pool
:param run_args: run args
:return: :return:
""" """
node_data = self.node_data node_data = self.node_data
...@@ -23,6 +22,17 @@ class LLMNode(BaseNode): ...@@ -23,6 +22,17 @@ class LLMNode(BaseNode):
pass pass
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
# TODO extract variable selector to variable mapping for single step debugging
return {}
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
""" """
......
from typing import Optional, cast from typing import cast
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
...@@ -12,12 +13,10 @@ class StartNode(BaseNode): ...@@ -12,12 +13,10 @@ class StartNode(BaseNode):
_node_data_cls = StartNodeData _node_data_cls = StartNodeData
node_type = NodeType.START node_type = NodeType.START
def _run(self, variable_pool: Optional[VariablePool] = None, def _run(self, variable_pool: VariablePool) -> NodeRunResult:
run_args: Optional[dict] = None) -> NodeRunResult:
""" """
Run node Run node
:param variable_pool: variable pool :param variable_pool: variable pool
:param run_args: run args
:return: :return:
""" """
node_data = self.node_data node_data = self.node_data
...@@ -25,7 +24,7 @@ class StartNode(BaseNode): ...@@ -25,7 +24,7 @@ class StartNode(BaseNode):
variables = node_data.variables variables = node_data.variables
# Get cleaned inputs # Get cleaned inputs
cleaned_inputs = self._get_cleaned_inputs(variables, run_args) cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
...@@ -68,3 +67,12 @@ class StartNode(BaseNode): ...@@ -68,3 +67,12 @@ class StartNode(BaseNode):
filtered_inputs[variable] = value.replace('\x00', '') if value else None filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs return filtered_inputs
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}
...@@ -109,9 +109,9 @@ class WorkflowEngineManager: ...@@ -109,9 +109,9 @@ class WorkflowEngineManager:
workflow_run_state = WorkflowRunState( workflow_run_state = WorkflowRunState(
workflow=workflow, workflow=workflow,
start_at=time.perf_counter(), start_at=time.perf_counter(),
user_inputs=user_inputs,
variable_pool=VariablePool( variable_pool=VariablePool(
system_variables=system_inputs, system_variables=system_inputs,
user_inputs=user_inputs
) )
) )
...@@ -292,9 +292,7 @@ class WorkflowEngineManager: ...@@ -292,9 +292,7 @@ class WorkflowEngineManager:
# run node, result must have inputs, process_data, outputs, execution_metadata # run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result = node.run( node_run_result = node.run(
variable_pool=workflow_run_state.variable_pool, variable_pool=workflow_run_state.variable_pool
run_args=workflow_run_state.user_inputs
if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node
) )
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment