Commit ab933a25 authored by takatost's avatar takatost

lint fix

parent 98507edc
This diff is collapsed.
......@@ -14,12 +14,12 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.queue_entities import (
AnnotationReplyEvent,
QueueAgentMessageEvent,
QueueAgentThoughtEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent,
QueueMessageEvent,
QueueMessageFileEvent,
QueueMessageReplaceEvent,
QueuePingEvent,
......@@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.moderation.output_moderation import ModerationRule, OutputModeration
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.tool_file_manager import ToolFileManager
from events.message_event import message_was_created
......@@ -58,9 +59,9 @@ class TaskState(BaseModel):
metadata: dict = {}
class GenerateTaskPipeline:
class EasyUIBasedGenerateTaskPipeline:
"""
GenerateTaskPipeline is a class that generate stream output and state management for Application.
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(self, application_generate_entity: Union[
......@@ -79,12 +80,13 @@ class GenerateTaskPipeline:
:param message: message
"""
self._application_generate_entity = application_generate_entity
self._model_config = application_generate_entity.model_config
self._queue_manager = queue_manager
self._conversation = conversation
self._message = message
self._task_state = TaskState(
llm_result=LLMResult(
model=self._application_generate_entity.model_config.model,
model=self._model_config.model,
prompt_messages=[],
message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage()
......@@ -115,7 +117,7 @@ class GenerateTaskPipeline:
raise self._handle_error(event)
elif isinstance(event, QueueRetrieverResourcesEvent):
self._task_state.metadata['retriever_resources'] = event.retriever_resources
elif isinstance(event, AnnotationReplyEvent):
elif isinstance(event, QueueAnnotationReplyEvent):
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
......@@ -132,7 +134,7 @@ class GenerateTaskPipeline:
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
model_config = self._application_generate_entity.model_config
model_config = self._model_config
model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
......@@ -189,7 +191,7 @@ class GenerateTaskPipeline:
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id
if self._task_state.metadata:
......@@ -215,7 +217,7 @@ class GenerateTaskPipeline:
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
model_config = self._application_generate_entity.model_config
model_config = self._model_config
model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
......@@ -268,7 +270,7 @@ class GenerateTaskPipeline:
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
if self._conversation.mode != AppMode.COMPLETION.value:
replace_response['conversation_id'] = self._conversation.id
yield self._yield_response(replace_response)
......@@ -283,7 +285,7 @@ class GenerateTaskPipeline:
'message_id': self._message.id,
}
if self._conversation.mode == 'chat':
if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id
if self._task_state.metadata:
......@@ -292,7 +294,7 @@ class GenerateTaskPipeline:
yield self._yield_response(response)
elif isinstance(event, QueueRetrieverResourcesEvent):
self._task_state.metadata['retriever_resources'] = event.retriever_resources
elif isinstance(event, AnnotationReplyEvent):
elif isinstance(event, QueueAnnotationReplyEvent):
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
......@@ -329,7 +331,7 @@ class GenerateTaskPipeline:
'message_files': agent_thought.files
}
if self._conversation.mode == 'chat':
if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id
yield self._yield_response(response)
......@@ -358,12 +360,12 @@ class GenerateTaskPipeline:
'url': url
}
if self._conversation.mode == 'chat':
if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id
yield self._yield_response(response)
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
chunk = event.chunk
delta_text = chunk.delta.message.content
if delta_text is None:
......@@ -376,7 +378,7 @@ class GenerateTaskPipeline:
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._queue_manager.publish_chunk_message(LLMResultChunk(
self._queue_manager.publish_llm_chunk(LLMResultChunk(
model=self._task_state.llm_result.model,
prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta(
......@@ -404,7 +406,7 @@ class GenerateTaskPipeline:
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id
yield self._yield_response(response)
......@@ -444,8 +446,7 @@ class GenerateTaskPipeline:
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [
AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.ADVANCED_CHAT
AppMode.CHAT
] and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
)
......@@ -465,7 +466,7 @@ class GenerateTaskPipeline:
'created_at': int(self._message.created_at.timestamp())
}
if self._conversation.mode == 'chat':
if self._conversation.mode != AppMode.COMPLETION.value:
response['conversation_id'] = self._conversation.id
return response
......@@ -575,7 +576,7 @@ class GenerateTaskPipeline:
:return:
"""
prompts = []
if self._application_generate_entity.model_config.mode == 'chat':
if self._model_config.mode == ModelMode.CHAT.value:
for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER:
role = 'user'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment