Commit fe126ec3 authored by takatost's avatar takatost

Merge branch 'feat/workflow-backend' into deploy/dev

parents 823d423f 463b68b7
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
...@@ -65,7 +66,7 @@ class DraftWorkflowApi(Resource): ...@@ -65,7 +66,7 @@ class DraftWorkflowApi(Resource):
return { return {
"result": "success", "result": "success",
"updated_at": TimestampField().format(workflow.updated_at) "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
} }
...@@ -79,9 +80,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource): ...@@ -79,9 +80,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
Run draft workflow Run draft workflow
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, location='json')
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, required=True, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('files', type=list, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
args = parser.parse_args() args = parser.parse_args()
...@@ -93,6 +94,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource): ...@@ -93,6 +94,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER invoke_from=InvokeFrom.DEBUGGER
) )
return compact_response(response)
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError: except services.errors.conversation.ConversationCompletedError:
...@@ -103,12 +106,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource): ...@@ -103,12 +106,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
class DraftWorkflowRunApi(Resource): class DraftWorkflowRunApi(Resource):
@setup_required @setup_required
...@@ -120,7 +117,7 @@ class DraftWorkflowRunApi(Resource): ...@@ -120,7 +117,7 @@ class DraftWorkflowRunApi(Resource):
Run draft workflow Run draft workflow
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
workflow_service = WorkflowService() workflow_service = WorkflowService()
...@@ -280,6 +277,17 @@ class ConvertToWorkflowApi(Resource): ...@@ -280,6 +277,17 @@ class ConvertToWorkflowApi(Resource):
return workflow return workflow
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft') api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run') api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run')
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run') api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
......
...@@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline:
response = { response = {
'event': 'workflow_started', 'event': 'workflow_started',
'task_id': self._application_generate_entity.task_id, 'task_id': self._application_generate_entity.task_id,
'workflow_run_id': event.workflow_run_id, 'workflow_run_id': workflow_run.id,
'data': { 'data': {
'id': workflow_run.id, 'id': workflow_run.id,
'workflow_id': workflow_run.workflow_id, 'workflow_id': workflow_run.workflow_id,
......
...@@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ...@@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
""" """
......
...@@ -46,7 +46,7 @@ class ChatAppConfigManager(BaseAppConfigManager): ...@@ -46,7 +46,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
else: else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict() app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy() config_dict = app_model_config_dict.copy()
else: else:
......
...@@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ...@@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
""" """
......
...@@ -5,7 +5,7 @@ from typing import Optional, Union ...@@ -5,7 +5,7 @@ from typing import Optional, Union
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowRunState from core.workflow.entities.workflow_entities import WorkflowRunState
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
...@@ -122,10 +122,10 @@ class WorkflowEngineManager: ...@@ -122,10 +122,10 @@ class WorkflowEngineManager:
if 'nodes' not in graph or 'edges' not in graph: if 'nodes' not in graph or 'edges' not in graph:
raise ValueError('nodes or edges not found in workflow graph') raise ValueError('nodes or edges not found in workflow graph')
if isinstance(graph.get('nodes'), list): if not isinstance(graph.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list') raise ValueError('nodes in workflow graph must be a list')
if isinstance(graph.get('edges'), list): if not isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list') raise ValueError('edges in workflow graph must be a list')
# init workflow run # init workflow run
...@@ -150,6 +150,7 @@ class WorkflowEngineManager: ...@@ -150,6 +150,7 @@ class WorkflowEngineManager:
try: try:
predecessor_node = None predecessor_node = None
has_entry_node = False
while True: while True:
# get next node, multiple target nodes in the future # get next node, multiple target nodes in the future
next_node = self._get_next_node( next_node = self._get_next_node(
...@@ -161,6 +162,8 @@ class WorkflowEngineManager: ...@@ -161,6 +162,8 @@ class WorkflowEngineManager:
if not next_node: if not next_node:
break break
has_entry_node = True
# max steps 30 reached # max steps 30 reached
if len(workflow_run_state.workflow_node_executions) > 30: if len(workflow_run_state.workflow_node_executions) > 30:
raise ValueError('Max steps 30 reached.') raise ValueError('Max steps 30 reached.')
...@@ -182,7 +185,7 @@ class WorkflowEngineManager: ...@@ -182,7 +185,7 @@ class WorkflowEngineManager:
predecessor_node = next_node predecessor_node = next_node
if not predecessor_node and not next_node: if not has_entry_node:
self._workflow_run_failed( self._workflow_run_failed(
workflow_run_state=workflow_run_state, workflow_run_state=workflow_run_state,
error='Start node not found in workflow graph.', error='Start node not found in workflow graph.',
...@@ -219,38 +222,31 @@ class WorkflowEngineManager: ...@@ -219,38 +222,31 @@ class WorkflowEngineManager:
:param callbacks: workflow callbacks :param callbacks: workflow callbacks
:return: :return:
""" """
try: max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
db.session.begin() .filter(WorkflowRun.tenant_id == workflow.tenant_id) \
.filter(WorkflowRun.app_id == workflow.app_id) \
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ .scalar() or 0
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \ new_sequence_number = max_sequence + 1
.filter(WorkflowRun.app_id == workflow.app_id) \
.for_update() \ # init workflow run
.scalar() or 0 workflow_run = WorkflowRun(
new_sequence_number = max_sequence + 1 tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
# init workflow run sequence_number=new_sequence_number,
workflow_run = WorkflowRun( workflow_id=workflow.id,
tenant_id=workflow.tenant_id, type=workflow.type,
app_id=workflow.app_id, triggered_from=triggered_from.value,
sequence_number=new_sequence_number, version=workflow.version,
workflow_id=workflow.id, graph=workflow.graph,
type=workflow.type, inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}),
triggered_from=triggered_from.value, status=WorkflowRunStatus.RUNNING.value,
version=workflow.version, created_by_role=(CreatedByRole.ACCOUNT.value
graph=workflow.graph, if isinstance(user, Account) else CreatedByRole.END_USER.value),
inputs=json.dumps({**user_inputs, **system_inputs}), created_by=user.id
status=WorkflowRunStatus.RUNNING.value, )
created_by_role=(CreatedByRole.ACCOUNT.value
if isinstance(user, Account) else CreatedByRole.END_USER.value),
created_by=user.id
)
db.session.add(workflow_run) db.session.add(workflow_run)
db.session.commit() db.session.commit()
except:
db.session.rollback()
raise
if callbacks: if callbacks:
for callback in callbacks: for callback in callbacks:
...@@ -330,7 +326,7 @@ class WorkflowEngineManager: ...@@ -330,7 +326,7 @@ class WorkflowEngineManager:
if not predecessor_node: if not predecessor_node:
for node_config in nodes: for node_config in nodes:
if node_config.get('type') == NodeType.START.value: if node_config.get('data', {}).get('type', '') == NodeType.START.value:
return StartNode(config=node_config) return StartNode(config=node_config)
else: else:
edges = graph.get('edges') edges = graph.get('edges')
...@@ -368,7 +364,7 @@ class WorkflowEngineManager: ...@@ -368,7 +364,7 @@ class WorkflowEngineManager:
return None return None
# get next node # get next node
target_node = node_classes.get(NodeType.value_of(target_node_config.get('type'))) target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
return target_node( return target_node(
config=target_node_config, config=target_node_config,
...@@ -424,17 +420,18 @@ class WorkflowEngineManager: ...@@ -424,17 +420,18 @@ class WorkflowEngineManager:
callbacks=callbacks callbacks=callbacks
) )
for variable_key, variable_value in node_run_result.outputs.items(): if node_run_result.outputs:
# append variables to variable pool recursively for variable_key, variable_value in node_run_result.outputs.items():
self._append_variables_recursively( # append variables to variable pool recursively
variable_pool=workflow_run_state.variable_pool, self._append_variables_recursively(
node_id=node.node_id, variable_pool=workflow_run_state.variable_pool,
variable_key_list=[variable_key], node_id=node.node_id,
variable_value=variable_value variable_key_list=[variable_key],
) variable_value=variable_value
)
if node_run_result.metadata.get('total_tokens'): if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens')) workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
return workflow_node_execution return workflow_node_execution
...@@ -464,7 +461,6 @@ class WorkflowEngineManager: ...@@ -464,7 +461,6 @@ class WorkflowEngineManager:
node_id=node.node_id, node_id=node.node_id,
node_type=node.node_type.value, node_type=node.node_type.value,
title=node.node_data.title, title=node.node_data.title,
type=node.node_type.value,
status=WorkflowNodeExecutionStatus.RUNNING.value, status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role, created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by created_by=workflow_run.created_by
...@@ -493,10 +489,11 @@ class WorkflowEngineManager: ...@@ -493,10 +489,11 @@ class WorkflowEngineManager:
""" """
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(result.inputs) workflow_node_execution.inputs = json.dumps(result.inputs) if result.inputs else None
workflow_node_execution.process_data = json.dumps(result.process_data) workflow_node_execution.process_data = json.dumps(result.process_data) if result.process_data else None
workflow_node_execution.outputs = json.dumps(result.outputs) workflow_node_execution.outputs = json.dumps(result.outputs) if result.outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) \
if result.metadata else None
workflow_node_execution.finished_at = datetime.utcnow() workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit() db.session.commit()
......
...@@ -48,7 +48,7 @@ app_detail_fields = { ...@@ -48,7 +48,7 @@ app_detail_fields = {
'icon_background': fields.String, 'icon_background': fields.String,
'enable_site': fields.Boolean, 'enable_site': fields.Boolean,
'enable_api': fields.Boolean, 'enable_api': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True),
'created_at': TimestampField 'created_at': TimestampField
} }
...@@ -68,7 +68,7 @@ app_partial_fields = { ...@@ -68,7 +68,7 @@ app_partial_fields = {
'mode': fields.String, 'mode': fields.String,
'icon': fields.String, 'icon': fields.String,
'icon_background': fields.String, 'icon_background': fields.String,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True),
'created_at': TimestampField 'created_at': TimestampField
} }
...@@ -118,7 +118,7 @@ app_detail_fields_with_site = { ...@@ -118,7 +118,7 @@ app_detail_fields_with_site = {
'icon_background': fields.String, 'icon_background': fields.String,
'enable_site': fields.Boolean, 'enable_site': fields.Boolean,
'enable_api': fields.Boolean, 'enable_api': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True),
'site': fields.Nested(site_fields), 'site': fields.Nested(site_fields),
'api_base_url': fields.String, 'api_base_url': fields.String,
'created_at': TimestampField, 'created_at': TimestampField,
......
...@@ -45,8 +45,8 @@ def upgrade(): ...@@ -45,8 +45,8 @@ def upgrade():
sa.Column('node_id', sa.String(length=255), nullable=False), sa.Column('node_id', sa.String(length=255), nullable=False),
sa.Column('node_type', sa.String(length=255), nullable=False), sa.Column('node_type', sa.String(length=255), nullable=False),
sa.Column('title', sa.String(length=255), nullable=False), sa.Column('title', sa.String(length=255), nullable=False),
sa.Column('inputs', sa.Text(), nullable=False), sa.Column('inputs', sa.Text(), nullable=True),
sa.Column('process_data', sa.Text(), nullable=False), sa.Column('process_data', sa.Text(), nullable=True),
sa.Column('outputs', sa.Text(), nullable=True), sa.Column('outputs', sa.Text(), nullable=True),
sa.Column('status', sa.String(length=255), nullable=False), sa.Column('status', sa.String(length=255), nullable=False),
sa.Column('error', sa.Text(), nullable=True), sa.Column('error', sa.Text(), nullable=True),
......
"""messages columns set nullable
Revision ID: b5429b71023c
Revises: 42e85ed5564d
Create Date: 2024-03-07 09:52:00.846136
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'b5429b71023c'
down_revision = '42e85ed5564d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True)
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False)
# ### end Alembic commands ###
...@@ -585,8 +585,8 @@ class Message(db.Model): ...@@ -585,8 +585,8 @@ class Message(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(UUID, nullable=False)
model_provider = db.Column(db.String(255), nullable=False) model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=False) model_id = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text) override_model_configs = db.Column(db.Text)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False)
inputs = db.Column(db.JSON) inputs = db.Column(db.JSON)
......
...@@ -138,7 +138,7 @@ class Workflow(db.Model): ...@@ -138,7 +138,7 @@ class Workflow(db.Model):
if 'nodes' not in graph_dict: if 'nodes' not in graph_dict:
return [] return []
start_node = next((node for node in graph_dict['nodes'] if node['type'] == 'start'), None) start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None)
if not start_node: if not start_node:
return [] return []
...@@ -392,8 +392,8 @@ class WorkflowNodeExecution(db.Model): ...@@ -392,8 +392,8 @@ class WorkflowNodeExecution(db.Model):
node_id = db.Column(db.String(255), nullable=False) node_id = db.Column(db.String(255), nullable=False)
node_type = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False)
title = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False)
inputs = db.Column(db.Text, nullable=False) inputs = db.Column(db.Text)
process_data = db.Column(db.Text, nullable=False) process_data = db.Column(db.Text)
outputs = db.Column(db.Text) outputs = db.Column(db.Text)
status = db.Column(db.String(255), nullable=False) status = db.Column(db.String(255), nullable=False)
error = db.Column(db.Text) error = db.Column(db.Text)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment