Commit c429005c authored by John Wang's avatar John Wang

feat: optimize dataset_retriever tool

parent c6c81164
......@@ -63,6 +63,8 @@ class AgentExecutor:
summary_llm=self.configuration.summary_llm,
verbose=True
)
else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
return agent
......
from typing import Optional
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain
class ChainBuilder:
@classmethod
def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
return ToolChain(
tool=tool,
input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'),
callbacks=[DifyStdOutCallbackHandler()]
)
@classmethod
def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
SensitiveWordAvoidanceChain]:
sensitive_words = tool_config.get("words", "")
if tool_config.get("enabled", False) \
and sensitive_words:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callbacks=[DifyStdOutCallbackHandler()],
**kwargs
)
return None
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.tools import BaseTool
class ToolChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
tool: BaseTool
@property
def _chain_type(self) -> str:
return "tool_chain"
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key]
output = self.tool.run(input, self.verbose)
return {self.output_key: output}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose)
return {self.output_key: output}
......@@ -26,8 +26,9 @@ class OrchestratorRuleParser:
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
self.tenant_id = tenant_id
self.app_model_config = app_model_config
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
def to_agent_arguments(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
def to_agent_chain(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, callbacks: Callbacks = None) -> Optional[Chain]:
if not self.app_model_config.agent_mode_dict:
return None
......@@ -54,7 +55,7 @@ class OrchestratorRuleParser:
summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name="gpt-3.5-turbo-16k",
model_name=self.agent_summary_model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
......@@ -181,15 +182,9 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
tool = DatasetRetrieverTool(
name=f"dataset_retriever",
description=description,
k=3,
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
k=3,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
......@@ -227,7 +222,8 @@ class OrchestratorRuleParser:
tool = Tool(
name="google_search",
description="A tool for performing a Google search and extracting snippets and webpages "
"when you need to search for something you don't know or when your information is not up to date."
"when you need to search for something you don't know or when your information "
"is not up to date."
"Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
callbacks=[DifyStdOutCallbackHandler]
......
import re
from typing import Type
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool
from pydantic import Field, BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset
class DatasetRetrieverToolInput(BaseModel):
dataset_id: str = Field(..., description="ID of dateset to be queried. MUST be UUID format.")
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(BaseTool):
"""Tool for querying a Dataset."""
# todo dataset id as tool argument
name: str = "dataset_retriever"
args_schema: Type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
tenant_id: str
k: int = 3
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description += '\nID of dataset MUST be ' + dataset.id
return cls(
tenant_id=dataset.tenant_id,
description=description,
**kwargs
)
dataset: Dataset
k: int = 2
def _run(self, dataset_id: str, query: str) -> str:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, dataset_id, re.IGNORECASE)
if match:
dataset_id = match.group()
def _run(self, tool_input: str) -> str:
if self.dataset.indexing_technique == "economy":
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
if dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=self.dataset,
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
......@@ -40,49 +78,23 @@ class DatasetRetrieverTool(BaseTool):
))
vector_index = VectorIndex(
dataset=self.dataset,
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
tool_input,
query,
search_type='similarity',
search_kwargs={
'k': self.k
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = await vector_index.asearch(
tool_input,
search_type='similarity',
search_kwargs={
'k': 10
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
raise NotImplementedError()
class ValidateFailedError(Exception):
description = "Provider Validate failed"
class ToolValidateFailedError(Exception):
description = "Tool Provider Validate failed"
from typing import Optional
from core.llm.provider.errors import ValidateFailedError
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.errors import ToolValidateFailedError
from models.tool import ToolProviderName
......@@ -56,4 +56,4 @@ class SerpAPIToolProvider(BaseToolProvider):
:return:
"""
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ValidateFailedError("SerpAPI api_key is required.")
raise ToolValidateFailedError("SerpAPI api_key is required.")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment