Commit 90bcb241 authored by takatost's avatar takatost

fix bugs

parent f4f7cfd4
...@@ -47,6 +47,7 @@ class TaskState(BaseModel): ...@@ -47,6 +47,7 @@ class TaskState(BaseModel):
answer: str = "" answer: str = ""
metadata: dict = {} metadata: dict = {}
usage: LLMUsage usage: LLMUsage
workflow_run_id: Optional[str] = None
class AdvancedChatAppGenerateTaskPipeline: class AdvancedChatAppGenerateTaskPipeline:
...@@ -110,6 +111,8 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -110,6 +111,8 @@ class AdvancedChatAppGenerateTaskPipeline:
} }
self._task_state.answer = annotation.content self._task_state.answer = annotation.content
elif isinstance(event, QueueWorkflowStartedEvent):
self._task_state.workflow_run_id = event.workflow_run_id
elif isinstance(event, QueueNodeFinishedEvent): elif isinstance(event, QueueNodeFinishedEvent):
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id) workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
...@@ -171,6 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -171,6 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline:
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._get_workflow_run(event.workflow_run_id) workflow_run = self._get_workflow_run(event.workflow_run_id)
self._task_state.workflow_run_id = workflow_run.id
response = { response = {
'event': 'workflow_started', 'event': 'workflow_started',
'task_id': self._application_generate_entity.task_id, 'task_id': self._application_generate_entity.task_id,
...@@ -234,7 +238,7 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -234,7 +238,7 @@ class AdvancedChatAppGenerateTaskPipeline:
if isinstance(event, QueueWorkflowFinishedEvent): if isinstance(event, QueueWorkflowFinishedEvent):
workflow_run = self._get_workflow_run(event.workflow_run_id) workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '') self._task_state.answer = outputs.get('text', '')
else: else:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
...@@ -389,7 +393,13 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -389,7 +393,13 @@ class AdvancedChatAppGenerateTaskPipeline:
:param workflow_run_id: workflow run id :param workflow_run_id: workflow run id
:return: :return:
""" """
return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
if workflow_run:
# Because the workflow_run will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_run)
return workflow_run
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
""" """
...@@ -397,7 +407,14 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -397,7 +407,14 @@ class AdvancedChatAppGenerateTaskPipeline:
:param workflow_node_execution_id: workflow node execution id :param workflow_node_execution_id: workflow node execution id
:return: :return:
""" """
return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() workflow_node_execution = (db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
if workflow_node_execution:
# Because the workflow_node_execution will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_node_execution)
return workflow_node_execution
def _save_message(self) -> None: def _save_message(self) -> None:
""" """
...@@ -408,6 +425,7 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -408,6 +425,7 @@ class AdvancedChatAppGenerateTaskPipeline:
self._message.answer = self._task_state.answer self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.workflow_run_id = self._task_state.workflow_run_id
if self._task_state.metadata and self._task_state.metadata.get('usage'): if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage']) usage = LLMUsage(**self._task_state.metadata['usage'])
......
...@@ -48,7 +48,7 @@ class DirectAnswerNode(BaseNode): ...@@ -48,7 +48,7 @@ class DirectAnswerNode(BaseNode):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values, inputs=variable_values,
output={ outputs={
"answer": answer "answer": answer
} }
) )
...@@ -33,6 +33,7 @@ from models.workflow import ( ...@@ -33,6 +33,7 @@ from models.workflow import (
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
WorkflowRunTriggeredFrom, WorkflowRunTriggeredFrom,
WorkflowType,
) )
node_classes = { node_classes = {
...@@ -268,7 +269,7 @@ class WorkflowEngineManager: ...@@ -268,7 +269,7 @@ class WorkflowEngineManager:
# fetch last workflow_node_executions # fetch last workflow_node_executions
last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1] last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1]
if last_workflow_node_execution: if last_workflow_node_execution:
workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs) workflow_run.outputs = last_workflow_node_execution.outputs
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
workflow_run.total_tokens = workflow_run_state.total_tokens workflow_run.total_tokens = workflow_run_state.total_tokens
...@@ -390,6 +391,7 @@ class WorkflowEngineManager: ...@@ -390,6 +391,7 @@ class WorkflowEngineManager:
workflow_run_state=workflow_run_state, workflow_run_state=workflow_run_state,
node=node, node=node,
predecessor_node=predecessor_node, predecessor_node=predecessor_node,
callbacks=callbacks
) )
# add to workflow node executions # add to workflow node executions
...@@ -412,6 +414,9 @@ class WorkflowEngineManager: ...@@ -412,6 +414,9 @@ class WorkflowEngineManager:
) )
raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
# set end node output if in chat
self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result)
# node run success # node run success
self._workflow_node_execution_success( self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
...@@ -529,6 +534,32 @@ class WorkflowEngineManager: ...@@ -529,6 +534,32 @@ class WorkflowEngineManager:
return workflow_node_execution return workflow_node_execution
def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
node_run_result: NodeRunResult):
"""
Set end node output if in chat
:param workflow_run_state: workflow run state
:param node: current node
:param node_run_result: node run result
:return:
"""
if workflow_run_state.workflow_run.type == WorkflowType.CHAT.value and node.node_type == NodeType.END:
workflow_node_execution_before_end = workflow_run_state.workflow_node_executions[-2]
if workflow_node_execution_before_end:
if workflow_node_execution_before_end.node_type == NodeType.LLM.value:
if not node_run_result.outputs:
node_run_result.outputs = {}
node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('text')
elif workflow_node_execution_before_end.node_type == NodeType.DIRECT_ANSWER.value:
if not node_run_result.outputs:
node_run_result.outputs = {}
node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('answer')
return node_run_result
def _append_variables_recursively(self, variable_pool: VariablePool, def _append_variables_recursively(self, variable_pool: VariablePool,
node_id: str, node_id: str,
variable_key_list: list[str], variable_key_list: list[str],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment