Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
88545184
Unverified
Commit
88545184
authored
May 28, 2023
by
John Wang
Committed by
GitHub
May 28, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: support multi datasets router chain mode (#231)
parent
2c23caac
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
295 additions
and
43 deletions
+295
-43
llm_router_chain.py
api/core/chain/llm_router_chain.py
+132
-0
main_chain_builder.py
api/core/chain/main_chain_builder.py
+22
-30
multi_dataset_router_chain.py
api/core/chain/multi_dataset_router_chain.py
+138
-0
dataset_tool_builder.py
api/core/tool/dataset_tool_builder.py
+3
-13
No files found.
api/core/chain/llm_router_chain.py
0 → 100644
View file @
88545184
"""Base classes for LLM-powered router chains."""
from
__future__
import
annotations
import
json
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
langchain.chains.base
import
Chain
from
pydantic
import
root_validator
from
langchain.chains
import
LLMChain
from
langchain.prompts
import
BasePromptTemplate
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
,
BaseLanguageModel
class
Route
(
NamedTuple
):
destination
:
Optional
[
str
]
next_inputs
:
Dict
[
str
,
Any
]
class
LLMRouterChain
(
Chain
):
"""A router chain that uses an LLM chain to perform routing."""
llm_chain
:
LLMChain
"""LLM chain used to perform routing"""
@
root_validator
()
def
validate_prompt
(
cls
,
values
:
dict
)
->
dict
:
prompt
=
values
[
"llm_chain"
]
.
prompt
if
prompt
.
output_parser
is
None
:
raise
ValueError
(
"LLMRouterChain requires base llm_chain prompt to have an output"
" parser that converts LLM text output to a dictionary with keys"
" 'destination' and 'next_inputs'. Received a prompt with no output"
" parser."
)
return
values
@
property
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return
self
.
llm_chain
.
input_keys
def
_validate_outputs
(
self
,
outputs
:
Dict
[
str
,
Any
])
->
None
:
super
()
.
_validate_outputs
(
outputs
)
if
not
isinstance
(
outputs
[
"next_inputs"
],
dict
):
raise
ValueError
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
]
)
->
Dict
[
str
,
Any
]:
output
=
cast
(
Dict
[
str
,
Any
],
self
.
llm_chain
.
predict_and_parse
(
**
inputs
),
)
return
output
@
classmethod
def
from_llm
(
cls
,
llm
:
BaseLanguageModel
,
prompt
:
BasePromptTemplate
,
**
kwargs
:
Any
)
->
LLMRouterChain
:
"""Convenience constructor."""
llm_chain
=
LLMChain
(
llm
=
llm
,
prompt
=
prompt
)
return
cls
(
llm_chain
=
llm_chain
,
**
kwargs
)
@
property
def
output_keys
(
self
)
->
List
[
str
]:
return
[
"destination"
,
"next_inputs"
]
def
route
(
self
,
inputs
:
Dict
[
str
,
Any
])
->
Route
:
result
=
self
(
inputs
)
return
Route
(
result
[
"destination"
],
result
[
"next_inputs"
])
class
RouterOutputParser
(
BaseOutputParser
[
Dict
[
str
,
str
]]):
"""Parser for output of router chain int he multi-prompt chain."""
default_destination
:
str
=
"DEFAULT"
next_inputs_type
:
Type
=
str
next_inputs_inner_key
:
str
=
"input"
def
parse_json_markdown
(
self
,
json_string
:
str
)
->
dict
:
# Remove the triple backticks if present
json_string
=
json_string
.
replace
(
"```json"
,
""
)
.
replace
(
"```"
,
""
)
# Strip whitespace and newlines from the start and end
json_string
=
json_string
.
strip
()
# Parse the JSON string into a Python dictionary
parsed
=
json
.
loads
(
json_string
)
return
parsed
def
parse_and_check_json_markdown
(
self
,
text
:
str
,
expected_keys
:
List
[
str
])
->
dict
:
try
:
json_obj
=
self
.
parse_json_markdown
(
text
)
except
json
.
JSONDecodeError
as
e
:
raise
OutputParserException
(
f
"Got invalid JSON object. Error: {e}"
)
for
key
in
expected_keys
:
if
key
not
in
json_obj
:
raise
OutputParserException
(
f
"Got invalid return object. Expected key `{key}` "
f
"to be present, but got {json_obj}"
)
return
json_obj
def
parse
(
self
,
text
:
str
)
->
Dict
[
str
,
Any
]:
try
:
expected_keys
=
[
"destination"
,
"next_inputs"
]
parsed
=
self
.
parse_and_check_json_markdown
(
text
,
expected_keys
)
if
not
isinstance
(
parsed
[
"destination"
],
str
):
raise
ValueError
(
"Expected 'destination' to be a string."
)
if
not
isinstance
(
parsed
[
"next_inputs"
],
self
.
next_inputs_type
):
raise
ValueError
(
f
"Expected 'next_inputs' to be {self.next_inputs_type}."
)
parsed
[
"next_inputs"
]
=
{
self
.
next_inputs_inner_key
:
parsed
[
"next_inputs"
]}
if
(
parsed
[
"destination"
]
.
strip
()
.
lower
()
==
self
.
default_destination
.
lower
()
):
parsed
[
"destination"
]
=
None
else
:
parsed
[
"destination"
]
=
parsed
[
"destination"
]
.
strip
()
return
parsed
except
Exception
as
e
:
raise
OutputParserException
(
f
"Parsing text
\n
{text}
\n
raised following error:
\n
{e}"
)
api/core/chain/main_chain_builder.py
View file @
88545184
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
from
langchain.callbacks
import
SharedCallbackManager
from
langchain.callbacks
import
SharedCallbackManager
,
CallbackManager
from
langchain.chains
import
SequentialChain
from
langchain.chains
import
SequentialChain
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.agent.agent_builder
import
AgentBuilder
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.chain_builder
import
ChainBuilder
from
core.chain.chain_builder
import
ChainBuilder
from
core.c
onstant
import
llm_constant
from
core.c
hain.multi_dataset_router_chain
import
MultiDatasetRouterChain
from
core.conversation_message_task
import
ConversationMessageTask
from
core.conversation_message_task
import
ConversationMessageTask
from
core.tool.dataset_tool_builder
import
DatasetToolBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
class
MainChainBuilder
:
class
MainChainBuilder
:
...
@@ -31,8 +31,7 @@ class MainChainBuilder:
...
@@ -31,8 +31,7 @@ class MainChainBuilder:
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
agent_mode
=
agent_mode
,
agent_mode
=
agent_mode
,
memory
=
memory
,
memory
=
memory
,
dataset_tool_callback_handler
=
DatasetToolCallbackHandler
(
conversation_message_task
),
conversation_message_task
=
conversation_message_task
agent_loop_gather_callback_handler
=
chain_callback_handler
.
agent_loop_gather_callback_handler
)
)
chains
+=
tool_chains
chains
+=
tool_chains
...
@@ -59,15 +58,15 @@ class MainChainBuilder:
...
@@ -59,15 +58,15 @@ class MainChainBuilder:
@
classmethod
@
classmethod
def
get_agent_chains
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
def
get_agent_chains
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
dataset_tool_callback_handler
:
DatasetToolCallbackHandler
,
conversation_message_task
:
ConversationMessageTask
):
agent_loop_gather_callback_handler
:
AgentLoopGatherCallbackHandler
):
# agent mode
# agent mode
chains
=
[]
chains
=
[]
if
agent_mode
and
agent_mode
.
get
(
'enabled'
):
if
agent_mode
and
agent_mode
.
get
(
'enabled'
):
tools
=
agent_mode
.
get
(
'tools'
,
[])
tools
=
agent_mode
.
get
(
'tools'
,
[])
pre_fixed_chains
=
[]
pre_fixed_chains
=
[]
agent_tools
=
[]
# agent_tools = []
datasets
=
[]
for
tool
in
tools
:
for
tool
in
tools
:
tool_type
=
list
(
tool
.
keys
())[
0
]
tool_type
=
list
(
tool
.
keys
())[
0
]
tool_config
=
list
(
tool
.
values
())[
0
]
tool_config
=
list
(
tool
.
values
())[
0
]
...
@@ -76,34 +75,27 @@ class MainChainBuilder:
...
@@ -76,34 +75,27 @@ class MainChainBuilder:
if
chain
:
if
chain
:
pre_fixed_chains
.
append
(
chain
)
pre_fixed_chains
.
append
(
chain
)
elif
tool_type
==
"dataset"
:
elif
tool_type
==
"dataset"
:
dataset_tool
=
DatasetToolBuilder
.
build_dataset_tool
(
# get dataset from dataset id
tenant_id
=
tenant_id
,
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
dataset_id
=
tool_config
.
get
(
"id"
),
Dataset
.
tenant_id
==
tenant_id
,
response_mode
=
'no_synthesizer'
,
# "compact"
Dataset
.
id
==
tool_config
.
get
(
"id"
)
callback_handler
=
dataset_tool_callback_handler
)
.
first
()
)
if
dataset
_tool
:
if
dataset
:
agent_tools
.
append
(
dataset_tool
)
datasets
.
append
(
dataset
)
# add pre-fixed chains
# add pre-fixed chains
chains
+=
pre_fixed_chains
chains
+=
pre_fixed_chains
if
len
(
agent_tools
)
==
1
:
if
len
(
datasets
)
>
0
:
# tool to chain
# tool to chain
tool_chain
=
ChainBuilder
.
to_tool_chain
(
tool
=
agent_tools
[
0
],
output_key
=
'tool_output'
)
multi_dataset_router_chain
=
MultiDatasetRouterChain
.
from_datasets
(
chains
.
append
(
tool_chain
)
elif
len
(
agent_tools
)
>
1
:
# build agent config
agent_chain
=
AgentBuilder
.
to_agent_chain
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
tools
=
agent_tools
,
datasets
=
datasets
,
memory
=
memory
,
conversation_message_task
=
conversation_message_task
,
dataset_tool_callback_handler
=
dataset_tool_callback_handler
,
callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
agent_loop_gather_callback_handler
=
agent_loop_gather_callback_handler
)
)
chains
.
append
(
multi_dataset_router_chain
)
chains
.
append
(
agent_chain
)
final_output_key
=
cls
.
get_chains_output_key
(
chains
)
final_output_key
=
cls
.
get_chains_output_key
(
chains
)
...
...
api/core/chain/multi_dataset_router_chain.py
0 → 100644
View file @
88545184
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
langchain
import
LLMChain
,
PromptTemplate
,
ConversationChain
from
langchain.callbacks
import
CallbackManager
from
langchain.chains.base
import
Chain
from
langchain.schema
import
BaseLanguageModel
from
pydantic
import
Extra
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.llm_router_chain
import
LLMRouterChain
,
RouterOutputParser
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.tool.dataset_tool_builder
import
DatasetToolBuilder
from
core.tool.llama_index_tool
import
EnhanceLlamaIndexTool
from
models.dataset
import
Dataset
MULTI_PROMPT_ROUTER_TEMPLATE
=
"""
Given a raw text input to a language model select the model prompt best suited for
\
the input. You will be given the names of the available prompts and a description of
\
what the prompt is best suited for. You may also revise the original input if you
\
think that revising it will ultimately lead to a better response from the language
\
model.
<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like:
```json
{{{{
"destination": string
\\
name of the prompt to use or "DEFAULT"
"next_inputs": string
\\
a potentially modified version of the original input
}}}}
```
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR
\
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any
\
modifications are needed.
<< CANDIDATE PROMPTS >>
{destinations}
<< INPUT >>
{{input}}
<< OUTPUT >>
"""
class
MultiDatasetRouterChain
(
Chain
):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain
:
LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools
:
Mapping
[
str
,
EnhanceLlamaIndexTool
]
"""Map of name to candidate chains that inputs can be routed to."""
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
arbitrary_types_allowed
=
True
@
property
def
input_keys
(
self
)
->
List
[
str
]:
"""Will be whatever keys the router chain prompt expects.
:meta private:
"""
return
self
.
router_chain
.
input_keys
@
property
def
output_keys
(
self
)
->
List
[
str
]:
return
[
"text"
]
@
classmethod
def
from_datasets
(
cls
,
tenant_id
:
str
,
datasets
:
List
[
Dataset
],
conversation_message_task
:
ConversationMessageTask
,
**
kwargs
:
Any
,
):
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
temperature
=
0
,
max_tokens
=
1024
,
callback_manager
=
llm_callback_manager
)
destinations
=
[
f
"{d.id}: {d.description}"
for
d
in
datasets
]
destinations_str
=
"
\n
"
.
join
(
destinations
)
router_template
=
MULTI_PROMPT_ROUTER_TEMPLATE
.
format
(
destinations
=
destinations_str
)
router_prompt
=
PromptTemplate
(
template
=
router_template
,
input_variables
=
[
"input"
],
output_parser
=
RouterOutputParser
(),
)
router_chain
=
LLMRouterChain
.
from_llm
(
llm
,
router_prompt
)
dataset_tools
=
{}
for
dataset
in
datasets
:
dataset_tool
=
DatasetToolBuilder
.
build_dataset_tool
(
dataset
=
dataset
,
response_mode
=
'no_synthesizer'
,
# "compact"
callback_handler
=
DatasetToolCallbackHandler
(
conversation_message_task
)
)
dataset_tools
[
dataset
.
id
]
=
dataset_tool
return
cls
(
router_chain
=
router_chain
,
dataset_tools
=
dataset_tools
,
**
kwargs
,
)
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
]
)
->
Dict
[
str
,
Any
]:
if
len
(
self
.
dataset_tools
)
==
0
:
return
{
"text"
:
''
}
elif
len
(
self
.
dataset_tools
)
==
1
:
return
{
"text"
:
next
(
iter
(
self
.
dataset_tools
.
values
()))
.
run
(
inputs
[
'input'
])}
route
=
self
.
router_chain
.
route
(
inputs
)
if
not
route
.
destination
:
return
{
"text"
:
''
}
elif
route
.
destination
in
self
.
dataset_tools
:
return
{
"text"
:
self
.
dataset_tools
[
route
.
destination
]
.
run
(
route
.
next_inputs
[
'input'
]
)}
else
:
raise
ValueError
(
f
"Received invalid destination chain name '{route.destination}'"
)
api/core/tool/dataset_tool_builder.py
View file @
88545184
...
@@ -10,24 +10,14 @@ from core.index.keyword_table_index import KeywordTableIndex
...
@@ -10,24 +10,14 @@ from core.index.keyword_table_index import KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.vector_index
import
VectorIndex
from
core.prompt.prompts
import
QUERY_KEYWORD_EXTRACT_TEMPLATE
from
core.prompt.prompts
import
QUERY_KEYWORD_EXTRACT_TEMPLATE
from
core.tool.llama_index_tool
import
EnhanceLlamaIndexTool
from
core.tool.llama_index_tool
import
EnhanceLlamaIndexTool
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
from
models.dataset
import
Dataset
class
DatasetToolBuilder
:
class
DatasetToolBuilder
:
@
classmethod
@
classmethod
def
build_dataset_tool
(
cls
,
tenant_id
:
str
,
dataset_id
:
str
,
def
build_dataset_tool
(
cls
,
dataset
:
Dataset
,
response_mode
:
str
=
"no_synthesizer"
,
response_mode
:
str
=
"no_synthesizer"
,
callback_handler
:
Optional
[
DatasetToolCallbackHandler
]
=
None
):
callback_handler
:
Optional
[
DatasetToolCallbackHandler
]
=
None
):
# get dataset from dataset id
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
tenant_id
,
Dataset
.
id
==
dataset_id
)
.
first
()
if
not
dataset
:
return
None
if
dataset
.
indexing_technique
==
"economy"
:
if
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
# use keyword table query
index
=
KeywordTableIndex
(
dataset
=
dataset
)
.
query_index
index
=
KeywordTableIndex
(
dataset
=
dataset
)
.
query_index
...
@@ -65,7 +55,7 @@ class DatasetToolBuilder:
...
@@ -65,7 +55,7 @@ class DatasetToolBuilder:
index_tool_config
=
IndexToolConfig
(
index_tool_config
=
IndexToolConfig
(
index
=
index
,
index
=
index
,
name
=
f
"dataset-{dataset
_
id}"
,
name
=
f
"dataset-{dataset
.
id}"
,
description
=
description
,
description
=
description
,
index_query_kwargs
=
query_kwargs
,
index_query_kwargs
=
query_kwargs
,
tool_kwargs
=
{
tool_kwargs
=
{
...
@@ -75,7 +65,7 @@ class DatasetToolBuilder:
...
@@ -75,7 +65,7 @@ class DatasetToolBuilder:
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
)
)
index_callback_handler
=
DatasetIndexToolCallbackHandler
(
dataset_id
=
dataset
_
id
)
index_callback_handler
=
DatasetIndexToolCallbackHandler
(
dataset_id
=
dataset
.
id
)
return
EnhanceLlamaIndexTool
.
from_tool_config
(
return
EnhanceLlamaIndexTool
.
from_tool_config
(
tool_config
=
index_tool_config
,
tool_config
=
index_tool_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment