Commit ea883b5e authored by takatost's avatar takatost

add start, end, direct answer node

parent 46296d77
......@@ -5,7 +5,5 @@ from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
type: str
title: str
desc: Optional[str] = None
from enum import Enum
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel
......@@ -46,6 +46,15 @@ class SystemVariable(Enum):
CONVERSATION = 'conversation'
class NodeRunMetadataKey(Enum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = 'total_tokens'
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
class NodeRunResult(BaseModel):
"""
Node Run Result.
......@@ -55,7 +64,7 @@ class NodeRunResult(BaseModel):
inputs: Optional[dict] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs
metadata: Optional[dict] = None # node metadata
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
......
from pydantic import BaseModel
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: list[str]
......@@ -5,13 +5,18 @@ from models.workflow import WorkflowNodeExecution, WorkflowRun
class WorkflowRunState:
workflow_run: WorkflowRun
start_at: float
user_inputs: dict
variable_pool: VariablePool
total_tokens: int = 0
workflow_node_executions: list[WorkflowNodeExecution] = []
def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None:
def __init__(self, workflow_run: WorkflowRun,
start_at: float,
user_inputs: dict,
variable_pool: VariablePool) -> None:
self.workflow_run = workflow_run
self.start_at = start_at
self.user_inputs = user_inputs
self.variable_pool = variable_pool
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Optional
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
......@@ -8,7 +8,7 @@ from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus
class BaseNode:
class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
......
import time
from typing import Optional, cast
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class DirectAnswerNode(BaseNode):
pass
_node_data_cls = DirectAnswerNodeData
node_type = NodeType.DIRECT_ANSWER
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
if variable_pool is None and run_args:
raise ValueError("Not support single step debug.")
variable_values = {}
for variable_selector in node_data.variables:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector,
target_value_type=ValueType.STRING
)
variable_values[variable_selector.variable] = value
# format answer template
template_parser = PromptTemplateParser(node_data.answer)
answer = template_parser.format(variable_values)
# publish answer as stream
for word in answer:
self.publish_text_chunk(word)
time.sleep(0.01)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values,
output={
"answer": answer
}
)
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class DirectAnswerNodeData(BaseNodeData):
"""
DirectAnswer Node Data.
"""
variables: list[VariableSelector] = []
answer: str
from typing import Optional, cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs
from models.workflow import WorkflowNodeExecutionStatus
class EndNode(BaseNode):
pass
_node_data_cls = EndNodeData
node_type = NodeType.END
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
outputs_config = node_data.outputs
if variable_pool is not None:
outputs = None
if outputs_config:
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
plain_text_selector = outputs_config.plain_text_selector
if plain_text_selector:
outputs = {
'text': variable_pool.get_variable_value(
variable_selector=plain_text_selector,
target_value_type=ValueType.STRING
)
}
else:
outputs = {
'text': ''
}
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
structured_variables = outputs_config.structured_variables
if structured_variables:
outputs = {}
for variable_selector in structured_variables:
variable_value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = variable_value
else:
outputs = {}
else:
raise ValueError("Not support single step debug.")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs
)
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class EndNodeOutputType(Enum):
......@@ -23,3 +29,40 @@ class EndNodeOutputType(Enum):
if output_type.value == value:
return output_type
raise ValueError(f'invalid output type value {value}')
class EndNodeDataOutputs(BaseModel):
"""
END Node Data Outputs.
"""
class OutputType(Enum):
"""
Output Types.
"""
NONE = 'none'
PLAIN_TEXT = 'plain-text'
STRUCTURED = 'structured'
@classmethod
def value_of(cls, value: str) -> 'OutputType':
"""
Get value of given output type.
:param value: output type value
:return: output type
"""
for output_type in cls:
if output_type.value == value:
return output_type
raise ValueError(f'invalid output type value {value}')
type: OutputType = OutputType.NONE
plain_text_selector: Optional[list[str]] = None
structured_variables: Optional[list[VariableSelector]] = None
class EndNodeData(BaseNodeData):
"""
END Node Data.
"""
outputs: Optional[EndNodeDataOutputs] = None
from core.workflow.entities.base_node_data_entities import BaseNodeData
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
pass
from typing import Optional
from typing import Optional, cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
node_type = NodeType.LLM
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
pass
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
......
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class StartNodeData(BaseNodeData):
"""
- title (string) 节点标题
- desc (string) optional 节点描述
- type (string) 节点类型,固定为 start
- variables (array[object]) 表单变量列表
- type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义)
- label (string) 控件展示标签名
- variable (string) 变量 key
- max_length (int) 最大长度,适用于 text-input 和 paragraph
- default (string) optional 默认值
- required (bool) optional是否必填,默认 false
- hint (string) optional 提示信息
- options (array[string]) 选项值(仅 select 可用)
Start Node Data
"""
type: str = NodeType.START.value
variables: list[VariableEntity] = []
from typing import Optional
from typing import Optional, cast
from core.workflow.entities.node_entities import NodeType
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class StartNode(BaseNode):
......@@ -11,12 +13,58 @@ class StartNode(BaseNode):
node_type = NodeType.START
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
pass
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
variables = node_data.variables
# Get cleaned inputs
cleaned_inputs = self._get_cleaned_inputs(variables, run_args)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
)
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"Input form variable {variable} is required")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value:
if not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
......@@ -3,6 +3,7 @@ import time
from datetime import datetime
from typing import Optional, Union
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue
......@@ -141,6 +142,7 @@ class WorkflowEngineManager:
workflow_run_state = WorkflowRunState(
workflow_run=workflow_run,
start_at=time.perf_counter(),
user_inputs=user_inputs,
variable_pool=VariablePool(
system_variables=system_inputs,
)
......@@ -399,7 +401,9 @@ class WorkflowEngineManager:
# run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result = node.run(
variable_pool=workflow_run_state.variable_pool
variable_pool=workflow_run_state.variable_pool,
run_args=workflow_run_state.user_inputs
if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node
)
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
......@@ -492,7 +496,7 @@ class WorkflowEngineManager:
workflow_node_execution.inputs = json.dumps(result.inputs)
workflow_node_execution.process_data = json.dumps(result.process_data)
workflow_node_execution.outputs = json.dumps(result.outputs)
workflow_node_execution.execution_metadata = json.dumps(result.metadata)
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata))
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment