Commit f4f7cfd4 authored by takatost's avatar takatost

fix bugs

parent d214c047
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
...@@ -79,9 +80,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource): ...@@ -79,9 +80,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
Run draft workflow Run draft workflow
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, location='json')
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, required=True, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('files', type=list, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
args = parser.parse_args() args = parser.parse_args()
...@@ -93,6 +94,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource): ...@@ -93,6 +94,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER invoke_from=InvokeFrom.DEBUGGER
) )
return compact_response(response)
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError: except services.errors.conversation.ConversationCompletedError:
...@@ -103,12 +106,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource): ...@@ -103,12 +106,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
class DraftWorkflowRunApi(Resource): class DraftWorkflowRunApi(Resource):
@setup_required @setup_required
...@@ -120,7 +117,7 @@ class DraftWorkflowRunApi(Resource): ...@@ -120,7 +117,7 @@ class DraftWorkflowRunApi(Resource):
Run draft workflow Run draft workflow
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
workflow_service = WorkflowService() workflow_service = WorkflowService()
...@@ -280,6 +277,17 @@ class ConvertToWorkflowApi(Resource): ...@@ -280,6 +277,17 @@ class ConvertToWorkflowApi(Resource):
return workflow return workflow
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft') api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run') api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run')
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run') api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
......
...@@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: ...@@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline:
response = { response = {
'event': 'workflow_started', 'event': 'workflow_started',
'task_id': self._application_generate_entity.task_id, 'task_id': self._application_generate_entity.task_id,
'workflow_run_id': event.workflow_run_id, 'workflow_run_id': workflow_run.id,
'data': { 'data': {
'id': workflow_run.id, 'id': workflow_run.id,
'workflow_id': workflow_run.workflow_id, 'workflow_id': workflow_run.workflow_id,
......
...@@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ...@@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
""" """
......
...@@ -46,7 +46,7 @@ class ChatAppConfigManager(BaseAppConfigManager): ...@@ -46,7 +46,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
else: else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS: if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict() app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy() config_dict = app_model_config_dict.copy()
else: else:
......
...@@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): ...@@ -15,7 +15,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
""" """
......
...@@ -5,7 +5,7 @@ from typing import Optional, Union ...@@ -5,7 +5,7 @@ from typing import Optional, Union
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowRunState from core.workflow.entities.workflow_entities import WorkflowRunState
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
...@@ -122,10 +122,10 @@ class WorkflowEngineManager: ...@@ -122,10 +122,10 @@ class WorkflowEngineManager:
if 'nodes' not in graph or 'edges' not in graph: if 'nodes' not in graph or 'edges' not in graph:
raise ValueError('nodes or edges not found in workflow graph') raise ValueError('nodes or edges not found in workflow graph')
if isinstance(graph.get('nodes'), list): if not isinstance(graph.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list') raise ValueError('nodes in workflow graph must be a list')
if isinstance(graph.get('edges'), list): if not isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list') raise ValueError('edges in workflow graph must be a list')
# init workflow run # init workflow run
...@@ -150,6 +150,7 @@ class WorkflowEngineManager: ...@@ -150,6 +150,7 @@ class WorkflowEngineManager:
try: try:
predecessor_node = None predecessor_node = None
has_entry_node = False
while True: while True:
# get next node, multiple target nodes in the future # get next node, multiple target nodes in the future
next_node = self._get_next_node( next_node = self._get_next_node(
...@@ -161,6 +162,8 @@ class WorkflowEngineManager: ...@@ -161,6 +162,8 @@ class WorkflowEngineManager:
if not next_node: if not next_node:
break break
has_entry_node = True
# max steps 30 reached # max steps 30 reached
if len(workflow_run_state.workflow_node_executions) > 30: if len(workflow_run_state.workflow_node_executions) > 30:
raise ValueError('Max steps 30 reached.') raise ValueError('Max steps 30 reached.')
...@@ -182,7 +185,7 @@ class WorkflowEngineManager: ...@@ -182,7 +185,7 @@ class WorkflowEngineManager:
predecessor_node = next_node predecessor_node = next_node
if not predecessor_node and not next_node: if not has_entry_node:
self._workflow_run_failed( self._workflow_run_failed(
workflow_run_state=workflow_run_state, workflow_run_state=workflow_run_state,
error='Start node not found in workflow graph.', error='Start node not found in workflow graph.',
...@@ -219,38 +222,31 @@ class WorkflowEngineManager: ...@@ -219,38 +222,31 @@ class WorkflowEngineManager:
:param callbacks: workflow callbacks :param callbacks: workflow callbacks
:return: :return:
""" """
try: max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
db.session.begin() .filter(WorkflowRun.tenant_id == workflow.tenant_id) \
.filter(WorkflowRun.app_id == workflow.app_id) \
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ .scalar() or 0
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \ new_sequence_number = max_sequence + 1
.filter(WorkflowRun.app_id == workflow.app_id) \
.for_update() \ # init workflow run
.scalar() or 0 workflow_run = WorkflowRun(
new_sequence_number = max_sequence + 1 tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
# init workflow run sequence_number=new_sequence_number,
workflow_run = WorkflowRun( workflow_id=workflow.id,
tenant_id=workflow.tenant_id, type=workflow.type,
app_id=workflow.app_id, triggered_from=triggered_from.value,
sequence_number=new_sequence_number, version=workflow.version,
workflow_id=workflow.id, graph=workflow.graph,
type=workflow.type, inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}),
triggered_from=triggered_from.value, status=WorkflowRunStatus.RUNNING.value,
version=workflow.version, created_by_role=(CreatedByRole.ACCOUNT.value
graph=workflow.graph, if isinstance(user, Account) else CreatedByRole.END_USER.value),
inputs=json.dumps({**user_inputs, **system_inputs}), created_by=user.id
status=WorkflowRunStatus.RUNNING.value, )
created_by_role=(CreatedByRole.ACCOUNT.value
if isinstance(user, Account) else CreatedByRole.END_USER.value),
created_by=user.id
)
db.session.add(workflow_run) db.session.add(workflow_run)
db.session.commit() db.session.commit()
except:
db.session.rollback()
raise
if callbacks: if callbacks:
for callback in callbacks: for callback in callbacks:
...@@ -330,7 +326,7 @@ class WorkflowEngineManager: ...@@ -330,7 +326,7 @@ class WorkflowEngineManager:
if not predecessor_node: if not predecessor_node:
for node_config in nodes: for node_config in nodes:
if node_config.get('type') == NodeType.START.value: if node_config.get('data', {}).get('type', '') == NodeType.START.value:
return StartNode(config=node_config) return StartNode(config=node_config)
else: else:
edges = graph.get('edges') edges = graph.get('edges')
...@@ -368,7 +364,7 @@ class WorkflowEngineManager: ...@@ -368,7 +364,7 @@ class WorkflowEngineManager:
return None return None
# get next node # get next node
target_node = node_classes.get(NodeType.value_of(target_node_config.get('type'))) target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
return target_node( return target_node(
config=target_node_config, config=target_node_config,
...@@ -424,17 +420,18 @@ class WorkflowEngineManager: ...@@ -424,17 +420,18 @@ class WorkflowEngineManager:
callbacks=callbacks callbacks=callbacks
) )
for variable_key, variable_value in node_run_result.outputs.items(): if node_run_result.outputs:
# append variables to variable pool recursively for variable_key, variable_value in node_run_result.outputs.items():
self._append_variables_recursively( # append variables to variable pool recursively
variable_pool=workflow_run_state.variable_pool, self._append_variables_recursively(
node_id=node.node_id, variable_pool=workflow_run_state.variable_pool,
variable_key_list=[variable_key], node_id=node.node_id,
variable_value=variable_value variable_key_list=[variable_key],
) variable_value=variable_value
)
if node_run_result.metadata.get('total_tokens'): if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens')) workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
return workflow_node_execution return workflow_node_execution
...@@ -464,7 +461,6 @@ class WorkflowEngineManager: ...@@ -464,7 +461,6 @@ class WorkflowEngineManager:
node_id=node.node_id, node_id=node.node_id,
node_type=node.node_type.value, node_type=node.node_type.value,
title=node.node_data.title, title=node.node_data.title,
type=node.node_type.value,
status=WorkflowNodeExecutionStatus.RUNNING.value, status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role, created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by created_by=workflow_run.created_by
...@@ -493,10 +489,11 @@ class WorkflowEngineManager: ...@@ -493,10 +489,11 @@ class WorkflowEngineManager:
""" """
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(result.inputs) workflow_node_execution.inputs = json.dumps(result.inputs) if result.inputs else None
workflow_node_execution.process_data = json.dumps(result.process_data) workflow_node_execution.process_data = json.dumps(result.process_data) if result.process_data else None
workflow_node_execution.outputs = json.dumps(result.outputs) workflow_node_execution.outputs = json.dumps(result.outputs) if result.outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) \
if result.metadata else None
workflow_node_execution.finished_at = datetime.utcnow() workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit() db.session.commit()
......
...@@ -45,8 +45,8 @@ def upgrade(): ...@@ -45,8 +45,8 @@ def upgrade():
sa.Column('node_id', sa.String(length=255), nullable=False), sa.Column('node_id', sa.String(length=255), nullable=False),
sa.Column('node_type', sa.String(length=255), nullable=False), sa.Column('node_type', sa.String(length=255), nullable=False),
sa.Column('title', sa.String(length=255), nullable=False), sa.Column('title', sa.String(length=255), nullable=False),
sa.Column('inputs', sa.Text(), nullable=False), sa.Column('inputs', sa.Text(), nullable=True),
sa.Column('process_data', sa.Text(), nullable=False), sa.Column('process_data', sa.Text(), nullable=True),
sa.Column('outputs', sa.Text(), nullable=True), sa.Column('outputs', sa.Text(), nullable=True),
sa.Column('status', sa.String(length=255), nullable=False), sa.Column('status', sa.String(length=255), nullable=False),
sa.Column('error', sa.Text(), nullable=True), sa.Column('error', sa.Text(), nullable=True),
......
"""messages columns set nullable
Revision ID: b5429b71023c
Revises: 42e85ed5564d
Create Date: 2024-03-07 09:52:00.846136
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'b5429b71023c'
down_revision = '42e85ed5564d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True)
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False)
# ### end Alembic commands ###
...@@ -585,8 +585,8 @@ class Message(db.Model): ...@@ -585,8 +585,8 @@ class Message(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False) app_id = db.Column(UUID, nullable=False)
model_provider = db.Column(db.String(255), nullable=False) model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=False) model_id = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text) override_model_configs = db.Column(db.Text)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False)
inputs = db.Column(db.JSON) inputs = db.Column(db.JSON)
......
...@@ -138,7 +138,7 @@ class Workflow(db.Model): ...@@ -138,7 +138,7 @@ class Workflow(db.Model):
if 'nodes' not in graph_dict: if 'nodes' not in graph_dict:
return [] return []
start_node = next((node for node in graph_dict['nodes'] if node['type'] == 'start'), None) start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None)
if not start_node: if not start_node:
return [] return []
...@@ -392,8 +392,8 @@ class WorkflowNodeExecution(db.Model): ...@@ -392,8 +392,8 @@ class WorkflowNodeExecution(db.Model):
node_id = db.Column(db.String(255), nullable=False) node_id = db.Column(db.String(255), nullable=False)
node_type = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False)
title = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False)
inputs = db.Column(db.Text, nullable=False) inputs = db.Column(db.Text)
process_data = db.Column(db.Text, nullable=False) process_data = db.Column(db.Text)
outputs = db.Column(db.Text) outputs = db.Column(db.Text)
status = db.Column(db.String(255), nullable=False) status = db.Column(db.String(255), nullable=False)
error = db.Column(db.Text) error = db.Column(db.Text)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment