Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
464a3615
Commit
464a3615
authored
Jul 10, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: router agent instead of router chain
parent
937061cf
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
159 additions
and
456 deletions
+159
-456
multi_dataset_router_agent.py
api/core/agent/agent/multi_dataset_router_agent.py
+83
-0
agent_executor.py
api/core/agent/agent_executor.py
+11
-1
dataset_tool_callback_handler.py
api/core/callback_handler/dataset_tool_callback_handler.py
+4
-2
llm_router_chain.py
api/core/chain/llm_router_chain.py
+0
-111
main_chain_builder.py
api/core/chain/main_chain_builder.py
+0
-68
multi_dataset_router_chain.py
api/core/chain/multi_dataset_router_chain.py
+0
-198
completion.py
api/core/completion.py
+4
-12
orchestrator_rule_parser.py
api/core/orchestrator_rule_parser.py
+47
-56
dataset_retriever_tool.py
api/core/tool/dataset_retriever_tool.py
+10
-7
web_reader_tool.py
api/core/tool/web_reader_tool.py
+0
-1
No files found.
api/core/agent/agent/multi_dataset_router_agent.py
0 → 100644
View file @
464a3615
from
typing
import
Tuple
,
List
,
Any
,
Union
,
Sequence
,
Optional
from
langchain.agents
import
OpenAIFunctionsAgent
,
BaseSingleActionAgent
from
langchain.callbacks.base
import
BaseCallbackManager
from
langchain.callbacks.manager
import
Callbacks
from
langchain.prompts.chat
import
BaseMessagePromptTemplate
from
langchain.schema
import
AgentAction
,
AgentFinish
,
BaseLanguageModel
,
SystemMessage
from
langchain.tools
import
BaseTool
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
class
MultiDatasetRouterAgent
(
OpenAIFunctionsAgent
):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
def
should_use_agent
(
self
,
query
:
str
):
"""
return should use agent
:param query:
:return:
"""
return
True
def
plan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if
len
(
self
.
tools
)
==
0
:
return
AgentFinish
(
return_values
=
{
"output"
:
''
},
log
=
''
)
elif
len
(
self
.
tools
)
==
1
:
rst
=
next
(
iter
(
self
.
tools
))
.
run
(
kwargs
[
'input'
])
return
AgentFinish
(
return_values
=
{
"output"
:
rst
},
log
=
rst
)
if
intermediate_steps
:
_
,
observation
=
intermediate_steps
[
-
1
]
return
AgentFinish
(
return_values
=
{
"output"
:
observation
},
log
=
observation
)
return
super
()
.
plan
(
intermediate_steps
,
callbacks
,
**
kwargs
)
async
def
aplan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
raise
NotImplementedError
()
@
classmethod
def
from_llm_and_tools
(
cls
,
llm
:
BaseLanguageModel
,
tools
:
Sequence
[
BaseTool
],
callback_manager
:
Optional
[
BaseCallbackManager
]
=
None
,
extra_prompt_messages
:
Optional
[
List
[
BaseMessagePromptTemplate
]]
=
None
,
system_message
:
Optional
[
SystemMessage
]
=
SystemMessage
(
content
=
"You are a helpful AI assistant."
),
**
kwargs
:
Any
,
)
->
BaseSingleActionAgent
:
tools
=
[
t
for
t
in
tools
if
isinstance
(
t
,
DatasetRetrieverTool
)]
llm
.
model_name
=
'gpt-3.5-turbo'
return
super
()
.
from_llm_and_tools
(
llm
=
llm
,
tools
=
tools
,
callback_manager
=
callback_manager
,
extra_prompt_messages
=
extra_prompt_messages
,
system_message
=
system_message
,
**
kwargs
,
)
api/core/agent/agent_executor.py
View file @
464a3615
...
@@ -4,9 +4,11 @@ from typing import Union, Optional
...
@@ -4,9 +4,11 @@ from typing import Union, Optional
from
langchain.agents
import
BaseSingleActionAgent
,
BaseMultiActionAgent
,
AgentExecutor
from
langchain.agents
import
BaseSingleActionAgent
,
BaseMultiActionAgent
,
AgentExecutor
from
langchain.base_language
import
BaseLanguageModel
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.tools
import
BaseTool
from
langchain.tools
import
BaseTool
from
pydantic
import
BaseModel
,
Extra
from
pydantic
import
BaseModel
,
Extra
from
core.agent.agent.multi_dataset_router_agent
import
MultiDatasetRouterAgent
from
core.agent.agent.openai_function_call
import
AutoSummarizingOpenAIFunctionCallAgent
from
core.agent.agent.openai_function_call
import
AutoSummarizingOpenAIFunctionCallAgent
from
core.agent.agent.openai_multi_function_call
import
AutoSummarizingOpenMultiAIFunctionCallAgent
from
core.agent.agent.openai_multi_function_call
import
AutoSummarizingOpenMultiAIFunctionCallAgent
from
core.agent.agent.structured_chat
import
AutoSummarizingStructuredChatAgent
from
core.agent.agent.structured_chat
import
AutoSummarizingStructuredChatAgent
...
@@ -16,6 +18,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
...
@@ -16,6 +18,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
class
PlanningStrategy
(
str
,
enum
.
Enum
):
class
PlanningStrategy
(
str
,
enum
.
Enum
):
ROUTER
=
'router'
REACT
=
'react'
REACT
=
'react'
FUNCTION_CALL
=
'function_call'
FUNCTION_CALL
=
'function_call'
MULTI_FUNCTION_CALL
=
'multi_function_call'
MULTI_FUNCTION_CALL
=
'multi_function_call'
...
@@ -26,7 +29,7 @@ class AgentConfiguration(BaseModel):
...
@@ -26,7 +29,7 @@ class AgentConfiguration(BaseModel):
llm
:
BaseLanguageModel
llm
:
BaseLanguageModel
tools
:
list
[
BaseTool
]
tools
:
list
[
BaseTool
]
summary_llm
:
BaseLanguageModel
summary_llm
:
BaseLanguageModel
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
=
None
memory
:
Optional
[
BaseChatMemory
]
=
None
callbacks
:
Callbacks
=
None
callbacks
:
Callbacks
=
None
max_iterations
:
int
=
6
max_iterations
:
int
=
6
max_execution_time
:
Optional
[
float
]
=
None
max_execution_time
:
Optional
[
float
]
=
None
...
@@ -69,6 +72,13 @@ class AgentExecutor:
...
@@ -69,6 +72,13 @@ class AgentExecutor:
summary_llm
=
self
.
configuration
.
summary_llm
,
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
verbose
=
True
)
)
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
ROUTER
:
agent
=
MultiDatasetRouterAgent
.
from_llm_and_tools
(
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
if
self
.
configuration
.
memory
else
None
,
verbose
=
True
)
else
:
else
:
raise
NotImplementedError
(
f
"Unknown Agent Strategy: {self.configuration.strategy}"
)
raise
NotImplementedError
(
f
"Unknown Agent Strategy: {self.configuration.strategy}"
)
...
...
api/core/callback_handler/dataset_tool_callback_handler.py
View file @
464a3615
import
json
import
logging
import
logging
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
...
@@ -43,8 +44,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
...
@@ -43,8 +44,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str
:
str
,
input_str
:
str
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
None
:
)
->
None
:
tool_name
=
serialized
.
get
(
'name'
)
# tool_name = serialized.get('name')
dataset_id
=
tool_name
[
len
(
"dataset-"
):]
input_dict
=
json
.
loads
(
input_str
.
replace
(
"'"
,
"
\"
"
))
dataset_id
=
input_dict
.
get
(
'dataset_id'
)
self
.
conversation_message_task
.
on_dataset_query_end
(
DatasetQueryObj
(
dataset_id
=
dataset_id
,
query
=
input_str
))
self
.
conversation_message_task
.
on_dataset_query_end
(
DatasetQueryObj
(
dataset_id
=
dataset_id
,
query
=
input_str
))
def
on_tool_end
(
def
on_tool_end
(
...
...
api/core/chain/llm_router_chain.py
deleted
100644 → 0
View file @
937061cf
"""Base classes for LLM-powered router chains."""
from
__future__
import
annotations
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
pydantic
import
root_validator
from
langchain.chains
import
LLMChain
from
langchain.prompts
import
BasePromptTemplate
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
from
libs.json_in_md_parser
import
parse_and_check_json_markdown
class
Route
(
NamedTuple
):
destination
:
Optional
[
str
]
next_inputs
:
Dict
[
str
,
Any
]
class
LLMRouterChain
(
Chain
):
"""A router chain that uses an LLM chain to perform routing."""
llm_chain
:
LLMChain
"""LLM chain used to perform routing"""
@
root_validator
()
def
validate_prompt
(
cls
,
values
:
dict
)
->
dict
:
prompt
=
values
[
"llm_chain"
]
.
prompt
if
prompt
.
output_parser
is
None
:
raise
ValueError
(
"LLMRouterChain requires base llm_chain prompt to have an output"
" parser that converts LLM text output to a dictionary with keys"
" 'destination' and 'next_inputs'. Received a prompt with no output"
" parser."
)
return
values
@
property
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return
self
.
llm_chain
.
input_keys
def
_validate_outputs
(
self
,
outputs
:
Dict
[
str
,
Any
])
->
None
:
super
()
.
_validate_outputs
(
outputs
)
if
not
isinstance
(
outputs
[
"next_inputs"
],
dict
):
raise
ValueError
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
output
=
cast
(
Dict
[
str
,
Any
],
self
.
llm_chain
.
predict_and_parse
(
**
inputs
),
)
return
output
@
classmethod
def
from_llm
(
cls
,
llm
:
BaseLanguageModel
,
prompt
:
BasePromptTemplate
,
**
kwargs
:
Any
)
->
LLMRouterChain
:
"""Convenience constructor."""
llm_chain
=
LLMChain
(
llm
=
llm
,
prompt
=
prompt
)
return
cls
(
llm_chain
=
llm_chain
,
**
kwargs
)
@
property
def
output_keys
(
self
)
->
List
[
str
]:
return
[
"destination"
,
"next_inputs"
]
def
route
(
self
,
inputs
:
Dict
[
str
,
Any
])
->
Route
:
result
=
self
(
inputs
)
return
Route
(
result
[
"destination"
],
result
[
"next_inputs"
])
class
RouterOutputParser
(
BaseOutputParser
[
Dict
[
str
,
str
]]):
"""Parser for output of router chain int he multi-prompt chain."""
default_destination
:
str
=
"DEFAULT"
next_inputs_type
:
Type
=
str
next_inputs_inner_key
:
str
=
"input"
def
parse
(
self
,
text
:
str
)
->
Dict
[
str
,
Any
]:
try
:
expected_keys
=
[
"destination"
,
"next_inputs"
]
parsed
=
parse_and_check_json_markdown
(
text
,
expected_keys
)
if
not
isinstance
(
parsed
[
"destination"
],
str
):
raise
ValueError
(
"Expected 'destination' to be a string."
)
if
not
isinstance
(
parsed
[
"next_inputs"
],
self
.
next_inputs_type
):
raise
ValueError
(
f
"Expected 'next_inputs' to be {self.next_inputs_type}."
)
parsed
[
"next_inputs"
]
=
{
self
.
next_inputs_inner_key
:
parsed
[
"next_inputs"
]}
if
(
parsed
[
"destination"
]
.
strip
()
.
lower
()
==
self
.
default_destination
.
lower
()
):
parsed
[
"destination"
]
=
None
else
:
parsed
[
"destination"
]
=
parsed
[
"destination"
]
.
strip
()
return
parsed
except
Exception
as
e
:
raise
OutputParserException
(
f
"Parsing text
\n
{text}
\n
of llm router raised following error:
\n
{e}"
)
api/core/chain/main_chain_builder.py
deleted
100644 → 0
View file @
937061cf
# from typing import Optional, List, cast, Tuple
#
# from langchain.chains import SequentialChain
# from langchain.chains.base import Chain
# from langchain.memory.chat_memory import BaseChatMemory
#
# from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
# from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
# from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
# from core.conversation_message_task import ConversationMessageTask
# from core.orchestrator_rule_parser import OrchestratorRuleParser
# from extensions.ext_database import db
# from models.dataset import Dataset
# from models.model import AppModelConfig
#
#
# class MainChainBuilder:
# @classmethod
# def get_chains(cls, tenant_id: str, app_model_config: AppModelConfig, memory: Optional[BaseChatMemory],
# rest_tokens: int, conversation_message_task: ConversationMessageTask):
# first_input_key = "input"
# final_output_key = "output"
#
# chains = []
#
# # init orchestrator rule parser
# orchestrator_rule_parser = OrchestratorRuleParser(
# tenant_id=tenant_id,
# app_model_config=app_model_config
# )
#
# # parse sensitive_word_avoidance_chain
# sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain()
# if sensitive_word_avoidance_chain:
# chains.append(sensitive_word_avoidance_chain)
#
# # parse agent chain
# agent_executor = orchestrator_rule_parser.to_agent_executor(
# conversation_message_task=conversation_message_task,
# memory=memory,
# rest_tokens=rest_tokens,
# callbacks=[DifyStdOutCallbackHandler()]
# )
#
# if agent_executor:
# if isinstance(agent_executor, MultiDatasetRouterChain):
# chains.append(agent_executor)
# final_output_key = agent_executor.output_keys[0]
# chains.append(agent_chain)
# final_output_key = agent_chain.output_keys[0]
#
# if len(chains) == 0:
# return None
#
# chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
# for chain in chains:
# chain = cast(Chain, chain)
# chain.callbacks.append(chain_callback)
#
# # build main chain
# overall_chain = SequentialChain(
# chains=chains,
# input_variables=[first_input_key],
# output_variables=[final_output_key],
# memory=memory, # only for use the memory prompt input key
# )
#
# return overall_chain
api/core/chain/multi_dataset_router_chain.py
deleted
100644 → 0
View file @
937061cf
import
math
import
re
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
langchain
import
PromptTemplate
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
pydantic
import
Extra
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.llm_router_chain
import
LLMRouterChain
,
RouterOutputParser
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.tool.dataset_index_tool
import
DatasetTool
from
models.dataset
import
Dataset
,
DatasetProcessRule
DEFAULT_K
=
2
CONTEXT_TOKENS_PERCENT
=
0.3
MULTI_PROMPT_ROUTER_TEMPLATE
=
"""
Given a raw text input to a language model select the model prompt best suited for
\
the input. You will be given the names of the available prompts and a description of
\
what the prompt is best suited for. You may also revise the original input if you
\
think that revising it will ultimately lead to a better response from the language
\
model.
<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like,
\
no any other string out of markdown code snippet:
```json
{{{{
"destination": string
\\
name of the prompt to use or "DEFAULT"
"next_inputs": string
\\
a potentially modified version of the original input
}}}}
```
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR
\
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any
\
modifications are needed.
<< CANDIDATE PROMPTS >>
{destinations}
<< INPUT >>
{{input}}
<< OUTPUT >>
"""
class
MultiDatasetRouterChain
(
Chain
):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain
:
LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools
:
Mapping
[
str
,
DatasetTool
]
"""Map of name to candidate chains that inputs can be routed to."""
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
arbitrary_types_allowed
=
True
@
property
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the router chain prompt expects.
:meta private:
"""
return
self
.
router_chain
.
input_keys
@
property
def
output_keys
(
self
)
->
List
[
str
]:
return
[
"text"
]
@
classmethod
def
from_datasets
(
cls
,
tenant_id
:
str
,
datasets
:
List
[
Dataset
],
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
,
**
kwargs
:
Any
,
):
"""Convenience constructor for instantiating from destination prompts."""
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
temperature
=
0
,
max_tokens
=
1024
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
destinations
=
[
"[[{}]]: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
else
(
'useful for when you want to answer queries about the '
+
d
.
name
))
for
d
in
datasets
]
destinations_str
=
"
\n
"
.
join
(
destinations
)
router_template
=
MULTI_PROMPT_ROUTER_TEMPLATE
.
format
(
destinations
=
destinations_str
)
router_prompt
=
PromptTemplate
(
template
=
router_template
,
input_variables
=
[
"input"
],
output_parser
=
RouterOutputParser
(),
)
router_chain
=
LLMRouterChain
.
from_llm
(
llm
,
router_prompt
)
dataset_tools
=
{}
for
dataset
in
datasets
:
# fulfill description when it is empty
if
dataset
.
available_document_count
==
0
or
dataset
.
available_document_count
==
0
:
continue
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
k
=
cls
.
_dynamic_calc_retrieve_k
(
dataset
,
rest_tokens
)
if
k
==
0
:
continue
dataset_tool
=
DatasetTool
(
name
=
f
"dataset-{dataset.id}"
,
description
=
description
,
k
=
k
,
dataset
=
dataset
,
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
)
dataset_tools
[
str
(
dataset
.
id
)]
=
dataset_tool
return
cls
(
router_chain
=
router_chain
,
dataset_tools
=
dataset_tools
,
**
kwargs
,
)
@
classmethod
def
_dynamic_calc_retrieve_k
(
cls
,
dataset
:
Dataset
,
rest_tokens
:
int
)
->
int
:
processing_rule
=
dataset
.
latest_process_rule
if
not
processing_rule
:
return
DEFAULT_K
if
processing_rule
.
mode
==
"custom"
:
rules
=
processing_rule
.
rules_dict
if
not
rules
:
return
DEFAULT_K
segmentation
=
rules
[
"segmentation"
]
segment_max_tokens
=
segmentation
[
"max_tokens"
]
else
:
segment_max_tokens
=
DatasetProcessRule
.
AUTOMATIC_RULES
[
'segmentation'
][
'max_tokens'
]
# when rest_tokens is less than default context tokens
if
rest_tokens
<
segment_max_tokens
*
DEFAULT_K
:
return
rest_tokens
//
segment_max_tokens
context_limit_tokens
=
math
.
floor
(
rest_tokens
*
CONTEXT_TOKENS_PERCENT
)
# when context_limit_tokens is less than default context tokens, use default_k
if
context_limit_tokens
<=
segment_max_tokens
*
DEFAULT_K
:
return
DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return
context_limit_tokens
//
segment_max_tokens
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
if
len
(
self
.
dataset_tools
)
==
0
:
return
{
"text"
:
''
}
elif
len
(
self
.
dataset_tools
)
==
1
:
return
{
"text"
:
next
(
iter
(
self
.
dataset_tools
.
values
()))
.
run
(
inputs
[
'input'
])}
route
=
self
.
router_chain
.
route
(
inputs
)
destination
=
''
if
route
.
destination
:
pattern
=
r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match
=
re
.
search
(
pattern
,
route
.
destination
,
re
.
IGNORECASE
)
if
match
:
destination
=
match
.
group
()
if
not
destination
:
return
{
"text"
:
''
}
elif
destination
in
self
.
dataset_tools
:
return
{
"text"
:
self
.
dataset_tools
[
destination
]
.
run
(
route
.
next_inputs
[
'input'
]
)}
else
:
raise
ValueError
(
f
"Received invalid destination chain name '{destination}'"
)
api/core/completion.py
View file @
464a3615
...
@@ -10,7 +10,6 @@ from langchain.schema import BaseMessage, HumanMessage
...
@@ -10,7 +10,6 @@ from langchain.schema import BaseMessage, HumanMessage
from
requests.exceptions
import
ChunkedEncodingError
from
requests.exceptions
import
ChunkedEncodingError
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.chain.multi_dataset_router_chain
import
MultiDatasetRouterChain
from
core.constant
import
llm_constant
from
core.constant
import
llm_constant
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
...
@@ -22,8 +21,6 @@ from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
...
@@ -22,8 +21,6 @@ from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
ReadOnlyConversationTokenDBBufferSharedMemory
ReadOnlyConversationTokenDBBufferSharedMemory
from
core.memory.read_only_conversation_token_db_string_buffer_shared_memory
import
\
ReadOnlyConversationTokenDBStringBufferSharedMemory
from
core.orchestrator_rule_parser
import
OrchestratorRuleParser
from
core.orchestrator_rule_parser
import
OrchestratorRuleParser
from
core.prompt.prompt_builder
import
PromptBuilder
from
core.prompt.prompt_builder
import
PromptBuilder
from
core.prompt.prompt_template
import
JinjaPromptTemplate
from
core.prompt.prompt_template
import
JinjaPromptTemplate
...
@@ -88,26 +85,21 @@ class Completion:
...
@@ -88,26 +85,21 @@ class Completion:
# get agent executor
# get agent executor
agent_executor
=
orchestrator_rule_parser
.
to_agent_executor
(
agent_executor
=
orchestrator_rule_parser
.
to_agent_executor
(
conversation_message_task
=
conversation_message_task
,
conversation_message_task
=
conversation_message_task
,
memory
=
ReadOnlyConversationTokenDBStringBufferSharedMemory
(
memory
=
memory
)
if
memory
else
None
,
memory
=
memory
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
callbacks
=
[
chain_callback
]
callbacks
=
[
chain_callback
]
)
)
# run agent executor
# run agent executor
executor_output
=
''
executor_output
=
''
is_agent_output
=
False
if
agent_executor
:
if
agent_executor
:
if
isinstance
(
agent_executor
,
MultiDatasetRouterChain
):
executor_output
=
agent_executor
.
run
(
query
)
else
:
should_use_agent
=
agent_executor
.
should_use_agent
(
query
)
should_use_agent
=
agent_executor
.
should_use_agent
(
query
)
if
should_use_agent
:
if
should_use_agent
:
executor_output
=
agent_executor
.
run
(
query
)
executor_output
=
agent_executor
.
run
(
query
)
is_agent_output
=
True
# run the final llm
# run the final llm
try
:
try
:
# if
is_agent
_output and not app_model_config.pre_prompt:
# if
executor
_output and not app_model_config.pre_prompt:
# # todo streaming flush the agent result to user, not call final llm
# # todo streaming flush the agent result to user, not call final llm
# pass
# pass
...
...
api/core/orchestrator_rule_parser.py
View file @
464a3615
from
typing
import
Optional
,
Union
import
math
from
typing
import
Optional
from
langchain
import
WikipediaAPIWrapper
from
langchain
import
WikipediaAPIWrapper
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.chains.base
import
Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.tools
import
BaseTool
,
Tool
,
WikipediaQueryRun
from
langchain.tools
import
BaseTool
,
Tool
,
WikipediaQueryRun
from
core.agent.agent_executor
import
AgentExecutor
,
PlanningStrategy
,
AgentConfiguration
from
core.agent.agent_executor
import
AgentExecutor
,
PlanningStrategy
,
AgentConfiguration
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.multi_dataset_router_chain
import
MultiDatasetRouterChain
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.conversation_message_task
import
ConversationMessageTask
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.llm_builder
import
LLMBuilder
...
@@ -18,7 +17,7 @@ from core.tool.provider.serpapi_provider import SerpAPIToolProvider
...
@@ -18,7 +17,7 @@ from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
from
core.tool.web_reader_tool
import
WebReaderTool
from
core.tool.web_reader_tool
import
WebReaderTool
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
from
models.dataset
import
Dataset
,
DatasetProcessRule
from
models.model
import
AppModelConfig
from
models.model
import
AppModelConfig
...
@@ -32,7 +31,7 @@ class OrchestratorRuleParser:
...
@@ -32,7 +31,7 @@ class OrchestratorRuleParser:
def
to_agent_executor
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
def
to_agent_executor
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
callbacks
:
Callbacks
=
None
)
\
rest_tokens
:
int
,
callbacks
:
Callbacks
=
None
)
\
->
Optional
[
Union
[
AgentExecutor
|
MultiDatasetRouterChain
]
]:
->
Optional
[
AgentExecutor
]:
if
not
self
.
app_model_config
.
agent_mode_dict
:
if
not
self
.
app_model_config
.
agent_mode_dict
:
return
None
return
None
...
@@ -41,11 +40,6 @@ class OrchestratorRuleParser:
...
@@ -41,11 +40,6 @@ class OrchestratorRuleParser:
chain
=
None
chain
=
None
if
agent_mode_config
and
agent_mode_config
.
get
(
'enabled'
):
if
agent_mode_config
and
agent_mode_config
.
get
(
'enabled'
):
tool_configs
=
agent_mode_config
.
get
(
'tools'
,
[])
tool_configs
=
agent_mode_config
.
get
(
'tools'
,
[])
# use router chain if planning strategy is router or not set
if
not
agent_mode_config
.
get
(
'strategy'
)
or
agent_mode_config
.
get
(
'strategy'
)
==
'router'
:
return
self
.
to_router_chain
(
tool_configs
,
conversation_message_task
,
rest_tokens
)
agent_model_name
=
agent_mode_config
.
get
(
'model_name'
,
'gpt-4'
)
agent_model_name
=
agent_mode_config
.
get
(
'model_name'
,
'gpt-4'
)
agent_llm
=
LLMBuilder
.
to_llm
(
agent_llm
=
LLMBuilder
.
to_llm
(
...
@@ -64,15 +58,15 @@ class OrchestratorRuleParser:
...
@@ -64,15 +58,15 @@ class OrchestratorRuleParser:
callbacks
=
[
DifyStdOutCallbackHandler
()]
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
)
tools
=
self
.
to_tools
(
tool_configs
,
conversation_message_task
)
tools
=
self
.
to_tools
(
tool_configs
,
conversation_message_task
,
rest_tokens
)
if
len
(
tools
)
==
0
:
if
len
(
tools
)
==
0
:
return
None
return
None
agent_configuration
=
AgentConfiguration
(
agent_configuration
=
AgentConfiguration
(
strategy
=
PlanningStrategy
(
agent_mode_config
.
get
(
'strategy'
)),
strategy
=
PlanningStrategy
(
agent_mode_config
.
get
(
'strategy'
,
'router'
)),
llm
=
agent_llm
,
llm
=
agent_llm
,
tools
=
self
.
to_tools
(
tool_configs
,
conversation_message_task
)
,
tools
=
tools
,
summary_llm
=
summary_llm
,
summary_llm
=
summary_llm
,
memory
=
memory
,
memory
=
memory
,
callbacks
=
callbacks
,
callbacks
=
callbacks
,
...
@@ -85,44 +79,6 @@ class OrchestratorRuleParser:
...
@@ -85,44 +79,6 @@ class OrchestratorRuleParser:
return
chain
return
chain
def
to_router_chain
(
self
,
tool_configs
:
list
,
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
)
->
Optional
[
Chain
]:
"""
Convert tool configs to router chain if planning strategy is router
:param tool_configs:
:param conversation_message_task:
:param rest_tokens:
:return:
"""
chain
=
None
datasets
=
[]
for
tool_config
in
tool_configs
:
tool_type
=
list
(
tool_config
.
keys
())[
0
]
tool_val
=
list
(
tool_config
.
values
())[
0
]
if
tool_type
==
"dataset"
:
# get dataset from dataset id
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
self
.
tenant_id
,
Dataset
.
id
==
tool_val
.
get
(
"id"
)
)
.
first
()
if
dataset
:
datasets
.
append
(
dataset
)
if
len
(
datasets
)
>
0
:
# tool to chain
multi_dataset_router_chain
=
MultiDatasetRouterChain
.
from_datasets
(
tenant_id
=
self
.
tenant_id
,
datasets
=
datasets
,
conversation_message_task
=
conversation_message_task
,
rest_tokens
=
rest_tokens
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
chain
=
multi_dataset_router_chain
return
chain
def
to_sensitive_word_avoidance_chain
(
self
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
\
def
to_sensitive_word_avoidance_chain
(
self
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
\
->
Optional
[
SensitiveWordAvoidanceChain
]:
->
Optional
[
SensitiveWordAvoidanceChain
]:
"""
"""
...
@@ -147,10 +103,12 @@ class OrchestratorRuleParser:
...
@@ -147,10 +103,12 @@ class OrchestratorRuleParser:
return
None
return
None
def
to_tools
(
self
,
tool_configs
:
list
,
conversation_message_task
:
ConversationMessageTask
)
->
list
[
BaseTool
]:
def
to_tools
(
self
,
tool_configs
:
list
,
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
)
->
list
[
BaseTool
]:
"""
"""
Convert app agent tool configs to tools
Convert app agent tool configs to tools
:param rest_tokens:
:param tool_configs: app agent tool configs
:param tool_configs: app agent tool configs
:param conversation_message_task:
:param conversation_message_task:
:return:
:return:
...
@@ -164,8 +122,7 @@ class OrchestratorRuleParser:
...
@@ -164,8 +122,7 @@ class OrchestratorRuleParser:
tool
=
None
tool
=
None
if
tool_type
==
"dataset"
:
if
tool_type
==
"dataset"
:
tool
=
None
tool
=
self
.
to_dataset_retriever_tool
(
tool_val
,
conversation_message_task
,
rest_tokens
)
# tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task)
elif
tool_type
==
"web_reader"
:
elif
tool_type
==
"web_reader"
:
tool
=
self
.
to_web_reader_tool
()
tool
=
self
.
to_web_reader_tool
()
elif
tool_type
==
"google_search"
:
elif
tool_type
==
"google_search"
:
...
@@ -178,10 +135,12 @@ class OrchestratorRuleParser:
...
@@ -178,10 +135,12 @@ class OrchestratorRuleParser:
return
tools
return
tools
def
to_dataset_retriever_tool
(
self
,
tool_config
:
dict
,
conversation_message_task
:
ConversationMessageTask
)
\
def
to_dataset_retriever_tool
(
self
,
tool_config
:
dict
,
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
)
\
->
Optional
[
BaseTool
]:
->
Optional
[
BaseTool
]:
"""
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param tool_config:
:param conversation_message_task:
:param conversation_message_task:
:return:
:return:
...
@@ -195,9 +154,10 @@ class OrchestratorRuleParser:
...
@@ -195,9 +154,10 @@ class OrchestratorRuleParser:
if
dataset
and
dataset
.
available_document_count
==
0
and
dataset
.
available_document_count
==
0
:
if
dataset
and
dataset
.
available_document_count
==
0
and
dataset
.
available_document_count
==
0
:
return
None
return
None
k
=
self
.
_dynamic_calc_retrieve_k
(
dataset
,
rest_tokens
)
tool
=
DatasetRetrieverTool
.
from_dataset
(
tool
=
DatasetRetrieverTool
.
from_dataset
(
dataset
=
dataset
,
dataset
=
dataset
,
k
=
3
,
k
=
k
,
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
)
)
...
@@ -249,3 +209,34 @@ class OrchestratorRuleParser:
...
@@ -249,3 +209,34 @@ class OrchestratorRuleParser:
api_wrapper
=
WikipediaAPIWrapper
(
doc_content_chars_max
=
4000
),
api_wrapper
=
WikipediaAPIWrapper
(
doc_content_chars_max
=
4000
),
callbacks
=
[
DifyStdOutCallbackHandler
()]
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
)
@
classmethod
def
_dynamic_calc_retrieve_k
(
cls
,
dataset
:
Dataset
,
rest_tokens
:
int
)
->
int
:
DEFAULT_K
=
2
CONTEXT_TOKENS_PERCENT
=
0.3
processing_rule
=
dataset
.
latest_process_rule
if
not
processing_rule
:
return
DEFAULT_K
if
processing_rule
.
mode
==
"custom"
:
rules
=
processing_rule
.
rules_dict
if
not
rules
:
return
DEFAULT_K
segmentation
=
rules
[
"segmentation"
]
segment_max_tokens
=
segmentation
[
"max_tokens"
]
else
:
segment_max_tokens
=
DatasetProcessRule
.
AUTOMATIC_RULES
[
'segmentation'
][
'max_tokens'
]
# when rest_tokens is less than default context tokens
if
rest_tokens
<
segment_max_tokens
*
DEFAULT_K
:
return
rest_tokens
//
segment_max_tokens
context_limit_tokens
=
math
.
floor
(
rest_tokens
*
CONTEXT_TOKENS_PERCENT
)
# when context_limit_tokens is less than default context tokens, use default_k
if
context_limit_tokens
<=
segment_max_tokens
*
DEFAULT_K
:
return
DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return
context_limit_tokens
//
segment_max_tokens
api/core/tool/dataset_retriever_tool.py
View file @
464a3615
...
@@ -83,6 +83,7 @@ class DatasetRetrieverTool(BaseTool):
...
@@ -83,6 +83,7 @@ class DatasetRetrieverTool(BaseTool):
embeddings
=
embeddings
embeddings
=
embeddings
)
)
if
self
.
k
>
0
:
documents
=
vector_index
.
search
(
documents
=
vector_index
.
search
(
query
,
query
,
search_type
=
'similarity'
,
search_type
=
'similarity'
,
...
@@ -90,6 +91,8 @@ class DatasetRetrieverTool(BaseTool):
...
@@ -90,6 +91,8 @@ class DatasetRetrieverTool(BaseTool):
'k'
:
self
.
k
'k'
:
self
.
k
}
}
)
)
else
:
documents
=
[]
hit_callback
=
DatasetIndexToolCallbackHandler
(
dataset
.
id
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
hit_callback
.
on_tool_end
(
documents
)
...
...
api/core/tool/web_reader_tool.py
View file @
464a3615
...
@@ -88,7 +88,6 @@ class WebReaderTool(BaseTool):
...
@@ -88,7 +88,6 @@ class WebReaderTool(BaseTool):
if
len
(
docs
)
>
10
:
if
len
(
docs
)
>
10
:
docs
=
docs
[:
10
]
docs
=
docs
[:
10
]
print
(
"summary docs: "
,
docs
)
chain
=
load_summarize_chain
(
self
.
llm
,
chain_type
=
"refine"
,
callbacks
=
self
.
callbacks
)
chain
=
load_summarize_chain
(
self
.
llm
,
chain_type
=
"refine"
,
callbacks
=
self
.
callbacks
)
page_contents
=
chain
.
run
(
docs
)
page_contents
=
chain
.
run
(
docs
)
# todo use cache
# todo use cache
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment