Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
c429005c
Commit
c429005c
authored
Jul 10, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: optimize dataset_retriever tool
parent
c6c81164
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
64 additions
and
137 deletions
+64
-137
agent_executor.py
api/core/agent/agent_executor.py
+2
-0
chain_builder.py
api/core/chain/chain_builder.py
+0
-32
tool_chain.py
api/core/chain/tool_chain.py
+0
-51
orchestrator_rule_parser.py
api/core/orchestrator_rule_parser.py
+7
-11
dataset_retriever_tool.py
api/core/tool/dataset_retriever_tool.py
+51
-39
errors.py
api/core/tool/provider/errors.py
+2
-2
serpapi_provider.py
api/core/tool/provider/serpapi_provider.py
+2
-2
No files found.
api/core/agent/agent_executor.py
View file @
c429005c
...
@@ -63,6 +63,8 @@ class AgentExecutor:
...
@@ -63,6 +63,8 @@ class AgentExecutor:
summary_llm
=
self
.
configuration
.
summary_llm
,
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
verbose
=
True
)
)
else
:
raise
NotImplementedError
(
f
"Unknown Agent Strategy: {self.configuration.strategy}"
)
return
agent
return
agent
...
...
api/core/chain/chain_builder.py
deleted
100644 → 0
View file @
c6c81164
from
typing
import
Optional
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.tool_chain
import
ToolChain
class
ChainBuilder
:
@
classmethod
def
to_tool_chain
(
cls
,
tool
,
**
kwargs
)
->
ToolChain
:
return
ToolChain
(
tool
=
tool
,
input_key
=
kwargs
.
get
(
'input_key'
,
'input'
),
output_key
=
kwargs
.
get
(
'output_key'
,
'tool_output'
),
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
@
classmethod
def
to_sensitive_word_avoidance_chain
(
cls
,
tool_config
:
dict
,
**
kwargs
)
->
Optional
[
SensitiveWordAvoidanceChain
]:
sensitive_words
=
tool_config
.
get
(
"words"
,
""
)
if
tool_config
.
get
(
"enabled"
,
False
)
\
and
sensitive_words
:
return
SensitiveWordAvoidanceChain
(
sensitive_words
=
sensitive_words
.
split
(
","
),
canned_response
=
tool_config
.
get
(
"canned_response"
,
''
),
output_key
=
"sensitive_word_avoidance_output"
,
callbacks
=
[
DifyStdOutCallbackHandler
()],
**
kwargs
)
return
None
api/core/chain/tool_chain.py
deleted
100644 → 0
View file @
c6c81164
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
,
AsyncCallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.tools
import
BaseTool
class
ToolChain
(
Chain
):
input_key
:
str
=
"input"
#: :meta private:
output_key
:
str
=
"output"
#: :meta private:
tool
:
BaseTool
@
property
def
_chain_type
(
self
)
->
str
:
return
"tool_chain"
@
property
def
input_keys
(
self
)
->
List
[
str
]:
"""Expect input key.
:meta private:
"""
return
[
self
.
input_key
]
@
property
def
output_keys
(
self
)
->
List
[
str
]:
"""Return output key.
:meta private:
"""
return
[
self
.
output_key
]
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
input
=
inputs
[
self
.
input_key
]
output
=
self
.
tool
.
run
(
input
,
self
.
verbose
)
return
{
self
.
output_key
:
output
}
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
AsyncCallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Run the logic of this chain and return the output."""
input
=
inputs
[
self
.
input_key
]
output
=
await
self
.
tool
.
arun
(
input
,
self
.
verbose
)
return
{
self
.
output_key
:
output
}
api/core/orchestrator_rule_parser.py
View file @
c429005c
...
@@ -26,8 +26,9 @@ class OrchestratorRuleParser:
...
@@ -26,8 +26,9 @@ class OrchestratorRuleParser:
def
__init__
(
self
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
):
def
__init__
(
self
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
):
self
.
tenant_id
=
tenant_id
self
.
tenant_id
=
tenant_id
self
.
app_model_config
=
app_model_config
self
.
app_model_config
=
app_model_config
self
.
agent_summary_model_name
=
"gpt-3.5-turbo-16k"
def
to_agent_
arguments
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
def
to_agent_
chain
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
callbacks
:
Callbacks
=
None
)
->
Optional
[
Chain
]:
rest_tokens
:
int
,
callbacks
:
Callbacks
=
None
)
->
Optional
[
Chain
]:
if
not
self
.
app_model_config
.
agent_mode_dict
:
if
not
self
.
app_model_config
.
agent_mode_dict
:
return
None
return
None
...
@@ -54,7 +55,7 @@ class OrchestratorRuleParser:
...
@@ -54,7 +55,7 @@ class OrchestratorRuleParser:
summary_llm
=
LLMBuilder
.
to_llm
(
summary_llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
tenant_id
,
tenant_id
=
self
.
tenant_id
,
model_name
=
"gpt-3.5-turbo-16k"
,
model_name
=
self
.
agent_summary_model_name
,
temperature
=
0
,
temperature
=
0
,
max_tokens
=
500
,
max_tokens
=
500
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
callbacks
=
[
DifyStdOutCallbackHandler
()]
...
@@ -181,15 +182,9 @@ class OrchestratorRuleParser:
...
@@ -181,15 +182,9 @@ class OrchestratorRuleParser:
if
dataset
and
dataset
.
available_document_count
==
0
and
dataset
.
available_document_count
==
0
:
if
dataset
and
dataset
.
available_document_count
==
0
and
dataset
.
available_document_count
==
0
:
return
None
return
None
description
=
dataset
.
description
tool
=
DatasetRetrieverTool
.
from_dataset
(
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
tool
=
DatasetRetrieverTool
(
name
=
f
"dataset_retriever"
,
description
=
description
,
k
=
3
,
dataset
=
dataset
,
dataset
=
dataset
,
k
=
3
,
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
)
)
...
@@ -227,7 +222,8 @@ class OrchestratorRuleParser:
...
@@ -227,7 +222,8 @@ class OrchestratorRuleParser:
tool
=
Tool
(
tool
=
Tool
(
name
=
"google_search"
,
name
=
"google_search"
,
description
=
"A tool for performing a Google search and extracting snippets and webpages "
description
=
"A tool for performing a Google search and extracting snippets and webpages "
"when you need to search for something you don't know or when your information is not up to date."
"when you need to search for something you don't know or when your information "
"is not up to date."
"Input should be a search query."
,
"Input should be a search query."
,
func
=
OptimizedSerpAPIWrapper
(
**
func_kwargs
)
.
run
,
func
=
OptimizedSerpAPIWrapper
(
**
func_kwargs
)
.
run
,
callbacks
=
[
DifyStdOutCallbackHandler
]
callbacks
=
[
DifyStdOutCallbackHandler
]
...
...
api/core/tool/dataset_retriever_tool.py
View file @
c429005c
import
re
from
typing
import
Type
from
flask
import
current_app
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.tools
import
BaseTool
from
langchain.tools
import
BaseTool
from
pydantic
import
Field
,
BaseModel
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.llm_builder
import
LLMBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
from
models.dataset
import
Dataset
class
DatasetRetrieverToolInput
(
BaseModel
):
dataset_id
:
str
=
Field
(
...
,
description
=
"ID of dateset to be queried. MUST be UUID format."
)
query
:
str
=
Field
(
...
,
description
=
"Query for the dataset to be used to retrieve the dataset."
)
class
DatasetRetrieverTool
(
BaseTool
):
class
DatasetRetrieverTool
(
BaseTool
):
"""Tool for querying a Dataset."""
"""Tool for querying a Dataset."""
# todo dataset id as tool argument
name
:
str
=
"dataset_retriever"
args_schema
:
Type
[
BaseModel
]
=
DatasetRetrieverToolInput
description
:
str
=
"use this to retrieve a dataset. "
tenant_id
:
str
k
:
int
=
3
@
classmethod
def
from_dataset
(
cls
,
dataset
:
Dataset
,
**
kwargs
):
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
description
+=
'
\n
ID of dataset MUST be '
+
dataset
.
id
return
cls
(
tenant_id
=
dataset
.
tenant_id
,
description
=
description
,
**
kwargs
)
dataset
:
Dataset
def
_run
(
self
,
dataset_id
:
str
,
query
:
str
)
->
str
:
k
:
int
=
2
pattern
=
r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match
=
re
.
search
(
pattern
,
dataset_id
,
re
.
IGNORECASE
)
if
match
:
dataset_id
=
match
.
group
()
def
_run
(
self
,
tool_input
:
str
)
->
str
:
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
if
self
.
dataset
.
indexing_technique
==
"economy"
:
Dataset
.
tenant_id
==
self
.
tenant_id
,
Dataset
.
id
==
dataset_id
)
.
first
()
if
not
dataset
:
return
f
'[{self.name} failed to find dataset with id {dataset_id}.]'
if
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
# use keyword table query
kw_table_index
=
KeywordTableIndex
(
kw_table_index
=
KeywordTableIndex
(
dataset
=
self
.
dataset
,
dataset
=
dataset
,
config
=
KeywordTableConfig
(
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
5
max_keywords_per_chunk
=
5
)
)
)
)
documents
=
kw_table_index
.
search
(
tool_input
,
search_kwargs
=
{
'k'
:
self
.
k
})
documents
=
kw_table_index
.
search
(
query
,
search_kwargs
=
{
'k'
:
self
.
k
})
else
:
else
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
model_name
=
'text-embedding-ada-002'
)
)
...
@@ -40,49 +78,23 @@ class DatasetRetrieverTool(BaseTool):
...
@@ -40,49 +78,23 @@ class DatasetRetrieverTool(BaseTool):
))
))
vector_index
=
VectorIndex
(
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
dataset
=
dataset
,
config
=
current_app
.
config
,
config
=
current_app
.
config
,
embeddings
=
embeddings
embeddings
=
embeddings
)
)
documents
=
vector_index
.
search
(
documents
=
vector_index
.
search
(
tool_input
,
query
,
search_type
=
'similarity'
,
search_type
=
'similarity'
,
search_kwargs
=
{
search_kwargs
=
{
'k'
:
self
.
k
'k'
:
self
.
k
}
}
)
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
raise
NotImplementedError
()
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
await
vector_index
.
asearch
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
10
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
api/core/tool/provider/errors.py
View file @
c429005c
class
ValidateFailedError
(
Exception
):
class
Tool
ValidateFailedError
(
Exception
):
description
=
"Provider Validate failed"
description
=
"
Tool
Provider Validate failed"
api/core/tool/provider/serpapi_provider.py
View file @
c429005c
from
typing
import
Optional
from
typing
import
Optional
from
core.llm.provider.errors
import
ValidateFailedError
from
core.tool.provider.base
import
BaseToolProvider
from
core.tool.provider.base
import
BaseToolProvider
from
core.tool.provider.errors
import
ToolValidateFailedError
from
models.tool
import
ToolProviderName
from
models.tool
import
ToolProviderName
...
@@ -56,4 +56,4 @@ class SerpAPIToolProvider(BaseToolProvider):
...
@@ -56,4 +56,4 @@ class SerpAPIToolProvider(BaseToolProvider):
:return:
:return:
"""
"""
if
'api_key'
not
in
credentials
or
not
credentials
.
get
(
'api_key'
):
if
'api_key'
not
in
credentials
or
not
credentials
.
get
(
'api_key'
):
raise
ValidateFailedError
(
"SerpAPI api_key is required."
)
raise
Tool
ValidateFailedError
(
"SerpAPI api_key is required."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment