Unverified Commit 88545184 authored by John Wang's avatar John Wang Committed by GitHub

feat: support multi datasets router chain mode (#231)

parent 2c23caac
"""Base classes for LLM-powered router chains."""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.chains.base import Chain
from pydantic import root_validator
from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
class Route(NamedTuple):
destination: Optional[str]
next_inputs: Dict[str, Any]
class LLMRouterChain(Chain):
"""A router chain that uses an LLM chain to perform routing."""
llm_chain: LLMChain
"""LLM chain used to perform routing"""
@root_validator()
def validate_prompt(cls, values: dict) -> dict:
prompt = values["llm_chain"].prompt
if prompt.output_parser is None:
raise ValueError(
"LLMRouterChain requires base llm_chain prompt to have an output"
" parser that converts LLM text output to a dictionary with keys"
" 'destination' and 'next_inputs'. Received a prompt with no output"
" parser."
)
return values
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return self.llm_chain.input_keys
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
super()._validate_outputs(outputs)
if not isinstance(outputs["next_inputs"], dict):
raise ValueError
def _call(
self,
inputs: Dict[str, Any]
) -> Dict[str, Any]:
output = cast(
Dict[str, Any],
self.llm_chain.predict_and_parse(**inputs),
)
return output
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
) -> LLMRouterChain:
"""Convenience constructor."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
@property
def output_keys(self) -> List[str]:
return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any]) -> Route:
result = self(inputs)
return Route(result["destination"], result["next_inputs"])
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
"""Parser for output of router chain int he multi-prompt chain."""
default_destination: str = "DEFAULT"
next_inputs_type: Type = str
next_inputs_inner_key: str = "input"
def parse_json_markdown(self, json_string: str) -> dict:
# Remove the triple backticks if present
json_string = json_string.replace("```json", "").replace("```", "")
# Strip whitespace and newlines from the start and end
json_string = json_string.strip()
# Parse the JSON string into a Python dictionary
parsed = json.loads(json_string)
return parsed
def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
try:
json_obj = self.parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{key}` "
f"to be present, but got {json_obj}"
)
return json_obj
def parse(self, text: str) -> Dict[str, Any]:
try:
expected_keys = ["destination", "next_inputs"]
parsed = self.parse_and_check_json_markdown(text, expected_keys)
if not isinstance(parsed["destination"], str):
raise ValueError("Expected 'destination' to be a string.")
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
raise ValueError(
f"Expected 'next_inputs' to be {self.next_inputs_type}."
)
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
if (
parsed["destination"].strip().lower()
== self.default_destination.lower()
):
parsed["destination"] = None
else:
parsed["destination"] = parsed["destination"].strip()
return parsed
except Exception as e:
raise OutputParserException(
f"Parsing text\n{text}\n raised following error:\n{e}"
)
from typing import Optional, List from typing import Optional, List
from langchain.callbacks import SharedCallbackManager from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain from langchain.chains import SequentialChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from core.agent.agent_builder import AgentBuilder
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder from core.chain.chain_builder import ChainBuilder
from core.constant import llm_constant from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.tool.dataset_tool_builder import DatasetToolBuilder from extensions.ext_database import db
from models.dataset import Dataset
class MainChainBuilder: class MainChainBuilder:
...@@ -31,8 +31,7 @@ class MainChainBuilder: ...@@ -31,8 +31,7 @@ class MainChainBuilder:
tenant_id=tenant_id, tenant_id=tenant_id,
agent_mode=agent_mode, agent_mode=agent_mode,
memory=memory, memory=memory,
dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task), conversation_message_task=conversation_message_task
agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler
) )
chains += tool_chains chains += tool_chains
...@@ -59,15 +58,15 @@ class MainChainBuilder: ...@@ -59,15 +58,15 @@ class MainChainBuilder:
@classmethod @classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler, conversation_message_task: ConversationMessageTask):
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
# agent mode # agent mode
chains = [] chains = []
if agent_mode and agent_mode.get('enabled'): if agent_mode and agent_mode.get('enabled'):
tools = agent_mode.get('tools', []) tools = agent_mode.get('tools', [])
pre_fixed_chains = [] pre_fixed_chains = []
agent_tools = [] # agent_tools = []
datasets = []
for tool in tools: for tool in tools:
tool_type = list(tool.keys())[0] tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0] tool_config = list(tool.values())[0]
...@@ -76,34 +75,27 @@ class MainChainBuilder: ...@@ -76,34 +75,27 @@ class MainChainBuilder:
if chain: if chain:
pre_fixed_chains.append(chain) pre_fixed_chains.append(chain)
elif tool_type == "dataset": elif tool_type == "dataset":
dataset_tool = DatasetToolBuilder.build_dataset_tool( # get dataset from dataset id
tenant_id=tenant_id, dataset = db.session.query(Dataset).filter(
dataset_id=tool_config.get("id"), Dataset.tenant_id == tenant_id,
response_mode='no_synthesizer', # "compact" Dataset.id == tool_config.get("id")
callback_handler=dataset_tool_callback_handler ).first()
)
if dataset_tool: if dataset:
agent_tools.append(dataset_tool) datasets.append(dataset)
# add pre-fixed chains # add pre-fixed chains
chains += pre_fixed_chains chains += pre_fixed_chains
if len(agent_tools) == 1: if len(datasets) > 0:
# tool to chain # tool to chain
tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output') multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
chains.append(tool_chain)
elif len(agent_tools) > 1:
# build agent config
agent_chain = AgentBuilder.to_agent_chain(
tenant_id=tenant_id, tenant_id=tenant_id,
tools=agent_tools, datasets=datasets,
memory=memory, conversation_message_task=conversation_message_task,
dataset_tool_callback_handler=dataset_tool_callback_handler, callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
agent_loop_gather_callback_handler=agent_loop_gather_callback_handler
) )
chains.append(multi_dataset_router_chain)
chains.append(agent_chain)
final_output_key = cls.get_chains_output_key(chains) final_output_key = cls.get_chains_output_key(chains)
......
from typing import Mapping, List, Dict, Any, Optional
from langchain import LLMChain, PromptTemplate, ConversationChain
from langchain.callbacks import CallbackManager
from langchain.chains.base import Chain
from langchain.schema import BaseLanguageModel
from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_tool_builder import DatasetToolBuilder
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
from models.dataset import Dataset
MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
what the prompt is best suited for. You may also revise the original input if you \
think that revising it will ultimately lead to a better response from the language \
model.
<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like:
```json
{{{{
"destination": string \\ name of the prompt to use or "DEFAULT"
"next_inputs": string \\ a potentially modified version of the original input
}}}}
```
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any \
modifications are needed.
<< CANDIDATE PROMPTS >>
{destinations}
<< INPUT >>
{{input}}
<< OUTPUT >>
"""
class MultiDatasetRouterChain(Chain):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
"""Map of name to candidate chains that inputs can be routed to."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the router chain prompt expects.
:meta private:
"""
return self.router_chain.input_keys
@property
def output_keys(self) -> List[str]:
return ["text"]
@classmethod
def from_datasets(
cls,
tenant_id: str,
datasets: List[Dataset],
conversation_message_task: ConversationMessageTask,
**kwargs: Any,
):
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=1024,
callback_manager=llm_callback_manager
)
destinations = [f"{d.id}: {d.description}" for d in datasets]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
dataset_tools = {}
for dataset in datasets:
dataset_tool = DatasetToolBuilder.build_dataset_tool(
dataset=dataset,
response_mode='no_synthesizer', # "compact"
callback_handler=DatasetToolCallbackHandler(conversation_message_task)
)
dataset_tools[dataset.id] = dataset_tool
return cls(
router_chain=router_chain,
dataset_tools=dataset_tools,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any]
) -> Dict[str, Any]:
if len(self.dataset_tools) == 0:
return {"text": ''}
elif len(self.dataset_tools) == 1:
return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
route = self.router_chain.route(inputs)
if not route.destination:
return {"text": ''}
elif route.destination in self.dataset_tools:
return {"text": self.dataset_tools[route.destination].run(
route.next_inputs['input']
)}
else:
raise ValueError(
f"Received invalid destination chain name '{route.destination}'"
)
...@@ -10,24 +10,14 @@ from core.index.keyword_table_index import KeywordTableIndex ...@@ -10,24 +10,14 @@ from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex from core.index.vector_index import VectorIndex
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
from core.tool.llama_index_tool import EnhanceLlamaIndexTool from core.tool.llama_index_tool import EnhanceLlamaIndexTool
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
class DatasetToolBuilder: class DatasetToolBuilder:
@classmethod @classmethod
def build_dataset_tool(cls, tenant_id: str, dataset_id: str, def build_dataset_tool(cls, dataset: Dataset,
response_mode: str = "no_synthesizer", response_mode: str = "no_synthesizer",
callback_handler: Optional[DatasetToolCallbackHandler] = None): callback_handler: Optional[DatasetToolCallbackHandler] = None):
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return None
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
index = KeywordTableIndex(dataset=dataset).query_index index = KeywordTableIndex(dataset=dataset).query_index
...@@ -65,7 +55,7 @@ class DatasetToolBuilder: ...@@ -65,7 +55,7 @@ class DatasetToolBuilder:
index_tool_config = IndexToolConfig( index_tool_config = IndexToolConfig(
index=index, index=index,
name=f"dataset-{dataset_id}", name=f"dataset-{dataset.id}",
description=description, description=description,
index_query_kwargs=query_kwargs, index_query_kwargs=query_kwargs,
tool_kwargs={ tool_kwargs={
...@@ -75,7 +65,7 @@ class DatasetToolBuilder: ...@@ -75,7 +65,7 @@ class DatasetToolBuilder:
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser # return_direct: Whether to return LLM results directly or process the output data with an Output Parser
) )
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset_id) index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
return EnhanceLlamaIndexTool.from_tool_config( return EnhanceLlamaIndexTool.from_tool_config(
tool_config=index_tool_config, tool_config=index_tool_config,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment