Unverified Commit 0c746f5c authored by takatost's avatar takatost Committed by GitHub

fix: generate not stop when pressing stop link (#1961)

parent a8cedea1
import time import time
from typing import cast, Optional, List, Tuple, Generator, Union from typing import cast, Optional, List, Tuple, Generator, Union
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
...@@ -183,7 +183,7 @@ class AppRunner: ...@@ -183,7 +183,7 @@ class AppRunner:
index=index, index=index,
message=AssistantPromptMessage(content=token) message=AssistantPromptMessage(content=token)
) )
)) ), PublishFrom.APPLICATION_MANAGER)
index += 1 index += 1
time.sleep(0.01) time.sleep(0.01)
...@@ -193,7 +193,8 @@ class AppRunner: ...@@ -193,7 +193,8 @@ class AppRunner:
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text), message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage() usage=usage if usage else LLMUsage.empty_usage()
) ),
pub_from=PublishFrom.APPLICATION_MANAGER
) )
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
...@@ -226,7 +227,8 @@ class AppRunner: ...@@ -226,7 +227,8 @@ class AppRunner:
:return: :return:
""" """
queue_manager.publish_message_end( queue_manager.publish_message_end(
llm_result=invoke_result llm_result=invoke_result,
pub_from=PublishFrom.APPLICATION_MANAGER
) )
def _handle_invoke_result_stream(self, invoke_result: Generator, def _handle_invoke_result_stream(self, invoke_result: Generator,
...@@ -242,7 +244,7 @@ class AppRunner: ...@@ -242,7 +244,7 @@ class AppRunner:
text = '' text = ''
usage = None usage = None
for result in invoke_result: for result in invoke_result:
queue_manager.publish_chunk_message(result) queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
text += result.delta.message.content text += result.delta.message.content
...@@ -263,5 +265,6 @@ class AppRunner: ...@@ -263,5 +265,6 @@ class AppRunner:
) )
queue_manager.publish_message_end( queue_manager.publish_message_end(
llm_result=llm_result llm_result=llm_result,
pub_from=PublishFrom.APPLICATION_MANAGER
) )
...@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner ...@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \ from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.features.annotation_reply import AnnotationReplyFeature from core.features.annotation_reply import AnnotationReplyFeature
from core.features.dataset_retrieval import DatasetRetrievalFeature from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.external_data_fetch import ExternalDataFetchFeature from core.features.external_data_fetch import ExternalDataFetchFeature
...@@ -121,7 +121,8 @@ class BasicApplicationRunner(AppRunner): ...@@ -121,7 +121,8 @@ class BasicApplicationRunner(AppRunner):
if annotation_reply: if annotation_reply:
queue_manager.publish_annotation_reply( queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id message_annotation_id=annotation_reply.id,
pub_from=PublishFrom.APPLICATION_MANAGER
) )
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
......
...@@ -7,7 +7,7 @@ from pydantic import BaseModel ...@@ -7,7 +7,7 @@ from pydantic import BaseModel
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
from core.entities.application_entities import ApplicationGenerateEntity from core.entities.application_entities import ApplicationGenerateEntity
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \ from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \ QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
AnnotationReplyEvent AnnotationReplyEvent
...@@ -312,8 +312,11 @@ class GenerateTaskPipeline: ...@@ -312,8 +312,11 @@ class GenerateTaskPipeline:
index=0, index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
) )
)) ), PublishFrom.TASK_PIPELINE)
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION)) self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
)
continue continue
else: else:
self._output_moderation_handler.append_new_token(delta_text) self._output_moderation_handler.append_new_token(delta_text)
......
...@@ -6,6 +6,7 @@ from typing import Any, Optional, Dict ...@@ -6,6 +6,7 @@ from typing import Any, Optional, Dict
from flask import current_app, Flask from flask import current_app, Flask
from pydantic import BaseModel from pydantic import BaseModel
from core.application_queue_manager import PublishFrom
from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.factory import ModerationFactory from core.moderation.factory import ModerationFactory
...@@ -66,7 +67,7 @@ class OutputModerationHandler(BaseModel): ...@@ -66,7 +67,7 @@ class OutputModerationHandler(BaseModel):
final_output = result.text final_output = result.text
if public_event: if public_event:
self.on_message_replace_func(final_output) self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
return final_output return final_output
......
...@@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr ...@@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_template import PromptTemplateParser
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import EndUser, Conversation, Message, MessageFile, App from models.model import EndUser, Conversation, Message, MessageFile, App
...@@ -169,15 +169,18 @@ class ApplicationManager: ...@@ -169,15 +169,18 @@ class ApplicationManager:
except ConversationTaskStoppedException: except ConversationTaskStoppedException:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided')) queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
queue_manager.publish_error(e) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:
logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally: finally:
db.session.remove() db.session.remove()
......
import queue import queue
import time import time
from enum import Enum
from typing import Generator, Any from typing import Generator, Any
from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeMeta
...@@ -13,6 +14,11 @@ from extensions.ext_redis import redis_client ...@@ -13,6 +14,11 @@ from extensions.ext_redis import redis_client
from models.model import MessageAgentThought from models.model import MessageAgentThought
class PublishFrom(Enum):
APPLICATION_MANAGER = 1
TASK_PIPELINE = 2
class ApplicationQueueManager: class ApplicationQueueManager:
def __init__(self, task_id: str, def __init__(self, task_id: str,
user_id: str, user_id: str,
...@@ -61,11 +67,14 @@ class ApplicationQueueManager: ...@@ -61,11 +67,14 @@ class ApplicationQueueManager:
if elapsed_time >= listen_timeout or self._is_stopped(): if elapsed_time >= listen_timeout or self._is_stopped():
# publish two messages to make sure the client can receive the stop signal # publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed # and stop listening after the stop signal processed
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)) self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
PublishFrom.TASK_PIPELINE
)
self.stop_listen() self.stop_listen()
if elapsed_time // 10 > last_ping_time: if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent()) self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10 last_ping_time = elapsed_time // 10
def stop_listen(self) -> None: def stop_listen(self) -> None:
...@@ -75,76 +84,83 @@ class ApplicationQueueManager: ...@@ -75,76 +84,83 @@ class ApplicationQueueManager:
""" """
self._q.put(None) self._q.put(None)
def publish_chunk_message(self, chunk: LLMResultChunk) -> None: def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
""" """
Publish chunk message to channel Publish chunk message to channel
:param chunk: chunk :param chunk: chunk
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueMessageEvent( self.publish(QueueMessageEvent(
chunk=chunk chunk=chunk
)) ), pub_from)
def publish_message_replace(self, text: str) -> None: def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
""" """
Publish message replace Publish message replace
:param text: text :param text: text
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueMessageReplaceEvent( self.publish(QueueMessageReplaceEvent(
text=text text=text
)) ), pub_from)
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None: def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
""" """
Publish retriever resources Publish retriever resources
:return: :return:
""" """
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources)) self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
def publish_annotation_reply(self, message_annotation_id: str) -> None: def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
""" """
Publish annotation reply Publish annotation reply
:param message_annotation_id: message annotation id :param message_annotation_id: message annotation id
:param pub_from: publish from
:return: :return:
""" """
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id)) self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
def publish_message_end(self, llm_result: LLMResult) -> None: def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
""" """
Publish message end Publish message end
:param llm_result: llm result :param llm_result: llm result
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueMessageEndEvent(llm_result=llm_result)) self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
self.stop_listen() self.stop_listen()
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None: def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
""" """
Publish agent thought Publish agent thought
:param message_agent_thought: message agent thought :param message_agent_thought: message agent thought
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueAgentThoughtEvent( self.publish(QueueAgentThoughtEvent(
agent_thought_id=message_agent_thought.id agent_thought_id=message_agent_thought.id
)) ), pub_from)
def publish_error(self, e) -> None: def publish_error(self, e, pub_from: PublishFrom) -> None:
""" """
Publish error Publish error
:param e: error :param e: error
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueErrorEvent( self.publish(QueueErrorEvent(
error=e error=e
)) ), pub_from)
self.stop_listen() self.stop_listen()
def publish(self, event: AppQueueEvent) -> None: def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
""" """
Publish event to queue Publish event to queue
:param event: :param event:
:param pub_from:
:return: :return:
""" """
self._check_for_sqlalchemy_models(event.dict()) self._check_for_sqlalchemy_models(event.dict())
...@@ -162,6 +178,9 @@ class ApplicationQueueManager: ...@@ -162,6 +178,9 @@ class ApplicationQueueManager:
if isinstance(event, QueueStopEvent): if isinstance(event, QueueStopEvent):
self.stop_listen() self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise ConversationTaskStoppedException()
@classmethod @classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
""" """
...@@ -187,7 +206,6 @@ class ApplicationQueueManager: ...@@ -187,7 +206,6 @@ class ApplicationQueueManager:
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id) stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
result = redis_client.get(stopped_cache_key) result = redis_client.get(stopped_cache_key)
if result is not None: if result is not None:
redis_client.delete(stopped_cache_key)
return True return True
return False return False
......
...@@ -8,7 +8,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen ...@@ -8,7 +8,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
...@@ -232,7 +232,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -232,7 +232,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
db.session.add(message_agent_thought) db.session.add(message_agent_thought)
db.session.commit() db.session.commit()
self.queue_manager.publish_agent_thought(message_agent_thought) self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER)
return message_agent_thought return message_agent_thought
......
...@@ -2,7 +2,7 @@ from typing import List, Union ...@@ -2,7 +2,7 @@ from typing import List, Union
from langchain.schema import Document from langchain.schema import Document
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, DatasetQuery from models.dataset import DocumentSegment, DatasetQuery
...@@ -80,4 +80,4 @@ class DatasetIndexToolCallbackHandler: ...@@ -80,4 +80,4 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_retriever_resource) db.session.add(dataset_retriever_resource)
db.session.commit() db.session.commit()
self._queue_manager.publish_retriever_resources(resource) self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment