Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
937061cf
Commit
937061cf
authored
Jul 10, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: should use agent
parent
c429005c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
201 additions
and
146 deletions
+201
-146
openai_function_call.py
api/core/agent/agent/openai_function_call.py
+8
-6
openai_multi_function_call.py
api/core/agent/agent/openai_multi_function_call.py
+10
-6
agent_executor.py
api/core/agent/agent_executor.py
+14
-8
main_chain_builder.py
api/core/chain/main_chain_builder.py
+68
-103
completion.py
api/core/completion.py
+69
-10
orchestrator_rule_parser.py
api/core/orchestrator_rule_parser.py
+30
-11
dataset_retriever_tool.py
api/core/tool/dataset_retriever_tool.py
+1
-1
web_reader_tool.py
api/core/tool/web_reader_tool.py
+1
-1
No files found.
api/core/agent/agent/openai_function_call.py
View file @
937061cf
...
@@ -38,7 +38,6 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
...
@@ -38,7 +38,6 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
**
kwargs
,
**
kwargs
,
)
)
def
should_use_agent
(
self
,
query
:
str
):
def
should_use_agent
(
self
,
query
:
str
):
"""
"""
return should use agent
return should use agent
...
@@ -49,15 +48,18 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
...
@@ -49,15 +48,18 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
original_max_tokens
=
self
.
llm
.
max_tokens
original_max_tokens
=
self
.
llm
.
max_tokens
self
.
llm
.
max_tokens
=
6
self
.
llm
.
max_tokens
=
6
agent_decision
=
self
.
plan
(
prompt
=
self
.
prompt
.
format_prompt
(
input
=
query
,
agent_scratchpad
=
[])
intermediate_steps
=
[],
messages
=
prompt
.
to_messages
()
callbacks
=
None
,
input
=
query
predicted_message
=
self
.
llm
.
predict_messages
(
messages
,
functions
=
self
.
functions
,
callbacks
=
None
)
)
function_call
=
predicted_message
.
additional_kwargs
.
get
(
"function_call"
,
{})
self
.
llm
.
max_tokens
=
original_max_tokens
self
.
llm
.
max_tokens
=
original_max_tokens
return
isinstance
(
agent_decision
,
AgentAction
)
return
True
if
function_call
else
False
def
plan
(
def
plan
(
self
,
self
,
...
...
api/core/agent/agent/openai_multi_function_call.py
View file @
937061cf
...
@@ -48,15 +48,18 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
...
@@ -48,15 +48,18 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
original_max_tokens
=
self
.
llm
.
max_tokens
original_max_tokens
=
self
.
llm
.
max_tokens
self
.
llm
.
max_tokens
=
6
self
.
llm
.
max_tokens
=
6
agent_decision
=
self
.
plan
(
prompt
=
self
.
prompt
.
format_prompt
(
input
=
query
,
agent_scratchpad
=
[])
intermediate_steps
=
[],
messages
=
prompt
.
to_messages
()
callbacks
=
None
,
input
=
query
predicted_message
=
self
.
llm
.
predict_messages
(
messages
,
functions
=
self
.
functions
,
callbacks
=
None
)
)
function_call
=
predicted_message
.
additional_kwargs
.
get
(
"function_call"
,
{})
self
.
llm
.
max_tokens
=
original_max_tokens
self
.
llm
.
max_tokens
=
original_max_tokens
return
isinstance
(
agent_decision
,
AgentAction
)
return
True
if
function_call
else
False
def
plan
(
def
plan
(
self
,
self
,
...
@@ -93,7 +96,8 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
...
@@ -93,7 +96,8 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
agent_decision
=
_parse_ai_message
(
predicted_message
)
agent_decision
=
_parse_ai_message
(
predicted_message
)
return
agent_decision
return
agent_decision
def
get_system_message
(
self
):
@
classmethod
def
get_system_message
(
cls
):
# get current time
# get current time
current_time
=
datetime
.
now
()
current_time
=
datetime
.
now
()
current_timezone
=
pytz
.
timezone
(
'UTC'
)
current_timezone
=
pytz
.
timezone
(
'UTC'
)
...
...
api/core/agent/agent_executor.py
View file @
937061cf
...
@@ -5,7 +5,7 @@ from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentE
...
@@ -5,7 +5,7 @@ from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentE
from
langchain.base_language
import
BaseLanguageModel
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.tools
import
BaseTool
from
langchain.tools
import
BaseTool
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
Extra
from
core.agent.agent.openai_function_call
import
AutoSummarizingOpenAIFunctionCallAgent
from
core.agent.agent.openai_function_call
import
AutoSummarizingOpenAIFunctionCallAgent
from
core.agent.agent.openai_multi_function_call
import
AutoSummarizingOpenMultiAIFunctionCallAgent
from
core.agent.agent.openai_multi_function_call
import
AutoSummarizingOpenMultiAIFunctionCallAgent
...
@@ -26,13 +26,19 @@ class AgentConfiguration(BaseModel):
...
@@ -26,13 +26,19 @@ class AgentConfiguration(BaseModel):
llm
:
BaseLanguageModel
llm
:
BaseLanguageModel
tools
:
list
[
BaseTool
]
tools
:
list
[
BaseTool
]
summary_llm
:
BaseLanguageModel
summary_llm
:
BaseLanguageModel
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
=
None
callbacks
:
Callbacks
=
None
callbacks
:
Callbacks
=
None
max_iterations
:
int
=
6
max_iterations
:
int
=
6
max_execution_time
:
Optional
[
float
]
=
None
max_execution_time
:
Optional
[
float
]
=
None
early_stopping_method
:
str
=
"generate"
early_stopping_method
:
str
=
"generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
arbitrary_types_allowed
=
True
class
AgentExecutor
:
class
AgentExecutor
:
def
__init__
(
self
,
configuration
:
AgentConfiguration
):
def
__init__
(
self
,
configuration
:
AgentConfiguration
):
...
@@ -48,18 +54,18 @@ class AgentExecutor:
...
@@ -48,18 +54,18 @@ class AgentExecutor:
verbose
=
True
verbose
=
True
)
)
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
FUNCTION_CALL
:
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
FUNCTION_CALL
:
agent
=
AutoSummarizingOpenAIFunctionCallAgent
(
agent
=
AutoSummarizingOpenAIFunctionCallAgent
.
from_llm_and_tools
(
llm
=
self
.
configuration
.
llm
,
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
tools
=
self
.
configuration
.
tools
,
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
,
# used for read chat histories memory
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
if
self
.
configuration
.
memory
else
None
,
# used for read chat histories memory
summary_llm
=
self
.
configuration
.
summary_llm
,
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
verbose
=
True
)
)
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
MULTI_FUNCTION_CALL
:
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
MULTI_FUNCTION_CALL
:
agent
=
AutoSummarizingOpenMultiAIFunctionCallAgent
(
agent
=
AutoSummarizingOpenMultiAIFunctionCallAgent
.
from_llm_and_tools
(
llm
=
self
.
configuration
.
llm
,
llm
=
self
.
configuration
.
llm
,
tools
=
self
.
configuration
.
tools
,
tools
=
self
.
configuration
.
tools
,
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
,
# used for read chat histories memory
extra_prompt_messages
=
self
.
configuration
.
memory
.
buffer
if
self
.
configuration
.
memory
else
None
,
# used for read chat histories memory
summary_llm
=
self
.
configuration
.
summary_llm
,
summary_llm
=
self
.
configuration
.
summary_llm
,
verbose
=
True
verbose
=
True
)
)
...
@@ -71,7 +77,7 @@ class AgentExecutor:
...
@@ -71,7 +77,7 @@ class AgentExecutor:
def
should_use_agent
(
self
,
query
:
str
)
->
bool
:
def
should_use_agent
(
self
,
query
:
str
)
->
bool
:
return
self
.
agent
.
should_use_agent
(
query
)
return
self
.
agent
.
should_use_agent
(
query
)
def
get_chain
(
self
)
->
AgentExecuto
r
:
def
run
(
self
,
query
:
str
)
->
st
r
:
agent_executor
=
LCAgentExecutor
.
from_agent_and_tools
(
agent_executor
=
LCAgentExecutor
.
from_agent_and_tools
(
agent
=
self
.
agent
,
agent
=
self
.
agent
,
tools
=
self
.
configuration
.
tools
,
tools
=
self
.
configuration
.
tools
,
...
@@ -82,4 +88,4 @@ class AgentExecutor:
...
@@ -82,4 +88,4 @@ class AgentExecutor:
verbose
=
True
verbose
=
True
)
)
return
agent_executor
return
agent_executor
.
run
(
query
)
api/core/chain/main_chain_builder.py
View file @
937061cf
from
typing
import
Optional
,
List
,
cast
,
Tuple
# from typing import Optional, List, cast, Tuple
#
from
langchain.chains
import
SequentialChain
# from langchain.chains import SequentialChain
from
langchain.chains.base
import
Chain
# from langchain.chains.base import Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
# from langchain.memory.chat_memory import BaseChatMemory
#
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
# from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
# from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from
core.chain.multi_dataset_router_chain
import
MultiDatasetRouterChain
# from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from
core.conversation_message_task
import
ConversationMessageTask
# from core.conversation_message_task import ConversationMessageTask
from
core.orchestrator_rule_parser
import
OrchestratorRuleParser
# from core.orchestrator_rule_parser import OrchestratorRuleParser
from
extensions.ext_database
import
db
# from extensions.ext_database import db
from
models.dataset
import
Dataset
# from models.dataset import Dataset
from
models.model
import
AppModelConfig
# from models.model import AppModelConfig
#
#
class
MainChainBuilder
:
# class MainChainBuilder:
@
classmethod
# @classmethod
def
get_chains
(
cls
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
,
memory
:
Optional
[
BaseChatMemory
],
# def get_chains(cls, tenant_id: str, app_model_config: AppModelConfig, memory: Optional[BaseChatMemory],
rest_tokens
:
int
,
conversation_message_task
:
ConversationMessageTask
):
# rest_tokens: int, conversation_message_task: ConversationMessageTask):
first_input_key
=
"input"
# first_input_key = "input"
final_output_key
=
"output"
# final_output_key = "output"
#
chains
=
[]
# chains = []
#
# init orchestrator rule parser
# # init orchestrator rule parser
orchestrator_rule_parser
=
OrchestratorRuleParser
(
# orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id
=
tenant_id
,
# tenant_id=tenant_id,
app_model_config
=
app_model_config
# app_model_config=app_model_config
)
# )
#
# parse sensitive_word_avoidance_chain
# # parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
()
# sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain()
if
sensitive_word_avoidance_chain
:
# if sensitive_word_avoidance_chain:
chains
.
append
(
sensitive_word_avoidance_chain
)
# chains.append(sensitive_word_avoidance_chain)
#
# parse agent chain
# # parse agent chain
agent_chain
=
cls
.
get_agent_chain
(
# agent_executor = orchestrator_rule_parser.to_agent_executor(
tenant_id
=
tenant_id
,
# conversation_message_task=conversation_message_task,
agent_mode
=
app_model_config
.
agent_mode_dict
,
# memory=memory,
rest_tokens
=
rest_tokens
,
# rest_tokens=rest_tokens,
memory
=
memory
,
# callbacks=[DifyStdOutCallbackHandler()]
conversation_message_task
=
conversation_message_task
# )
)
#
# if agent_executor:
if
agent_chain
:
# if isinstance(agent_executor, MultiDatasetRouterChain):
chains
.
append
(
agent_chain
)
# chains.append(agent_executor)
final_output_key
=
agent_chain
.
output_keys
[
0
]
# final_output_key = agent_executor.output_keys[0]
# chains.append(agent_chain)
if
len
(
chains
)
==
0
:
# final_output_key = agent_chain.output_keys[0]
return
None
#
# if len(chains) == 0:
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
# return None
for
chain
in
chains
:
#
chain
=
cast
(
Chain
,
chain
)
# chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
chain
.
callbacks
.
append
(
chain_callback
)
# for chain in chains:
# chain = cast(Chain, chain)
# build main chain
# chain.callbacks.append(chain_callback)
overall_chain
=
SequentialChain
(
#
chains
=
chains
,
# # build main chain
input_variables
=
[
first_input_key
],
# overall_chain = SequentialChain(
output_variables
=
[
final_output_key
],
# chains=chains,
memory
=
memory
,
# only for use the memory prompt input key
# input_variables=[first_input_key],
)
# output_variables=[final_output_key],
# memory=memory, # only for use the memory prompt input key
return
overall_chain
# )
#
@
classmethod
# return overall_chain
def
get_agent_chain
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
rest_tokens
:
int
,
memory
:
Optional
[
BaseChatMemory
],
conversation_message_task
:
ConversationMessageTask
)
->
Chain
:
# agent mode
chain
=
None
if
agent_mode
and
agent_mode
.
get
(
'enabled'
):
tools
=
agent_mode
.
get
(
'tools'
,
[])
datasets
=
[]
for
tool
in
tools
:
tool_type
=
list
(
tool
.
keys
())[
0
]
tool_config
=
list
(
tool
.
values
())[
0
]
if
tool_type
==
"dataset"
:
# get dataset from dataset id
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
tenant_id
,
Dataset
.
id
==
tool_config
.
get
(
"id"
)
)
.
first
()
if
dataset
:
datasets
.
append
(
dataset
)
if
len
(
datasets
)
>
0
:
# tool to chain
multi_dataset_router_chain
=
MultiDatasetRouterChain
.
from_datasets
(
tenant_id
=
tenant_id
,
datasets
=
datasets
,
conversation_message_task
=
conversation_message_task
,
rest_tokens
=
rest_tokens
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
chain
=
multi_dataset_router_chain
return
chain
api/core/completion.py
View file @
937061cf
import
logging
import
logging
import
time
from
typing
import
Optional
,
List
,
Union
,
Tuple
from
typing
import
Optional
,
List
,
Union
,
Tuple
from
langchain.base_language
import
BaseLanguageModel
from
langchain.base_language
import
BaseLanguageModel
...
@@ -8,6 +9,8 @@ from langchain.llms import BaseLLM
...
@@ -8,6 +9,8 @@ from langchain.llms import BaseLLM
from
langchain.schema
import
BaseMessage
,
HumanMessage
from
langchain.schema
import
BaseMessage
,
HumanMessage
from
requests.exceptions
import
ChunkedEncodingError
from
requests.exceptions
import
ChunkedEncodingError
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.chain.multi_dataset_router_chain
import
MultiDatasetRouterChain
from
core.constant
import
llm_constant
from
core.constant
import
llm_constant
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
...
@@ -15,13 +18,13 @@ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCa
...
@@ -15,13 +18,13 @@ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCa
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.llm.error
import
LLMBadRequestError
from
core.llm.error
import
LLMBadRequestError
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.llm_builder
import
LLMBuilder
from
core.chain.main_chain_builder
import
MainChainBuilder
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
ReadOnlyConversationTokenDBBufferSharedMemory
ReadOnlyConversationTokenDBBufferSharedMemory
from
core.memory.read_only_conversation_token_db_string_buffer_shared_memory
import
\
from
core.memory.read_only_conversation_token_db_string_buffer_shared_memory
import
\
ReadOnlyConversationTokenDBStringBufferSharedMemory
ReadOnlyConversationTokenDBStringBufferSharedMemory
from
core.orchestrator_rule_parser
import
OrchestratorRuleParser
from
core.prompt.prompt_builder
import
PromptBuilder
from
core.prompt.prompt_builder
import
PromptBuilder
from
core.prompt.prompt_template
import
JinjaPromptTemplate
from
core.prompt.prompt_template
import
JinjaPromptTemplate
from
core.prompt.prompts
import
MORE_LIKE_THIS_GENERATE_PROMPT
from
core.prompt.prompts
import
MORE_LIKE_THIS_GENERATE_PROMPT
...
@@ -69,28 +72,52 @@ class Completion:
...
@@ -69,28 +72,52 @@ class Completion:
streaming
=
streaming
streaming
=
streaming
)
)
# build main chain include agent
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
main_chain
=
MainChainBuilder
.
get_chains
(
# init orchestrator rule parser
orchestrator_rule_parser
=
OrchestratorRuleParser
(
tenant_id
=
app
.
tenant_id
,
tenant_id
=
app
.
tenant_id
,
app_model_config
=
app_model_config
,
app_model_config
=
app_model_config
rest_tokens
=
rest_tokens_for_context_and_memory
,
)
# parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
([
chain_callback
])
if
sensitive_word_avoidance_chain
:
query
=
sensitive_word_avoidance_chain
.
run
(
query
)
# get agent executor
agent_executor
=
orchestrator_rule_parser
.
to_agent_executor
(
conversation_message_task
=
conversation_message_task
,
memory
=
ReadOnlyConversationTokenDBStringBufferSharedMemory
(
memory
=
memory
)
if
memory
else
None
,
memory
=
ReadOnlyConversationTokenDBStringBufferSharedMemory
(
memory
=
memory
)
if
memory
else
None
,
conversation_message_task
=
conversation_message_task
rest_tokens
=
rest_tokens_for_context_and_memory
,
callbacks
=
[
chain_callback
]
)
)
chain_output
=
''
# run agent executor
if
main_chain
:
executor_output
=
''
chain_output
=
main_chain
.
run
(
query
)
is_agent_output
=
False
if
agent_executor
:
if
isinstance
(
agent_executor
,
MultiDatasetRouterChain
):
executor_output
=
agent_executor
.
run
(
query
)
else
:
should_use_agent
=
agent_executor
.
should_use_agent
(
query
)
if
should_use_agent
:
executor_output
=
agent_executor
.
run
(
query
)
is_agent_output
=
True
# run the final llm
# run the final llm
try
:
try
:
# if is_agent_output and not app_model_config.pre_prompt:
# # todo streaming flush the agent result to user, not call final llm
# pass
cls
.
run_final_llm
(
cls
.
run_final_llm
(
tenant_id
=
app
.
tenant_id
,
tenant_id
=
app
.
tenant_id
,
mode
=
app
.
mode
,
mode
=
app
.
mode
,
app_model_config
=
app_model_config
,
app_model_config
=
app_model_config
,
query
=
query
,
query
=
query
,
inputs
=
inputs
,
inputs
=
inputs
,
chain_output
=
chain
_output
,
chain_output
=
executor
_output
,
conversation_message_task
=
conversation_message_task
,
conversation_message_task
=
conversation_message_task
,
memory
=
memory
,
memory
=
memory
,
streaming
=
streaming
streaming
=
streaming
...
@@ -137,6 +164,38 @@ class Completion:
...
@@ -137,6 +164,38 @@ class Completion:
return
response
return
response
# @classmethod
# def simulate_output(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
# conversation_message_task: ConversationMessageTask,
# executor_output: str, memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
# streaming: bool):
# final_llm = LLMBuilder.to_llm_from_model(
# tenant_id=tenant_id,
# model=app_model_config.model_dict,
# streaming=streaming
# )
#
# # get llm prompt
# prompt, stop_words = cls.get_main_llm_prompt(
# mode=mode,
# llm=final_llm,
# pre_prompt=app_model_config.pre_prompt,
# query=query,
# inputs=inputs,
# chain_output=executor_output,
# memory=memory
# )
#
# final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
#
# llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
# llm_callback_handler
# for token in executor_output:
# if token:
# sys.stdout.write(token)
# sys.stdout.flush()
# time.sleep(0.01)
@
classmethod
@
classmethod
def
get_main_llm_prompt
(
cls
,
mode
:
str
,
llm
:
BaseLanguageModel
,
pre_prompt
:
str
,
query
:
str
,
inputs
:
dict
,
def
get_main_llm_prompt
(
cls
,
mode
:
str
,
llm
:
BaseLanguageModel
,
pre_prompt
:
str
,
query
:
str
,
inputs
:
dict
,
chain_output
:
Optional
[
str
],
chain_output
:
Optional
[
str
],
...
...
api/core/orchestrator_rule_parser.py
View file @
937061cf
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
langchain
import
WikipediaAPIWrapper
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.tools
import
BaseTool
,
Tool
from
langchain.tools
import
BaseTool
,
Tool
,
WikipediaQueryRun
from
core.agent.agent_executor
import
AgentExecutor
,
PlanningStrategy
,
AgentConfiguration
from
core.agent.agent_executor
import
AgentExecutor
,
PlanningStrategy
,
AgentConfiguration
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
...
@@ -23,13 +24,15 @@ from models.model import AppModelConfig
...
@@ -23,13 +24,15 @@ from models.model import AppModelConfig
class
OrchestratorRuleParser
:
class
OrchestratorRuleParser
:
"""Parse the orchestrator rule to entities."""
"""Parse the orchestrator rule to entities."""
def
__init__
(
self
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
):
def
__init__
(
self
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
):
self
.
tenant_id
=
tenant_id
self
.
tenant_id
=
tenant_id
self
.
app_model_config
=
app_model_config
self
.
app_model_config
=
app_model_config
self
.
agent_summary_model_name
=
"gpt-3.5-turbo-16k"
self
.
agent_summary_model_name
=
"gpt-3.5-turbo-16k"
def
to_agent_chain
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
def
to_agent_executor
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
callbacks
:
Callbacks
=
None
)
->
Optional
[
Chain
]:
rest_tokens
:
int
,
callbacks
:
Callbacks
=
None
)
\
->
Optional
[
Union
[
AgentExecutor
|
MultiDatasetRouterChain
]]:
if
not
self
.
app_model_config
.
agent_mode_dict
:
if
not
self
.
app_model_config
.
agent_mode_dict
:
return
None
return
None
...
@@ -61,6 +64,11 @@ class OrchestratorRuleParser:
...
@@ -61,6 +64,11 @@ class OrchestratorRuleParser:
callbacks
=
[
DifyStdOutCallbackHandler
()]
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
)
tools
=
self
.
to_tools
(
tool_configs
,
conversation_message_task
)
if
len
(
tools
)
==
0
:
return
None
agent_configuration
=
AgentConfiguration
(
agent_configuration
=
AgentConfiguration
(
strategy
=
PlanningStrategy
(
agent_mode_config
.
get
(
'strategy'
)),
strategy
=
PlanningStrategy
(
agent_mode_config
.
get
(
'strategy'
)),
llm
=
agent_llm
,
llm
=
agent_llm
,
...
@@ -73,8 +81,7 @@ class OrchestratorRuleParser:
...
@@ -73,8 +81,7 @@ class OrchestratorRuleParser:
early_stopping_method
=
"generate"
early_stopping_method
=
"generate"
)
)
agent_executor
=
AgentExecutor
(
agent_configuration
)
return
AgentExecutor
(
agent_configuration
)
chain
=
agent_executor
.
get_chain
()
return
chain
return
chain
...
@@ -116,7 +123,8 @@ class OrchestratorRuleParser:
...
@@ -116,7 +123,8 @@ class OrchestratorRuleParser:
return
chain
return
chain
def
to_sensitive_word_avoidance_chain
(
self
,
**
kwargs
)
->
Optional
[
SensitiveWordAvoidanceChain
]:
def
to_sensitive_word_avoidance_chain
(
self
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
\
->
Optional
[
SensitiveWordAvoidanceChain
]:
"""
"""
Convert app sensitive word avoidance config to chain
Convert app sensitive word avoidance config to chain
...
@@ -133,7 +141,7 @@ class OrchestratorRuleParser:
...
@@ -133,7 +141,7 @@ class OrchestratorRuleParser:
sensitive_words
=
sensitive_words
.
split
(
","
),
sensitive_words
=
sensitive_words
.
split
(
","
),
canned_response
=
sensitive_word_avoidance_config
.
get
(
"canned_response"
,
''
),
canned_response
=
sensitive_word_avoidance_config
.
get
(
"canned_response"
,
''
),
output_key
=
"sensitive_word_avoidance_output"
,
output_key
=
"sensitive_word_avoidance_output"
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
,
callbacks
=
callbacks
,
**
kwargs
**
kwargs
)
)
...
@@ -151,14 +159,19 @@ class OrchestratorRuleParser:
...
@@ -151,14 +159,19 @@ class OrchestratorRuleParser:
for
tool_config
in
tool_configs
:
for
tool_config
in
tool_configs
:
tool_type
=
list
(
tool_config
.
keys
())[
0
]
tool_type
=
list
(
tool_config
.
keys
())[
0
]
tool_val
=
list
(
tool_config
.
values
())[
0
]
tool_val
=
list
(
tool_config
.
values
())[
0
]
if
not
tool_
config
.
get
(
"enabled"
)
or
tool_config
.
get
(
"enabled"
)
is
not
True
:
if
not
tool_
val
.
get
(
"enabled"
)
or
tool_val
.
get
(
"enabled"
)
is
not
True
:
continue
continue
tool
=
None
tool
=
None
if
tool_type
==
"dataset"
:
if
tool_type
==
"dataset"
:
tool
=
self
.
to_dataset_retriever_tool
(
tool_val
,
conversation_message_task
)
tool
=
None
# tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task)
elif
tool_type
==
"web_reader"
:
elif
tool_type
==
"web_reader"
:
tool
=
self
.
to_web_reader_tool
()
tool
=
self
.
to_web_reader_tool
()
elif
tool_type
==
"google_search"
:
tool
=
self
.
to_google_search_tool
()
elif
tool_type
==
"wikipedia"
:
tool
=
self
.
to_wikipedia_tool
()
if
tool
:
if
tool
:
tools
.
append
(
tool
)
tools
.
append
(
tool
)
...
@@ -226,7 +239,13 @@ class OrchestratorRuleParser:
...
@@ -226,7 +239,13 @@ class OrchestratorRuleParser:
"is not up to date."
"is not up to date."
"Input should be a search query."
,
"Input should be a search query."
,
func
=
OptimizedSerpAPIWrapper
(
**
func_kwargs
)
.
run
,
func
=
OptimizedSerpAPIWrapper
(
**
func_kwargs
)
.
run
,
callbacks
=
[
DifyStdOutCallbackHandler
]
callbacks
=
[
DifyStdOutCallbackHandler
()
]
)
)
return
tool
return
tool
def
to_wikipedia_tool
(
self
)
->
Optional
[
BaseTool
]:
return
WikipediaQueryRun
(
api_wrapper
=
WikipediaAPIWrapper
(
doc_content_chars_max
=
4000
),
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
api/core/tool/dataset_retriever_tool.py
View file @
937061cf
...
@@ -16,7 +16,7 @@ from models.dataset import Dataset
...
@@ -16,7 +16,7 @@ from models.dataset import Dataset
class
DatasetRetrieverToolInput
(
BaseModel
):
class
DatasetRetrieverToolInput
(
BaseModel
):
dataset_id
:
str
=
Field
(
...
,
description
=
"ID of dat
e
set to be queried. MUST be UUID format."
)
dataset_id
:
str
=
Field
(
...
,
description
=
"ID of dat
a
set to be queried. MUST be UUID format."
)
query
:
str
=
Field
(
...
,
description
=
"Query for the dataset to be used to retrieve the dataset."
)
query
:
str
=
Field
(
...
,
description
=
"Query for the dataset to be used to retrieve the dataset."
)
...
...
api/core/tool/web_reader_tool.py
View file @
937061cf
...
@@ -97,7 +97,7 @@ class WebReaderTool(BaseTool):
...
@@ -97,7 +97,7 @@ class WebReaderTool(BaseTool):
if
self
.
continue_reading
and
len
(
page_contents
)
>=
self
.
max_chunk_length
:
if
self
.
continue_reading
and
len
(
page_contents
)
>=
self
.
max_chunk_length
:
page_contents
+=
f
"
\n
PAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION "
\
page_contents
+=
f
"
\n
PAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION "
\
f
"THEN DIRECT ANSWER AND STOP INVOKING
read_page
TOOL, OTHERWISE USE "
\
f
"THEN DIRECT ANSWER AND STOP INVOKING
web_reader
TOOL, OTHERWISE USE "
\
f
"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
f
"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return
page_contents
return
page_contents
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment