Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
fb5118f0
Commit
fb5118f0
authored
Jun 19, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: support router chain using langchain index
parent
0d82aa8f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
115 additions
and
144 deletions
+115
-144
index_tool_callback_handler.py
api/core/callback_handler/index_tool_callback_handler.py
+9
-20
multi_dataset_router_chain.py
api/core/chain/multi_dataset_router_chain.py
+19
-10
dataset_index_tool.py
api/core/tool/dataset_index_tool.py
+87
-0
dataset_tool_builder.py
api/core/tool/dataset_tool_builder.py
+0
-73
llama_index_tool.py
api/core/tool/llama_index_tool.py
+0
-41
No files found.
api/core/callback_handler/index_tool_callback_handler.py
View file @
fb5118f0
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
class
IndexToolCallbackHandler
:
from
typing
import
List
def
__init__
(
self
)
->
None
:
self
.
_response
=
None
from
langchain.schema
import
Document
@
property
def
response
(
self
)
->
Response
:
return
self
.
_response
def
on_tool_end
(
self
,
response
:
Response
)
->
None
:
"""Handle tool end."""
self
.
_response
=
response
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
class
DatasetIndexToolCallbackHandler
(
IndexToolCallbackHandler
)
:
class
DatasetIndexToolCallbackHandler
:
"""Callback handler for dataset tool."""
def
__init__
(
self
,
dataset_id
:
str
)
->
None
:
super
()
.
__init__
()
self
.
dataset_id
=
dataset_id
def
on_tool_end
(
self
,
response
:
Response
)
->
None
:
def
on_tool_end
(
self
,
documents
:
List
[
Document
]
)
->
None
:
"""Handle tool end."""
for
node
in
response
.
source_node
s
:
index_node_id
=
node
.
node
.
doc_id
for
document
in
document
s
:
doc_id
=
document
.
metadata
[
'doc_id'
]
# add hit count to document segment
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
dataset_id
,
DocumentSegment
.
index_node_id
==
index_node
_id
DocumentSegment
.
index_node_id
==
doc
_id
)
.
update
(
{
DocumentSegment
.
hit_count
:
DocumentSegment
.
hit_count
+
1
},
synchronize_session
=
False
...
...
api/core/chain/multi_dataset_router_chain.py
View file @
fb5118f0
from
typing
import
Mapping
,
List
,
Dict
,
Any
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
langchain
import
PromptTemplate
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
pydantic
import
Extra
...
...
@@ -9,8 +10,7 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from
core.chain.llm_router_chain
import
LLMRouterChain
,
RouterOutputParser
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.tool.dataset_tool_builder
import
DatasetToolBuilder
from
core.tool.llama_index_tool
import
EnhanceLlamaIndexTool
from
core.tool.dataset_index_tool
import
DatasetTool
from
models.dataset
import
Dataset
MULTI_PROMPT_ROUTER_TEMPLATE
=
"""
...
...
@@ -50,7 +50,7 @@ class MultiDatasetRouterChain(Chain):
router_chain
:
LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools
:
Mapping
[
str
,
EnhanceLlamaIndex
Tool
]
dataset_tools
:
Mapping
[
str
,
Dataset
Tool
]
"""Map of name to candidate chains that inputs can be routed to."""
class
Config
:
...
...
@@ -95,22 +95,30 @@ class MultiDatasetRouterChain(Chain):
router_template
=
MULTI_PROMPT_ROUTER_TEMPLATE
.
format
(
destinations
=
destinations_str
)
router_prompt
=
PromptTemplate
(
template
=
router_template
,
input_variables
=
[
"input"
],
output_parser
=
RouterOutputParser
(),
)
router_chain
=
LLMRouterChain
.
from_llm
(
llm
,
router_prompt
)
dataset_tools
=
{}
for
dataset
in
datasets
:
dataset_tool
=
DatasetToolBuilder
.
build_dataset_tool
(
# fulfill description when it is empty
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
dataset_tool
=
DatasetTool
(
name
=
f
"dataset-{dataset.id}"
,
description
=
description
,
k
=
2
,
# todo set by llm tokens limit
dataset
=
dataset
,
response_mode
=
'no_synthesizer'
,
# "compact"
callback_handler
=
DatasetToolCallbackHandler
(
conversation_message_task
)
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
)
if
dataset_tool
:
dataset_tools
[
dataset
.
id
]
=
dataset_tool
dataset_tools
[
dataset
.
id
]
=
dataset_tool
return
cls
(
router_chain
=
router_chain
,
...
...
@@ -120,7 +128,8 @@ class MultiDatasetRouterChain(Chain):
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
]
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
if
len
(
self
.
dataset_tools
)
==
0
:
return
{
"text"
:
''
}
...
...
api/core/tool/dataset_index_tool.py
0 → 100644
View file @
fb5118f0
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.tools
import
BaseTool
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
class
DatasetTool
(
BaseTool
):
"""Tool for querying a Dataset."""
dataset
:
Dataset
k
:
int
=
2
def
_run
(
self
,
tool_input
:
str
)
->
str
:
if
self
.
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
kw_table_index
=
KeywordTableIndex
(
dataset
=
self
.
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
5
)
)
documents
=
kw_table_index
.
search
(
tool_input
,
search_kwargs
=
{
'k'
:
self
.
k
})
else
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
vector_index
.
search
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
self
.
k
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
await
vector_index
.
asearch
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
10
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
api/core/tool/dataset_tool_builder.py
deleted
100644 → 0
View file @
0d82aa8f
from
typing
import
Optional
from
langchain.callbacks
import
CallbackManager
from
llama_index.langchain_helpers.agents
import
IndexToolConfig
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.prompt.prompts
import
QUERY_KEYWORD_EXTRACT_TEMPLATE
from
core.tool.llama_index_tool
import
EnhanceLlamaIndexTool
from
models.dataset
import
Dataset
class
DatasetToolBuilder
:
@
classmethod
def
build_dataset_tool
(
cls
,
dataset
:
Dataset
,
response_mode
:
str
=
"no_synthesizer"
,
callback_handler
:
Optional
[
DatasetToolCallbackHandler
]
=
None
):
if
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
index
=
KeywordTableIndex
(
dataset
=
dataset
)
.
query_index
if
not
index
:
return
None
query_kwargs
=
{
"mode"
:
"default"
,
"response_mode"
:
response_mode
,
"query_keyword_extract_template"
:
QUERY_KEYWORD_EXTRACT_TEMPLATE
,
"max_keywords_per_query"
:
5
,
# If num_chunks_per_query is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"num_chunks_per_query"
:
2
}
else
:
index
=
VectorIndex
(
dataset
=
dataset
)
.
query_index
if
not
index
:
return
None
query_kwargs
=
{
"mode"
:
"default"
,
"response_mode"
:
response_mode
,
# If top_k is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"similarity_top_k"
:
2
}
# fulfill description when it is empty
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
index_tool_config
=
IndexToolConfig
(
index
=
index
,
name
=
f
"dataset-{dataset.id}"
,
description
=
description
,
index_query_kwargs
=
query_kwargs
,
tool_kwargs
=
{
"callback_manager"
:
CallbackManager
([
callback_handler
,
DifyStdOutCallbackHandler
()])
},
# tool_kwargs={"return_direct": True},
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
)
index_callback_handler
=
DatasetIndexToolCallbackHandler
(
dataset_id
=
dataset
.
id
)
return
EnhanceLlamaIndexTool
.
from_tool_config
(
tool_config
=
index_tool_config
,
callback_handler
=
index_callback_handler
)
api/core/tool/llama_index_tool.py
deleted
100644 → 0
View file @
0d82aa8f
from
typing
import
Dict
from
langchain.tools
import
BaseTool
from
pydantic
import
Field
from
core.callback_handler.index_tool_callback_handler
import
IndexToolCallbackHandler
class
EnhanceLlamaIndexTool
(
BaseTool
):
"""Tool for querying a LlamaIndex."""
# NOTE: name/description still needs to be set
index
:
BaseGPTIndex
query_kwargs
:
Dict
=
Field
(
default_factory
=
dict
)
return_sources
:
bool
=
False
callback_handler
:
IndexToolCallbackHandler
@
classmethod
def
from_tool_config
(
cls
,
tool_config
:
IndexToolConfig
,
callback_handler
:
IndexToolCallbackHandler
)
->
"EnhanceLlamaIndexTool"
:
"""Create a tool from a tool config."""
return_sources
=
tool_config
.
tool_kwargs
.
pop
(
"return_sources"
,
False
)
return
cls
(
index
=
tool_config
.
index
,
callback_handler
=
callback_handler
,
name
=
tool_config
.
name
,
description
=
tool_config
.
description
,
return_sources
=
return_sources
,
query_kwargs
=
tool_config
.
index_query_kwargs
,
**
tool_config
.
tool_kwargs
,
)
def
_run
(
self
,
tool_input
:
str
)
->
str
:
response
=
self
.
index
.
query
(
tool_input
,
**
self
.
query_kwargs
)
self
.
callback_handler
.
on_tool_end
(
response
)
return
str
(
response
)
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
response
=
await
self
.
index
.
aquery
(
tool_input
,
**
self
.
query_kwargs
)
self
.
callback_handler
.
on_tool_end
(
response
)
return
str
(
response
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment