Commit 100fb0c5 authored by takatost's avatar takatost

optimize workflow db connections

parent b75cd251
......@@ -59,7 +59,7 @@ class TaskState(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution: WorkflowNodeExecution
workflow_node_execution_id: str
start_at: float
class Config:
......@@ -72,7 +72,7 @@ class TaskState(BaseModel):
metadata: dict = {}
usage: LLMUsage
workflow_run: Optional[WorkflowRun] = None
workflow_run_id: Optional[str] = None
start_at: Optional[float] = None
total_tokens: int = 0
total_steps: int = 0
......@@ -168,8 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
self._on_workflow_finished(event)
workflow_run = self._task_state.workflow_run
workflow_run = self._on_workflow_finished(event)
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
......@@ -218,8 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(data)
break
elif isinstance(event, QueueWorkflowStartedEvent):
self._on_workflow_start()
workflow_run = self._task_state.workflow_run
workflow_run = self._on_workflow_start()
response = {
'event': 'workflow_started',
......@@ -234,8 +232,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent):
self._on_node_start(event)
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
workflow_node_execution = self._on_node_start(event)
response = {
'event': 'node_started',
......@@ -253,8 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event)
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
workflow_node_execution = self._on_node_finished(event)
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
if workflow_node_execution.node_type == NodeType.LLM.value:
......@@ -285,8 +281,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
self._on_workflow_finished(event)
workflow_run = self._task_state.workflow_run
workflow_run = self._on_workflow_finished(event)
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
......@@ -435,7 +430,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
else:
continue
def _on_workflow_start(self) -> None:
def _on_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run(
......@@ -452,11 +447,16 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
}
)
self._task_state.workflow_run = workflow_run
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _on_node_start(self, event: QueueNodeStartedEvent) -> None:
def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=self._task_state.workflow_run,
workflow_run=workflow_run,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_data.title,
......@@ -465,19 +465,26 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
)
latest_node_execution_info = TaskState.NodeExecutionInfo(
workflow_node_execution=workflow_node_execution,
workflow_node_execution_id=workflow_node_execution.id,
start_at=time.perf_counter()
)
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None:
db.session.close()
return workflow_node_execution
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=current_node_execution.workflow_node_execution,
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
inputs=event.inputs,
process_data=event.process_data,
......@@ -495,19 +502,24 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
self._task_state.metadata['usage'] = usage_dict
else:
workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=current_node_execution.workflow_node_execution,
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
error=event.error
)
# remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution
# remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
db.session.close()
return workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None:
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> WorkflowRun:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=self._task_state.workflow_run,
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
......@@ -516,7 +528,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
)
elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed(
workflow_run=self._task_state.workflow_run,
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
......@@ -524,39 +536,30 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
error=event.error
)
else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success(
workflow_run=self._task_state.workflow_run,
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs
if self._task_state.latest_node_execution_info else None
outputs=outputs
)
self._task_state.workflow_run = workflow_run
self._task_state.workflow_run_id = workflow_run.id
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '')
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Get workflow run.
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
return workflow_run
db.session.close()
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
"""
Get workflow node execution.
:param workflow_node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
return workflow_node_execution
return workflow_run
def _save_message(self) -> None:
"""
......@@ -567,7 +570,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.workflow_run_id = self._task_state.workflow_run.id
self._message.workflow_run_id = self._task_state.workflow_run_id
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
......
......@@ -45,7 +45,7 @@ class TaskState(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution: WorkflowNodeExecution
workflow_node_execution_id: str
start_at: float
class Config:
......@@ -57,7 +57,7 @@ class TaskState(BaseModel):
answer: str = ""
metadata: dict = {}
workflow_run: Optional[WorkflowRun] = None
workflow_run_id: Optional[str] = None
start_at: Optional[float] = None
total_tokens: int = 0
total_steps: int = 0
......@@ -130,8 +130,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
self._on_workflow_finished(event)
workflow_run = self._task_state.workflow_run
workflow_run = self._on_workflow_finished(event)
# response moderation
if self._output_moderation_handler:
......@@ -179,8 +178,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(data)
break
elif isinstance(event, QueueWorkflowStartedEvent):
self._on_workflow_start()
workflow_run = self._task_state.workflow_run
workflow_run = self._on_workflow_start()
response = {
'event': 'workflow_started',
......@@ -195,8 +193,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent):
self._on_node_start(event)
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
workflow_node_execution = self._on_node_start(event)
response = {
'event': 'node_started',
......@@ -214,8 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event)
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
workflow_node_execution = self._on_node_finished(event)
response = {
'event': 'node_finished',
......@@ -240,8 +236,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
self._on_workflow_finished(event)
workflow_run = self._task_state.workflow_run
workflow_run = self._on_workflow_finished(event)
# response moderation
if self._output_moderation_handler:
......@@ -257,7 +252,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
replace_response = {
'event': 'text_replace',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': self._task_state.workflow_run.id,
'workflow_run_id': self._task_state.workflow_run_id,
'data': {
'text': self._task_state.answer
}
......@@ -317,7 +312,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
response = {
'event': 'text_replace',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': self._task_state.workflow_run.id,
'workflow_run_id': self._task_state.workflow_run_id,
'data': {
'text': event.text
}
......@@ -329,7 +324,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
else:
continue
def _on_workflow_start(self) -> None:
def _on_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run(
......@@ -344,11 +339,16 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
}
)
self._task_state.workflow_run = workflow_run
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _on_node_start(self, event: QueueNodeStartedEvent) -> None:
def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=self._task_state.workflow_run,
workflow_run=workflow_run,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_data.title,
......@@ -357,7 +357,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
)
latest_node_execution_info = TaskState.NodeExecutionInfo(
workflow_node_execution=workflow_node_execution,
workflow_node_execution_id=workflow_node_execution.id,
start_at=time.perf_counter()
)
......@@ -366,11 +366,17 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
self._task_state.total_steps += 1
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None:
db.session.close()
return workflow_node_execution
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=current_node_execution.workflow_node_execution,
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
inputs=event.inputs,
process_data=event.process_data,
......@@ -383,19 +389,24 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
else:
workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=current_node_execution.workflow_node_execution,
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
error=event.error
)
# remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None:
db.session.close()
return workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> WorkflowRun:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=self._task_state.workflow_run,
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
......@@ -404,7 +415,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
)
elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed(
workflow_run=self._task_state.workflow_run,
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
......@@ -412,39 +423,30 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
error=event.error
)
else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success(
workflow_run=self._task_state.workflow_run,
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs
if self._task_state.latest_node_execution_info else None
outputs=outputs
)
self._task_state.workflow_run = workflow_run
self._task_state.workflow_run_id = workflow_run.id
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '')
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Get workflow run.
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
return workflow_run
db.session.close()
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
"""
Get workflow node execution.
:param workflow_node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
return workflow_node_execution
return workflow_run
def _save_workflow_app_log(self) -> None:
"""
......@@ -461,7 +463,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
"""
response = {
'event': 'text_chunk',
'workflow_run_id': self._task_state.workflow_run.id,
'workflow_run_id': self._task_state.workflow_run_id,
'task_id': self._application_generate_entity.task_id,
'data': {
'text': text
......
......@@ -87,6 +87,7 @@ class WorkflowBasedGenerateTaskPipeline:
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
return workflow_run
......@@ -115,6 +116,7 @@ class WorkflowBasedGenerateTaskPipeline:
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
return workflow_run
......@@ -185,6 +187,7 @@ class WorkflowBasedGenerateTaskPipeline:
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
......@@ -205,6 +208,7 @@ class WorkflowBasedGenerateTaskPipeline:
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment