Commit 82daa14c authored by takatost's avatar takatost

Merge branch 'feat/workflow-backend' into deploy/dev

parents 3f7a9330 14973190
......@@ -21,7 +21,7 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.app.app_queue_manager import AppQueueManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
......
......@@ -21,7 +21,7 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_queue_manager import AppQueueManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
......
......@@ -19,7 +19,7 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError,
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.app_queue_manager import AppQueueManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
......
......@@ -20,7 +20,7 @@ from controllers.web.error import (
ProviderQuotaExceededError,
)
from controllers.web.wraps import WebApiResource
from core.app.app_queue_manager import AppQueueManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
......
......@@ -6,8 +6,8 @@ from mimetypes import guess_extension
from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_queue_manager import AppQueueManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
......
......@@ -5,7 +5,8 @@ from typing import Literal, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
from core.app.app_queue_manager import PublishFrom
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
......@@ -121,7 +122,9 @@ class CotAgentRunner(BaseAgentRunner):
)
if iteration_step > 1:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt messages
prompt_messages = self._organize_cot_prompt_messages(
......@@ -163,7 +166,9 @@ class CotAgentRunner(BaseAgentRunner):
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
for chunk in react_chunks:
if isinstance(chunk, dict):
......@@ -225,7 +230,9 @@ class CotAgentRunner(BaseAgentRunner):
llm_usage=usage_dict['usage'])
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
if not scratchpad.action:
# failed to extract action, return final answer directly
......@@ -255,7 +262,9 @@ class CotAgentRunner(BaseAgentRunner):
observation=answer,
answer=answer,
messages_ids=[])
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
else:
# invoke tool
error_response = None
......@@ -282,7 +291,9 @@ class CotAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name,
value=message_file.id,
name=save_as)
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
message_file_ids = [message_file.id for message_file, _ in message_files]
except ToolProviderCredentialValidationError as e:
......@@ -318,7 +329,9 @@ class CotAgentRunner(BaseAgentRunner):
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool message
for prompt_tool in prompt_messages_tools:
......@@ -352,7 +365,7 @@ class CotAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish_message_end(LLMResult(
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
......@@ -360,7 +373,7 @@ class CotAgentRunner(BaseAgentRunner):
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
), PublishFrom.APPLICATION_MANAGER)
)), PublishFrom.APPLICATION_MANAGER)
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
-> Generator[Union[str, dict], None, None]:
......
......@@ -4,7 +4,8 @@ from collections.abc import Generator
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.app_queue_manager import PublishFrom
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
......@@ -135,7 +136,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
......@@ -195,7 +198,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if not result.message.content:
result.message.content = ''
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk(
model=model_instance.model,
......@@ -233,8 +238,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
messages_ids=[],
llm_usage=current_llm_usage
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
final_answer += response + '\n'
......@@ -275,7 +281,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
# publish message file
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file.id)
......@@ -331,7 +339,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
answer=None,
messages_ids=message_file_ids
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool
for prompt_tool in prompt_messages_tools:
......@@ -341,15 +351,15 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish_message_end(LLMResult(
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer,
content=final_answer
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
), PublishFrom.APPLICATION_MANAGER)
)), PublishFrom.APPLICATION_MANAGER)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
......
......@@ -8,11 +8,12 @@ from flask import Flask, current_app
from pydantic import ValidationError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
......@@ -101,7 +102,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = AppQueueManager(
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
......
......@@ -2,14 +2,14 @@ import logging
import time
from typing import cast
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import QueueStopEvent
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.moderation.base import ModerationException
from core.workflow.entities.node_entities import SystemVariable
......@@ -83,7 +83,6 @@ class AdvancedChatAppRunner(AppRunner):
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
app_model=app_record,
workflow=workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
......@@ -154,9 +153,9 @@ class AdvancedChatAppRunner(AppRunner):
)
if annotation_reply:
queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id,
pub_from=PublishFrom.APPLICATION_MANAGER
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
)
self._stream_output(
......@@ -183,7 +182,11 @@ class AdvancedChatAppRunner(AppRunner):
if stream:
index = 0
for token in text:
queue_manager.publish_text_chunk(token, PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.APPLICATION_MANAGER
)
index += 1
time.sleep(0.01)
......@@ -191,4 +194,3 @@ class AdvancedChatAppRunner(AppRunner):
QueueStopEvent(stopped_by=stopped_by),
PublishFrom.APPLICATION_MANAGER
)
queue_manager.stop_listen()
......@@ -6,7 +6,7 @@ from typing import Optional, Union
from pydantic import BaseModel
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
......@@ -46,6 +46,7 @@ class TaskState(BaseModel):
"""
answer: str = ""
metadata: dict = {}
usage: LLMUsage
class AdvancedChatAppGenerateTaskPipeline:
......@@ -253,8 +254,6 @@ class AdvancedChatAppGenerateTaskPipeline:
'error': workflow_run.error,
'elapsed_time': workflow_run.elapsed_time,
'total_tokens': workflow_run.total_tokens,
'total_price': workflow_run.total_price,
'currency': workflow_run.currency,
'total_steps': workflow_run.total_steps,
'created_at': int(workflow_run.created_at.timestamp()),
'finished_at': int(workflow_run.finished_at.timestamp())
......@@ -351,7 +350,12 @@ class AdvancedChatAppGenerateTaskPipeline:
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish_text_chunk(self._task_state.answer, PublishFrom.TASK_PIPELINE)
self._queue_manager.publish(
QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
......@@ -560,5 +564,5 @@ class AdvancedChatAppGenerateTaskPipeline:
type=sensitive_word_avoidance.type,
config=sensitive_word_avoidance.config
),
on_message_replace_func=self._queue_manager.publish_message_replace
queue_manager=self._queue_manager
)
......@@ -9,10 +9,11 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
......@@ -119,7 +120,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = AppQueueManager(
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
......
......@@ -4,10 +4,11 @@ from typing import cast
from core.agent.cot_agent_runner import CotAgentRunner
from core.agent.entities import AgentEntity
from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
......@@ -120,10 +121,11 @@ class AgentChatAppRunner(AppRunner):
)
if annotation_reply:
queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id,
pub_from=PublishFrom.APPLICATION_MANAGER
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
)
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
......
import queue
import time
from abc import abstractmethod
from collections.abc import Generator
from enum import Enum
from typing import Any
......@@ -9,27 +10,13 @@ from sqlalchemy.orm import DeclarativeMeta
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentMessageEvent,
QueueAgentThoughtEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueLLMChunkEvent,
QueueMessage,
QueueMessageEndEvent,
QueueMessageFileEvent,
QueueMessageReplaceEvent,
QueueNodeFinishedEvent,
QueueNodeStartedEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFinishedEvent,
QueueWorkflowStartedEvent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from extensions.ext_redis import redis_client
from models.model import MessageAgentThought, MessageFile
class PublishFrom(Enum):
......@@ -40,19 +27,13 @@ class PublishFrom(Enum):
class AppQueueManager:
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
conversation_id: str,
app_mode: str,
message_id: str) -> None:
invoke_from: InvokeFrom) -> None:
if not user_id:
raise ValueError("user is required")
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self._conversation_id = str(conversation_id)
self._app_mode = app_mode
self._message_id = str(message_id)
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
......@@ -89,7 +70,6 @@ class AppQueueManager:
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
PublishFrom.TASK_PIPELINE
)
self.stop_listen()
if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
......@@ -102,137 +82,6 @@ class AppQueueManager:
"""
self._q.put(None)
def publish_llm_chunk(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
"""
Publish llm chunk to channel
:param chunk: llm chunk
:param pub_from: publish from
:return:
"""
self.publish(QueueLLMChunkEvent(
chunk=chunk
), pub_from)
def publish_text_chunk(self, text: str, pub_from: PublishFrom) -> None:
"""
Publish text chunk to channel
:param text: text
:param pub_from: publish from
:return:
"""
self.publish(QueueTextChunkEvent(
text=text
), pub_from)
def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
"""
Publish agent chunk message to channel
:param chunk: chunk
:param pub_from: publish from
:return:
"""
self.publish(QueueAgentMessageEvent(
chunk=chunk
), pub_from)
def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
"""
Publish message replace
:param text: text
:param pub_from: publish from
:return:
"""
self.publish(QueueMessageReplaceEvent(
text=text
), pub_from)
def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
"""
Publish retriever resources
:return:
"""
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
"""
Publish annotation reply
:param message_annotation_id: message annotation id
:param pub_from: publish from
:return:
"""
self.publish(QueueAnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
"""
Publish message end
:param llm_result: llm result
:param pub_from: publish from
:return:
"""
self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
self.stop_listen()
def publish_workflow_started(self, workflow_run_id: str, pub_from: PublishFrom) -> None:
"""
Publish workflow started
:param workflow_run_id: workflow run id
:param pub_from: publish from
:return:
"""
self.publish(QueueWorkflowStartedEvent(workflow_run_id=workflow_run_id), pub_from)
def publish_workflow_finished(self, workflow_run_id: str, pub_from: PublishFrom) -> None:
"""
Publish workflow finished
:param workflow_run_id: workflow run id
:param pub_from: publish from
:return:
"""
self.publish(QueueWorkflowFinishedEvent(workflow_run_id=workflow_run_id), pub_from)
def publish_node_started(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None:
"""
Publish node started
:param workflow_node_execution_id: workflow node execution id
:param pub_from: publish from
:return:
"""
self.publish(QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from)
def publish_node_finished(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None:
"""
Publish node finished
:param workflow_node_execution_id: workflow node execution id
:param pub_from: publish from
:return:
"""
self.publish(QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from)
def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
"""
Publish agent thought
:param message_agent_thought: message agent thought
:param pub_from: publish from
:return:
"""
self.publish(QueueAgentThoughtEvent(
agent_thought_id=message_agent_thought.id
), pub_from)
def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None:
"""
Publish agent thought
:param message_file: message file
:param pub_from: publish from
:return:
"""
self.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), pub_from)
def publish_error(self, e, pub_from: PublishFrom) -> None:
"""
Publish error
......@@ -243,7 +92,6 @@ class AppQueueManager:
self.publish(QueueErrorEvent(
error=e
), pub_from)
self.stop_listen()
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
......@@ -254,22 +102,20 @@ class AppQueueManager:
"""
self._check_for_sqlalchemy_models(event.dict())
message = QueueMessage(
task_id=self._task_id,
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
)
message = self.construct_queue_message(event)
self._q.put(message)
if isinstance(event, QueueStopEvent):
if isinstance(event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise ConversationTaskStoppedException()
@abstractmethod
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
raise NotImplementedError
@classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
"""
......
......@@ -3,13 +3,14 @@ from collections.abc import Generator
from typing import Optional, Union, cast
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AppGenerateEntity,
EasyUIBasedAppGenerateEntity,
InvokeFrom,
ModelConfigWithCredentialsEntity,
)
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
......@@ -187,25 +188,32 @@ class AppRunner:
if stream:
index = 0
for token in text:
queue_manager.publish_llm_chunk(LLMResultChunk(
chunk = LLMResultChunk(
model=app_generate_entity.model_config.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=token)
)
), PublishFrom.APPLICATION_MANAGER)
)
queue_manager.publish(
QueueLLMChunkEvent(
chunk=chunk
), PublishFrom.APPLICATION_MANAGER
)
index += 1
time.sleep(0.01)
queue_manager.publish_message_end(
queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=app_generate_entity.model_config.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage()
),
pub_from=PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER
)
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
......@@ -241,9 +249,10 @@ class AppRunner:
:param queue_manager: application queue manager
:return:
"""
queue_manager.publish_message_end(
queue_manager.publish(
QueueMessageEndEvent(
llm_result=invoke_result,
pub_from=PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER
)
def _handle_invoke_result_stream(self, invoke_result: Generator,
......@@ -261,9 +270,17 @@ class AppRunner:
usage = None
for result in invoke_result:
if not agent:
queue_manager.publish_llm_chunk(result, PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueLLMChunkEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
else:
queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueAgentMessageEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
text += result.delta.message.content
......@@ -286,9 +303,10 @@ class AppRunner:
usage=usage
)
queue_manager.publish_message_end(
queue_manager.publish(
QueueMessageEndEvent(
llm_result=llm_result,
pub_from=PublishFrom.APPLICATION_MANAGER
), PublishFrom.APPLICATION_MANAGER
)
def moderation_for_inputs(self, app_id: str,
......@@ -311,7 +329,7 @@ class AppRunner:
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=query,
query=query if query else ''
)
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
......
......@@ -9,10 +9,11 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
......@@ -119,7 +120,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = AppQueueManager(
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
......
import logging
from typing import cast
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
from core.app.entities.app_invoke_entities import (
ChatAppGenerateEntity,
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
......@@ -117,10 +118,11 @@ class ChatAppRunner(AppRunner):
)
if annotation_reply:
queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id,
pub_from=PublishFrom.APPLICATION_MANAGER
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
)
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
......
......@@ -9,10 +9,11 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
......@@ -112,7 +113,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = AppQueueManager(
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
......@@ -263,7 +264,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = AppQueueManager(
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
......
import logging
from typing import cast
from core.app.app_queue_manager import AppQueueManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
from core.app.entities.app_invoke_entities import (
......
......@@ -6,7 +6,7 @@ from typing import Optional, Union, cast
from pydantic import BaseModel
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
......@@ -378,14 +378,19 @@ class EasyUIBasedGenerateTaskPipeline:
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._queue_manager.publish_llm_chunk(LLMResultChunk(
self._queue_manager.publish(
QueueLLMChunkEvent(
chunk=LLMResultChunk(
model=self._task_state.llm_result.model,
prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
)
), PublishFrom.TASK_PIPELINE)
)
), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
......@@ -656,5 +661,5 @@ class EasyUIBasedGenerateTaskPipeline:
type=sensitive_word_avoidance.type,
config=sensitive_word_avoidance.config
),
on_message_replace_func=self._queue_manager.publish_message_replace
queue_manager=self._queue_manager
)
......@@ -6,8 +6,8 @@ from typing import Optional, Union
from sqlalchemy import and_
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException
from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
......
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueMessage,
)
class MessageBasedAppQueueManager(AppQueueManager):
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
conversation_id: str,
app_mode: str,
message_id: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id)
self._app_mode = app_mode
self._message_id = str(message_id)
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
return QueueMessage(
task_id=self._task_id,
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
)
import logging
import threading
import uuid
from collections.abc import Generator
from typing import Union
from flask import Flask, current_app
from pydantic import ValidationError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
def generate(self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
inputs = args['inputs']
# parse files
files = args['files'] if 'files' in args and args['files'] else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict)
if file_upload_entity:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_upload_entity,
user
)
else:
file_objs = []
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs=self._get_cleaned_inputs(inputs, app_config),
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from
)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:return:
"""
with flask_app.app_context():
try:
# workflow app
runner = WorkflowAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool = False) -> Union[dict, Generator]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise ConversationTaskStoppedException()
else:
logger.exception(e)
raise e
finally:
db.session.remove()
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueMessage,
)
class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
return QueueMessage(
task_id=self._task_id,
app_mode=self._app_mode,
event=event
)
import logging
import time
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.entities.app_invoke_entities import (
AppGenerateEntity,
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.moderation.base import ModerationException
from core.moderation.input_moderation import InputModeration
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from models.workflow import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
class WorkflowAppRunner:
"""
Workflow Application Runner
"""
def run(self, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:return:
"""
app_config = application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
inputs = application_generate_entity.inputs
files = application_generate_entity.files
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs
):
return
# fetch user
if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]:
user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first()
else:
user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
user=user,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files
},
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
)
def handle_input_moderation(self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: WorkflowAppGenerateEntity,
inputs: dict) -> bool:
"""
Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
:return:
"""
try:
# process sensitive_word_avoidance
moderation_feature = InputModeration()
_, inputs, query = moderation_feature.check(
app_id=app_record.id,
tenant_id=app_generate_entity.app_config.tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=''
)
except ModerationException as e:
if app_generate_entity.stream:
self._stream_output(
queue_manager=queue_manager,
text=str(e),
)
queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION),
PublishFrom.APPLICATION_MANAGER
)
return True
return False
def _stream_output(self, queue_manager: AppQueueManager,
text: str) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param text: text
:return:
"""
index = 0
for token in text:
queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.APPLICATION_MANAGER
)
index += 1
time.sleep(0.01)
def moderation_for_inputs(self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: dict) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_generate_entity: app generate entity
:param inputs: inputs
:return:
"""
moderation_feature = InputModeration()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=''
)
import json
import logging
import time
from collections.abc import Generator
from typing import Optional, Union
from pydantic import BaseModel
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueErrorEvent,
QueueMessageReplaceEvent,
QueueNodeFinishedEvent,
QueueNodeStartedEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFinishedEvent,
QueueWorkflowStartedEvent,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus
logger = logging.getLogger(__name__)
class TaskState(BaseModel):
"""
TaskState entity
"""
answer: str = ""
metadata: dict = {}
workflow_run_id: Optional[str] = None
class WorkflowAppGenerateTaskPipeline:
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._task_state = TaskState()
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
def process(self) -> Union[dict, Generator]:
"""
Process generate task pipeline.
:return:
"""
if self._stream:
return self._process_stream_response()
else:
return self._process_blocking_response()
def _process_blocking_response(self) -> dict:
"""
Process blocking response.
:return:
"""
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueueErrorEvent):
raise self._handle_error(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
if isinstance(event, QueueStopEvent):
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
else:
workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs
self._task_state.answer = outputs.get('text', '')
else:
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
self._task_state.answer = self._output_moderation_handler.moderation_completion(
completion=self._task_state.answer,
public_event=False
)
response = {
'event': 'workflow_finished',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': event.workflow_run_id,
'data': {
'id': workflow_run.id,
'workflow_id': workflow_run.workflow_id,
'status': workflow_run.status,
'outputs': workflow_run.outputs_dict,
'error': workflow_run.error,
'elapsed_time': workflow_run.elapsed_time,
'total_tokens': workflow_run.total_tokens,
'total_steps': workflow_run.total_steps,
'created_at': int(workflow_run.created_at.timestamp()),
'finished_at': int(workflow_run.finished_at.timestamp())
}
}
return response
else:
continue
def _process_stream_response(self) -> Generator:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
event = message.event
if isinstance(event, QueueErrorEvent):
data = self._error_to_stream_response_data(self._handle_error(event))
yield self._yield_response(data)
break
elif isinstance(event, QueueWorkflowStartedEvent):
self._task_state.workflow_run_id = event.workflow_run_id
workflow_run = self._get_workflow_run(event.workflow_run_id)
response = {
'event': 'workflow_started',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': event.workflow_run_id,
'data': {
'id': workflow_run.id,
'workflow_id': workflow_run.workflow_id,
'created_at': int(workflow_run.created_at.timestamp())
}
}
yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
response = {
'event': 'node_started',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': workflow_node_execution.workflow_run_id,
'data': {
'id': workflow_node_execution.id,
'node_id': workflow_node_execution.node_id,
'index': workflow_node_execution.index,
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
'inputs': workflow_node_execution.inputs_dict,
'created_at': int(workflow_node_execution.created_at.timestamp())
}
}
yield self._yield_response(response)
elif isinstance(event, QueueNodeFinishedEvent):
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
response = {
'event': 'node_finished',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': workflow_node_execution.workflow_run_id,
'data': {
'id': workflow_node_execution.id,
'node_id': workflow_node_execution.node_id,
'index': workflow_node_execution.index,
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
'inputs': workflow_node_execution.inputs_dict,
'process_data': workflow_node_execution.process_data_dict,
'outputs': workflow_node_execution.outputs_dict,
'status': workflow_node_execution.status,
'error': workflow_node_execution.error,
'elapsed_time': workflow_node_execution.elapsed_time,
'execution_metadata': workflow_node_execution.execution_metadata_dict,
'created_at': int(workflow_node_execution.created_at.timestamp()),
'finished_at': int(workflow_node_execution.finished_at.timestamp())
}
}
yield self._yield_response(response)
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
if isinstance(event, QueueStopEvent):
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
else:
workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs
self._task_state.answer = outputs.get('text', '')
else:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
data = self._error_to_stream_response_data(self._handle_error(err_event))
yield self._yield_response(data)
break
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
self._task_state.answer = self._output_moderation_handler.moderation_completion(
completion=self._task_state.answer,
public_event=False
)
self._output_moderation_handler = None
replace_response = {
'event': 'text_replace',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': self._task_state.workflow_run_id,
'data': {
'text': self._task_state.answer
}
}
yield self._yield_response(replace_response)
workflow_run_response = {
'event': 'workflow_finished',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': event.workflow_run_id,
'data': {
'id': workflow_run.id,
'workflow_id': workflow_run.workflow_id,
'status': workflow_run.status,
'outputs': workflow_run.outputs_dict,
'error': workflow_run.error,
'elapsed_time': workflow_run.elapsed_time,
'total_tokens': workflow_run.total_tokens,
'total_steps': workflow_run.total_steps,
'created_at': int(workflow_run.created_at.timestamp()),
'finished_at': int(workflow_run.finished_at.timestamp())
}
}
yield self._yield_response(workflow_run_response)
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.chunk_text
if delta_text is None:
continue
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
)
continue
else:
self._output_moderation_handler.append_new_token(delta_text)
self._task_state.answer += delta_text
response = self._handle_chunk(delta_text)
yield self._yield_response(response)
elif isinstance(event, QueueMessageReplaceEvent):
response = {
'event': 'text_replace',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': self._task_state.workflow_run_id,
'data': {
'text': event.text
}
}
yield self._yield_response(response)
elif isinstance(event, QueuePingEvent):
yield "event: ping\n\n"
else:
continue
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Get workflow run.
:param workflow_run_id: workflow run id
:return:
"""
return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
"""
Get workflow node execution.
:param workflow_node_execution_id: workflow node execution id
:return:
"""
return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()
def _handle_chunk(self, text: str) -> dict:
"""
Handle completed event.
:param text: text
:return:
"""
response = {
'event': 'text_chunk',
'workflow_run_id': self._task_state.workflow_run_id,
'task_id': self._application_generate_entity.task_id,
'data': {
'text': text
}
}
return response
def _handle_error(self, event: QueueErrorEvent) -> Exception:
"""
Handle error event.
:param event: event
:return:
"""
logger.debug("error: %s", event.error)
e = event.error
if isinstance(e, InvokeAuthorizationError):
return InvokeAuthorizationError('Incorrect API key provided')
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
return e
else:
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
def _error_to_stream_response_data(self, e: Exception) -> dict:
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses = {
ValueError: {'code': 'invalid_param', 'status': 400},
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
QuotaExceededError: {
'code': 'provider_quota_exceeded',
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
'status': 400
},
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
InvokeError: {'code': 'completion_request_error', 'status': 400}
}
# Determine the response based on the type of exception
data = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v
if data:
data.setdefault('message', getattr(e, 'description', str(e)))
else:
logging.error(e)
data = {
'code': 'internal_server_error',
'message': 'Internal Server Error, please contact support.',
'status': 500
}
return {
'event': 'error',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': self._task_state.workflow_run_id,
**data
}
def _yield_response(self, response: dict) -> str:
"""
Yield response.
:param response: response
:return:
"""
return "data: " + json.dumps(response) + "\n\n"
def _init_output_moderation(self) -> Optional[OutputModeration]:
"""
Init output moderation.
:return:
"""
app_config = self._application_generate_entity.app_config
sensitive_word_avoidance = app_config.sensitive_word_avoidance
if sensitive_word_avoidance:
return OutputModeration(
tenant_id=app_config.tenant_id,
app_id=app_config.app_id,
rule=ModerationRule(
type=sensitive_word_avoidance.type,
config=sensitive_word_avoidance.config
),
queue_manager=self._queue_manager
)
......@@ -127,9 +127,9 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
query: Optional[str] = None
class WorkflowUIBasedAppGenerateEntity(AppGenerateEntity):
class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
Workflow UI Based Application Generate Entity.
Workflow Application Generate Entity.
"""
# app config
app_config: WorkflowUIBasedAppConfig
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import DatasetQuery, DocumentSegment
......@@ -82,4 +83,7 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource),
PublishFrom.APPLICATION_MANAGER
)
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
QueueNodeFinishedEvent,
QueueNodeStartedEvent,
QueueTextChunkEvent,
QueueWorkflowFinishedEvent,
QueueWorkflowStartedEvent,
)
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from models.workflow import WorkflowNodeExecution, WorkflowRun
......@@ -12,34 +19,45 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
"""
Workflow run started
"""
self._queue_manager.publish_workflow_started(
workflow_run_id=workflow_run.id,
pub_from=PublishFrom.TASK_PIPELINE
self._queue_manager.publish(
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
"""
Workflow run finished
"""
self._queue_manager.publish_workflow_finished(
workflow_run_id=workflow_run.id,
pub_from=PublishFrom.TASK_PIPELINE
self._queue_manager.publish(
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish_node_started(
workflow_node_execution_id=workflow_node_execution.id,
pub_from=PublishFrom.TASK_PIPELINE
self._queue_manager.publish(
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
"""
Workflow node execute finished
"""
self._queue_manager.publish_node_finished(
workflow_node_execution_id=workflow_node_execution.id,
pub_from=PublishFrom.TASK_PIPELINE
self._queue_manager.publish(
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
PublishFrom.APPLICATION_MANAGER
)
def on_text_chunk(self, text: str) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
......@@ -6,7 +6,8 @@ from typing import Any, Optional
from flask import Flask, current_app
from pydantic import BaseModel
from core.app.app_queue_manager import PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import QueueMessageReplaceEvent
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.factory import ModerationFactory
......@@ -25,7 +26,7 @@ class OutputModeration(BaseModel):
app_id: str
rule: ModerationRule
on_message_replace_func: Any
queue_manager: AppQueueManager
thread: Optional[threading.Thread] = None
thread_running: bool = True
......@@ -67,7 +68,12 @@ class OutputModeration(BaseModel):
final_output = result.text
if public_event:
self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
self.queue_manager.publish(
QueueMessageReplaceEvent(
text=final_output
),
PublishFrom.TASK_PIPELINE
)
return final_output
......@@ -117,7 +123,12 @@ class OutputModeration(BaseModel):
# trigger replace event
if self.thread_running:
self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
self.queue_manager.publish(
QueueMessageReplaceEvent(
text=final_output
),
PublishFrom.TASK_PIPELINE
)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
......
......@@ -31,3 +31,11 @@ class BaseWorkflowCallback:
Workflow node execute finished
"""
raise NotImplementedError
@abstractmethod
def on_text_chunk(self, text: str) -> None:
"""
Publish text chunk
"""
raise NotImplementedError
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from models.workflow import WorkflowNodeExecutionStatus
class NodeType(Enum):
......@@ -39,3 +44,19 @@ class SystemVariable(Enum):
QUERY = 'query'
FILES = 'files'
CONVERSATION = 'conversation'
class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[dict] = None # node outputs
metadata: Optional[dict] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
from decimal import Decimal
from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecution, WorkflowRun
......@@ -10,7 +8,10 @@ class WorkflowRunState:
variable_pool: VariablePool
total_tokens: int = 0
total_price: Decimal = Decimal(0)
currency: str = "USD"
workflow_node_executions: list[WorkflowNodeExecution] = []
def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None:
self.workflow_run = workflow_run
self.start_at = start_at
self.variable_pool = variable_pool
from abc import abstractmethod
from typing import Optional
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus
class BaseNode:
......@@ -13,17 +14,23 @@ class BaseNode:
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
def __init__(self, config: dict) -> None:
stream_output_supported: bool = False
callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
self.node_id = config.get("id")
if not self.node_id:
raise ValueError("Node ID is required.")
self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> dict:
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
......@@ -33,22 +40,41 @@ class BaseNode:
raise NotImplementedError
def run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> dict:
run_args: Optional[dict] = None) -> NodeRunResult:
"""
Run node entry
:param variable_pool: variable pool
:param run_args: run args
:param callbacks: callbacks
:return:
"""
if variable_pool is None and run_args is None:
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
return self._run(
try:
result = self._run(
variable_pool=variable_pool,
run_args=run_args
)
except Exception as e:
# process unhandled exception
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str) -> None:
"""
Publish text chunk
:param text: chunk text
:return:
"""
if self.stream_output_supported:
if self.callbacks:
for callback in self.callbacks:
callback.on_text_chunk(text)
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
......
import json
import time
from datetime import datetime
from typing import Optional, Union
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowRunState
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.code.code_node import CodeNode
......@@ -31,6 +32,7 @@ from models.workflow import (
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
WorkflowType,
)
node_classes = {
......@@ -120,8 +122,7 @@ class WorkflowEngineManager:
return default_config
def run_workflow(self, app_model: App,
workflow: Workflow,
def run_workflow(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
......@@ -129,7 +130,6 @@ class WorkflowEngineManager:
callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Run workflow
:param app_model: App instance
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
......@@ -143,13 +143,23 @@ class WorkflowEngineManager:
if not graph:
raise ValueError('workflow graph not found')
if 'nodes' not in graph or 'edges' not in graph:
raise ValueError('nodes or edges not found in workflow graph')
if isinstance(graph.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init workflow run
workflow_run = self._init_workflow_run(
workflow=workflow,
triggered_from=triggered_from,
user=user,
user_inputs=user_inputs,
system_inputs=system_inputs
system_inputs=system_inputs,
callbacks=callbacks
)
# init workflow run state
......@@ -161,44 +171,54 @@ class WorkflowEngineManager:
)
)
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started(workflow_run)
# fetch start node
start_node = self._get_entry_node(graph)
if not start_node:
self._workflow_run_failed(
workflow_run_state=workflow_run_state,
error='Start node not found in workflow graph',
callbacks=callbacks
)
return
# fetch predecessor node ids before end node (include: llm, direct answer)
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
try:
predecessor_node = None
current_node = start_node
while True:
# run workflow
self._run_workflow_node(
workflow_run_state=workflow_run_state,
node=current_node,
# get next node, multiple target nodes in the future
next_node = self._get_next_node(
graph=graph,
predecessor_node=predecessor_node,
callbacks=callbacks
)
if current_node.node_type == NodeType.END:
if not next_node:
break
# todo fetch next node until end node finished or no next node
current_node = None
# check if node is streamable
if next_node.node_id in streamable_node_ids:
next_node.stream_output_supported = True
if not current_node:
break
# max steps 30 reached
if len(workflow_run_state.workflow_node_executions) > 30:
raise ValueError('Max steps 30 reached.')
predecessor_node = current_node
# or max steps 30 reached
# or max execution time 10min reached
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600):
raise ValueError('Max execution time 10min reached.')
# run workflow, run multiple target nodes in the future
self._run_workflow_node(
workflow_run_state=workflow_run_state,
node=next_node,
predecessor_node=predecessor_node,
callbacks=callbacks
)
if next_node.node_type == NodeType.END:
break
predecessor_node = next_node
if not predecessor_node and not next_node:
self._workflow_run_failed(
workflow_run_state=workflow_run_state,
error='Start node not found in workflow graph.',
callbacks=callbacks
)
return
except Exception as e:
self._workflow_run_failed(
workflow_run_state=workflow_run_state,
......@@ -213,11 +233,40 @@ class WorkflowEngineManager:
callbacks=callbacks
)
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param workflow: Workflow instance
:param graph: workflow graph
:return:
"""
workflow_type = WorkflowType.value_of(workflow.type)
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('type') == NodeType.END.value:
if workflow_type == WorkflowType.WORKFLOW:
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
end_node_ids.append(node_config.get('id'))
else:
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> WorkflowRun:
system_inputs: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
"""
Init workflow run
:param workflow: Workflow instance
......@@ -225,6 +274,7 @@ class WorkflowEngineManager:
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
:return:
"""
try:
......@@ -260,6 +310,39 @@ class WorkflowEngineManager:
db.session.rollback()
raise
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started(workflow_run)
return workflow_run
def _workflow_run_success(self, workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
"""
Workflow run success
:param workflow_run_state: workflow run state
:param callbacks: workflow callbacks
:return:
"""
workflow_run = workflow_run_state.workflow_run
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
# fetch last workflow_node_executions
last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1]
if last_workflow_node_execution:
workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs)
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
workflow_run.total_tokens = workflow_run_state.total_tokens
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_run_finished(workflow_run)
return workflow_run
def _workflow_run_failed(self, workflow_run_state: WorkflowRunState,
......@@ -277,9 +360,8 @@ class WorkflowEngineManager:
workflow_run.error = error
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
workflow_run.total_tokens = workflow_run_state.total_tokens
workflow_run.total_price = workflow_run_state.total_price
workflow_run.currency = workflow_run_state.currency
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
......@@ -289,22 +371,78 @@ class WorkflowEngineManager:
return workflow_run
def _get_entry_node(self, graph: dict) -> Optional[StartNode]:
def _get_next_node(self, graph: dict,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
"""
Get entry node
Get next node
multiple target nodes in the future.
:param graph: workflow graph
:param predecessor_node: predecessor node
:param callbacks: workflow callbacks
:return:
"""
nodes = graph.get('nodes')
if not nodes:
return None
for node_config in nodes.items():
if not predecessor_node:
for node_config in nodes:
if node_config.get('type') == NodeType.START.value:
return StartNode(config=node_config)
else:
edges = graph.get('edges')
source_node_id = predecessor_node.node_id
# fetch all outgoing edges from source node
outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id]
if not outgoing_edges:
return None
# fetch target node id from outgoing edges
outgoing_edge = None
source_handle = predecessor_node.node_run_result.edge_source_handle
if source_handle:
for edge in outgoing_edges:
if edge.get('source_handle') and edge.get('source_handle') == source_handle:
outgoing_edge = edge
break
else:
outgoing_edge = outgoing_edges[0]
if not outgoing_edge:
return None
target_node_id = outgoing_edge.get('target')
# fetch target node from target node id
target_node_config = None
for node in nodes:
if node.get('id') == target_node_id:
target_node_config = node
break
if not target_node_config:
return None
# get next node
target_node = node_classes.get(NodeType.value_of(target_node_config.get('type')))
return target_node(
config=target_node_config,
callbacks=callbacks
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
:param start_at: start time
:param max_execution_time: max execution time
:return:
"""
# TODO check queue is stopped
return time.perf_counter() - start_at > max_execution_time
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
predecessor_node: Optional[BaseNode] = None,
......@@ -320,28 +458,41 @@ class WorkflowEngineManager:
# add to workflow node executions
workflow_run_state.workflow_node_executions.append(workflow_node_execution)
try:
# run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result = node.run(
variable_pool=workflow_run_state.variable_pool,
callbacks=callbacks
variable_pool=workflow_run_state.variable_pool
)
except Exception as e:
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
# node run failed
self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
error=str(e),
start_at=start_at,
error=node_run_result.error,
callbacks=callbacks
)
raise
raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
# node run success
self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
start_at=start_at,
result=node_run_result,
callbacks=callbacks
)
for variable_key, variable_value in node_run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
variable_pool=workflow_run_state.variable_pool,
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
if node_run_result.metadata.get('total_tokens'):
workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens'))
return workflow_node_execution
def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState,
......@@ -384,3 +535,86 @@ class WorkflowEngineManager:
callback.on_workflow_node_execute_started(workflow_node_execution)
return workflow_node_execution
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
result: NodeRunResult,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
"""
Workflow node execution success
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param result: node run result
:param callbacks: workflow callbacks
:return:
"""
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(result.inputs)
workflow_node_execution.process_data = json.dumps(result.process_data)
workflow_node_execution.outputs = json.dumps(result.outputs)
workflow_node_execution.execution_metadata = json.dumps(result.metadata)
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_finished(workflow_node_execution)
return workflow_node_execution
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
error: str,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param error: error message
:param callbacks: workflow callbacks
:return:
"""
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_finished(workflow_node_execution)
return workflow_node_execution
def _append_variables_recursively(self, variable_pool: VariablePool,
node_id: str,
variable_key_list: list[str],
variable_value: VariableValue):
"""
Append variables recursively
:param variable_pool: variable pool
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value
)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
variable_pool=variable_pool,
node_id=node_id,
variable_key_list=new_key_list,
variable_value=value
)
......@@ -11,8 +11,6 @@ workflow_run_for_log_fields = {
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField
......@@ -29,8 +27,6 @@ workflow_run_for_list_fields = {
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
"created_at": TimestampField,
......@@ -56,8 +52,6 @@ workflow_run_detail_fields = {
"error": fields.String,
"elapsed_time": fields.Float,
"total_tokens": fields.Integer,
"total_price": fields.Float,
"currency": fields.String,
"total_steps": fields.Integer,
"created_by_role": fields.String,
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
......
......@@ -216,8 +216,6 @@ class WorkflowRun(db.Model):
- error (string) `optional` Error reason
- elapsed_time (float) `optional` Time consumption (s)
- total_tokens (int) `optional` Total tokens used
- total_price (decimal) `optional` Total cost
- currency (string) `optional` Currency, such as USD / RMB
- total_steps (int) Total steps (redundant), default 0
- created_by_role (string) Creator role
......@@ -251,8 +249,6 @@ class WorkflowRun(db.Model):
error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0'))
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
total_price = db.Column(db.Numeric(10, 7))
currency = db.Column(db.String(255))
total_steps = db.Column(db.Integer, server_default=db.text('0'))
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
......
......@@ -6,6 +6,7 @@ from typing import Optional, Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeType
from core.workflow.workflow_engine_manager import WorkflowEngineManager
......@@ -175,8 +176,24 @@ class WorkflowService:
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom) -> Union[dict, Generator]:
# TODO
pass
# fetch draft workflow by app_model
draft_workflow = self.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError('Workflow not initialized')
# run draft workflow
app_generator = WorkflowAppGenerator()
response = app_generator.generate(
app_model=app_model,
workflow=draft_workflow,
user=user,
args=args,
invoke_from=invoke_from,
stream=True
)
return response
def convert_to_workflow(self, app_model: App, account: Account) -> App:
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment