Commit 100fb0c5 authored by takatost's avatar takatost

optimize workflow db connections

parent b75cd251
...@@ -59,7 +59,7 @@ class TaskState(BaseModel): ...@@ -59,7 +59,7 @@ class TaskState(BaseModel):
""" """
NodeExecutionInfo entity NodeExecutionInfo entity
""" """
workflow_node_execution: WorkflowNodeExecution workflow_node_execution_id: str
start_at: float start_at: float
class Config: class Config:
...@@ -72,7 +72,7 @@ class TaskState(BaseModel): ...@@ -72,7 +72,7 @@ class TaskState(BaseModel):
metadata: dict = {} metadata: dict = {}
usage: LLMUsage usage: LLMUsage
workflow_run: Optional[WorkflowRun] = None workflow_run_id: Optional[str] = None
start_at: Optional[float] = None start_at: Optional[float] = None
total_tokens: int = 0 total_tokens: int = 0
total_steps: int = 0 total_steps: int = 0
...@@ -168,8 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -168,8 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event) self._on_node_finished(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
self._on_workflow_finished(event) workflow_run = self._on_workflow_finished(event)
workflow_run = self._task_state.workflow_run
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
...@@ -218,8 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -218,8 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(data) yield self._yield_response(data)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
self._on_workflow_start() workflow_run = self._on_workflow_start()
workflow_run = self._task_state.workflow_run
response = { response = {
'event': 'workflow_started', 'event': 'workflow_started',
...@@ -234,8 +232,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -234,8 +232,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
self._on_node_start(event) workflow_node_execution = self._on_node_start(event)
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
response = { response = {
'event': 'node_started', 'event': 'node_started',
...@@ -253,8 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -253,8 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event) workflow_node_execution = self._on_node_finished(event)
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value: if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
if workflow_node_execution.node_type == NodeType.LLM.value: if workflow_node_execution.node_type == NodeType.LLM.value:
...@@ -285,8 +281,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -285,8 +281,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
self._on_workflow_finished(event) workflow_run = self._on_workflow_finished(event)
workflow_run = self._task_state.workflow_run
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value: if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
...@@ -435,7 +430,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -435,7 +430,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
else: else:
continue continue
def _on_workflow_start(self) -> None: def _on_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter() self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run( workflow_run = self._init_workflow_run(
...@@ -452,11 +447,16 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -452,11 +447,16 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
} }
) )
self._task_state.workflow_run = workflow_run self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _on_node_start(self, event: QueueNodeStartedEvent) -> None: def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run( workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=self._task_state.workflow_run, workflow_run=workflow_run,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
node_title=event.node_data.title, node_title=event.node_data.title,
...@@ -465,19 +465,26 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -465,19 +465,26 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
) )
latest_node_execution_info = TaskState.NodeExecutionInfo( latest_node_execution_info = TaskState.NodeExecutionInfo(
workflow_node_execution=workflow_node_execution, workflow_node_execution_id=workflow_node_execution.id,
start_at=time.perf_counter() start_at=time.perf_counter()
) )
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1 self._task_state.total_steps += 1
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: db.session.close()
return workflow_node_execution
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id] current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
if isinstance(event, QueueNodeSucceededEvent): if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success( workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=current_node_execution.workflow_node_execution, workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at, start_at=current_node_execution.start_at,
inputs=event.inputs, inputs=event.inputs,
process_data=event.process_data, process_data=event.process_data,
...@@ -495,19 +502,24 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -495,19 +502,24 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
self._task_state.metadata['usage'] = usage_dict self._task_state.metadata['usage'] = usage_dict
else: else:
workflow_node_execution = self._workflow_node_execution_failed( workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=current_node_execution.workflow_node_execution, workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at, start_at=current_node_execution.start_at,
error=event.error error=event.error
) )
# remove running node execution info # remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id] del self._task_state.running_node_execution_infos[event.node_id]
self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution
db.session.close()
return workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> WorkflowRun:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
if isinstance(event, QueueStopEvent): if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed( workflow_run = self._workflow_run_failed(
workflow_run=self._task_state.workflow_run, workflow_run=workflow_run,
start_at=self._task_state.start_at, start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens, total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps, total_steps=self._task_state.total_steps,
...@@ -516,7 +528,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -516,7 +528,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
) )
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed( workflow_run = self._workflow_run_failed(
workflow_run=self._task_state.workflow_run, workflow_run=workflow_run,
start_at=self._task_state.start_at, start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens, total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps, total_steps=self._task_state.total_steps,
...@@ -524,39 +536,30 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -524,39 +536,30 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
error=event.error error=event.error
) )
else: else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success( workflow_run = self._workflow_run_success(
workflow_run=self._task_state.workflow_run, workflow_run=workflow_run,
start_at=self._task_state.start_at, start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens, total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps, total_steps=self._task_state.total_steps,
outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs outputs=outputs
if self._task_state.latest_node_execution_info else None
) )
self._task_state.workflow_run = workflow_run self._task_state.workflow_run_id = workflow_run.id
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '') self._task_state.answer = outputs.get('text', '')
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: db.session.close()
"""
Get workflow run.
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
return workflow_run
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: return workflow_run
"""
Get workflow node execution.
:param workflow_node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
return workflow_node_execution
def _save_message(self) -> None: def _save_message(self) -> None:
""" """
...@@ -567,7 +570,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -567,7 +570,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
self._message.answer = self._task_state.answer self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.workflow_run_id = self._task_state.workflow_run.id self._message.workflow_run_id = self._task_state.workflow_run_id
if self._task_state.metadata and self._task_state.metadata.get('usage'): if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage']) usage = LLMUsage(**self._task_state.metadata['usage'])
......
...@@ -45,7 +45,7 @@ class TaskState(BaseModel): ...@@ -45,7 +45,7 @@ class TaskState(BaseModel):
""" """
NodeExecutionInfo entity NodeExecutionInfo entity
""" """
workflow_node_execution: WorkflowNodeExecution workflow_node_execution_id: str
start_at: float start_at: float
class Config: class Config:
...@@ -57,7 +57,7 @@ class TaskState(BaseModel): ...@@ -57,7 +57,7 @@ class TaskState(BaseModel):
answer: str = "" answer: str = ""
metadata: dict = {} metadata: dict = {}
workflow_run: Optional[WorkflowRun] = None workflow_run_id: Optional[str] = None
start_at: Optional[float] = None start_at: Optional[float] = None
total_tokens: int = 0 total_tokens: int = 0
total_steps: int = 0 total_steps: int = 0
...@@ -130,8 +130,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -130,8 +130,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event) self._on_node_finished(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
self._on_workflow_finished(event) workflow_run = self._on_workflow_finished(event)
workflow_run = self._task_state.workflow_run
# response moderation # response moderation
if self._output_moderation_handler: if self._output_moderation_handler:
...@@ -179,8 +178,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -179,8 +178,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(data) yield self._yield_response(data)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
self._on_workflow_start() workflow_run = self._on_workflow_start()
workflow_run = self._task_state.workflow_run
response = { response = {
'event': 'workflow_started', 'event': 'workflow_started',
...@@ -195,8 +193,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -195,8 +193,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
self._on_node_start(event) workflow_node_execution = self._on_node_start(event)
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
response = { response = {
'event': 'node_started', 'event': 'node_started',
...@@ -214,8 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -214,8 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event) workflow_node_execution = self._on_node_finished(event)
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
response = { response = {
'event': 'node_finished', 'event': 'node_finished',
...@@ -240,8 +236,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -240,8 +236,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
self._on_workflow_finished(event) workflow_run = self._on_workflow_finished(event)
workflow_run = self._task_state.workflow_run
# response moderation # response moderation
if self._output_moderation_handler: if self._output_moderation_handler:
...@@ -257,7 +252,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -257,7 +252,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
replace_response = { replace_response = {
'event': 'text_replace', 'event': 'text_replace',
'task_id': self._application_generate_entity.task_id, 'task_id': self._application_generate_entity.task_id,
'workflow_run_id': self._task_state.workflow_run.id, 'workflow_run_id': self._task_state.workflow_run_id,
'data': { 'data': {
'text': self._task_state.answer 'text': self._task_state.answer
} }
...@@ -317,7 +312,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -317,7 +312,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
response = { response = {
'event': 'text_replace', 'event': 'text_replace',
'task_id': self._application_generate_entity.task_id, 'task_id': self._application_generate_entity.task_id,
'workflow_run_id': self._task_state.workflow_run.id, 'workflow_run_id': self._task_state.workflow_run_id,
'data': { 'data': {
'text': event.text 'text': event.text
} }
...@@ -329,7 +324,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -329,7 +324,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
else: else:
continue continue
def _on_workflow_start(self) -> None: def _on_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter() self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run( workflow_run = self._init_workflow_run(
...@@ -344,11 +339,16 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -344,11 +339,16 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
} }
) )
self._task_state.workflow_run = workflow_run self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _on_node_start(self, event: QueueNodeStartedEvent) -> None: def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run( workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=self._task_state.workflow_run, workflow_run=workflow_run,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type, node_type=event.node_type,
node_title=event.node_data.title, node_title=event.node_data.title,
...@@ -357,7 +357,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -357,7 +357,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
) )
latest_node_execution_info = TaskState.NodeExecutionInfo( latest_node_execution_info = TaskState.NodeExecutionInfo(
workflow_node_execution=workflow_node_execution, workflow_node_execution_id=workflow_node_execution.id,
start_at=time.perf_counter() start_at=time.perf_counter()
) )
...@@ -366,11 +366,17 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -366,11 +366,17 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
self._task_state.total_steps += 1 self._task_state.total_steps += 1
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None: db.session.close()
return workflow_node_execution
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id] current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
if isinstance(event, QueueNodeSucceededEvent): if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success( workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=current_node_execution.workflow_node_execution, workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at, start_at=current_node_execution.start_at,
inputs=event.inputs, inputs=event.inputs,
process_data=event.process_data, process_data=event.process_data,
...@@ -383,19 +389,24 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -383,19 +389,24 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
else: else:
workflow_node_execution = self._workflow_node_execution_failed( workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=current_node_execution.workflow_node_execution, workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at, start_at=current_node_execution.start_at,
error=event.error error=event.error
) )
# remove running node execution info # remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id] del self._task_state.running_node_execution_infos[event.node_id]
self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None: db.session.close()
return workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> WorkflowRun:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
if isinstance(event, QueueStopEvent): if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed( workflow_run = self._workflow_run_failed(
workflow_run=self._task_state.workflow_run, workflow_run=workflow_run,
start_at=self._task_state.start_at, start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens, total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps, total_steps=self._task_state.total_steps,
...@@ -404,7 +415,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -404,7 +415,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
) )
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed( workflow_run = self._workflow_run_failed(
workflow_run=self._task_state.workflow_run, workflow_run=workflow_run,
start_at=self._task_state.start_at, start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens, total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps, total_steps=self._task_state.total_steps,
...@@ -412,39 +423,30 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -412,39 +423,30 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
error=event.error error=event.error
) )
else: else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success( workflow_run = self._workflow_run_success(
workflow_run=self._task_state.workflow_run, workflow_run=workflow_run,
start_at=self._task_state.start_at, start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens, total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps, total_steps=self._task_state.total_steps,
outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs outputs=outputs
if self._task_state.latest_node_execution_info else None
) )
self._task_state.workflow_run = workflow_run self._task_state.workflow_run_id = workflow_run.id
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '') self._task_state.answer = outputs.get('text', '')
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: db.session.close()
"""
Get workflow run.
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
return workflow_run
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: return workflow_run
"""
Get workflow node execution.
:param workflow_node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
return workflow_node_execution
def _save_workflow_app_log(self) -> None: def _save_workflow_app_log(self) -> None:
""" """
...@@ -461,7 +463,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline): ...@@ -461,7 +463,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
""" """
response = { response = {
'event': 'text_chunk', 'event': 'text_chunk',
'workflow_run_id': self._task_state.workflow_run.id, 'workflow_run_id': self._task_state.workflow_run_id,
'task_id': self._application_generate_entity.task_id, 'task_id': self._application_generate_entity.task_id,
'data': { 'data': {
'text': text 'text': text
......
...@@ -87,6 +87,7 @@ class WorkflowBasedGenerateTaskPipeline: ...@@ -87,6 +87,7 @@ class WorkflowBasedGenerateTaskPipeline:
workflow_run.finished_at = datetime.utcnow() workflow_run.finished_at = datetime.utcnow()
db.session.commit() db.session.commit()
db.session.refresh(workflow_run)
db.session.close() db.session.close()
return workflow_run return workflow_run
...@@ -115,6 +116,7 @@ class WorkflowBasedGenerateTaskPipeline: ...@@ -115,6 +116,7 @@ class WorkflowBasedGenerateTaskPipeline:
workflow_run.finished_at = datetime.utcnow() workflow_run.finished_at = datetime.utcnow()
db.session.commit() db.session.commit()
db.session.refresh(workflow_run)
db.session.close() db.session.close()
return workflow_run return workflow_run
...@@ -185,6 +187,7 @@ class WorkflowBasedGenerateTaskPipeline: ...@@ -185,6 +187,7 @@ class WorkflowBasedGenerateTaskPipeline:
workflow_node_execution.finished_at = datetime.utcnow() workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit() db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close() db.session.close()
return workflow_node_execution return workflow_node_execution
...@@ -205,6 +208,7 @@ class WorkflowBasedGenerateTaskPipeline: ...@@ -205,6 +208,7 @@ class WorkflowBasedGenerateTaskPipeline:
workflow_node_execution.finished_at = datetime.utcnow() workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit() db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close() db.session.close()
return workflow_node_execution return workflow_node_execution
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment