Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
edb86f5f
Unverified
Commit
edb86f5f
authored
Feb 21, 2024
by
Yeuoly
Committed by
GitHub
Feb 21, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Feat/stream react (#2498)
parent
adf2651d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
159 additions
and
150 deletions
+159
-150
assistant_cot_runner.py
api/core/features/assistant_cot_runner.py
+159
-150
No files found.
api/core/features/assistant_cot_runner.py
View file @
edb86f5f
...
@@ -133,61 +133,95 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -133,61 +133,95 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# recale llm max tokens
# recale llm max tokens
self
.
recale_llm_max_tokens
(
self
.
model_config
,
prompt_messages
)
self
.
recale_llm_max_tokens
(
self
.
model_config
,
prompt_messages
)
# invoke model
# invoke model
llm_result
:
LLMResult
=
model_instance
.
invoke_llm
(
chunks
:
Generator
[
LLMResultChunk
,
None
,
None
]
=
model_instance
.
invoke_llm
(
prompt_messages
=
prompt_messages
,
prompt_messages
=
prompt_messages
,
model_parameters
=
app_orchestration_config
.
model_config
.
parameters
,
model_parameters
=
app_orchestration_config
.
model_config
.
parameters
,
tools
=
[],
tools
=
[],
stop
=
app_orchestration_config
.
model_config
.
stop
,
stop
=
app_orchestration_config
.
model_config
.
stop
,
stream
=
Fals
e
,
stream
=
Tru
e
,
user
=
self
.
user_id
,
user
=
self
.
user_id
,
callbacks
=
[],
callbacks
=
[],
)
)
# check llm result
# check llm result
if
not
llm_result
:
if
not
chunks
:
raise
ValueError
(
"failed to invoke llm"
)
raise
ValueError
(
"failed to invoke llm"
)
# get scratchpad
usage_dict
=
{}
scratchpad
=
self
.
_extract_response_scratchpad
(
llm_result
.
message
.
content
)
react_chunks
=
self
.
_handle_stream_react
(
chunks
,
usage_dict
)
agent_scratchpad
.
append
(
scratchpad
)
scratchpad
=
AgentScratchpadUnit
(
agent_response
=
''
,
# get llm usage
thought
=
''
,
if
llm_result
.
usage
:
action_str
=
''
,
increase_usage
(
llm_usage
,
llm_result
.
usage
)
observation
=
''
,
action
=
None
)
# publish agent thought if it's first iteration
# publish agent thought if it's first iteration
if
iteration_step
==
1
:
if
iteration_step
==
1
:
self
.
queue_manager
.
publish_agent_thought
(
agent_thought
,
PublishFrom
.
APPLICATION_MANAGER
)
self
.
queue_manager
.
publish_agent_thought
(
agent_thought
,
PublishFrom
.
APPLICATION_MANAGER
)
for
chunk
in
react_chunks
:
if
isinstance
(
chunk
,
dict
):
scratchpad
.
agent_response
+=
json
.
dumps
(
chunk
)
try
:
if
scratchpad
.
action
:
raise
Exception
(
""
)
scratchpad
.
action_str
=
json
.
dumps
(
chunk
)
scratchpad
.
action
=
AgentScratchpadUnit
.
Action
(
action_name
=
chunk
[
'action'
],
action_input
=
chunk
[
'action_input'
]
)
except
:
scratchpad
.
thought
+=
json
.
dumps
(
chunk
)
yield
LLMResultChunk
(
model
=
self
.
model_config
.
model
,
prompt_messages
=
prompt_messages
,
system_fingerprint
=
''
,
delta
=
LLMResultChunkDelta
(
index
=
0
,
message
=
AssistantPromptMessage
(
content
=
json
.
dumps
(
chunk
)
),
usage
=
None
)
)
else
:
scratchpad
.
agent_response
+=
chunk
scratchpad
.
thought
+=
chunk
yield
LLMResultChunk
(
model
=
self
.
model_config
.
model
,
prompt_messages
=
prompt_messages
,
system_fingerprint
=
''
,
delta
=
LLMResultChunkDelta
(
index
=
0
,
message
=
AssistantPromptMessage
(
content
=
chunk
),
usage
=
None
)
)
agent_scratchpad
.
append
(
scratchpad
)
# get llm usage
if
'usage'
in
usage_dict
:
increase_usage
(
llm_usage
,
usage_dict
[
'usage'
])
else
:
usage_dict
[
'usage'
]
=
LLMUsage
.
empty_usage
()
self
.
save_agent_thought
(
agent_thought
=
agent_thought
,
self
.
save_agent_thought
(
agent_thought
=
agent_thought
,
tool_name
=
scratchpad
.
action
.
action_name
if
scratchpad
.
action
else
''
,
tool_name
=
scratchpad
.
action
.
action_name
if
scratchpad
.
action
else
''
,
tool_input
=
scratchpad
.
action
.
action_input
if
scratchpad
.
action
else
''
,
tool_input
=
scratchpad
.
action
.
action_input
if
scratchpad
.
action
else
''
,
thought
=
scratchpad
.
thought
,
thought
=
scratchpad
.
thought
,
observation
=
''
,
observation
=
''
,
answer
=
llm_result
.
message
.
content
,
answer
=
scratchpad
.
agent_response
,
messages_ids
=
[],
messages_ids
=
[],
llm_usage
=
llm_result
.
usage
)
llm_usage
=
usage_dict
[
'usage'
]
)
if
scratchpad
.
action
and
scratchpad
.
action
.
action_name
.
lower
()
!=
"final answer"
:
if
scratchpad
.
action
and
scratchpad
.
action
.
action_name
.
lower
()
!=
"final answer"
:
self
.
queue_manager
.
publish_agent_thought
(
agent_thought
,
PublishFrom
.
APPLICATION_MANAGER
)
self
.
queue_manager
.
publish_agent_thought
(
agent_thought
,
PublishFrom
.
APPLICATION_MANAGER
)
# publish agent thought if it's not empty and there is a action
if
scratchpad
.
thought
and
scratchpad
.
action
:
# check if final answer
if
not
scratchpad
.
action
.
action_name
.
lower
()
==
"final answer"
:
yield
LLMResultChunk
(
model
=
model_instance
.
model
,
prompt_messages
=
prompt_messages
,
delta
=
LLMResultChunkDelta
(
index
=
0
,
message
=
AssistantPromptMessage
(
content
=
scratchpad
.
thought
),
usage
=
llm_result
.
usage
,
),
system_fingerprint
=
''
)
if
not
scratchpad
.
action
:
if
not
scratchpad
.
action
:
# failed to extract action, return final answer directly
# failed to extract action, return final answer directly
final_answer
=
scratchpad
.
agent_response
or
''
final_answer
=
scratchpad
.
agent_response
or
''
...
@@ -262,7 +296,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -262,7 +296,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# save scratchpad
# save scratchpad
scratchpad
.
observation
=
observation
scratchpad
.
observation
=
observation
scratchpad
.
agent_response
=
llm_result
.
message
.
content
# save agent thought
# save agent thought
self
.
save_agent_thought
(
self
.
save_agent_thought
(
...
@@ -271,7 +304,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -271,7 +304,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
tool_input
=
tool_call_args
,
tool_input
=
tool_call_args
,
thought
=
None
,
thought
=
None
,
observation
=
observation
,
observation
=
observation
,
answer
=
llm_result
.
message
.
content
,
answer
=
scratchpad
.
agent_response
,
messages_ids
=
message_file_ids
,
messages_ids
=
message_file_ids
,
)
)
self
.
queue_manager
.
publish_agent_thought
(
agent_thought
,
PublishFrom
.
APPLICATION_MANAGER
)
self
.
queue_manager
.
publish_agent_thought
(
agent_thought
,
PublishFrom
.
APPLICATION_MANAGER
)
...
@@ -318,6 +351,97 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -318,6 +351,97 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
system_fingerprint
=
''
system_fingerprint
=
''
),
PublishFrom
.
APPLICATION_MANAGER
)
),
PublishFrom
.
APPLICATION_MANAGER
)
def
_handle_stream_react
(
self
,
llm_response
:
Generator
[
LLMResultChunk
,
None
,
None
],
usage
:
dict
)
\
->
Generator
[
Union
[
str
,
dict
],
None
,
None
]:
def
parse_json
(
json_str
):
try
:
return
json
.
loads
(
json_str
.
strip
())
except
:
return
json_str
def
extra_json_from_code_block
(
code_block
)
->
Generator
[
Union
[
dict
,
str
],
None
,
None
]:
code_blocks
=
re
.
findall
(
r'```(.*?)```'
,
code_block
,
re
.
DOTALL
)
if
not
code_blocks
:
return
for
block
in
code_blocks
:
json_text
=
re
.
sub
(
r'^[a-zA-Z]+\n'
,
''
,
block
.
strip
(),
flags
=
re
.
MULTILINE
)
yield
parse_json
(
json_text
)
code_block_cache
=
''
code_block_delimiter_count
=
0
in_code_block
=
False
json_cache
=
''
json_quote_count
=
0
in_json
=
False
got_json
=
False
for
response
in
llm_response
:
response
=
response
.
delta
.
message
.
content
if
not
isinstance
(
response
,
str
):
continue
# stream
index
=
0
while
index
<
len
(
response
):
steps
=
1
delta
=
response
[
index
:
index
+
steps
]
if
delta
==
'`'
:
code_block_cache
+=
delta
code_block_delimiter_count
+=
1
else
:
if
not
in_code_block
:
if
code_block_delimiter_count
>
0
:
yield
code_block_cache
code_block_cache
=
''
else
:
code_block_cache
+=
delta
code_block_delimiter_count
=
0
if
code_block_delimiter_count
==
3
:
if
in_code_block
:
yield
from
extra_json_from_code_block
(
code_block_cache
)
code_block_cache
=
''
in_code_block
=
not
in_code_block
code_block_delimiter_count
=
0
if
not
in_code_block
:
# handle single json
if
delta
==
'{'
:
json_quote_count
+=
1
in_json
=
True
json_cache
+=
delta
elif
delta
==
'}'
:
json_cache
+=
delta
if
json_quote_count
>
0
:
json_quote_count
-=
1
if
json_quote_count
==
0
:
in_json
=
False
got_json
=
True
index
+=
steps
continue
else
:
if
in_json
:
json_cache
+=
delta
if
got_json
:
got_json
=
False
yield
parse_json
(
json_cache
)
json_cache
=
''
json_quote_count
=
0
in_json
=
False
if
not
in_code_block
and
not
in_json
:
yield
delta
.
replace
(
'`'
,
''
)
index
+=
steps
if
code_block_cache
:
yield
code_block_cache
if
json_cache
:
yield
parse_json
(
json_cache
)
def
_fill_in_inputs_from_external_data_tools
(
self
,
instruction
:
str
,
inputs
:
dict
)
->
str
:
def
_fill_in_inputs_from_external_data_tools
(
self
,
instruction
:
str
,
inputs
:
dict
)
->
str
:
"""
"""
fill in inputs from external data tools
fill in inputs from external data tools
...
@@ -363,121 +487,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -363,121 +487,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
return
agent_scratchpad
return
agent_scratchpad
def
_extract_response_scratchpad
(
self
,
content
:
str
)
->
AgentScratchpadUnit
:
"""
extract response from llm response
"""
def
extra_quotes
()
->
AgentScratchpadUnit
:
agent_response
=
content
# try to extract all quotes
pattern
=
re
.
compile
(
r'```(.*?)```'
,
re
.
DOTALL
)
quotes
=
pattern
.
findall
(
content
)
# try to extract action from end to start
for
i
in
range
(
len
(
quotes
)
-
1
,
0
,
-
1
):
"""
1. use json load to parse action
2. use plain text `Action: xxx` to parse action
"""
try
:
action
=
json
.
loads
(
quotes
[
i
]
.
replace
(
'```'
,
''
))
action_name
=
action
.
get
(
"action"
)
action_input
=
action
.
get
(
"action_input"
)
agent_thought
=
agent_response
.
replace
(
quotes
[
i
],
''
)
if
action_name
and
action_input
:
return
AgentScratchpadUnit
(
agent_response
=
content
,
thought
=
agent_thought
,
action_str
=
quotes
[
i
],
action
=
AgentScratchpadUnit
.
Action
(
action_name
=
action_name
,
action_input
=
action_input
,
)
)
except
:
# try to parse action from plain text
action_name
=
re
.
findall
(
r'action: (.*)'
,
quotes
[
i
],
re
.
IGNORECASE
)
action_input
=
re
.
findall
(
r'action input: (.*)'
,
quotes
[
i
],
re
.
IGNORECASE
)
# delete action from agent response
agent_thought
=
agent_response
.
replace
(
quotes
[
i
],
''
)
# remove extra quotes
agent_thought
=
re
.
sub
(
r'```(json)*\n*```'
,
''
,
agent_thought
,
flags
=
re
.
DOTALL
)
# remove Action: xxx from agent thought
agent_thought
=
re
.
sub
(
r'Action:.*'
,
''
,
agent_thought
,
flags
=
re
.
IGNORECASE
)
if
action_name
and
action_input
:
return
AgentScratchpadUnit
(
agent_response
=
content
,
thought
=
agent_thought
,
action_str
=
quotes
[
i
],
action
=
AgentScratchpadUnit
.
Action
(
action_name
=
action_name
[
0
],
action_input
=
action_input
[
0
],
)
)
def
extra_json
():
agent_response
=
content
# try to extract all json
structures
,
pair_match_stack
=
[],
[]
started_at
,
end_at
=
0
,
0
for
i
in
range
(
len
(
content
)):
if
content
[
i
]
==
'{'
:
pair_match_stack
.
append
(
i
)
if
len
(
pair_match_stack
)
==
1
:
started_at
=
i
elif
content
[
i
]
==
'}'
:
begin
=
pair_match_stack
.
pop
()
if
not
pair_match_stack
:
end_at
=
i
+
1
structures
.
append
((
content
[
begin
:
i
+
1
],
(
started_at
,
end_at
)))
# handle the last character
if
pair_match_stack
:
end_at
=
len
(
content
)
structures
.
append
((
content
[
pair_match_stack
[
0
]:],
(
started_at
,
end_at
)))
for
i
in
range
(
len
(
structures
),
0
,
-
1
):
try
:
json_content
,
(
started_at
,
end_at
)
=
structures
[
i
-
1
]
action
=
json
.
loads
(
json_content
)
action_name
=
action
.
get
(
"action"
)
action_input
=
action
.
get
(
"action_input"
)
# delete json content from agent response
agent_thought
=
agent_response
[:
started_at
]
+
agent_response
[
end_at
:]
# remove extra quotes like ```(json)*\n\n```
agent_thought
=
re
.
sub
(
r'```(json)*\n*```'
,
''
,
agent_thought
,
flags
=
re
.
DOTALL
)
# remove Action: xxx from agent thought
agent_thought
=
re
.
sub
(
r'Action:.*'
,
''
,
agent_thought
,
flags
=
re
.
IGNORECASE
)
if
action_name
and
action_input
is
not
None
:
return
AgentScratchpadUnit
(
agent_response
=
content
,
thought
=
agent_thought
,
action_str
=
json_content
,
action
=
AgentScratchpadUnit
.
Action
(
action_name
=
action_name
,
action_input
=
action_input
,
)
)
except
:
pass
agent_scratchpad
=
extra_quotes
()
if
agent_scratchpad
:
return
agent_scratchpad
agent_scratchpad
=
extra_json
()
if
agent_scratchpad
:
return
agent_scratchpad
return
AgentScratchpadUnit
(
agent_response
=
content
,
thought
=
content
,
action_str
=
''
,
action
=
None
)
def
_check_cot_prompt_messages
(
self
,
mode
:
Literal
[
"completion"
,
"chat"
],
def
_check_cot_prompt_messages
(
self
,
mode
:
Literal
[
"completion"
,
"chat"
],
agent_prompt_message
:
AgentPromptEntity
,
agent_prompt_message
:
AgentPromptEntity
,
):
):
...
@@ -591,15 +600,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
...
@@ -591,15 +600,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# organize prompt messages
# organize prompt messages
if
mode
==
"chat"
:
if
mode
==
"chat"
:
# override system message
# override system message
overrid
ed
=
False
overrid
den
=
False
prompt_messages
=
prompt_messages
.
copy
()
prompt_messages
=
prompt_messages
.
copy
()
for
prompt_message
in
prompt_messages
:
for
prompt_message
in
prompt_messages
:
if
isinstance
(
prompt_message
,
SystemPromptMessage
):
if
isinstance
(
prompt_message
,
SystemPromptMessage
):
prompt_message
.
content
=
system_message
prompt_message
.
content
=
system_message
overrid
ed
=
True
overrid
den
=
True
break
break
if
not
overrid
ed
:
if
not
overrid
den
:
prompt_messages
.
insert
(
0
,
SystemPromptMessage
(
prompt_messages
.
insert
(
0
,
SystemPromptMessage
(
content
=
system_message
,
content
=
system_message
,
))
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment