Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
c6c81164
Commit
c6c81164
authored
Jul 10, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: optimize tool providers and tool parse
parent
d7712cf7
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
627 additions
and
115 deletions
+627
-115
app.py
api/controllers/console/app/app.py
+3
-0
model_config.py
api/controllers/console/app/model_config.py
+1
-0
agent_executor.py
api/core/agent/agent_executor.py
+39
-53
main_chain_builder.py
api/core/chain/main_chain_builder.py
+31
-38
completion.py
api/core/completion.py
+2
-2
conversation_message_task.py
api/core/conversation_message_task.py
+1
-0
orchestrator_rule_parser.py
api/core/orchestrator_rule_parser.py
+236
-0
dataset_index_tool.py
api/core/tool/dataset_index_tool.py
+4
-1
dataset_retriever_tool.py
api/core/tool/dataset_retriever_tool.py
+88
-0
base.py
api/core/tool/provider/base.py
+59
-0
errors.py
api/core/tool/provider/errors.py
+2
-0
serpapi_provider.py
api/core/tool/provider/serpapi_provider.py
+59
-0
tool_provider_service.py
api/core/tool/provider/tool_provider_service.py
+28
-0
web_reader_tool.py
api/core/tool/web_reader_tool.py
+1
-1
7ce5a52e4eee_add_tool_providers.py
api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
+11
-5
model.py
api/models/model.py
+10
-0
tool.py
api/models/tool.py
+20
-0
app_model_config_service.py
api/services/app_model_config_service.py
+31
-15
completion_service.py
api/services/completion_service.py
+1
-0
No files found.
api/controllers/console/app/app.py
View file @
c6c81164
...
...
@@ -24,6 +24,7 @@ model_config_fields = {
'suggested_questions_after_answer'
:
fields
.
Raw
(
attribute
=
'suggested_questions_after_answer_dict'
),
'speech_to_text'
:
fields
.
Raw
(
attribute
=
'speech_to_text_dict'
),
'more_like_this'
:
fields
.
Raw
(
attribute
=
'more_like_this_dict'
),
'sensitive_word_avoidance'
:
fields
.
Raw
(
attribute
=
'sensitive_word_avoidance_dict'
),
'model'
:
fields
.
Raw
(
attribute
=
'model_dict'
),
'user_input_form'
:
fields
.
Raw
(
attribute
=
'user_input_form_list'
),
'pre_prompt'
:
fields
.
String
,
...
...
@@ -148,6 +149,7 @@ class AppListApi(Resource):
suggested_questions_after_answer
=
json
.
dumps
(
model_configuration
[
'suggested_questions_after_answer'
]),
speech_to_text
=
json
.
dumps
(
model_configuration
[
'speech_to_text'
]),
more_like_this
=
json
.
dumps
(
model_configuration
[
'more_like_this'
]),
sensitive_word_avoidance
=
json
.
dumps
(
model_configuration
[
'sensitive_word_avoidance'
]),
model
=
json
.
dumps
(
model_configuration
[
'model'
]),
user_input_form
=
json
.
dumps
(
model_configuration
[
'user_input_form'
]),
pre_prompt
=
model_configuration
[
'pre_prompt'
],
...
...
@@ -439,6 +441,7 @@ class AppCopy(Resource):
suggested_questions_after_answer
=
app_config
.
suggested_questions_after_answer
,
speech_to_text
=
app_config
.
speech_to_text
,
more_like_this
=
app_config
.
more_like_this
,
sensitive_word_avoidance
=
app_config
.
sensitive_word_avoidance
,
model
=
app_config
.
model
,
user_input_form
=
app_config
.
user_input_form
,
pre_prompt
=
app_config
.
pre_prompt
,
...
...
api/controllers/console/app/model_config.py
View file @
c6c81164
...
...
@@ -43,6 +43,7 @@ class ModelConfigResource(Resource):
suggested_questions_after_answer
=
json
.
dumps
(
model_configuration
[
'suggested_questions_after_answer'
]),
speech_to_text
=
json
.
dumps
(
model_configuration
[
'speech_to_text'
]),
more_like_this
=
json
.
dumps
(
model_configuration
[
'more_like_this'
]),
sensitive_word_avoidance
=
json
.
dumps
(
model_configuration
[
'sensitive_word_avoidance'
]),
model
=
json
.
dumps
(
model_configuration
[
'model'
]),
user_input_form
=
json
.
dumps
(
model_configuration
[
'user_input_form'
]),
pre_prompt
=
model_configuration
[
'pre_prompt'
],
...
...
api/core/agent/agent_executor.py
View file @
c6c81164
import
enum
from
typing
import
Union
,
Optional
from
langchain.agents
import
BaseSingleActionAgent
,
BaseMultiActionAgent
from
langchain.agents
import
BaseSingleActionAgent
,
BaseMultiActionAgent
,
AgentExecutor
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
Callbacks
from
langchain.tools
import
BaseTool
from
pydantic
import
BaseModel
from
core.agent.agent.openai_function_call
import
AutoSummarizingOpenAIFunctionCallAgent
from
core.agent.agent.openai_multi_function_call
import
AutoSummarizingOpenMultiAIFunctionCallAgent
...
...
@@ -20,55 +21,46 @@ class PlanningStrategy(str, enum.Enum):
MULTI_FUNCTION_CALL
=
'multi_function_call'
class
AgentExecutor
:
def
__init__
(
self
,
strategy
:
PlanningStrategy
,
llm
:
BaseLanguageModel
,
tools
:
list
[
BaseTool
],
summary_llm
:
BaseLanguageModel
,
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
,
callbacks
:
Callbacks
=
None
,
max_iterations
:
int
=
6
,
max_execution_time
:
Optional
[
float
]
=
None
,
early_stopping_method
:
str
=
"generate"
):
self
.
strategy
=
strategy
self
.
llm
=
llm
self
.
tools
=
tools
self
.
summary_llm
=
summary_llm
self
.
memory
=
memory
self
.
callbacks
=
callbacks
self
.
agent
=
self
.
_init_agent
(
strategy
,
llm
,
tools
,
memory
,
callbacks
)
class
AgentConfiguration
(
BaseModel
):
strategy
:
PlanningStrategy
llm
:
BaseLanguageModel
tools
:
list
[
BaseTool
]
summary_llm
:
BaseLanguageModel
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
callbacks
:
Callbacks
=
None
max_iterations
:
int
=
6
max_execution_time
:
Optional
[
float
]
=
None
early_stopping_method
:
str
=
"generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
self
.
max_iterations
=
max_iterations
self
.
max_execution_time
=
max_execution_time
self
.
early_stopping_method
=
early_stopping_method
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
# summary_llm: StreamableChatOpenAI = LLMBuilder.to_llm(
# tenant_id=tenant_id,
# model_name='gpt-3.5-turbo-16k',
# max_tokens=300
# )
class
AgentExecutor
:
def
__init__
(
self
,
configuration
:
AgentConfiguration
):
self
.
configuration
=
configuration
self
.
agent
=
self
.
_init_agent
()
def
_init_agent
(
self
,
strategy
:
PlanningStrategy
,
llm
:
BaseLanguageModel
,
tools
:
list
[
BaseTool
],
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
,
callbacks
:
Callbacks
=
None
)
\
->
Union
[
BaseSingleActionAgent
|
BaseMultiActionAgent
]:
if
strategy
==
PlanningStrategy
.
REACT
:
def
_init_agent
(
self
)
->
Union
[
BaseSingleActionAgent
|
BaseMultiActionAgent
]:
if
self
.
configuration
.
strategy
==
PlanningStrategy
.
REACT
:
agent
=
AutoSummarizingStructuredChatAgent
.
from_llm_and_tools
(
llm
=
llm
,
tools
=
tools
,
summary_llm
=
self
.
summary_llm
,
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
)
elif
strategy
==
PlanningStrategy
.
FUNCTION_CALL
:
elif
s
elf
.
configuration
.
s
trategy
==
PlanningStrategy
.
FUNCTION_CALL
:
agent
=
AutoSummarizingOpenAIFunctionCallAgent
(
llm
=
llm
,
tools
=
tools
,
extra_prompt_messages
=
memory
.
buffer
,
# used for read chat histories memory
summary_llm
=
self
.
summary_llm
,
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
,
# used for read chat histories memory
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
)
elif
strategy
==
PlanningStrategy
.
MULTI_FUNCTION_CALL
:
elif
s
elf
.
configuration
.
s
trategy
==
PlanningStrategy
.
MULTI_FUNCTION_CALL
:
agent
=
AutoSummarizingOpenMultiAIFunctionCallAgent
(
llm
=
llm
,
tools
=
tools
,
extra_prompt_messages
=
memory
.
buffer
,
# used for read chat histories memory
summary_llm
=
self
.
summary_llm
,
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
,
# used for read chat histories memory
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
)
...
...
@@ -77,21 +69,15 @@ class AgentExecutor:
def
should_use_agent
(
self
,
query
:
str
)
->
bool
:
return
self
.
agent
.
should_use_agent
(
query
)
def
run
(
self
,
query
:
str
)
->
st
r
:
def
get_chain
(
self
)
->
AgentExecuto
r
:
agent_executor
=
LCAgentExecutor
.
from_agent_and_tools
(
agent
=
self
.
agent
,
tools
=
self
.
tools
,
memory
=
self
.
memory
,
max_iterations
=
self
.
max_iterations
,
max_execution_time
=
self
.
max_execution_time
,
early_stopping_method
=
self
.
early_stopping_method
,
tools
=
self
.
configuration
.
tools
,
memory
=
self
.
configuration
.
memory
,
max_iterations
=
self
.
configuration
.
max_iterations
,
max_execution_time
=
self
.
configuration
.
max_execution_time
,
early_stopping_method
=
self
.
configuration
.
early_stopping_method
,
verbose
=
True
)
# run agent
result
=
agent_executor
.
run
(
query
,
callbacks
=
self
.
callbacks
)
return
result
return
agent_executor
api/core/chain/main_chain_builder.py
View file @
c6c81164
from
typing
import
Optional
,
List
,
cast
from
typing
import
Optional
,
List
,
cast
,
Tuple
from
langchain.chains
import
SequentialChain
from
langchain.chains.base
import
Chain
...
...
@@ -6,44 +6,54 @@ from langchain.memory.chat_memory import BaseChatMemory
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.chain_builder
import
ChainBuilder
from
core.chain.multi_dataset_router_chain
import
MultiDatasetRouterChain
from
core.conversation_message_task
import
ConversationMessageTask
from
core.orchestrator_rule_parser
import
OrchestratorRuleParser
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
from
models.model
import
AppModelConfig
class
MainChainBuilder
:
@
classmethod
def
to_langchain_components
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
conversation_message_task
:
ConversationMessageTask
):
def
get_chains
(
cls
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
conversation_message_task
:
ConversationMessageTask
):
first_input_key
=
"input"
final_output_key
=
"output"
chains
=
[]
chain_callback_handler
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
# init orchestrator rule parser
orchestrator_rule_parser
=
OrchestratorRuleParser
(
tenant_id
=
tenant_id
,
app_model_config
=
app_model_config
)
# agent mode
tool_chains
,
chains_output_key
=
cls
.
get_agent_chains
(
# parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
()
if
sensitive_word_avoidance_chain
:
chains
.
append
(
sensitive_word_avoidance_chain
)
# parse agent chain
agent_chain
=
cls
.
get_agent_chain
(
tenant_id
=
tenant_id
,
agent_mode
=
a
gent_mode
,
agent_mode
=
a
pp_model_config
.
agent_mode_dict
,
rest_tokens
=
rest_tokens
,
memory
=
memory
,
conversation_message_task
=
conversation_message_task
)
chains
+=
tool_chains
if
chains_output_key
:
final_output_key
=
chains_output_key
if
agent_chain
:
chains
.
append
(
agent_chain
)
final_output_key
=
agent_chain
.
output_keys
[
0
]
if
len
(
chains
)
==
0
:
return
None
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
for
chain
in
chains
:
chain
=
cast
(
Chain
,
chain
)
chain
.
callbacks
.
append
(
chain_callback
_handler
)
chain
.
callbacks
.
append
(
chain_callback
)
# build main chain
overall_chain
=
SequentialChain
(
...
...
@@ -56,26 +66,20 @@ class MainChainBuilder:
return
overall_chain
@
classmethod
def
get_agent_chain
s
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
rest_tokens
:
int
,
memory
:
Optional
[
BaseChatMemory
],
conversation_message_task
:
ConversationMessageTask
)
:
def
get_agent_chain
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
rest_tokens
:
int
,
memory
:
Optional
[
BaseChatMemory
],
conversation_message_task
:
ConversationMessageTask
)
->
Chain
:
# agent mode
chain
s
=
[]
chain
=
None
if
agent_mode
and
agent_mode
.
get
(
'enabled'
):
tools
=
agent_mode
.
get
(
'tools'
,
[])
pre_fixed_chains
=
[]
# agent_tools = []
datasets
=
[]
for
tool
in
tools
:
tool_type
=
list
(
tool
.
keys
())[
0
]
tool_config
=
list
(
tool
.
values
())[
0
]
if
tool_type
==
'sensitive-word-avoidance'
:
chain
=
ChainBuilder
.
to_sensitive_word_avoidance_chain
(
tool_config
)
if
chain
:
pre_fixed_chains
.
append
(
chain
)
elif
tool_type
==
"dataset"
:
if
tool_type
==
"dataset"
:
# get dataset from dataset id
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
tenant_id
,
...
...
@@ -85,9 +89,6 @@ class MainChainBuilder:
if
dataset
:
datasets
.
append
(
dataset
)
# add pre-fixed chains
chains
+=
pre_fixed_chains
if
len
(
datasets
)
>
0
:
# tool to chain
multi_dataset_router_chain
=
MultiDatasetRouterChain
.
from_datasets
(
...
...
@@ -97,14 +98,6 @@ class MainChainBuilder:
rest_tokens
=
rest_tokens
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
chains
.
append
(
multi_dataset_router_chain
)
final_output_key
=
cls
.
get_chains_output_key
(
chains
)
chain
=
multi_dataset_router_chain
return
chains
,
final_output_key
@
classmethod
def
get_chains_output_key
(
cls
,
chains
:
List
[
Chain
]):
if
len
(
chains
)
>
0
:
return
chains
[
-
1
]
.
output_keys
[
0
]
return
None
return
chain
api/core/completion.py
View file @
c6c81164
...
...
@@ -70,9 +70,9 @@ class Completion:
)
# build main chain include agent
main_chain
=
MainChainBuilder
.
to_langchain_component
s
(
main_chain
=
MainChainBuilder
.
get_chain
s
(
tenant_id
=
app
.
tenant_id
,
a
gent_mode
=
app_model_config
.
agent_mode_dict
,
a
pp_model_config
=
app_model_config
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
memory
=
ReadOnlyConversationTokenDBStringBufferSharedMemory
(
memory
=
memory
)
if
memory
else
None
,
conversation_message_task
=
conversation_message_task
...
...
api/core/conversation_message_task.py
View file @
c6c81164
...
...
@@ -69,6 +69,7 @@ class ConversationMessageTask:
"suggested_questions"
:
self
.
app_model_config
.
suggested_questions_list
,
"suggested_questions_after_answer"
:
self
.
app_model_config
.
suggested_questions_after_answer_dict
,
"more_like_this"
:
self
.
app_model_config
.
more_like_this_dict
,
"sensitive_word_avoidance"
:
self
.
app_model_config
.
sensitive_word_avoidance_dict
,
"user_input_form"
:
self
.
app_model_config
.
user_input_form_list
,
}
...
...
api/core/orchestrator_rule_parser.py
0 → 100644
View file @
c6c81164
from
typing
import
Optional
from
langchain.callbacks.manager
import
Callbacks
from
langchain.chains.base
import
Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.tools
import
BaseTool
,
Tool
from
core.agent.agent_executor
import
AgentExecutor
,
PlanningStrategy
,
AgentConfiguration
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.multi_dataset_router_chain
import
MultiDatasetRouterChain
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
from
core.tool.provider.serpapi_provider
import
SerpAPIToolProvider
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
from
core.tool.web_reader_tool
import
WebReaderTool
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
from
models.model
import
AppModelConfig
class
OrchestratorRuleParser
:
"""Parse the orchestrator rule to entities."""
def
__init__
(
self
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
):
self
.
tenant_id
=
tenant_id
self
.
app_model_config
=
app_model_config
def
to_agent_arguments
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
callbacks
:
Callbacks
=
None
)
->
Optional
[
Chain
]:
if
not
self
.
app_model_config
.
agent_mode_dict
:
return
None
agent_mode_config
=
self
.
app_model_config
.
agent_mode_dict
chain
=
None
if
agent_mode_config
and
agent_mode_config
.
get
(
'enabled'
):
tool_configs
=
agent_mode_config
.
get
(
'tools'
,
[])
# use router chain if planning strategy is router or not set
if
not
agent_mode_config
.
get
(
'strategy'
)
or
agent_mode_config
.
get
(
'strategy'
)
==
'router'
:
return
self
.
to_router_chain
(
tool_configs
,
conversation_message_task
,
rest_tokens
)
agent_model_name
=
agent_mode_config
.
get
(
'model_name'
,
'gpt-4'
)
agent_llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
tenant_id
,
model_name
=
agent_model_name
,
temperature
=
0
,
max_tokens
=
800
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
summary_llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
tenant_id
,
model_name
=
"gpt-3.5-turbo-16k"
,
temperature
=
0
,
max_tokens
=
500
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
agent_configuration
=
AgentConfiguration
(
strategy
=
PlanningStrategy
(
agent_mode_config
.
get
(
'strategy'
)),
llm
=
agent_llm
,
tools
=
self
.
to_tools
(
tool_configs
,
conversation_message_task
),
summary_llm
=
summary_llm
,
memory
=
memory
,
callbacks
=
callbacks
,
max_iterations
=
6
,
max_execution_time
=
None
,
early_stopping_method
=
"generate"
)
agent_executor
=
AgentExecutor
(
agent_configuration
)
chain
=
agent_executor
.
get_chain
()
return
chain
def
to_router_chain
(
self
,
tool_configs
:
list
,
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
)
->
Optional
[
Chain
]:
"""
Convert tool configs to router chain if planning strategy is router
:param tool_configs:
:param conversation_message_task:
:param rest_tokens:
:return:
"""
chain
=
None
datasets
=
[]
for
tool_config
in
tool_configs
:
tool_type
=
list
(
tool_config
.
keys
())[
0
]
tool_val
=
list
(
tool_config
.
values
())[
0
]
if
tool_type
==
"dataset"
:
# get dataset from dataset id
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
self
.
tenant_id
,
Dataset
.
id
==
tool_val
.
get
(
"id"
)
)
.
first
()
if
dataset
:
datasets
.
append
(
dataset
)
if
len
(
datasets
)
>
0
:
# tool to chain
multi_dataset_router_chain
=
MultiDatasetRouterChain
.
from_datasets
(
tenant_id
=
self
.
tenant_id
,
datasets
=
datasets
,
conversation_message_task
=
conversation_message_task
,
rest_tokens
=
rest_tokens
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
chain
=
multi_dataset_router_chain
return
chain
def
to_sensitive_word_avoidance_chain
(
self
,
**
kwargs
)
->
Optional
[
SensitiveWordAvoidanceChain
]:
"""
Convert app sensitive word avoidance config to chain
:param kwargs:
:return:
"""
if
not
self
.
app_model_config
.
sensitive_word_avoidance_dict
:
return
None
sensitive_word_avoidance_config
=
self
.
app_model_config
.
sensitive_word_avoidance_dict
sensitive_words
=
sensitive_word_avoidance_config
.
get
(
"words"
,
""
)
if
sensitive_word_avoidance_config
.
get
(
"enabled"
,
False
)
and
sensitive_words
:
return
SensitiveWordAvoidanceChain
(
sensitive_words
=
sensitive_words
.
split
(
","
),
canned_response
=
sensitive_word_avoidance_config
.
get
(
"canned_response"
,
''
),
output_key
=
"sensitive_word_avoidance_output"
,
callbacks
=
[
DifyStdOutCallbackHandler
()],
**
kwargs
)
return
None
def
to_tools
(
self
,
tool_configs
:
list
,
conversation_message_task
:
ConversationMessageTask
)
->
list
[
BaseTool
]:
"""
Convert app agent tool configs to tools
:param tool_configs: app agent tool configs
:param conversation_message_task:
:return:
"""
tools
=
[]
for
tool_config
in
tool_configs
:
tool_type
=
list
(
tool_config
.
keys
())[
0
]
tool_val
=
list
(
tool_config
.
values
())[
0
]
if
not
tool_config
.
get
(
"enabled"
)
or
tool_config
.
get
(
"enabled"
)
is
not
True
:
continue
tool
=
None
if
tool_type
==
"dataset"
:
tool
=
self
.
to_dataset_retriever_tool
(
tool_val
,
conversation_message_task
)
elif
tool_type
==
"web_reader"
:
tool
=
self
.
to_web_reader_tool
()
if
tool
:
tools
.
append
(
tool
)
return
tools
def
to_dataset_retriever_tool
(
self
,
tool_config
:
dict
,
conversation_message_task
:
ConversationMessageTask
)
\
->
Optional
[
BaseTool
]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config:
:param conversation_message_task:
:return:
"""
# get dataset from dataset id
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
self
.
tenant_id
,
Dataset
.
id
==
tool_config
.
get
(
"id"
)
)
.
first
()
if
dataset
and
dataset
.
available_document_count
==
0
and
dataset
.
available_document_count
==
0
:
return
None
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
tool
=
DatasetRetrieverTool
(
name
=
f
"dataset_retriever"
,
description
=
description
,
k
=
3
,
dataset
=
dataset
,
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
)
return
tool
def
to_web_reader_tool
(
self
)
->
Optional
[
BaseTool
]:
"""
A tool for reading web pages
:return:
"""
summary_llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
tenant_id
,
model_name
=
"gpt-3.5-turbo-16k"
,
temperature
=
0
,
max_tokens
=
500
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
tool
=
WebReaderTool
(
llm
=
summary_llm
,
max_chunk_length
=
4000
,
continue_reading
=
True
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
return
tool
def
to_google_search_tool
(
self
)
->
Optional
[
BaseTool
]:
tool_provider
=
SerpAPIToolProvider
(
tenant_id
=
self
.
tenant_id
)
func_kwargs
=
tool_provider
.
credentials_to_func_kwargs
()
if
not
func_kwargs
:
return
None
tool
=
Tool
(
name
=
"google_search"
,
description
=
"A tool for performing a Google search and extracting snippets and webpages "
"when you need to search for something you don't know or when your information is not up to date."
"Input should be a search query."
,
func
=
OptimizedSerpAPIWrapper
(
**
func_kwargs
)
.
run
,
callbacks
=
[
DifyStdOutCallbackHandler
]
)
return
tool
api/core/tool/dataset_index_tool.py
View file @
c6c81164
...
...
@@ -11,7 +11,10 @@ from models.dataset import Dataset
class
DatasetTool
(
BaseTool
):
"""Tool for querying a Dataset."""
"""
Tool for querying a Dataset.
Only use for router chain.
"""
dataset
:
Dataset
k
:
int
=
2
...
...
api/core/tool/dataset_retriever_tool.py
0 → 100644
View file @
c6c81164
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.tools
import
BaseTool
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
class
DatasetRetrieverTool
(
BaseTool
):
"""Tool for querying a Dataset."""
# todo dataset id as tool argument
dataset
:
Dataset
k
:
int
=
2
def
_run
(
self
,
tool_input
:
str
)
->
str
:
if
self
.
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
kw_table_index
=
KeywordTableIndex
(
dataset
=
self
.
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
5
)
)
documents
=
kw_table_index
.
search
(
tool_input
,
search_kwargs
=
{
'k'
:
self
.
k
})
else
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
vector_index
.
search
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
self
.
k
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
await
vector_index
.
asearch
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
10
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
api/core/tool/provider/base.py
0 → 100644
View file @
c6c81164
import
base64
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
extensions.ext_database
import
db
from
libs
import
rsa
from
models.account
import
Tenant
from
models.tool
import
ToolProvider
,
ToolProviderName
class
BaseToolProvider
(
ABC
):
def
__init__
(
self
,
tenant_id
:
str
):
self
.
tenant_id
=
tenant_id
@
abstractmethod
def
get_provider_name
(
self
)
->
ToolProviderName
:
raise
NotImplementedError
@
abstractmethod
def
get_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
Optional
[
dict
]:
raise
NotImplementedError
@
abstractmethod
def
credentials_to_func_kwargs
(
self
)
->
Optional
[
dict
]:
raise
NotImplementedError
@
abstractmethod
def
credentials_validate
(
self
,
credentials
:
dict
):
raise
NotImplementedError
def
get_provider
(
self
,
must_enabled
:
bool
=
False
)
->
Optional
[
ToolProvider
]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
"""
query
=
db
.
session
.
query
(
ToolProvider
)
.
filter
(
ToolProvider
.
tenant_id
==
self
.
tenant_id
,
ToolProvider
.
provider_name
==
self
.
get_provider_name
()
)
if
must_enabled
:
query
=
query
.
filter
(
ToolProvider
.
is_enabled
==
True
)
return
query
.
first
()
def
encrypt_token
(
self
,
token
)
->
str
:
tenant
=
db
.
session
.
query
(
Tenant
)
.
filter
(
Tenant
.
id
==
self
.
tenant_id
)
.
first
()
encrypted_token
=
rsa
.
encrypt
(
token
,
tenant
.
encrypt_public_key
)
return
base64
.
b64encode
(
encrypted_token
)
.
decode
()
def
decrypt_token
(
self
,
token
:
str
,
obfuscated
:
bool
=
False
)
->
str
:
token
=
rsa
.
decrypt
(
base64
.
b64decode
(
token
),
self
.
tenant_id
)
if
obfuscated
:
return
self
.
_obfuscated_token
(
token
)
return
token
def
_obfuscated_token
(
self
,
token
:
str
)
->
str
:
return
token
[:
6
]
+
'*'
*
(
len
(
token
)
-
8
)
+
token
[
-
2
:]
api/core/tool/provider/errors.py
0 → 100644
View file @
c6c81164
class
ValidateFailedError
(
Exception
):
description
=
"Provider Validate failed"
api/core/tool/provider/serpapi_provider.py
0 → 100644
View file @
c6c81164
from
typing
import
Optional
from
core.llm.provider.errors
import
ValidateFailedError
from
core.tool.provider.base
import
BaseToolProvider
from
models.tool
import
ToolProviderName
class
SerpAPIToolProvider
(
BaseToolProvider
):
def
get_provider_name
(
self
)
->
ToolProviderName
:
"""
Returns the name of the provider.
:return:
"""
return
ToolProviderName
.
SERPAPI
def
get_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
Optional
[
dict
]:
"""
Returns the credentials for SerpAPI as a dictionary.
:param obfuscated: obfuscate credentials if True
:return:
"""
tool_provider
=
self
.
get_provider
(
must_enabled
=
True
)
if
not
tool_provider
:
return
None
config
=
tool_provider
.
config
if
not
config
:
return
None
if
config
.
get
(
'api_key'
):
config
[
'api_key'
]
=
self
.
decrypt_token
(
config
.
get
(
'api_key'
),
obfuscated
)
return
config
def
credentials_to_func_kwargs
(
self
)
->
Optional
[
dict
]:
"""
Returns the credentials function kwargs as a dictionary.
:return:
"""
credentials
=
self
.
get_credentials
()
if
not
credentials
:
return
None
return
{
'serpapi_api_key'
:
credentials
.
get
(
'api_key'
)
}
def
credentials_validate
(
self
,
credentials
:
dict
):
"""
Validates the given credentials.
:param credentials:
:return:
"""
if
'api_key'
not
in
credentials
or
not
credentials
.
get
(
'api_key'
):
raise
ValidateFailedError
(
"SerpAPI api_key is required."
)
api/core/tool/provider/tool_provider_service.py
0 → 100644
View file @
c6c81164
from
typing
import
Optional
from
core.tool.provider.base
import
BaseToolProvider
from
core.tool.provider.serpapi_provider
import
SerpAPIToolProvider
class
ToolProviderService
:
def
__init__
(
self
,
tenant_id
:
str
,
provider_name
:
str
):
self
.
provider
=
self
.
_init_provider
(
tenant_id
,
provider_name
)
def
_init_provider
(
self
,
tenant_id
:
str
,
provider_name
:
str
)
->
BaseToolProvider
:
if
provider_name
==
'serpapi'
:
return
SerpAPIToolProvider
(
tenant_id
)
else
:
raise
Exception
(
'tool provider {} not found'
.
format
(
provider_name
))
def
get_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
Optional
[
dict
]:
return
self
.
provider
.
get_credentials
(
obfuscated
)
def
credentials_validate
(
self
,
credentials
:
dict
):
"""
Validates the given credentials.
:param credentials:
:raises: ValidateFailedError
"""
return
self
.
provider
.
credentials_validate
(
credentials
)
api/core/tool/web_reader_tool.py
View file @
c6c81164
...
...
@@ -52,7 +52,7 @@ class WebReaderToolInput(BaseModel):
class
WebReaderTool
(
BaseTool
):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name
:
str
=
"
read_page
"
name
:
str
=
"
web_reader
"
args_schema
:
Type
[
BaseModel
]
=
WebReaderToolInput
description
:
str
=
"use this to read a website. "
\
"If you can answer the question based on the information provided, "
\
...
...
api/migrations/versions/
46c503018f11_add_tool_pt
oviders.py
→
api/migrations/versions/
7ce5a52e4eee_add_tool_pr
oviders.py
View file @
c6c81164
"""add tool p
t
oviders
"""add tool p
r
oviders
Revision ID:
46c503018f11
Revision ID:
7ce5a52e4eee
Revises: 2beac44e5f5f
Create Date: 2023-07-
07 16:35:32.97407
5
Create Date: 2023-07-
10 10:26:50.07451
5
"""
from
alembic
import
op
...
...
@@ -10,7 +10,7 @@ import sqlalchemy as sa
from
sqlalchemy.dialects
import
postgresql
# revision identifiers, used by Alembic.
revision
=
'
46c503018f11
'
revision
=
'
7ce5a52e4eee
'
down_revision
=
'2beac44e5f5f'
branch_labels
=
None
depends_on
=
None
...
...
@@ -23,16 +23,22 @@ def upgrade():
sa
.
Column
(
'tenant_id'
,
postgresql
.
UUID
(),
nullable
=
False
),
sa
.
Column
(
'tool_name'
,
sa
.
String
(
length
=
40
),
nullable
=
False
),
sa
.
Column
(
'encrypted_config'
,
sa
.
Text
(),
nullable
=
True
),
sa
.
Column
(
'is_
vali
d'
,
sa
.
Boolean
(),
server_default
=
sa
.
text
(
'false'
),
nullable
=
False
),
sa
.
Column
(
'is_
enable
d'
,
sa
.
Boolean
(),
server_default
=
sa
.
text
(
'false'
),
nullable
=
False
),
sa
.
Column
(
'created_at'
,
sa
.
DateTime
(),
server_default
=
sa
.
text
(
'CURRENT_TIMESTAMP(0)'
),
nullable
=
False
),
sa
.
Column
(
'updated_at'
,
sa
.
DateTime
(),
server_default
=
sa
.
text
(
'CURRENT_TIMESTAMP(0)'
),
nullable
=
False
),
sa
.
PrimaryKeyConstraint
(
'id'
,
name
=
'tool_provider_pkey'
),
sa
.
UniqueConstraint
(
'tenant_id'
,
'tool_name'
,
name
=
'unique_tool_provider_tool_name'
)
)
with
op
.
batch_alter_table
(
'app_model_configs'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'sensitive_word_avoidance'
,
sa
.
Text
(),
nullable
=
True
))
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'app_model_configs'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'sensitive_word_avoidance'
)
op
.
drop_table
(
'tool_providers'
)
# ### end Alembic commands ###
api/models/model.py
View file @
c6c81164
...
...
@@ -88,6 +88,7 @@ class AppModelConfig(db.Model):
user_input_form
=
db
.
Column
(
db
.
Text
)
pre_prompt
=
db
.
Column
(
db
.
Text
)
agent_mode
=
db
.
Column
(
db
.
Text
)
sensitive_word_avoidance
=
db
.
Column
(
db
.
Text
)
@
property
def
app
(
self
):
...
...
@@ -116,6 +117,11 @@ class AppModelConfig(db.Model):
def
more_like_this_dict
(
self
)
->
dict
:
return
json
.
loads
(
self
.
more_like_this
)
if
self
.
more_like_this
else
{
"enabled"
:
False
}
@
property
def
sensitive_word_avoidance_dict
(
self
)
->
dict
:
return
json
.
loads
(
self
.
sensitive_word_avoidance
)
if
self
.
sensitive_word_avoidance
\
else
{
"enabled"
:
False
,
"words"
:
[],
"canned_response"
:
[]}
@
property
def
user_input_form_list
(
self
)
->
dict
:
return
json
.
loads
(
self
.
user_input_form
)
if
self
.
user_input_form
else
[]
...
...
@@ -235,6 +241,9 @@ class Conversation(db.Model):
if
'speech_to_text'
in
override_model_configs
else
{
"enabled"
:
False
}
model_config
[
'more_like_this'
]
=
override_model_configs
[
'more_like_this'
]
\
if
'more_like_this'
in
override_model_configs
else
{
"enabled"
:
False
}
model_config
[
'sensitive_word_avoidance'
]
=
override_model_configs
[
'sensitive_word_avoidance'
]
\
if
'sensitive_word_avoidance'
in
override_model_configs
\
else
{
"enabled"
:
False
,
"words"
:
[],
"canned_response"
:
[]}
model_config
[
'user_input_form'
]
=
override_model_configs
[
'user_input_form'
]
else
:
model_config
[
'configs'
]
=
override_model_configs
...
...
@@ -251,6 +260,7 @@ class Conversation(db.Model):
model_config
[
'suggested_questions_after_answer'
]
=
app_model_config
.
suggested_questions_after_answer_dict
model_config
[
'speech_to_text'
]
=
app_model_config
.
speech_to_text_dict
model_config
[
'more_like_this'
]
=
app_model_config
.
more_like_this_dict
model_config
[
'sensitive_word_avoidance'
]
=
app_model_config
.
sensitive_word_avoidance_dict
model_config
[
'user_input_form'
]
=
app_model_config
.
user_input_form_list
model_config
[
'model_id'
]
=
self
.
model_id
...
...
api/models/tool.py
View file @
c6c81164
import
json
from
enum
import
Enum
from
sqlalchemy.dialects.postgresql
import
UUID
from
extensions.ext_database
import
db
class
ToolProviderName
(
Enum
):
SERPAPI
=
'serpapi'
@
staticmethod
def
value_of
(
value
):
for
member
in
ToolProviderName
:
if
member
.
value
==
value
:
return
member
raise
ValueError
(
f
"No matching enum found for value '{value}'"
)
class
ToolProvider
(
db
.
Model
):
__tablename__
=
'tool_providers'
...
...
@@ -24,3 +37,10 @@ class ToolProvider(db.Model):
Returns True if the encrypted_config is not None, indicating that the token is set.
"""
return
self
.
encrypted_config
is
not
None
@
property
def
config
(
self
):
"""
Returns the decrypted config.
"""
return
json
.
loads
(
self
.
decrypt_config
())
if
self
.
encrypted_config
is
not
None
else
None
api/services/app_model_config_service.py
View file @
c6c81164
...
...
@@ -145,6 +145,33 @@ class AppModelConfigService:
if
not
isinstance
(
config
[
"more_like_this"
][
"enabled"
],
bool
):
raise
ValueError
(
"enabled in more_like_this must be of boolean type"
)
# sensitive_word_avoidance
if
'sensitive_word_avoidance'
not
in
config
or
not
config
[
"sensitive_word_avoidance"
]:
config
[
"sensitive_word_avoidance"
]
=
{
"enabled"
:
False
}
if
not
isinstance
(
config
[
"sensitive_word_avoidance"
],
dict
):
raise
ValueError
(
"sensitive_word_avoidance must be of dict type"
)
if
"enabled"
not
in
config
[
"sensitive_word_avoidance"
]
or
not
config
[
"sensitive_word_avoidance"
][
"enabled"
]:
config
[
"sensitive_word_avoidance"
][
"enabled"
]
=
False
if
not
isinstance
(
config
[
"sensitive_word_avoidance"
][
"enabled"
],
bool
):
raise
ValueError
(
"enabled in sensitive_word_avoidance must be of boolean type"
)
if
"words"
not
in
config
[
"sensitive_word_avoidance"
]
or
not
config
[
"sensitive_word_avoidance"
][
"words"
]:
config
[
"sensitive_word_avoidance"
][
"words"
]
=
""
if
not
isinstance
(
config
[
"sensitive_word_avoidance"
][
"words"
],
str
):
raise
ValueError
(
"words in sensitive_word_avoidance must be of string type"
)
if
"canned_response"
not
in
config
[
"sensitive_word_avoidance"
]
or
not
config
[
"sensitive_word_avoidance"
][
"canned_response"
]:
config
[
"sensitive_word_avoidance"
][
"canned_response"
]
=
""
if
not
isinstance
(
config
[
"sensitive_word_avoidance"
][
"canned_response"
],
str
):
raise
ValueError
(
"canned_response in sensitive_word_avoidance must be of string type"
)
# model
if
'model'
not
in
config
:
raise
ValueError
(
"model is required"
)
...
...
@@ -258,8 +285,8 @@ class AppModelConfigService:
for
tool
in
config
[
"agent_mode"
][
"tools"
]:
key
=
list
(
tool
.
keys
())[
0
]
if
key
not
in
[
"
sensitive-word-avoidance"
,
"
dataset"
]:
raise
ValueError
(
"Keys in agent_mode.tools list can only be '
sensitive-word-avoidance' or '
dataset'"
)
if
key
not
in
[
"dataset"
]:
raise
ValueError
(
"Keys in agent_mode.tools list can only be 'dataset'"
)
tool_item
=
tool
[
key
]
...
...
@@ -269,19 +296,7 @@ class AppModelConfigService:
if
not
isinstance
(
tool_item
[
"enabled"
],
bool
):
raise
ValueError
(
"enabled in agent_mode.tools must be of boolean type"
)
if
key
==
"sensitive-word-avoidance"
:
if
"words"
not
in
tool_item
or
not
tool_item
[
"words"
]:
tool_item
[
"words"
]
=
""
if
not
isinstance
(
tool_item
[
"words"
],
str
):
raise
ValueError
(
"words in sensitive-word-avoidance must be of string type"
)
if
"canned_response"
not
in
tool_item
or
not
tool_item
[
"canned_response"
]:
tool_item
[
"canned_response"
]
=
""
if
not
isinstance
(
tool_item
[
"canned_response"
],
str
):
raise
ValueError
(
"canned_response in sensitive-word-avoidance must be of string type"
)
elif
key
==
"dataset"
:
if
key
==
"dataset"
:
if
'id'
not
in
tool_item
:
raise
ValueError
(
"id is required in dataset"
)
...
...
@@ -300,6 +315,7 @@ class AppModelConfigService:
"suggested_questions_after_answer"
:
config
[
"suggested_questions_after_answer"
],
"speech_to_text"
:
config
[
"speech_to_text"
],
"more_like_this"
:
config
[
"more_like_this"
],
"sensitive_word_avoidance"
:
config
[
"sensitive_word_avoidance"
],
"model"
:
{
"provider"
:
config
[
"model"
][
"provider"
],
"name"
:
config
[
"model"
][
"name"
],
...
...
api/services/completion_service.py
View file @
c6c81164
...
...
@@ -140,6 +140,7 @@ class CompletionService:
suggested_questions
=
json
.
dumps
(
model_config
[
'suggested_questions'
]),
suggested_questions_after_answer
=
json
.
dumps
(
model_config
[
'suggested_questions_after_answer'
]),
more_like_this
=
json
.
dumps
(
model_config
[
'more_like_this'
]),
sensitive_word_avoidance
=
json
.
dumps
(
model_config
[
'sensitive_word_avoidance'
]),
model
=
json
.
dumps
(
model_config
[
'model'
]),
user_input_form
=
json
.
dumps
(
model_config
[
'user_input_form'
]),
pre_prompt
=
model_config
[
'pre_prompt'
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment