Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
ab933a25
Commit
ab933a25
authored
Mar 03, 2024
by
takatost
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
lint fix
parent
98507edc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
585 additions
and
21 deletions
+585
-21
generate_task_pipeline.py
api/core/app/apps/advanced_chat/generate_task_pipeline.py
+563
-0
easy_ui_based_generate_task_pipeline.py
api/core/app/apps/easy_ui_based_generate_task_pipeline.py
+22
-21
No files found.
api/core/app/apps/advanced_chat/generate_task_pipeline.py
0 → 100644
View file @
ab933a25
import
json
import
logging
import
time
from
collections.abc
import
Generator
from
typing
import
Optional
,
Union
from
pydantic
import
BaseModel
from
core.app.app_queue_manager
import
AppQueueManager
,
PublishFrom
from
core.app.entities.app_invoke_entities
import
(
AdvancedChatAppGenerateEntity
,
InvokeFrom
,
)
from
core.app.entities.queue_entities
import
(
QueueAnnotationReplyEvent
,
QueueErrorEvent
,
QueueMessageFileEvent
,
QueueMessageReplaceEvent
,
QueueNodeFinishedEvent
,
QueueNodeStartedEvent
,
QueuePingEvent
,
QueueRetrieverResourcesEvent
,
QueueStopEvent
,
QueueTextChunkEvent
,
QueueWorkflowFinishedEvent
,
QueueWorkflowStartedEvent
,
)
from
core.errors.error
import
ModelCurrentlyNotSupportError
,
ProviderTokenNotInitError
,
QuotaExceededError
from
core.model_runtime.entities.llm_entities
import
LLMUsage
from
core.model_runtime.errors.invoke
import
InvokeAuthorizationError
,
InvokeError
from
core.moderation.output_moderation
import
ModerationRule
,
OutputModeration
from
core.tools.tool_file_manager
import
ToolFileManager
from
events.message_event
import
message_was_created
from
extensions.ext_database
import
db
from
models.model
import
Conversation
,
Message
,
MessageFile
from
models.workflow
import
WorkflowNodeExecution
,
WorkflowNodeExecutionStatus
,
WorkflowRun
,
WorkflowRunStatus
from
services.annotation_service
import
AppAnnotationService
logger
=
logging
.
getLogger
(
__name__
)
class
TaskState
(
BaseModel
):
"""
TaskState entity
"""
answer
:
str
=
""
metadata
:
dict
=
{}
class
AdvancedChatAppGenerateTaskPipeline
:
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def
__init__
(
self
,
application_generate_entity
:
AdvancedChatAppGenerateEntity
,
queue_manager
:
AppQueueManager
,
conversation
:
Conversation
,
message
:
Message
)
->
None
:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
"""
self
.
_application_generate_entity
=
application_generate_entity
self
.
_queue_manager
=
queue_manager
self
.
_conversation
=
conversation
self
.
_message
=
message
self
.
_task_state
=
TaskState
(
usage
=
LLMUsage
.
empty_usage
()
)
self
.
_start_at
=
time
.
perf_counter
()
self
.
_output_moderation_handler
=
self
.
_init_output_moderation
()
def
process
(
self
,
stream
:
bool
)
->
Union
[
dict
,
Generator
]:
"""
Process generate task pipeline.
:return:
"""
if
stream
:
return
self
.
_process_stream_response
()
else
:
return
self
.
_process_blocking_response
()
def
_process_blocking_response
(
self
)
->
dict
:
"""
Process blocking response.
:return:
"""
for
queue_message
in
self
.
_queue_manager
.
listen
():
event
=
queue_message
.
event
if
isinstance
(
event
,
QueueErrorEvent
):
raise
self
.
_handle_error
(
event
)
elif
isinstance
(
event
,
QueueRetrieverResourcesEvent
):
self
.
_task_state
.
metadata
[
'retriever_resources'
]
=
event
.
retriever_resources
elif
isinstance
(
event
,
QueueAnnotationReplyEvent
):
annotation
=
AppAnnotationService
.
get_annotation_by_id
(
event
.
message_annotation_id
)
if
annotation
:
account
=
annotation
.
account
self
.
_task_state
.
metadata
[
'annotation_reply'
]
=
{
'id'
:
annotation
.
id
,
'account'
:
{
'id'
:
annotation
.
account_id
,
'name'
:
account
.
name
if
account
else
'Dify user'
}
}
self
.
_task_state
.
answer
=
annotation
.
content
elif
isinstance
(
event
,
QueueNodeFinishedEvent
):
workflow_node_execution
=
self
.
_get_workflow_node_execution
(
event
.
workflow_node_execution_id
)
if
workflow_node_execution
.
status
==
WorkflowNodeExecutionStatus
.
SUCCEEDED
.
value
:
if
workflow_node_execution
.
node_type
==
'llm'
:
# todo use enum
outputs
=
workflow_node_execution
.
outputs_dict
usage_dict
=
outputs
.
get
(
'usage'
,
{})
self
.
_task_state
.
metadata
[
'usage'
]
=
usage_dict
elif
isinstance
(
event
,
QueueStopEvent
|
QueueWorkflowFinishedEvent
):
if
isinstance
(
event
,
QueueWorkflowFinishedEvent
):
workflow_run
=
self
.
_get_workflow_run
(
event
.
workflow_run_id
)
if
workflow_run
.
status
==
WorkflowRunStatus
.
SUCCEEDED
.
value
:
outputs
=
workflow_run
.
outputs
self
.
_task_state
.
answer
=
outputs
.
get
(
'text'
,
''
)
else
:
raise
self
.
_handle_error
(
QueueErrorEvent
(
error
=
ValueError
(
f
'Run failed: {workflow_run.error}'
)))
# response moderation
if
self
.
_output_moderation_handler
:
self
.
_output_moderation_handler
.
stop_thread
()
self
.
_task_state
.
answer
=
self
.
_output_moderation_handler
.
moderation_completion
(
completion
=
self
.
_task_state
.
answer
,
public_event
=
False
)
# Save message
self
.
_save_message
()
response
=
{
'event'
:
'message'
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'id'
:
self
.
_message
.
id
,
'message_id'
:
self
.
_message
.
id
,
'conversation_id'
:
self
.
_conversation
.
id
,
'mode'
:
self
.
_conversation
.
mode
,
'answer'
:
self
.
_task_state
.
answer
,
'metadata'
:
{},
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
}
if
self
.
_task_state
.
metadata
:
response
[
'metadata'
]
=
self
.
_get_response_metadata
()
return
response
else
:
continue
def
_process_stream_response
(
self
)
->
Generator
:
"""
Process stream response.
:return:
"""
for
message
in
self
.
_queue_manager
.
listen
():
event
=
message
.
event
if
isinstance
(
event
,
QueueErrorEvent
):
data
=
self
.
_error_to_stream_response_data
(
self
.
_handle_error
(
event
))
yield
self
.
_yield_response
(
data
)
break
elif
isinstance
(
event
,
QueueWorkflowStartedEvent
):
workflow_run
=
self
.
_get_workflow_run
(
event
.
workflow_run_id
)
response
=
{
'event'
:
'workflow_started'
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'workflow_run_id'
:
event
.
workflow_run_id
,
'data'
:
{
'id'
:
workflow_run
.
id
,
'workflow_id'
:
workflow_run
.
workflow_id
,
'created_at'
:
int
(
workflow_run
.
created_at
.
timestamp
())
}
}
yield
self
.
_yield_response
(
response
)
elif
isinstance
(
event
,
QueueNodeStartedEvent
):
workflow_node_execution
=
self
.
_get_workflow_node_execution
(
event
.
workflow_node_execution_id
)
response
=
{
'event'
:
'node_started'
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'workflow_run_id'
:
workflow_node_execution
.
workflow_run_id
,
'data'
:
{
'id'
:
workflow_node_execution
.
id
,
'node_id'
:
workflow_node_execution
.
node_id
,
'index'
:
workflow_node_execution
.
index
,
'predecessor_node_id'
:
workflow_node_execution
.
predecessor_node_id
,
'inputs'
:
workflow_node_execution
.
inputs_dict
,
'created_at'
:
int
(
workflow_node_execution
.
created_at
.
timestamp
())
}
}
yield
self
.
_yield_response
(
response
)
elif
isinstance
(
event
,
QueueNodeFinishedEvent
):
workflow_node_execution
=
self
.
_get_workflow_node_execution
(
event
.
workflow_node_execution_id
)
if
workflow_node_execution
.
status
==
WorkflowNodeExecutionStatus
.
SUCCEEDED
.
value
:
if
workflow_node_execution
.
node_type
==
'llm'
:
# todo use enum
outputs
=
workflow_node_execution
.
outputs_dict
usage_dict
=
outputs
.
get
(
'usage'
,
{})
self
.
_task_state
.
metadata
[
'usage'
]
=
usage_dict
response
=
{
'event'
:
'node_finished'
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'workflow_run_id'
:
workflow_node_execution
.
workflow_run_id
,
'data'
:
{
'id'
:
workflow_node_execution
.
id
,
'node_id'
:
workflow_node_execution
.
node_id
,
'index'
:
workflow_node_execution
.
index
,
'predecessor_node_id'
:
workflow_node_execution
.
predecessor_node_id
,
'inputs'
:
workflow_node_execution
.
inputs_dict
,
'process_data'
:
workflow_node_execution
.
process_data_dict
,
'outputs'
:
workflow_node_execution
.
outputs_dict
,
'status'
:
workflow_node_execution
.
status
,
'error'
:
workflow_node_execution
.
error
,
'elapsed_time'
:
workflow_node_execution
.
elapsed_time
,
'execution_metadata'
:
workflow_node_execution
.
execution_metadata_dict
,
'created_at'
:
int
(
workflow_node_execution
.
created_at
.
timestamp
()),
'finished_at'
:
int
(
workflow_node_execution
.
finished_at
.
timestamp
())
}
}
yield
self
.
_yield_response
(
response
)
elif
isinstance
(
event
,
QueueStopEvent
|
QueueWorkflowFinishedEvent
):
if
isinstance
(
event
,
QueueWorkflowFinishedEvent
):
workflow_run
=
self
.
_get_workflow_run
(
event
.
workflow_run_id
)
if
workflow_run
.
status
==
WorkflowRunStatus
.
SUCCEEDED
.
value
:
outputs
=
workflow_run
.
outputs
self
.
_task_state
.
answer
=
outputs
.
get
(
'text'
,
''
)
else
:
err_event
=
QueueErrorEvent
(
error
=
ValueError
(
f
'Run failed: {workflow_run.error}'
))
data
=
self
.
_error_to_stream_response_data
(
self
.
_handle_error
(
err_event
))
yield
self
.
_yield_response
(
data
)
break
workflow_run_response
=
{
'event'
:
'workflow_finished'
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'workflow_run_id'
:
event
.
workflow_run_id
,
'data'
:
{
'id'
:
workflow_run
.
id
,
'workflow_id'
:
workflow_run
.
workflow_id
,
'status'
:
workflow_run
.
status
,
'outputs'
:
workflow_run
.
outputs_dict
,
'error'
:
workflow_run
.
error
,
'elapsed_time'
:
workflow_run
.
elapsed_time
,
'total_tokens'
:
workflow_run
.
total_tokens
,
'total_price'
:
workflow_run
.
total_price
,
'currency'
:
workflow_run
.
currency
,
'total_steps'
:
workflow_run
.
total_steps
,
'created_at'
:
int
(
workflow_run
.
created_at
.
timestamp
()),
'finished_at'
:
int
(
workflow_run
.
finished_at
.
timestamp
())
}
}
yield
self
.
_yield_response
(
workflow_run_response
)
# response moderation
if
self
.
_output_moderation_handler
:
self
.
_output_moderation_handler
.
stop_thread
()
self
.
_task_state
.
answer
=
self
.
_output_moderation_handler
.
moderation_completion
(
completion
=
self
.
_task_state
.
answer
,
public_event
=
False
)
self
.
_output_moderation_handler
=
None
replace_response
=
{
'event'
:
'message_replace'
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'message_id'
:
self
.
_message
.
id
,
'conversation_id'
:
self
.
_conversation
.
id
,
'answer'
:
self
.
_task_state
.
answer
,
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
}
yield
self
.
_yield_response
(
replace_response
)
# Save message
self
.
_save_message
()
response
=
{
'event'
:
'message_end'
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'id'
:
self
.
_message
.
id
,
'message_id'
:
self
.
_message
.
id
,
'conversation_id'
:
self
.
_conversation
.
id
,
}
if
self
.
_task_state
.
metadata
:
response
[
'metadata'
]
=
self
.
_get_response_metadata
()
yield
self
.
_yield_response
(
response
)
elif
isinstance
(
event
,
QueueRetrieverResourcesEvent
):
self
.
_task_state
.
metadata
[
'retriever_resources'
]
=
event
.
retriever_resources
elif
isinstance
(
event
,
QueueAnnotationReplyEvent
):
annotation
=
AppAnnotationService
.
get_annotation_by_id
(
event
.
message_annotation_id
)
if
annotation
:
account
=
annotation
.
account
self
.
_task_state
.
metadata
[
'annotation_reply'
]
=
{
'id'
:
annotation
.
id
,
'account'
:
{
'id'
:
annotation
.
account_id
,
'name'
:
account
.
name
if
account
else
'Dify user'
}
}
self
.
_task_state
.
answer
=
annotation
.
content
elif
isinstance
(
event
,
QueueMessageFileEvent
):
message_file
:
MessageFile
=
(
db
.
session
.
query
(
MessageFile
)
.
filter
(
MessageFile
.
id
==
event
.
message_file_id
)
.
first
()
)
# get extension
if
'.'
in
message_file
.
url
:
extension
=
f
'.{message_file.url.split(".")[-1]}'
if
len
(
extension
)
>
10
:
extension
=
'.bin'
else
:
extension
=
'.bin'
# add sign url
url
=
ToolFileManager
.
sign_file
(
file_id
=
message_file
.
id
,
extension
=
extension
)
if
message_file
:
response
=
{
'event'
:
'message_file'
,
'conversation_id'
:
self
.
_conversation
.
id
,
'id'
:
message_file
.
id
,
'type'
:
message_file
.
type
,
'belongs_to'
:
message_file
.
belongs_to
or
'user'
,
'url'
:
url
}
yield
self
.
_yield_response
(
response
)
elif
isinstance
(
event
,
QueueTextChunkEvent
):
delta_text
=
event
.
chunk_text
if
delta_text
is
None
:
continue
if
self
.
_output_moderation_handler
:
if
self
.
_output_moderation_handler
.
should_direct_output
():
# stop subscribe new token when output moderation should direct output
self
.
_task_state
.
answer
=
self
.
_output_moderation_handler
.
get_final_output
()
self
.
_queue_manager
.
publish_text_chunk
(
self
.
_task_state
.
answer
,
PublishFrom
.
TASK_PIPELINE
)
self
.
_queue_manager
.
publish
(
QueueStopEvent
(
stopped_by
=
QueueStopEvent
.
StopBy
.
OUTPUT_MODERATION
),
PublishFrom
.
TASK_PIPELINE
)
continue
else
:
self
.
_output_moderation_handler
.
append_new_token
(
delta_text
)
self
.
_task_state
.
answer
+=
delta_text
response
=
self
.
_handle_chunk
(
delta_text
)
yield
self
.
_yield_response
(
response
)
elif
isinstance
(
event
,
QueueMessageReplaceEvent
):
response
=
{
'event'
:
'message_replace'
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'message_id'
:
self
.
_message
.
id
,
'conversation_id'
:
self
.
_conversation
.
id
,
'answer'
:
event
.
text
,
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
}
yield
self
.
_yield_response
(
response
)
elif
isinstance
(
event
,
QueuePingEvent
):
yield
"event: ping
\n\n
"
else
:
continue
def
_get_workflow_run
(
self
,
workflow_run_id
:
str
)
->
WorkflowRun
:
"""
Get workflow run.
:param workflow_run_id: workflow run id
:return:
"""
return
db
.
session
.
query
(
WorkflowRun
)
.
filter
(
WorkflowRun
.
id
==
workflow_run_id
)
.
first
()
def
_get_workflow_node_execution
(
self
,
workflow_node_execution_id
:
str
)
->
WorkflowNodeExecution
:
"""
Get workflow node execution.
:param workflow_node_execution_id: workflow node execution id
:return:
"""
return
db
.
session
.
query
(
WorkflowNodeExecution
)
.
filter
(
WorkflowNodeExecution
.
id
==
workflow_node_execution_id
)
.
first
()
def
_save_message
(
self
)
->
None
:
"""
Save message.
:return:
"""
self
.
_message
=
db
.
session
.
query
(
Message
)
.
filter
(
Message
.
id
==
self
.
_message
.
id
)
.
first
()
self
.
_message
.
answer
=
self
.
_task_state
.
answer
self
.
_message
.
provider_response_latency
=
time
.
perf_counter
()
-
self
.
_start_at
if
self
.
_task_state
.
metadata
and
self
.
_task_state
.
metadata
.
get
(
'usage'
):
usage
=
LLMUsage
(
**
self
.
_task_state
.
metadata
[
'usage'
])
self
.
_message
.
message_tokens
=
usage
.
prompt_tokens
self
.
_message
.
message_unit_price
=
usage
.
prompt_unit_price
self
.
_message
.
message_price_unit
=
usage
.
prompt_price_unit
self
.
_message
.
answer_tokens
=
usage
.
completion_tokens
self
.
_message
.
answer_unit_price
=
usage
.
completion_unit_price
self
.
_message
.
answer_price_unit
=
usage
.
completion_price_unit
self
.
_message
.
provider_response_latency
=
time
.
perf_counter
()
-
self
.
_start_at
self
.
_message
.
total_price
=
usage
.
total_price
self
.
_message
.
currency
=
usage
.
currency
db
.
session
.
commit
()
message_was_created
.
send
(
self
.
_message
,
application_generate_entity
=
self
.
_application_generate_entity
,
conversation
=
self
.
_conversation
,
is_first_message
=
self
.
_application_generate_entity
.
conversation_id
is
None
,
extras
=
self
.
_application_generate_entity
.
extras
)
def
_handle_chunk
(
self
,
text
:
str
)
->
dict
:
"""
Handle completed event.
:param text: text
:return:
"""
response
=
{
'event'
:
'message'
,
'id'
:
self
.
_message
.
id
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'message_id'
:
self
.
_message
.
id
,
'conversation_id'
:
self
.
_conversation
.
id
,
'answer'
:
text
,
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
}
return
response
def
_handle_error
(
self
,
event
:
QueueErrorEvent
)
->
Exception
:
"""
Handle error event.
:param event: event
:return:
"""
logger
.
debug
(
"error:
%
s"
,
event
.
error
)
e
=
event
.
error
if
isinstance
(
e
,
InvokeAuthorizationError
):
return
InvokeAuthorizationError
(
'Incorrect API key provided'
)
elif
isinstance
(
e
,
InvokeError
)
or
isinstance
(
e
,
ValueError
):
return
e
else
:
return
Exception
(
e
.
description
if
getattr
(
e
,
'description'
,
None
)
is
not
None
else
str
(
e
))
def
_error_to_stream_response_data
(
self
,
e
:
Exception
)
->
dict
:
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses
=
{
ValueError
:
{
'code'
:
'invalid_param'
,
'status'
:
400
},
ProviderTokenNotInitError
:
{
'code'
:
'provider_not_initialize'
,
'status'
:
400
},
QuotaExceededError
:
{
'code'
:
'provider_quota_exceeded'
,
'message'
:
"Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
,
'status'
:
400
},
ModelCurrentlyNotSupportError
:
{
'code'
:
'model_currently_not_support'
,
'status'
:
400
},
InvokeError
:
{
'code'
:
'completion_request_error'
,
'status'
:
400
}
}
# Determine the response based on the type of exception
data
=
None
for
k
,
v
in
error_responses
.
items
():
if
isinstance
(
e
,
k
):
data
=
v
if
data
:
data
.
setdefault
(
'message'
,
getattr
(
e
,
'description'
,
str
(
e
)))
else
:
logging
.
error
(
e
)
data
=
{
'code'
:
'internal_server_error'
,
'message'
:
'Internal Server Error, please contact support.'
,
'status'
:
500
}
return
{
'event'
:
'error'
,
'task_id'
:
self
.
_application_generate_entity
.
task_id
,
'message_id'
:
self
.
_message
.
id
,
**
data
}
def
_get_response_metadata
(
self
)
->
dict
:
"""
Get response metadata by invoke from.
:return:
"""
metadata
=
{}
# show_retrieve_source
if
'retriever_resources'
in
self
.
_task_state
.
metadata
:
if
self
.
_application_generate_entity
.
invoke_from
in
[
InvokeFrom
.
DEBUGGER
,
InvokeFrom
.
SERVICE_API
]:
metadata
[
'retriever_resources'
]
=
self
.
_task_state
.
metadata
[
'retriever_resources'
]
else
:
metadata
[
'retriever_resources'
]
=
[]
for
resource
in
self
.
_task_state
.
metadata
[
'retriever_resources'
]:
metadata
[
'retriever_resources'
]
.
append
({
'segment_id'
:
resource
[
'segment_id'
],
'position'
:
resource
[
'position'
],
'document_name'
:
resource
[
'document_name'
],
'score'
:
resource
[
'score'
],
'content'
:
resource
[
'content'
],
})
# show annotation reply
if
'annotation_reply'
in
self
.
_task_state
.
metadata
:
if
self
.
_application_generate_entity
.
invoke_from
in
[
InvokeFrom
.
DEBUGGER
,
InvokeFrom
.
SERVICE_API
]:
metadata
[
'annotation_reply'
]
=
self
.
_task_state
.
metadata
[
'annotation_reply'
]
# show usage
if
self
.
_application_generate_entity
.
invoke_from
in
[
InvokeFrom
.
DEBUGGER
,
InvokeFrom
.
SERVICE_API
]:
metadata
[
'usage'
]
=
self
.
_task_state
.
metadata
[
'usage'
]
return
metadata
def
_yield_response
(
self
,
response
:
dict
)
->
str
:
"""
Yield response.
:param response: response
:return:
"""
return
"data: "
+
json
.
dumps
(
response
)
+
"
\n\n
"
def
_init_output_moderation
(
self
)
->
Optional
[
OutputModeration
]:
"""
Init output moderation.
:return:
"""
app_config
=
self
.
_application_generate_entity
.
app_config
sensitive_word_avoidance
=
app_config
.
sensitive_word_avoidance
if
sensitive_word_avoidance
:
return
OutputModeration
(
tenant_id
=
app_config
.
tenant_id
,
app_id
=
app_config
.
app_id
,
rule
=
ModerationRule
(
type
=
sensitive_word_avoidance
.
type
,
config
=
sensitive_word_avoidance
.
config
),
on_message_replace_func
=
self
.
_queue_manager
.
publish_message_replace
)
api/core/app/generate_task_pipeline.py
→
api/core/app/
apps/easy_ui_based_
generate_task_pipeline.py
View file @
ab933a25
...
@@ -14,12 +14,12 @@ from core.app.entities.app_invoke_entities import (
...
@@ -14,12 +14,12 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom
,
InvokeFrom
,
)
)
from
core.app.entities.queue_entities
import
(
from
core.app.entities.queue_entities
import
(
AnnotationReplyEvent
,
QueueAgentMessageEvent
,
QueueAgentMessageEvent
,
QueueAgentThoughtEvent
,
QueueAgentThoughtEvent
,
QueueAnnotationReplyEvent
,
QueueErrorEvent
,
QueueErrorEvent
,
QueueLLMChunkEvent
,
QueueMessageEndEvent
,
QueueMessageEndEvent
,
QueueMessageEvent
,
QueueMessageFileEvent
,
QueueMessageFileEvent
,
QueueMessageReplaceEvent
,
QueueMessageReplaceEvent
,
QueuePingEvent
,
QueuePingEvent
,
...
@@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
...
@@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.model_runtime.utils.encoders
import
jsonable_encoder
from
core.moderation.output_moderation
import
ModerationRule
,
OutputModeration
from
core.moderation.output_moderation
import
ModerationRule
,
OutputModeration
from
core.prompt.simple_prompt_transform
import
ModelMode
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
core.prompt.utils.prompt_template_parser
import
PromptTemplateParser
from
core.tools.tool_file_manager
import
ToolFileManager
from
core.tools.tool_file_manager
import
ToolFileManager
from
events.message_event
import
message_was_created
from
events.message_event
import
message_was_created
...
@@ -58,9 +59,9 @@ class TaskState(BaseModel):
...
@@ -58,9 +59,9 @@ class TaskState(BaseModel):
metadata
:
dict
=
{}
metadata
:
dict
=
{}
class
GenerateTaskPipeline
:
class
EasyUIBased
GenerateTaskPipeline
:
"""
"""
GenerateTaskPipeline is a class that generate stream output and state management for Application.
EasyUIBased
GenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
"""
def
__init__
(
self
,
application_generate_entity
:
Union
[
def
__init__
(
self
,
application_generate_entity
:
Union
[
...
@@ -79,12 +80,13 @@ class GenerateTaskPipeline:
...
@@ -79,12 +80,13 @@ class GenerateTaskPipeline:
:param message: message
:param message: message
"""
"""
self
.
_application_generate_entity
=
application_generate_entity
self
.
_application_generate_entity
=
application_generate_entity
self
.
_model_config
=
application_generate_entity
.
model_config
self
.
_queue_manager
=
queue_manager
self
.
_queue_manager
=
queue_manager
self
.
_conversation
=
conversation
self
.
_conversation
=
conversation
self
.
_message
=
message
self
.
_message
=
message
self
.
_task_state
=
TaskState
(
self
.
_task_state
=
TaskState
(
llm_result
=
LLMResult
(
llm_result
=
LLMResult
(
model
=
self
.
_
application_generate_entity
.
model_config
.
model
,
model
=
self
.
_model_config
.
model
,
prompt_messages
=
[],
prompt_messages
=
[],
message
=
AssistantPromptMessage
(
content
=
""
),
message
=
AssistantPromptMessage
(
content
=
""
),
usage
=
LLMUsage
.
empty_usage
()
usage
=
LLMUsage
.
empty_usage
()
...
@@ -115,7 +117,7 @@ class GenerateTaskPipeline:
...
@@ -115,7 +117,7 @@ class GenerateTaskPipeline:
raise
self
.
_handle_error
(
event
)
raise
self
.
_handle_error
(
event
)
elif
isinstance
(
event
,
QueueRetrieverResourcesEvent
):
elif
isinstance
(
event
,
QueueRetrieverResourcesEvent
):
self
.
_task_state
.
metadata
[
'retriever_resources'
]
=
event
.
retriever_resources
self
.
_task_state
.
metadata
[
'retriever_resources'
]
=
event
.
retriever_resources
elif
isinstance
(
event
,
AnnotationReplyEvent
):
elif
isinstance
(
event
,
Queue
AnnotationReplyEvent
):
annotation
=
AppAnnotationService
.
get_annotation_by_id
(
event
.
message_annotation_id
)
annotation
=
AppAnnotationService
.
get_annotation_by_id
(
event
.
message_annotation_id
)
if
annotation
:
if
annotation
:
account
=
annotation
.
account
account
=
annotation
.
account
...
@@ -132,7 +134,7 @@ class GenerateTaskPipeline:
...
@@ -132,7 +134,7 @@ class GenerateTaskPipeline:
if
isinstance
(
event
,
QueueMessageEndEvent
):
if
isinstance
(
event
,
QueueMessageEndEvent
):
self
.
_task_state
.
llm_result
=
event
.
llm_result
self
.
_task_state
.
llm_result
=
event
.
llm_result
else
:
else
:
model_config
=
self
.
_
application_generate_entity
.
model_config
model_config
=
self
.
_model_config
model
=
model_config
.
model
model
=
model_config
.
model
model_type_instance
=
model_config
.
provider_model_bundle
.
model_type_instance
model_type_instance
=
model_config
.
provider_model_bundle
.
model_type_instance
model_type_instance
=
cast
(
LargeLanguageModel
,
model_type_instance
)
model_type_instance
=
cast
(
LargeLanguageModel
,
model_type_instance
)
...
@@ -189,7 +191,7 @@ class GenerateTaskPipeline:
...
@@ -189,7 +191,7 @@ class GenerateTaskPipeline:
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
}
}
if
self
.
_conversation
.
mode
==
'chat'
:
if
self
.
_conversation
.
mode
!=
AppMode
.
COMPLETION
.
value
:
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
if
self
.
_task_state
.
metadata
:
if
self
.
_task_state
.
metadata
:
...
@@ -215,7 +217,7 @@ class GenerateTaskPipeline:
...
@@ -215,7 +217,7 @@ class GenerateTaskPipeline:
if
isinstance
(
event
,
QueueMessageEndEvent
):
if
isinstance
(
event
,
QueueMessageEndEvent
):
self
.
_task_state
.
llm_result
=
event
.
llm_result
self
.
_task_state
.
llm_result
=
event
.
llm_result
else
:
else
:
model_config
=
self
.
_
application_generate_entity
.
model_config
model_config
=
self
.
_model_config
model
=
model_config
.
model
model
=
model_config
.
model
model_type_instance
=
model_config
.
provider_model_bundle
.
model_type_instance
model_type_instance
=
model_config
.
provider_model_bundle
.
model_type_instance
model_type_instance
=
cast
(
LargeLanguageModel
,
model_type_instance
)
model_type_instance
=
cast
(
LargeLanguageModel
,
model_type_instance
)
...
@@ -268,7 +270,7 @@ class GenerateTaskPipeline:
...
@@ -268,7 +270,7 @@ class GenerateTaskPipeline:
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
}
}
if
self
.
_conversation
.
mode
==
'chat'
:
if
self
.
_conversation
.
mode
!=
AppMode
.
COMPLETION
.
value
:
replace_response
[
'conversation_id'
]
=
self
.
_conversation
.
id
replace_response
[
'conversation_id'
]
=
self
.
_conversation
.
id
yield
self
.
_yield_response
(
replace_response
)
yield
self
.
_yield_response
(
replace_response
)
...
@@ -283,7 +285,7 @@ class GenerateTaskPipeline:
...
@@ -283,7 +285,7 @@ class GenerateTaskPipeline:
'message_id'
:
self
.
_message
.
id
,
'message_id'
:
self
.
_message
.
id
,
}
}
if
self
.
_conversation
.
mode
==
'chat'
:
if
self
.
_conversation
.
mode
!=
AppMode
.
COMPLETION
.
value
:
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
if
self
.
_task_state
.
metadata
:
if
self
.
_task_state
.
metadata
:
...
@@ -292,7 +294,7 @@ class GenerateTaskPipeline:
...
@@ -292,7 +294,7 @@ class GenerateTaskPipeline:
yield
self
.
_yield_response
(
response
)
yield
self
.
_yield_response
(
response
)
elif
isinstance
(
event
,
QueueRetrieverResourcesEvent
):
elif
isinstance
(
event
,
QueueRetrieverResourcesEvent
):
self
.
_task_state
.
metadata
[
'retriever_resources'
]
=
event
.
retriever_resources
self
.
_task_state
.
metadata
[
'retriever_resources'
]
=
event
.
retriever_resources
elif
isinstance
(
event
,
AnnotationReplyEvent
):
elif
isinstance
(
event
,
Queue
AnnotationReplyEvent
):
annotation
=
AppAnnotationService
.
get_annotation_by_id
(
event
.
message_annotation_id
)
annotation
=
AppAnnotationService
.
get_annotation_by_id
(
event
.
message_annotation_id
)
if
annotation
:
if
annotation
:
account
=
annotation
.
account
account
=
annotation
.
account
...
@@ -329,7 +331,7 @@ class GenerateTaskPipeline:
...
@@ -329,7 +331,7 @@ class GenerateTaskPipeline:
'message_files'
:
agent_thought
.
files
'message_files'
:
agent_thought
.
files
}
}
if
self
.
_conversation
.
mode
==
'chat'
:
if
self
.
_conversation
.
mode
!=
AppMode
.
COMPLETION
.
value
:
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
yield
self
.
_yield_response
(
response
)
yield
self
.
_yield_response
(
response
)
...
@@ -358,12 +360,12 @@ class GenerateTaskPipeline:
...
@@ -358,12 +360,12 @@ class GenerateTaskPipeline:
'url'
:
url
'url'
:
url
}
}
if
self
.
_conversation
.
mode
==
'chat'
:
if
self
.
_conversation
.
mode
!=
AppMode
.
COMPLETION
.
value
:
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
yield
self
.
_yield_response
(
response
)
yield
self
.
_yield_response
(
response
)
elif
isinstance
(
event
,
Queue
Message
Event
|
QueueAgentMessageEvent
):
elif
isinstance
(
event
,
Queue
LLMChunk
Event
|
QueueAgentMessageEvent
):
chunk
=
event
.
chunk
chunk
=
event
.
chunk
delta_text
=
chunk
.
delta
.
message
.
content
delta_text
=
chunk
.
delta
.
message
.
content
if
delta_text
is
None
:
if
delta_text
is
None
:
...
@@ -376,7 +378,7 @@ class GenerateTaskPipeline:
...
@@ -376,7 +378,7 @@ class GenerateTaskPipeline:
if
self
.
_output_moderation_handler
.
should_direct_output
():
if
self
.
_output_moderation_handler
.
should_direct_output
():
# stop subscribe new token when output moderation should direct output
# stop subscribe new token when output moderation should direct output
self
.
_task_state
.
llm_result
.
message
.
content
=
self
.
_output_moderation_handler
.
get_final_output
()
self
.
_task_state
.
llm_result
.
message
.
content
=
self
.
_output_moderation_handler
.
get_final_output
()
self
.
_queue_manager
.
publish_
chunk_message
(
LLMResultChunk
(
self
.
_queue_manager
.
publish_
llm_chunk
(
LLMResultChunk
(
model
=
self
.
_task_state
.
llm_result
.
model
,
model
=
self
.
_task_state
.
llm_result
.
model
,
prompt_messages
=
self
.
_task_state
.
llm_result
.
prompt_messages
,
prompt_messages
=
self
.
_task_state
.
llm_result
.
prompt_messages
,
delta
=
LLMResultChunkDelta
(
delta
=
LLMResultChunkDelta
(
...
@@ -404,7 +406,7 @@ class GenerateTaskPipeline:
...
@@ -404,7 +406,7 @@ class GenerateTaskPipeline:
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
}
}
if
self
.
_conversation
.
mode
==
'chat'
:
if
self
.
_conversation
.
mode
!=
AppMode
.
COMPLETION
.
value
:
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
yield
self
.
_yield_response
(
response
)
yield
self
.
_yield_response
(
response
)
...
@@ -444,8 +446,7 @@ class GenerateTaskPipeline:
...
@@ -444,8 +446,7 @@ class GenerateTaskPipeline:
conversation
=
self
.
_conversation
,
conversation
=
self
.
_conversation
,
is_first_message
=
self
.
_application_generate_entity
.
app_config
.
app_mode
in
[
is_first_message
=
self
.
_application_generate_entity
.
app_config
.
app_mode
in
[
AppMode
.
AGENT_CHAT
,
AppMode
.
AGENT_CHAT
,
AppMode
.
CHAT
,
AppMode
.
CHAT
AppMode
.
ADVANCED_CHAT
]
and
self
.
_application_generate_entity
.
conversation_id
is
None
,
]
and
self
.
_application_generate_entity
.
conversation_id
is
None
,
extras
=
self
.
_application_generate_entity
.
extras
extras
=
self
.
_application_generate_entity
.
extras
)
)
...
@@ -465,7 +466,7 @@ class GenerateTaskPipeline:
...
@@ -465,7 +466,7 @@ class GenerateTaskPipeline:
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
'created_at'
:
int
(
self
.
_message
.
created_at
.
timestamp
())
}
}
if
self
.
_conversation
.
mode
==
'chat'
:
if
self
.
_conversation
.
mode
!=
AppMode
.
COMPLETION
.
value
:
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
response
[
'conversation_id'
]
=
self
.
_conversation
.
id
return
response
return
response
...
@@ -575,7 +576,7 @@ class GenerateTaskPipeline:
...
@@ -575,7 +576,7 @@ class GenerateTaskPipeline:
:return:
:return:
"""
"""
prompts
=
[]
prompts
=
[]
if
self
.
_
application_generate_entity
.
model_config
.
mode
==
'chat'
:
if
self
.
_
model_config
.
mode
==
ModelMode
.
CHAT
.
value
:
for
prompt_message
in
prompt_messages
:
for
prompt_message
in
prompt_messages
:
if
prompt_message
.
role
==
PromptMessageRole
.
USER
:
if
prompt_message
.
role
==
PromptMessageRole
.
USER
:
role
=
'user'
role
=
'user'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment