Commit d175d82b authored by takatost's avatar takatost

fix workflow app bugs

parent b368d9ab
...@@ -129,18 +129,14 @@ class DraftWorkflowRunApi(Resource): ...@@ -129,18 +129,14 @@ class DraftWorkflowRunApi(Resource):
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER invoke_from=InvokeFrom.DEBUGGER
) )
return compact_response(response)
except ValueError as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
class WorkflowTaskStopApi(Resource): class WorkflowTaskStopApi(Resource):
@setup_required @setup_required
......
...@@ -235,36 +235,39 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -235,36 +235,39 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent): elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
if isinstance(event, QueueWorkflowFinishedEvent): if isinstance(event, QueueStopEvent):
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
else:
workflow_run = self._get_workflow_run(event.workflow_run_id) workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '')
else:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
data = self._error_to_stream_response_data(self._handle_error(err_event))
yield self._yield_response(data)
break
workflow_run_response = { if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
'event': 'workflow_finished', outputs = workflow_run.outputs_dict
'task_id': self._application_generate_entity.task_id, self._task_state.answer = outputs.get('text', '')
'workflow_run_id': event.workflow_run_id, else:
'data': { err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
'id': workflow_run.id, data = self._error_to_stream_response_data(self._handle_error(err_event))
'workflow_id': workflow_run.workflow_id, yield self._yield_response(data)
'status': workflow_run.status, break
'outputs': workflow_run.outputs_dict,
'error': workflow_run.error, workflow_run_response = {
'elapsed_time': workflow_run.elapsed_time, 'event': 'workflow_finished',
'total_tokens': workflow_run.total_tokens, 'task_id': self._application_generate_entity.task_id,
'total_steps': workflow_run.total_steps, 'workflow_run_id': event.workflow_run_id,
'created_at': int(workflow_run.created_at.timestamp()), 'data': {
'finished_at': int(workflow_run.finished_at.timestamp()) 'id': workflow_run.id,
} 'workflow_id': workflow_run.workflow_id,
'status': workflow_run.status,
'outputs': workflow_run.outputs_dict,
'error': workflow_run.error,
'elapsed_time': workflow_run.elapsed_time,
'total_tokens': workflow_run.total_tokens,
'total_steps': workflow_run.total_steps,
'created_at': int(workflow_run.created_at.timestamp()),
'finished_at': int(workflow_run.finished_at.timestamp())
} }
}
yield self._yield_response(workflow_run_response) yield self._yield_response(workflow_run_response)
# response moderation # response moderation
if self._output_moderation_handler: if self._output_moderation_handler:
......
...@@ -2,6 +2,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager ...@@ -2,6 +2,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
MessageQueueMessage,
QueueMessage, QueueMessage,
) )
...@@ -20,7 +21,7 @@ class MessageBasedAppQueueManager(AppQueueManager): ...@@ -20,7 +21,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
self._message_id = str(message_id) self._message_id = str(message_id)
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
return QueueMessage( return MessageQueueMessage(
task_id=self._task_id, task_id=self._task_id,
message_id=self._message_id, message_id=self._message_id,
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
......
...@@ -3,6 +3,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom ...@@ -3,6 +3,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
QueueMessage, QueueMessage,
WorkflowQueueMessage,
) )
...@@ -16,7 +17,7 @@ class WorkflowAppQueueManager(AppQueueManager): ...@@ -16,7 +17,7 @@ class WorkflowAppQueueManager(AppQueueManager):
self._app_mode = app_mode self._app_mode = app_mode
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
return QueueMessage( return WorkflowQueueMessage(
task_id=self._task_id, task_id=self._task_id,
app_mode=self._app_mode, app_mode=self._app_mode,
event=event event=event
......
...@@ -86,7 +86,7 @@ class WorkflowAppGenerateTaskPipeline: ...@@ -86,7 +86,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_run = self._get_workflow_run(event.workflow_run_id) workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '') self._task_state.answer = outputs.get('text', '')
else: else:
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))) raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
...@@ -136,12 +136,11 @@ class WorkflowAppGenerateTaskPipeline: ...@@ -136,12 +136,11 @@ class WorkflowAppGenerateTaskPipeline:
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
self._task_state.workflow_run_id = event.workflow_run_id self._task_state.workflow_run_id = event.workflow_run_id
workflow_run = self._get_workflow_run(event.workflow_run_id) workflow_run = self._get_workflow_run(event.workflow_run_id)
response = { response = {
'event': 'workflow_started', 'event': 'workflow_started',
'task_id': self._application_generate_entity.task_id, 'task_id': self._application_generate_entity.task_id,
'workflow_run_id': event.workflow_run_id, 'workflow_run_id': workflow_run.id,
'data': { 'data': {
'id': workflow_run.id, 'id': workflow_run.id,
'workflow_id': workflow_run.workflow_id, 'workflow_id': workflow_run.workflow_id,
...@@ -198,7 +197,7 @@ class WorkflowAppGenerateTaskPipeline: ...@@ -198,7 +197,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_run = self._get_workflow_run(event.workflow_run_id) workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value: if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '') self._task_state.answer = outputs.get('text', '')
else: else:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
...@@ -228,6 +227,9 @@ class WorkflowAppGenerateTaskPipeline: ...@@ -228,6 +227,9 @@ class WorkflowAppGenerateTaskPipeline:
yield self._yield_response(replace_response) yield self._yield_response(replace_response)
# save workflow app log
self._save_workflow_app_log()
workflow_run_response = { workflow_run_response = {
'event': 'workflow_finished', 'event': 'workflow_finished',
'task_id': self._application_generate_entity.task_id, 'task_id': self._application_generate_entity.task_id,
...@@ -295,7 +297,13 @@ class WorkflowAppGenerateTaskPipeline: ...@@ -295,7 +297,13 @@ class WorkflowAppGenerateTaskPipeline:
:param workflow_run_id: workflow run id :param workflow_run_id: workflow run id
:return: :return:
""" """
return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
if workflow_run:
# Because the workflow_run will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_run)
return workflow_run
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
""" """
...@@ -303,7 +311,21 @@ class WorkflowAppGenerateTaskPipeline: ...@@ -303,7 +311,21 @@ class WorkflowAppGenerateTaskPipeline:
:param workflow_node_execution_id: workflow node execution id :param workflow_node_execution_id: workflow node execution id
:return: :return:
""" """
return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first() workflow_node_execution = (db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
if workflow_node_execution:
# Because the workflow_node_execution will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_node_execution)
return workflow_node_execution
def _save_workflow_app_log(self) -> None:
"""
Save workflow app log.
:return:
"""
pass # todo
def _handle_chunk(self, text: str) -> dict: def _handle_chunk(self, text: str) -> dict:
""" """
......
...@@ -176,7 +176,20 @@ class QueueMessage(BaseModel): ...@@ -176,7 +176,20 @@ class QueueMessage(BaseModel):
QueueMessage entity QueueMessage entity
""" """
task_id: str task_id: str
message_id: str
conversation_id: str
app_mode: str app_mode: str
event: AppQueueEvent event: AppQueueEvent
class MessageQueueMessage(QueueMessage):
"""
MessageQueueMessage entity
"""
message_id: str
conversation_id: str
class WorkflowQueueMessage(QueueMessage):
"""
WorkflowQueueMessage entity
"""
pass
...@@ -143,7 +143,7 @@ class Workflow(db.Model): ...@@ -143,7 +143,7 @@ class Workflow(db.Model):
return [] return []
# get user_input_form from start node # get user_input_form from start node
return start_node.get('variables', []) return start_node.get('data', {}).get('variables', [])
class WorkflowRunTriggeredFrom(Enum): class WorkflowRunTriggeredFrom(Enum):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment