Commit 937061cf authored by John Wang's avatar John Wang

fix: should use agent

parent c429005c
...@@ -38,7 +38,6 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio ...@@ -38,7 +38,6 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
**kwargs, **kwargs,
) )
def should_use_agent(self, query: str): def should_use_agent(self, query: str):
""" """
return should use agent return should use agent
...@@ -49,15 +48,18 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio ...@@ -49,15 +48,18 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
original_max_tokens = self.llm.max_tokens original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 6 self.llm.max_tokens = 6
agent_decision = self.plan( prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
intermediate_steps=[], messages = prompt.to_messages()
callbacks=None,
input=query predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
) )
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens self.llm.max_tokens = original_max_tokens
return isinstance(agent_decision, AgentAction) return True if function_call else False
def plan( def plan(
self, self,
......
...@@ -48,15 +48,18 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope ...@@ -48,15 +48,18 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
original_max_tokens = self.llm.max_tokens original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 6 self.llm.max_tokens = 6
agent_decision = self.plan( prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
intermediate_steps=[], messages = prompt.to_messages()
callbacks=None,
input=query predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
) )
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens self.llm.max_tokens = original_max_tokens
return isinstance(agent_decision, AgentAction) return True if function_call else False
def plan( def plan(
self, self,
...@@ -93,7 +96,8 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope ...@@ -93,7 +96,8 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
agent_decision = _parse_ai_message(predicted_message) agent_decision = _parse_ai_message(predicted_message)
return agent_decision return agent_decision
def get_system_message(self): @classmethod
def get_system_message(cls):
# get current time # get current time
current_time = datetime.now() current_time = datetime.now()
current_timezone = pytz.timezone('UTC') current_timezone = pytz.timezone('UTC')
......
...@@ -5,7 +5,7 @@ from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentE ...@@ -5,7 +5,7 @@ from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentE
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel from pydantic import BaseModel, Extra
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
...@@ -26,13 +26,19 @@ class AgentConfiguration(BaseModel): ...@@ -26,13 +26,19 @@ class AgentConfiguration(BaseModel):
llm: BaseLanguageModel llm: BaseLanguageModel
tools: list[BaseTool] tools: list[BaseTool]
summary_llm: BaseLanguageModel summary_llm: BaseLanguageModel
memory: ReadOnlyConversationTokenDBBufferSharedMemory memory: ReadOnlyConversationTokenDBBufferSharedMemory = None
callbacks: Callbacks = None callbacks: Callbacks = None
max_iterations: int = 6 max_iterations: int = 6
max_execution_time: Optional[float] = None max_execution_time: Optional[float] = None
early_stopping_method: str = "generate" early_stopping_method: str = "generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
class AgentExecutor: class AgentExecutor:
def __init__(self, configuration: AgentConfiguration): def __init__(self, configuration: AgentConfiguration):
...@@ -48,18 +54,18 @@ class AgentExecutor: ...@@ -48,18 +54,18 @@ class AgentExecutor:
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent( agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm, llm=self.configuration.llm,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm, summary_llm=self.configuration.summary_llm,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent( agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm, llm=self.configuration.llm,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm, summary_llm=self.configuration.summary_llm,
verbose=True verbose=True
) )
...@@ -71,7 +77,7 @@ class AgentExecutor: ...@@ -71,7 +77,7 @@ class AgentExecutor:
def should_use_agent(self, query: str) -> bool: def should_use_agent(self, query: str) -> bool:
return self.agent.should_use_agent(query) return self.agent.should_use_agent(query)
def get_chain(self) -> AgentExecutor: def run(self, query: str) -> str:
agent_executor = LCAgentExecutor.from_agent_and_tools( agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent, agent=self.agent,
tools=self.configuration.tools, tools=self.configuration.tools,
...@@ -82,4 +88,4 @@ class AgentExecutor: ...@@ -82,4 +88,4 @@ class AgentExecutor:
verbose=True verbose=True
) )
return agent_executor return agent_executor.run(query)
from typing import Optional, List, cast, Tuple # from typing import Optional, List, cast, Tuple
#
from langchain.chains import SequentialChain # from langchain.chains import SequentialChain
from langchain.chains.base import Chain # from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory # from langchain.memory.chat_memory import BaseChatMemory
#
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler # from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler # from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain # from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.conversation_message_task import ConversationMessageTask # from core.conversation_message_task import ConversationMessageTask
from core.orchestrator_rule_parser import OrchestratorRuleParser # from core.orchestrator_rule_parser import OrchestratorRuleParser
from extensions.ext_database import db # from extensions.ext_database import db
from models.dataset import Dataset # from models.dataset import Dataset
from models.model import AppModelConfig # from models.model import AppModelConfig
#
#
class MainChainBuilder: # class MainChainBuilder:
@classmethod # @classmethod
def get_chains(cls, tenant_id: str, app_model_config: AppModelConfig, memory: Optional[BaseChatMemory], # def get_chains(cls, tenant_id: str, app_model_config: AppModelConfig, memory: Optional[BaseChatMemory],
rest_tokens: int, conversation_message_task: ConversationMessageTask): # rest_tokens: int, conversation_message_task: ConversationMessageTask):
first_input_key = "input" # first_input_key = "input"
final_output_key = "output" # final_output_key = "output"
#
chains = [] # chains = []
#
# init orchestrator rule parser # # init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser( # orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=tenant_id, # tenant_id=tenant_id,
app_model_config=app_model_config # app_model_config=app_model_config
) # )
#
# parse sensitive_word_avoidance_chain # # parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain() # sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain()
if sensitive_word_avoidance_chain: # if sensitive_word_avoidance_chain:
chains.append(sensitive_word_avoidance_chain) # chains.append(sensitive_word_avoidance_chain)
#
# parse agent chain # # parse agent chain
agent_chain = cls.get_agent_chain( # agent_executor = orchestrator_rule_parser.to_agent_executor(
tenant_id=tenant_id, # conversation_message_task=conversation_message_task,
agent_mode=app_model_config.agent_mode_dict, # memory=memory,
rest_tokens=rest_tokens, # rest_tokens=rest_tokens,
memory=memory, # callbacks=[DifyStdOutCallbackHandler()]
conversation_message_task=conversation_message_task # )
) #
# if agent_executor:
if agent_chain: # if isinstance(agent_executor, MultiDatasetRouterChain):
chains.append(agent_chain) # chains.append(agent_executor)
final_output_key = agent_chain.output_keys[0] # final_output_key = agent_executor.output_keys[0]
# chains.append(agent_chain)
if len(chains) == 0: # final_output_key = agent_chain.output_keys[0]
return None #
# if len(chains) == 0:
chain_callback = MainChainGatherCallbackHandler(conversation_message_task) # return None
for chain in chains: #
chain = cast(Chain, chain) # chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
chain.callbacks.append(chain_callback) # for chain in chains:
# chain = cast(Chain, chain)
# build main chain # chain.callbacks.append(chain_callback)
overall_chain = SequentialChain( #
chains=chains, # # build main chain
input_variables=[first_input_key], # overall_chain = SequentialChain(
output_variables=[final_output_key], # chains=chains,
memory=memory, # only for use the memory prompt input key # input_variables=[first_input_key],
) # output_variables=[final_output_key],
# memory=memory, # only for use the memory prompt input key
return overall_chain # )
#
@classmethod # return overall_chain
def get_agent_chain(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask) -> Chain:
# agent mode
chain = None
if agent_mode and agent_mode.get('enabled'):
tools = agent_mode.get('tools', [])
datasets = []
for tool in tools:
tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0]
if tool_type == "dataset":
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == tool_config.get("id")
).first()
if dataset:
datasets.append(dataset)
if len(datasets) > 0:
# tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chain = multi_dataset_router_chain
return chain
import logging import logging
import time
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
...@@ -8,6 +9,8 @@ from langchain.llms import BaseLLM ...@@ -8,6 +9,8 @@ from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, HumanMessage from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.constant import llm_constant from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
...@@ -15,13 +18,13 @@ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCa ...@@ -15,13 +18,13 @@ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCa
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.streamable_open_ai import StreamableOpenAI
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory ReadOnlyConversationTokenDBBufferSharedMemory
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
ReadOnlyConversationTokenDBStringBufferSharedMemory ReadOnlyConversationTokenDBStringBufferSharedMemory
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
...@@ -69,28 +72,52 @@ class Completion: ...@@ -69,28 +72,52 @@ class Completion:
streaming=streaming streaming=streaming
) )
# build main chain include agent chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
main_chain = MainChainBuilder.get_chains(
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
app_model_config=app_model_config, app_model_config=app_model_config
rest_tokens=rest_tokens_for_context_and_memory, )
# parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task rest_tokens=rest_tokens_for_context_and_memory,
callbacks=[chain_callback]
) )
chain_output = '' # run agent executor
if main_chain: executor_output = ''
chain_output = main_chain.run(query) is_agent_output = False
if agent_executor:
if isinstance(agent_executor, MultiDatasetRouterChain):
executor_output = agent_executor.run(query)
else:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
executor_output = agent_executor.run(query)
is_agent_output = True
# run the final llm # run the final llm
try: try:
# if is_agent_output and not app_model_config.pre_prompt:
# # todo streaming flush the agent result to user, not call final llm
# pass
cls.run_final_llm( cls.run_final_llm(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
mode=app.mode, mode=app.mode,
app_model_config=app_model_config, app_model_config=app_model_config,
query=query, query=query,
inputs=inputs, inputs=inputs,
chain_output=chain_output, chain_output=executor_output,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
memory=memory, memory=memory,
streaming=streaming streaming=streaming
...@@ -137,6 +164,38 @@ class Completion: ...@@ -137,6 +164,38 @@ class Completion:
return response return response
# @classmethod
# def simulate_output(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
# conversation_message_task: ConversationMessageTask,
# executor_output: str, memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
# streaming: bool):
# final_llm = LLMBuilder.to_llm_from_model(
# tenant_id=tenant_id,
# model=app_model_config.model_dict,
# streaming=streaming
# )
#
# # get llm prompt
# prompt, stop_words = cls.get_main_llm_prompt(
# mode=mode,
# llm=final_llm,
# pre_prompt=app_model_config.pre_prompt,
# query=query,
# inputs=inputs,
# chain_output=executor_output,
# memory=memory
# )
#
# final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
#
# llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
# llm_callback_handler
# for token in executor_output:
# if token:
# sys.stdout.write(token)
# sys.stdout.flush()
# time.sleep(0.01)
@classmethod @classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str], chain_output: Optional[str],
......
from typing import Optional from typing import Optional, Union
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
...@@ -23,13 +24,15 @@ from models.model import AppModelConfig ...@@ -23,13 +24,15 @@ from models.model import AppModelConfig
class OrchestratorRuleParser: class OrchestratorRuleParser:
"""Parse the orchestrator rule to entities.""" """Parse the orchestrator rule to entities."""
def __init__(self, tenant_id: str, app_model_config: AppModelConfig): def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.app_model_config = app_model_config self.app_model_config = app_model_config
self.agent_summary_model_name = "gpt-3.5-turbo-16k" self.agent_summary_model_name = "gpt-3.5-turbo-16k"
def to_agent_chain(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, callbacks: Callbacks = None) -> Optional[Chain]: rest_tokens: int, callbacks: Callbacks = None) \
-> Optional[Union[AgentExecutor | MultiDatasetRouterChain]]:
if not self.app_model_config.agent_mode_dict: if not self.app_model_config.agent_mode_dict:
return None return None
...@@ -61,6 +64,11 @@ class OrchestratorRuleParser: ...@@ -61,6 +64,11 @@ class OrchestratorRuleParser:
callbacks=[DifyStdOutCallbackHandler()] callbacks=[DifyStdOutCallbackHandler()]
) )
tools = self.to_tools(tool_configs, conversation_message_task)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration( agent_configuration = AgentConfiguration(
strategy=PlanningStrategy(agent_mode_config.get('strategy')), strategy=PlanningStrategy(agent_mode_config.get('strategy')),
llm=agent_llm, llm=agent_llm,
...@@ -73,8 +81,7 @@ class OrchestratorRuleParser: ...@@ -73,8 +81,7 @@ class OrchestratorRuleParser:
early_stopping_method="generate" early_stopping_method="generate"
) )
agent_executor = AgentExecutor(agent_configuration) return AgentExecutor(agent_configuration)
chain = agent_executor.get_chain()
return chain return chain
...@@ -116,7 +123,8 @@ class OrchestratorRuleParser: ...@@ -116,7 +123,8 @@ class OrchestratorRuleParser:
return chain return chain
def to_sensitive_word_avoidance_chain(self, **kwargs) -> Optional[SensitiveWordAvoidanceChain]: def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]:
""" """
Convert app sensitive word avoidance config to chain Convert app sensitive word avoidance config to chain
...@@ -133,7 +141,7 @@ class OrchestratorRuleParser: ...@@ -133,7 +141,7 @@ class OrchestratorRuleParser:
sensitive_words=sensitive_words.split(","), sensitive_words=sensitive_words.split(","),
canned_response=sensitive_word_avoidance_config.get("canned_response", ''), canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output", output_key="sensitive_word_avoidance_output",
callbacks=[DifyStdOutCallbackHandler()], callbacks=callbacks,
**kwargs **kwargs
) )
...@@ -151,14 +159,19 @@ class OrchestratorRuleParser: ...@@ -151,14 +159,19 @@ class OrchestratorRuleParser:
for tool_config in tool_configs: for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0] tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0] tool_val = list(tool_config.values())[0]
if not tool_config.get("enabled") or tool_config.get("enabled") is not True: if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
continue continue
tool = None tool = None
if tool_type == "dataset": if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task) tool = None
# tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task)
elif tool_type == "web_reader": elif tool_type == "web_reader":
tool = self.to_web_reader_tool() tool = self.to_web_reader_tool()
elif tool_type == "google_search":
tool = self.to_google_search_tool()
elif tool_type == "wikipedia":
tool = self.to_wikipedia_tool()
if tool: if tool:
tools.append(tool) tools.append(tool)
...@@ -226,7 +239,13 @@ class OrchestratorRuleParser: ...@@ -226,7 +239,13 @@ class OrchestratorRuleParser:
"is not up to date." "is not up to date."
"Input should be a search query.", "Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run, func=OptimizedSerpAPIWrapper(**func_kwargs).run,
callbacks=[DifyStdOutCallbackHandler] callbacks=[DifyStdOutCallbackHandler()]
) )
return tool return tool
def to_wikipedia_tool(self) -> Optional[BaseTool]:
return WikipediaQueryRun(
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
callbacks=[DifyStdOutCallbackHandler()]
)
...@@ -16,7 +16,7 @@ from models.dataset import Dataset ...@@ -16,7 +16,7 @@ from models.dataset import Dataset
class DatasetRetrieverToolInput(BaseModel): class DatasetRetrieverToolInput(BaseModel):
dataset_id: str = Field(..., description="ID of dateset to be queried. MUST be UUID format.") dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
......
...@@ -97,7 +97,7 @@ class WebReaderTool(BaseTool): ...@@ -97,7 +97,7 @@ class WebReaderTool(BaseTool):
if self.continue_reading and len(page_contents) >= self.max_chunk_length: if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \ page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING read_page TOOL, OTHERWISE USE " \ f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING." f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return page_contents return page_contents
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment