Commit fcd470fc authored by takatost's avatar takatost

add answer output parse

parent fd8fe15d
...@@ -5,7 +5,6 @@ from core.app.entities.queue_entities import ( ...@@ -5,7 +5,6 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent, QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
...@@ -20,7 +19,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ...@@ -20,7 +19,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
def on_workflow_run_started(self) -> None: def on_workflow_run_started(self) -> None:
""" """
...@@ -118,31 +116,4 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ...@@ -118,31 +116,4 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
""" """
Publish text chunk Publish text chunk
""" """
if node_id in self._streamable_node_ids: pass
self._queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param graph: workflow graph
:return:
"""
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('data', {}).get('type') == NodeType.END.value:
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
import time
from typing import cast from typing import cast
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
...@@ -32,14 +31,49 @@ class AnswerNode(BaseNode): ...@@ -32,14 +31,49 @@ class AnswerNode(BaseNode):
variable_values[variable_selector.variable] = value variable_values[variable_selector.variable] = value
variable_keys = list(variable_values.keys())
# format answer template # format answer template
template_parser = PromptTemplateParser(node_data.answer) template_parser = PromptTemplateParser(node_data.answer)
answer = template_parser.format(variable_values) template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
split_template = [
{
"type": "var" if self._is_variable(part, variable_keys) else "text",
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
}
for part in template.split('Ω') if part
]
answer = []
for part in split_template:
if part["type"] == "var":
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
answer_part = {
"type": "text",
"text": value
}
# TODO File
else:
answer_part = {
"type": "text",
"text": part["value"]
}
# publish answer as stream if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
for word in answer: answer[-1]["text"] += answer_part["text"]
self.publish_text_chunk(word) else:
time.sleep(10) # TODO for debug answer.append(answer_part)
if len(answer) == 1 and answer[0]["type"] == "text":
answer = answer[0]["text"]
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
...@@ -49,6 +83,10 @@ class AnswerNode(BaseNode): ...@@ -49,6 +83,10 @@ class AnswerNode(BaseNode):
} }
) )
def _is_variable(self, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
""" """
......
...@@ -6,7 +6,6 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback ...@@ -6,7 +6,6 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus
class UserFrom(Enum): class UserFrom(Enum):
...@@ -80,16 +79,9 @@ class BaseNode(ABC): ...@@ -80,16 +79,9 @@ class BaseNode(ABC):
:param variable_pool: variable pool :param variable_pool: variable pool
:return: :return:
""" """
try: result = self._run(
result = self._run( variable_pool=variable_pool
variable_pool=variable_pool )
)
except Exception as e:
# process unhandled exception
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
self.node_run_result = result self.node_run_result = result
return result return result
......
...@@ -2,9 +2,9 @@ from typing import cast ...@@ -2,9 +2,9 @@ from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs from core.workflow.nodes.end.entities import EndNodeData
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
...@@ -20,34 +20,14 @@ class EndNode(BaseNode): ...@@ -20,34 +20,14 @@ class EndNode(BaseNode):
""" """
node_data = self.node_data node_data = self.node_data
node_data = cast(self._node_data_cls, node_data) node_data = cast(self._node_data_cls, node_data)
outputs_config = node_data.outputs output_variables = node_data.outputs
outputs = None outputs = {}
if outputs_config: for variable_selector in output_variables:
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT: variable_value = variable_pool.get_variable_value(
plain_text_selector = outputs_config.plain_text_selector variable_selector=variable_selector.value_selector
if plain_text_selector: )
outputs = { outputs[variable_selector.variable] = variable_value
'text': variable_pool.get_variable_value(
variable_selector=plain_text_selector,
target_value_type=ValueType.STRING
)
}
else:
outputs = {
'text': ''
}
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
structured_variables = outputs_config.structured_variables
if structured_variables:
outputs = {}
for variable_selector in structured_variables:
variable_value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = variable_value
else:
outputs = {}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
......
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
class EndNodeOutputType(Enum):
"""
END Node Output Types.
none, plain-text, structured
"""
NONE = 'none'
PLAIN_TEXT = 'plain-text'
STRUCTURED = 'structured'
@classmethod
def value_of(cls, value: str) -> 'OutputType':
"""
Get value of given output type.
:param value: output type value
:return: output type
"""
for output_type in cls:
if output_type.value == value:
return output_type
raise ValueError(f'invalid output type value {value}')
class EndNodeDataOutputs(BaseModel):
"""
END Node Data Outputs.
"""
class OutputType(Enum):
"""
Output Types.
"""
NONE = 'none'
PLAIN_TEXT = 'plain-text'
STRUCTURED = 'structured'
@classmethod
def value_of(cls, value: str) -> 'OutputType':
"""
Get value of given output type.
:param value: output type value
:return: output type
"""
for output_type in cls:
if output_type.value == value:
return output_type
raise ValueError(f'invalid output type value {value}')
type: OutputType = OutputType.NONE
plain_text_selector: Optional[list[str]] = None
structured_variables: Optional[list[VariableSelector]] = None
class EndNodeData(BaseNodeData): class EndNodeData(BaseNodeData):
""" """
END Node Data. END Node Data.
""" """
outputs: Optional[EndNodeDataOutputs] = None outputs: list[VariableSelector]
import logging
import time import time
from typing import Optional from typing import Optional
...@@ -41,6 +42,8 @@ node_classes = { ...@@ -41,6 +42,8 @@ node_classes = {
NodeType.VARIABLE_ASSIGNER: VariableAssignerNode, NodeType.VARIABLE_ASSIGNER: VariableAssignerNode,
} }
logger = logging.getLogger(__name__)
class WorkflowEngineManager: class WorkflowEngineManager:
def get_default_configs(self) -> list[dict]: def get_default_configs(self) -> list[dict]:
...@@ -407,6 +410,7 @@ class WorkflowEngineManager: ...@@ -407,6 +410,7 @@ class WorkflowEngineManager:
variable_pool=workflow_run_state.variable_pool variable_pool=workflow_run_state.variable_pool
) )
except Exception as e: except Exception as e:
logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
node_run_result = NodeRunResult( node_run_result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(e) error=str(e)
......
...@@ -531,10 +531,10 @@ class WorkflowConverter: ...@@ -531,10 +531,10 @@ class WorkflowConverter:
"data": { "data": {
"title": "END", "title": "END",
"type": NodeType.END.value, "type": NodeType.END.value,
"outputs": { "outputs": [{
"variable": "result", "variable": "result",
"value_selector": ["llm", "text"] "value_selector": ["llm", "text"]
} }]
} }
} }
......
from unittest.mock import MagicMock
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
def test_execute_answer():
node = AnswerNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
user_from=UserFrom.ACCOUNT,
config={
'id': 'answer',
'data': {
'title': '123',
'type': 'answer',
'variables': [
{
'value_selector': ['llm', 'text'],
'variable': 'text'
},
{
'value_selector': ['start', 'weather'],
'variable': 'weather'
},
],
'answer': 'Today\'s weather is {{weather}}\n{{text}}\n{{img}}\nFin.'
}
}
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny')
pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.')
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run(pool)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
# TODO test files
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment