Commit ab933a25 authored by takatost's avatar takatost

lint fix

parent 98507edc
This diff is collapsed.
...@@ -14,12 +14,12 @@ from core.app.entities.app_invoke_entities import ( ...@@ -14,12 +14,12 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom, InvokeFrom,
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AnnotationReplyEvent,
QueueAgentMessageEvent, QueueAgentMessageEvent,
QueueAgentThoughtEvent, QueueAgentThoughtEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent, QueueErrorEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent, QueueMessageEndEvent,
QueueMessageEvent,
QueueMessageFileEvent, QueueMessageFileEvent,
QueueMessageReplaceEvent, QueueMessageReplaceEvent,
QueuePingEvent, QueuePingEvent,
...@@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr ...@@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.moderation.output_moderation import ModerationRule, OutputModeration from core.moderation.output_moderation import ModerationRule, OutputModeration
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from events.message_event import message_was_created from events.message_event import message_was_created
...@@ -58,9 +59,9 @@ class TaskState(BaseModel): ...@@ -58,9 +59,9 @@ class TaskState(BaseModel):
metadata: dict = {} metadata: dict = {}
class GenerateTaskPipeline: class EasyUIBasedGenerateTaskPipeline:
""" """
GenerateTaskPipeline is a class that generate stream output and state management for Application. EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
def __init__(self, application_generate_entity: Union[ def __init__(self, application_generate_entity: Union[
...@@ -79,12 +80,13 @@ class GenerateTaskPipeline: ...@@ -79,12 +80,13 @@ class GenerateTaskPipeline:
:param message: message :param message: message
""" """
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._model_config = application_generate_entity.model_config
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._conversation = conversation self._conversation = conversation
self._message = message self._message = message
self._task_state = TaskState( self._task_state = TaskState(
llm_result=LLMResult( llm_result=LLMResult(
model=self._application_generate_entity.model_config.model, model=self._model_config.model,
prompt_messages=[], prompt_messages=[],
message=AssistantPromptMessage(content=""), message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage() usage=LLMUsage.empty_usage()
...@@ -115,7 +117,7 @@ class GenerateTaskPipeline: ...@@ -115,7 +117,7 @@ class GenerateTaskPipeline:
raise self._handle_error(event) raise self._handle_error(event)
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._task_state.metadata['retriever_resources'] = event.retriever_resources self._task_state.metadata['retriever_resources'] = event.retriever_resources
elif isinstance(event, AnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation: if annotation:
account = annotation.account account = annotation.account
...@@ -132,7 +134,7 @@ class GenerateTaskPipeline: ...@@ -132,7 +134,7 @@ class GenerateTaskPipeline:
if isinstance(event, QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result self._task_state.llm_result = event.llm_result
else: else:
model_config = self._application_generate_entity.model_config model_config = self._model_config
model = model_config.model model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
...@@ -189,7 +191,7 @@ class GenerateTaskPipeline: ...@@ -189,7 +191,7 @@ class GenerateTaskPipeline:
'created_at': int(self._message.created_at.timestamp()) 'created_at': int(self._message.created_at.timestamp())
} }
if self._conversation.mode == 'chat': if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id response['conversation_id'] = self._conversation.id
if self._task_state.metadata: if self._task_state.metadata:
...@@ -215,7 +217,7 @@ class GenerateTaskPipeline: ...@@ -215,7 +217,7 @@ class GenerateTaskPipeline:
if isinstance(event, QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result self._task_state.llm_result = event.llm_result
else: else:
model_config = self._application_generate_entity.model_config model_config = self._model_config
model = model_config.model model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
...@@ -268,7 +270,7 @@ class GenerateTaskPipeline: ...@@ -268,7 +270,7 @@ class GenerateTaskPipeline:
'created_at': int(self._message.created_at.timestamp()) 'created_at': int(self._message.created_at.timestamp())
} }
if self._conversation.mode == 'chat': if self._conversation.mode != AppMode.COMPLETION.value:
replace_response['conversation_id'] = self._conversation.id replace_response['conversation_id'] = self._conversation.id
yield self._yield_response(replace_response) yield self._yield_response(replace_response)
...@@ -283,7 +285,7 @@ class GenerateTaskPipeline: ...@@ -283,7 +285,7 @@ class GenerateTaskPipeline:
'message_id': self._message.id, 'message_id': self._message.id,
} }
if self._conversation.mode == 'chat': if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id response['conversation_id'] = self._conversation.id
if self._task_state.metadata: if self._task_state.metadata:
...@@ -292,7 +294,7 @@ class GenerateTaskPipeline: ...@@ -292,7 +294,7 @@ class GenerateTaskPipeline:
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._task_state.metadata['retriever_resources'] = event.retriever_resources self._task_state.metadata['retriever_resources'] = event.retriever_resources
elif isinstance(event, AnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation: if annotation:
account = annotation.account account = annotation.account
...@@ -329,7 +331,7 @@ class GenerateTaskPipeline: ...@@ -329,7 +331,7 @@ class GenerateTaskPipeline:
'message_files': agent_thought.files 'message_files': agent_thought.files
} }
if self._conversation.mode == 'chat': if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id response['conversation_id'] = self._conversation.id
yield self._yield_response(response) yield self._yield_response(response)
...@@ -358,12 +360,12 @@ class GenerateTaskPipeline: ...@@ -358,12 +360,12 @@ class GenerateTaskPipeline:
'url': url 'url': url
} }
if self._conversation.mode == 'chat': if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id response['conversation_id'] = self._conversation.id
yield self._yield_response(response) yield self._yield_response(response)
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
chunk = event.chunk chunk = event.chunk
delta_text = chunk.delta.message.content delta_text = chunk.delta.message.content
if delta_text is None: if delta_text is None:
...@@ -376,7 +378,7 @@ class GenerateTaskPipeline: ...@@ -376,7 +378,7 @@ class GenerateTaskPipeline:
if self._output_moderation_handler.should_direct_output(): if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output # stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._queue_manager.publish_chunk_message(LLMResultChunk( self._queue_manager.publish_llm_chunk(LLMResultChunk(
model=self._task_state.llm_result.model, model=self._task_state.llm_result.model,
prompt_messages=self._task_state.llm_result.prompt_messages, prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
...@@ -404,7 +406,7 @@ class GenerateTaskPipeline: ...@@ -404,7 +406,7 @@ class GenerateTaskPipeline:
'created_at': int(self._message.created_at.timestamp()) 'created_at': int(self._message.created_at.timestamp())
} }
if self._conversation.mode == 'chat': if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id response['conversation_id'] = self._conversation.id
yield self._yield_response(response) yield self._yield_response(response)
...@@ -444,8 +446,7 @@ class GenerateTaskPipeline: ...@@ -444,8 +446,7 @@ class GenerateTaskPipeline:
conversation=self._conversation, conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [ is_first_message=self._application_generate_entity.app_config.app_mode in [
AppMode.AGENT_CHAT, AppMode.AGENT_CHAT,
AppMode.CHAT, AppMode.CHAT
AppMode.ADVANCED_CHAT
] and self._application_generate_entity.conversation_id is None, ] and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras extras=self._application_generate_entity.extras
) )
...@@ -465,7 +466,7 @@ class GenerateTaskPipeline: ...@@ -465,7 +466,7 @@ class GenerateTaskPipeline:
'created_at': int(self._message.created_at.timestamp()) 'created_at': int(self._message.created_at.timestamp())
} }
if self._conversation.mode == 'chat': if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id response['conversation_id'] = self._conversation.id
return response return response
...@@ -575,7 +576,7 @@ class GenerateTaskPipeline: ...@@ -575,7 +576,7 @@ class GenerateTaskPipeline:
:return: :return:
""" """
prompts = [] prompts = []
if self._application_generate_entity.model_config.mode == 'chat': if self._model_config.mode == ModelMode.CHAT.value:
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER: if prompt_message.role == PromptMessageRole.USER:
role = 'user' role = 'user'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment