Commit 464a3615 authored by John Wang's avatar John Wang

feat: router agent instead of router chain

parent 937061cf
from typing import Tuple, List, Any, Union, Sequence, Optional
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
from langchain.tools import BaseTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
rst = next(iter(self.tools)).run(kwargs['input'])
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
return super().plan(intermediate_steps, callbacks, **kwargs)
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
tools = [t for t in tools if isinstance(t, DatasetRetrieverTool)]
llm.model_name = 'gpt-3.5-turbo'
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
**kwargs,
)
...@@ -4,9 +4,11 @@ from typing import Union, Optional ...@@ -4,9 +4,11 @@ from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
...@@ -16,6 +18,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ ...@@ -16,6 +18,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
class PlanningStrategy(str, enum.Enum): class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT = 'react' REACT = 'react'
FUNCTION_CALL = 'function_call' FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call' MULTI_FUNCTION_CALL = 'multi_function_call'
...@@ -26,7 +29,7 @@ class AgentConfiguration(BaseModel): ...@@ -26,7 +29,7 @@ class AgentConfiguration(BaseModel):
llm: BaseLanguageModel llm: BaseLanguageModel
tools: list[BaseTool] tools: list[BaseTool]
summary_llm: BaseLanguageModel summary_llm: BaseLanguageModel
memory: ReadOnlyConversationTokenDBBufferSharedMemory = None memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None callbacks: Callbacks = None
max_iterations: int = 6 max_iterations: int = 6
max_execution_time: Optional[float] = None max_execution_time: Optional[float] = None
...@@ -69,6 +72,13 @@ class AgentExecutor: ...@@ -69,6 +72,13 @@ class AgentExecutor:
summary_llm=self.configuration.summary_llm, summary_llm=self.configuration.summary_llm,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.ROUTER:
agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
)
else: else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
......
import json
import logging import logging
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
...@@ -43,8 +44,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): ...@@ -43,8 +44,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str: str, input_str: str,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
tool_name = serialized.get('name') # tool_name = serialized.get('name')
dataset_id = tool_name[len("dataset-"):] input_dict = json.loads(input_str.replace("'", "\""))
dataset_id = input_dict.get('dataset_id')
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str)) self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
def on_tool_end( def on_tool_end(
......
"""Base classes for LLM-powered router chains."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import root_validator
from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown
class Route(NamedTuple):
destination: Optional[str]
next_inputs: Dict[str, Any]
class LLMRouterChain(Chain):
"""A router chain that uses an LLM chain to perform routing."""
llm_chain: LLMChain
"""LLM chain used to perform routing"""
@root_validator()
def validate_prompt(cls, values: dict) -> dict:
prompt = values["llm_chain"].prompt
if prompt.output_parser is None:
raise ValueError(
"LLMRouterChain requires base llm_chain prompt to have an output"
" parser that converts LLM text output to a dictionary with keys"
" 'destination' and 'next_inputs'. Received a prompt with no output"
" parser."
)
return values
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return self.llm_chain.input_keys
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
super()._validate_outputs(outputs)
if not isinstance(outputs["next_inputs"], dict):
raise ValueError
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
output = cast(
Dict[str, Any],
self.llm_chain.predict_and_parse(**inputs),
)
return output
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
) -> LLMRouterChain:
"""Convenience constructor."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
@property
def output_keys(self) -> List[str]:
return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any]) -> Route:
result = self(inputs)
return Route(result["destination"], result["next_inputs"])
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
"""Parser for output of router chain int he multi-prompt chain."""
default_destination: str = "DEFAULT"
next_inputs_type: Type = str
next_inputs_inner_key: str = "input"
def parse(self, text: str) -> Dict[str, Any]:
try:
expected_keys = ["destination", "next_inputs"]
parsed = parse_and_check_json_markdown(text, expected_keys)
if not isinstance(parsed["destination"], str):
raise ValueError("Expected 'destination' to be a string.")
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
raise ValueError(
f"Expected 'next_inputs' to be {self.next_inputs_type}."
)
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
if (
parsed["destination"].strip().lower()
== self.default_destination.lower()
):
parsed["destination"] = None
else:
parsed["destination"] = parsed["destination"].strip()
return parsed
except Exception as e:
raise OutputParserException(
f"Parsing text\n{text}\n of llm router raised following error:\n{e}"
)
# from typing import Optional, List, cast, Tuple
#
# from langchain.chains import SequentialChain
# from langchain.chains.base import Chain
# from langchain.memory.chat_memory import BaseChatMemory
#
# from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
# from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
# from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
# from core.conversation_message_task import ConversationMessageTask
# from core.orchestrator_rule_parser import OrchestratorRuleParser
# from extensions.ext_database import db
# from models.dataset import Dataset
# from models.model import AppModelConfig
#
#
# class MainChainBuilder:
# @classmethod
# def get_chains(cls, tenant_id: str, app_model_config: AppModelConfig, memory: Optional[BaseChatMemory],
# rest_tokens: int, conversation_message_task: ConversationMessageTask):
# first_input_key = "input"
# final_output_key = "output"
#
# chains = []
#
# # init orchestrator rule parser
# orchestrator_rule_parser = OrchestratorRuleParser(
# tenant_id=tenant_id,
# app_model_config=app_model_config
# )
#
# # parse sensitive_word_avoidance_chain
# sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain()
# if sensitive_word_avoidance_chain:
# chains.append(sensitive_word_avoidance_chain)
#
# # parse agent chain
# agent_executor = orchestrator_rule_parser.to_agent_executor(
# conversation_message_task=conversation_message_task,
# memory=memory,
# rest_tokens=rest_tokens,
# callbacks=[DifyStdOutCallbackHandler()]
# )
#
# if agent_executor:
# if isinstance(agent_executor, MultiDatasetRouterChain):
# chains.append(agent_executor)
# final_output_key = agent_executor.output_keys[0]
# chains.append(agent_chain)
# final_output_key = agent_chain.output_keys[0]
#
# if len(chains) == 0:
# return None
#
# chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
# for chain in chains:
# chain = cast(Chain, chain)
# chain.callbacks.append(chain_callback)
#
# # build main chain
# overall_chain = SequentialChain(
# chains=chains,
# input_variables=[first_input_key],
# output_variables=[final_output_key],
# memory=memory, # only for use the memory prompt input key
# )
#
# return overall_chain
import math
import re
from typing import Mapping, List, Dict, Any, Optional
from langchain import PromptTemplate
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_index_tool import DatasetTool
from models.dataset import Dataset, DatasetProcessRule
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
what the prompt is best suited for. You may also revise the original input if you \
think that revising it will ultimately lead to a better response from the language \
model.
<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like, \
no any other string out of markdown code snippet:
```json
{{{{
"destination": string \\ name of the prompt to use or "DEFAULT"
"next_inputs": string \\ a potentially modified version of the original input
}}}}
```
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any \
modifications are needed.
<< CANDIDATE PROMPTS >>
{destinations}
<< INPUT >>
{{input}}
<< OUTPUT >>
"""
class MultiDatasetRouterChain(Chain):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools: Mapping[str, DatasetTool]
"""Map of name to candidate chains that inputs can be routed to."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the router chain prompt expects.
:meta private:
"""
return self.router_chain.input_keys
@property
def output_keys(self) -> List[str]:
return ["text"]
@classmethod
def from_datasets(
cls,
tenant_id: str,
datasets: List[Dataset],
conversation_message_task: ConversationMessageTask,
rest_tokens: int,
**kwargs: Any,
):
"""Convenience constructor for instantiating from destination prompts."""
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=1024,
callbacks=[DifyStdOutCallbackHandler()]
)
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
else ('useful for when you want to answer queries about the ' + d.name))
for d in datasets]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
dataset_tools = {}
for dataset in datasets:
# fulfill description when it is empty
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
continue
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
if k == 0:
continue
dataset_tool = DatasetTool(
name=f"dataset-{dataset.id}",
description=description,
k=k,
dataset=dataset,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
dataset_tools[str(dataset.id)] = dataset_tool
return cls(
router_chain=router_chain,
dataset_tools=dataset_tools,
**kwargs,
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if len(self.dataset_tools) == 0:
return {"text": ''}
elif len(self.dataset_tools) == 1:
return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
route = self.router_chain.route(inputs)
destination = ''
if route.destination:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, route.destination, re.IGNORECASE)
if match:
destination = match.group()
if not destination:
return {"text": ''}
elif destination in self.dataset_tools:
return {"text": self.dataset_tools[destination].run(
route.next_inputs['input']
)}
else:
raise ValueError(
f"Received invalid destination chain name '{destination}'"
)
...@@ -10,7 +10,6 @@ from langchain.schema import BaseMessage, HumanMessage ...@@ -10,7 +10,6 @@ from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.constant import llm_constant from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
...@@ -22,8 +21,6 @@ from core.llm.streamable_chat_open_ai import StreamableChatOpenAI ...@@ -22,8 +21,6 @@ from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.streamable_open_ai import StreamableOpenAI
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory ReadOnlyConversationTokenDBBufferSharedMemory
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
ReadOnlyConversationTokenDBStringBufferSharedMemory
from core.orchestrator_rule_parser import OrchestratorRuleParser from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompt_template import JinjaPromptTemplate
...@@ -88,26 +85,21 @@ class Completion: ...@@ -88,26 +85,21 @@ class Completion:
# get agent executor # get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor( agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, memory=memory,
rest_tokens=rest_tokens_for_context_and_memory, rest_tokens=rest_tokens_for_context_and_memory,
callbacks=[chain_callback] callbacks=[chain_callback]
) )
# run agent executor # run agent executor
executor_output = '' executor_output = ''
is_agent_output = False
if agent_executor: if agent_executor:
if isinstance(agent_executor, MultiDatasetRouterChain):
executor_output = agent_executor.run(query)
else:
should_use_agent = agent_executor.should_use_agent(query) should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent: if should_use_agent:
executor_output = agent_executor.run(query) executor_output = agent_executor.run(query)
is_agent_output = True
# run the final llm # run the final llm
try: try:
# if is_agent_output and not app_model_config.pre_prompt: # if executor_output and not app_model_config.pre_prompt:
# # todo streaming flush the agent result to user, not call final llm # # todo streaming flush the agent result to user, not call final llm
# pass # pass
......
from typing import Optional, Union import math
from typing import Optional
from langchain import WikipediaAPIWrapper from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
...@@ -18,7 +17,7 @@ from core.tool.provider.serpapi_provider import SerpAPIToolProvider ...@@ -18,7 +17,7 @@ from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
from core.tool.web_reader_tool import WebReaderTool from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig from models.model import AppModelConfig
...@@ -32,7 +31,7 @@ class OrchestratorRuleParser: ...@@ -32,7 +31,7 @@ class OrchestratorRuleParser:
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, callbacks: Callbacks = None) \ rest_tokens: int, callbacks: Callbacks = None) \
-> Optional[Union[AgentExecutor | MultiDatasetRouterChain]]: -> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict: if not self.app_model_config.agent_mode_dict:
return None return None
...@@ -41,11 +40,6 @@ class OrchestratorRuleParser: ...@@ -41,11 +40,6 @@ class OrchestratorRuleParser:
chain = None chain = None
if agent_mode_config and agent_mode_config.get('enabled'): if agent_mode_config and agent_mode_config.get('enabled'):
tool_configs = agent_mode_config.get('tools', []) tool_configs = agent_mode_config.get('tools', [])
# use router chain if planning strategy is router or not set
if not agent_mode_config.get('strategy') or agent_mode_config.get('strategy') == 'router':
return self.to_router_chain(tool_configs, conversation_message_task, rest_tokens)
agent_model_name = agent_mode_config.get('model_name', 'gpt-4') agent_model_name = agent_mode_config.get('model_name', 'gpt-4')
agent_llm = LLMBuilder.to_llm( agent_llm = LLMBuilder.to_llm(
...@@ -64,15 +58,15 @@ class OrchestratorRuleParser: ...@@ -64,15 +58,15 @@ class OrchestratorRuleParser:
callbacks=[DifyStdOutCallbackHandler()] callbacks=[DifyStdOutCallbackHandler()]
) )
tools = self.to_tools(tool_configs, conversation_message_task) tools = self.to_tools(tool_configs, conversation_message_task, rest_tokens)
if len(tools) == 0: if len(tools) == 0:
return None return None
agent_configuration = AgentConfiguration( agent_configuration = AgentConfiguration(
strategy=PlanningStrategy(agent_mode_config.get('strategy')), strategy=PlanningStrategy(agent_mode_config.get('strategy', 'router')),
llm=agent_llm, llm=agent_llm,
tools=self.to_tools(tool_configs, conversation_message_task), tools=tools,
summary_llm=summary_llm, summary_llm=summary_llm,
memory=memory, memory=memory,
callbacks=callbacks, callbacks=callbacks,
...@@ -85,44 +79,6 @@ class OrchestratorRuleParser: ...@@ -85,44 +79,6 @@ class OrchestratorRuleParser:
return chain return chain
def to_router_chain(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int) -> Optional[Chain]:
"""
Convert tool configs to router chain if planning strategy is router
:param tool_configs:
:param conversation_message_task:
:param rest_tokens:
:return:
"""
chain = None
datasets = []
for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0]
if tool_type == "dataset":
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_val.get("id")
).first()
if dataset:
datasets.append(dataset)
if len(datasets) > 0:
# tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
tenant_id=self.tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chain = multi_dataset_router_chain
return chain
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \ def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]: -> Optional[SensitiveWordAvoidanceChain]:
""" """
...@@ -147,10 +103,12 @@ class OrchestratorRuleParser: ...@@ -147,10 +103,12 @@ class OrchestratorRuleParser:
return None return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask) -> list[BaseTool]: def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int) -> list[BaseTool]:
""" """
Convert app agent tool configs to tools Convert app agent tool configs to tools
:param rest_tokens:
:param tool_configs: app agent tool configs :param tool_configs: app agent tool configs
:param conversation_message_task: :param conversation_message_task:
:return: :return:
...@@ -164,8 +122,7 @@ class OrchestratorRuleParser: ...@@ -164,8 +122,7 @@ class OrchestratorRuleParser:
tool = None tool = None
if tool_type == "dataset": if tool_type == "dataset":
tool = None tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
# tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task)
elif tool_type == "web_reader": elif tool_type == "web_reader":
tool = self.to_web_reader_tool() tool = self.to_web_reader_tool()
elif tool_type == "google_search": elif tool_type == "google_search":
...@@ -178,10 +135,12 @@ class OrchestratorRuleParser: ...@@ -178,10 +135,12 @@ class OrchestratorRuleParser:
return tools return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask) \ def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
rest_tokens: int) \
-> Optional[BaseTool]: -> Optional[BaseTool]:
""" """
A dataset tool is a tool that can be used to retrieve information from a dataset A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config: :param tool_config:
:param conversation_message_task: :param conversation_message_task:
:return: :return:
...@@ -195,9 +154,10 @@ class OrchestratorRuleParser: ...@@ -195,9 +154,10 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None return None
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
tool = DatasetRetrieverTool.from_dataset( tool = DatasetRetrieverTool.from_dataset(
dataset=dataset, dataset=dataset,
k=3, k=k,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()] callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
) )
...@@ -249,3 +209,34 @@ class OrchestratorRuleParser: ...@@ -249,3 +209,34 @@ class OrchestratorRuleParser:
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
callbacks=[DifyStdOutCallbackHandler()] callbacks=[DifyStdOutCallbackHandler()]
) )
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
...@@ -83,6 +83,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -83,6 +83,7 @@ class DatasetRetrieverTool(BaseTool):
embeddings=embeddings embeddings=embeddings
) )
if self.k > 0:
documents = vector_index.search( documents = vector_index.search(
query, query,
search_type='similarity', search_type='similarity',
...@@ -90,6 +91,8 @@ class DatasetRetrieverTool(BaseTool): ...@@ -90,6 +91,8 @@ class DatasetRetrieverTool(BaseTool):
'k': self.k 'k': self.k
} }
) )
else:
documents = []
hit_callback = DatasetIndexToolCallbackHandler(dataset.id) hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
hit_callback.on_tool_end(documents) hit_callback.on_tool_end(documents)
......
...@@ -88,7 +88,6 @@ class WebReaderTool(BaseTool): ...@@ -88,7 +88,6 @@ class WebReaderTool(BaseTool):
if len(docs) > 10: if len(docs) > 10:
docs = docs[:10] docs = docs[:10]
print("summary docs: ", docs)
chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks) chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
page_contents = chain.run(docs) page_contents = chain.run(docs)
# todo use cache # todo use cache
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment