Commit c6c81164 authored by John Wang's avatar John Wang

feat: optimize tool providers and tool parse

parent d7712cf7
...@@ -24,6 +24,7 @@ model_config_fields = { ...@@ -24,6 +24,7 @@ model_config_fields = {
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'), 'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'), 'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'), 'user_input_form': fields.Raw(attribute='user_input_form_list'),
'pre_prompt': fields.String, 'pre_prompt': fields.String,
...@@ -148,6 +149,7 @@ class AppListApi(Resource): ...@@ -148,6 +149,7 @@ class AppListApi(Resource):
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']), suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']), speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']), more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']), model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']), user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'], pre_prompt=model_configuration['pre_prompt'],
...@@ -439,6 +441,7 @@ class AppCopy(Resource): ...@@ -439,6 +441,7 @@ class AppCopy(Resource):
suggested_questions_after_answer=app_config.suggested_questions_after_answer, suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text, speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this, more_like_this=app_config.more_like_this,
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
model=app_config.model, model=app_config.model,
user_input_form=app_config.user_input_form, user_input_form=app_config.user_input_form,
pre_prompt=app_config.pre_prompt, pre_prompt=app_config.pre_prompt,
......
...@@ -43,6 +43,7 @@ class ModelConfigResource(Resource): ...@@ -43,6 +43,7 @@ class ModelConfigResource(Resource):
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']), suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']), speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']), more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']), model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']), user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'], pre_prompt=model_configuration['pre_prompt'],
......
import enum import enum
from typing import Union, Optional from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
...@@ -20,55 +21,46 @@ class PlanningStrategy(str, enum.Enum): ...@@ -20,55 +21,46 @@ class PlanningStrategy(str, enum.Enum):
MULTI_FUNCTION_CALL = 'multi_function_call' MULTI_FUNCTION_CALL = 'multi_function_call'
class AgentExecutor: class AgentConfiguration(BaseModel):
def __init__(self, strategy: PlanningStrategy, llm: BaseLanguageModel, tools: list[BaseTool], strategy: PlanningStrategy
summary_llm: BaseLanguageModel, memory: ReadOnlyConversationTokenDBBufferSharedMemory, llm: BaseLanguageModel
callbacks: Callbacks = None, max_iterations: int = 6, max_execution_time: Optional[float] = None, tools: list[BaseTool]
early_stopping_method: str = "generate"): summary_llm: BaseLanguageModel
self.strategy = strategy memory: ReadOnlyConversationTokenDBBufferSharedMemory
self.llm = llm callbacks: Callbacks = None
self.tools = tools max_iterations: int = 6
self.summary_llm = summary_llm max_execution_time: Optional[float] = None
self.memory = memory early_stopping_method: str = "generate"
self.callbacks = callbacks # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
self.agent = self._init_agent(strategy, llm, tools, memory, callbacks)
self.max_iterations = max_iterations
self.max_execution_time = max_execution_time
self.early_stopping_method = early_stopping_method
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
# summary_llm: StreamableChatOpenAI = LLMBuilder.to_llm( class AgentExecutor:
# tenant_id=tenant_id, def __init__(self, configuration: AgentConfiguration):
# model_name='gpt-3.5-turbo-16k', self.configuration = configuration
# max_tokens=300 self.agent = self._init_agent()
# )
def _init_agent(self, strategy: PlanningStrategy, llm: BaseLanguageModel, tools: list[BaseTool], def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
memory: ReadOnlyConversationTokenDBBufferSharedMemory, callbacks: Callbacks = None) \ if self.configuration.strategy == PlanningStrategy.REACT:
-> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=llm, llm=self.configuration.llm,
tools=tools, tools=self.configuration.tools,
summary_llm=self.summary_llm, summary_llm=self.configuration.summary_llm,
verbose=True verbose=True
) )
elif strategy == PlanningStrategy.FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent( agent = AutoSummarizingOpenAIFunctionCallAgent(
llm=llm, llm=self.configuration.llm,
tools=tools, tools=self.configuration.tools,
extra_prompt_messages=memory.buffer, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer, # used for read chat histories memory
summary_llm=self.summary_llm, summary_llm=self.configuration.summary_llm,
verbose=True verbose=True
) )
elif strategy == PlanningStrategy.MULTI_FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent( agent = AutoSummarizingOpenMultiAIFunctionCallAgent(
llm=llm, llm=self.configuration.llm,
tools=tools, tools=self.configuration.tools,
extra_prompt_messages=memory.buffer, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer, # used for read chat histories memory
summary_llm=self.summary_llm, summary_llm=self.configuration.summary_llm,
verbose=True verbose=True
) )
...@@ -77,21 +69,15 @@ class AgentExecutor: ...@@ -77,21 +69,15 @@ class AgentExecutor:
def should_use_agent(self, query: str) -> bool: def should_use_agent(self, query: str) -> bool:
return self.agent.should_use_agent(query) return self.agent.should_use_agent(query)
def run(self, query: str) -> str: def get_chain(self) -> AgentExecutor:
agent_executor = LCAgentExecutor.from_agent_and_tools( agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent, agent=self.agent,
tools=self.tools, tools=self.configuration.tools,
memory=self.memory, memory=self.configuration.memory,
max_iterations=self.max_iterations, max_iterations=self.configuration.max_iterations,
max_execution_time=self.max_execution_time, max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.early_stopping_method, early_stopping_method=self.configuration.early_stopping_method,
verbose=True verbose=True
) )
# run agent return agent_executor
result = agent_executor.run(
query,
callbacks=self.callbacks
)
return result
from typing import Optional, List, cast from typing import Optional, List, cast, Tuple
from langchain.chains import SequentialChain from langchain.chains import SequentialChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
...@@ -6,44 +6,54 @@ from langchain.memory.chat_memory import BaseChatMemory ...@@ -6,44 +6,54 @@ from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.orchestrator_rule_parser import OrchestratorRuleParser
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from models.model import AppModelConfig
class MainChainBuilder: class MainChainBuilder:
@classmethod @classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], def get_chains(cls, tenant_id: str, app_model_config: AppModelConfig, memory: Optional[BaseChatMemory],
rest_tokens: int, rest_tokens: int, conversation_message_task: ConversationMessageTask):
conversation_message_task: ConversationMessageTask):
first_input_key = "input" first_input_key = "input"
final_output_key = "output" final_output_key = "output"
chains = [] chains = []
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task) # init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=tenant_id,
app_model_config=app_model_config
)
# agent mode # parse sensitive_word_avoidance_chain
tool_chains, chains_output_key = cls.get_agent_chains( sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain()
if sensitive_word_avoidance_chain:
chains.append(sensitive_word_avoidance_chain)
# parse agent chain
agent_chain = cls.get_agent_chain(
tenant_id=tenant_id, tenant_id=tenant_id,
agent_mode=agent_mode, agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens, rest_tokens=rest_tokens,
memory=memory, memory=memory,
conversation_message_task=conversation_message_task conversation_message_task=conversation_message_task
) )
chains += tool_chains
if chains_output_key: if agent_chain:
final_output_key = chains_output_key chains.append(agent_chain)
final_output_key = agent_chain.output_keys[0]
if len(chains) == 0: if len(chains) == 0:
return None return None
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
for chain in chains: for chain in chains:
chain = cast(Chain, chain) chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler) chain.callbacks.append(chain_callback)
# build main chain # build main chain
overall_chain = SequentialChain( overall_chain = SequentialChain(
...@@ -56,26 +66,20 @@ class MainChainBuilder: ...@@ -56,26 +66,20 @@ class MainChainBuilder:
return overall_chain return overall_chain
@classmethod @classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, def get_agent_chain(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int, rest_tokens: int,
memory: Optional[BaseChatMemory], memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask) -> Chain:
# agent mode # agent mode
chains = [] chain = None
if agent_mode and agent_mode.get('enabled'): if agent_mode and agent_mode.get('enabled'):
tools = agent_mode.get('tools', []) tools = agent_mode.get('tools', [])
pre_fixed_chains = []
# agent_tools = []
datasets = [] datasets = []
for tool in tools: for tool in tools:
tool_type = list(tool.keys())[0] tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0] tool_config = list(tool.values())[0]
if tool_type == 'sensitive-word-avoidance': if tool_type == "dataset":
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
if chain:
pre_fixed_chains.append(chain)
elif tool_type == "dataset":
# get dataset from dataset id # get dataset from dataset id
dataset = db.session.query(Dataset).filter( dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id, Dataset.tenant_id == tenant_id,
...@@ -85,9 +89,6 @@ class MainChainBuilder: ...@@ -85,9 +89,6 @@ class MainChainBuilder:
if dataset: if dataset:
datasets.append(dataset) datasets.append(dataset)
# add pre-fixed chains
chains += pre_fixed_chains
if len(datasets) > 0: if len(datasets) > 0:
# tool to chain # tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets( multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
...@@ -97,14 +98,6 @@ class MainChainBuilder: ...@@ -97,14 +98,6 @@ class MainChainBuilder:
rest_tokens=rest_tokens, rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()] callbacks=[DifyStdOutCallbackHandler()]
) )
chains.append(multi_dataset_router_chain) chain = multi_dataset_router_chain
final_output_key = cls.get_chains_output_key(chains)
return chains, final_output_key return chain
@classmethod
def get_chains_output_key(cls, chains: List[Chain]):
if len(chains) > 0:
return chains[-1].output_keys[0]
return None
...@@ -70,9 +70,9 @@ class Completion: ...@@ -70,9 +70,9 @@ class Completion:
) )
# build main chain include agent # build main chain include agent
main_chain = MainChainBuilder.to_langchain_components( main_chain = MainChainBuilder.get_chains(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict, app_model_config=app_model_config,
rest_tokens=rest_tokens_for_context_and_memory, rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task conversation_message_task=conversation_message_task
......
...@@ -69,6 +69,7 @@ class ConversationMessageTask: ...@@ -69,6 +69,7 @@ class ConversationMessageTask:
"suggested_questions": self.app_model_config.suggested_questions_list, "suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict, "suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict, "more_like_this": self.app_model_config.more_like_this_dict,
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
"user_input_form": self.app_model_config.user_input_form_list, "user_input_form": self.app_model_config.user_input_form_list,
} }
......
from typing import Optional
from langchain.callbacks.manager import Callbacks
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import AppModelConfig
class OrchestratorRuleParser:
"""Parse the orchestrator rule to entities."""
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
self.tenant_id = tenant_id
self.app_model_config = app_model_config
def to_agent_arguments(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, callbacks: Callbacks = None) -> Optional[Chain]:
if not self.app_model_config.agent_mode_dict:
return None
agent_mode_config = self.app_model_config.agent_mode_dict
chain = None
if agent_mode_config and agent_mode_config.get('enabled'):
tool_configs = agent_mode_config.get('tools', [])
# use router chain if planning strategy is router or not set
if not agent_mode_config.get('strategy') or agent_mode_config.get('strategy') == 'router':
return self.to_router_chain(tool_configs, conversation_message_task, rest_tokens)
agent_model_name = agent_mode_config.get('model_name', 'gpt-4')
agent_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=agent_model_name,
temperature=0,
max_tokens=800,
callbacks=[DifyStdOutCallbackHandler()]
)
summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name="gpt-3.5-turbo-16k",
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
agent_configuration = AgentConfiguration(
strategy=PlanningStrategy(agent_mode_config.get('strategy')),
llm=agent_llm,
tools=self.to_tools(tool_configs, conversation_message_task),
summary_llm=summary_llm,
memory=memory,
callbacks=callbacks,
max_iterations=6,
max_execution_time=None,
early_stopping_method="generate"
)
agent_executor = AgentExecutor(agent_configuration)
chain = agent_executor.get_chain()
return chain
def to_router_chain(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int) -> Optional[Chain]:
"""
Convert tool configs to router chain if planning strategy is router
:param tool_configs:
:param conversation_message_task:
:param rest_tokens:
:return:
"""
chain = None
datasets = []
for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0]
if tool_type == "dataset":
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_val.get("id")
).first()
if dataset:
datasets.append(dataset)
if len(datasets) > 0:
# tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
tenant_id=self.tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chain = multi_dataset_router_chain
return chain
def to_sensitive_word_avoidance_chain(self, **kwargs) -> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
:param kwargs:
:return:
"""
if not self.app_model_config.sensitive_word_avoidance_dict:
return None
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callbacks=[DifyStdOutCallbackHandler()],
**kwargs
)
return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param tool_configs: app agent tool configs
:param conversation_message_task:
:return:
"""
tools = []
for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0]
if not tool_config.get("enabled") or tool_config.get("enabled") is not True:
continue
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool()
if tool:
tools.append(tool)
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config:
:param conversation_message_task:
:return:
"""
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
tool = DatasetRetrieverTool(
name=f"dataset_retriever",
description=description,
k=3,
dataset=dataset,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
return tool
def to_web_reader_tool(self) -> Optional[BaseTool]:
"""
A tool for reading web pages
:return:
"""
summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name="gpt-3.5-turbo-16k",
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
tool = WebReaderTool(
llm=summary_llm,
max_chunk_length=4000,
continue_reading=True,
callbacks=[DifyStdOutCallbackHandler()]
)
return tool
def to_google_search_tool(self) -> Optional[BaseTool]:
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
return None
tool = Tool(
name="google_search",
description="A tool for performing a Google search and extracting snippets and webpages "
"when you need to search for something you don't know or when your information is not up to date."
"Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
callbacks=[DifyStdOutCallbackHandler]
)
return tool
...@@ -11,7 +11,10 @@ from models.dataset import Dataset ...@@ -11,7 +11,10 @@ from models.dataset import Dataset
class DatasetTool(BaseTool): class DatasetTool(BaseTool):
"""Tool for querying a Dataset.""" """
Tool for querying a Dataset.
Only use for router chain.
"""
dataset: Dataset dataset: Dataset
k: int = 2 k: int = 2
......
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class DatasetRetrieverTool(BaseTool):
"""Tool for querying a Dataset."""
# todo dataset id as tool argument
dataset: Dataset
k: int = 2
def _run(self, tool_input: str) -> str:
if self.dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=self.dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
tool_input,
search_type='similarity',
search_kwargs={
'k': self.k
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = await vector_index.asearch(
tool_input,
search_type='similarity',
search_kwargs={
'k': 10
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
import base64
from abc import ABC, abstractmethod
from typing import Optional
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.tool import ToolProvider, ToolProviderName
class BaseToolProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
@abstractmethod
def get_provider_name(self) -> ToolProviderName:
raise NotImplementedError
@abstractmethod
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_to_func_kwargs(self) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_validate(self, credentials: dict):
raise NotImplementedError
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
"""
query = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == self.tenant_id,
ToolProvider.provider_name == self.get_provider_name()
)
if must_enabled:
query = query.filter(ToolProvider.is_enabled == True)
return query.first()
def encrypt_token(self, token) -> str:
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
if obfuscated:
return self._obfuscated_token(token)
return token
def _obfuscated_token(self, token: str) -> str:
return token[:6] + '*' * (len(token) - 8) + token[-2:]
class ValidateFailedError(Exception):
description = "Provider Validate failed"
from typing import Optional
from core.llm.provider.errors import ValidateFailedError
from core.tool.provider.base import BaseToolProvider
from models.tool import ToolProviderName
class SerpAPIToolProvider(BaseToolProvider):
def get_provider_name(self) -> ToolProviderName:
"""
Returns the name of the provider.
:return:
"""
return ToolProviderName.SERPAPI
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for SerpAPI as a dictionary.
:param obfuscated: obfuscate credentials if True
:return:
"""
tool_provider = self.get_provider(must_enabled=True)
if not tool_provider:
return None
config = tool_provider.config
if not config:
return None
if config.get('api_key'):
config['api_key'] = self.decrypt_token(config.get('api_key'), obfuscated)
return config
def credentials_to_func_kwargs(self) -> Optional[dict]:
"""
Returns the credentials function kwargs as a dictionary.
:return:
"""
credentials = self.get_credentials()
if not credentials:
return None
return {
'serpapi_api_key': credentials.get('api_key')
}
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:return:
"""
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ValidateFailedError("SerpAPI api_key is required.")
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
class ToolProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self._init_provider(tenant_id, provider_name)
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
if provider_name == 'serpapi':
return SerpAPIToolProvider(tenant_id)
else:
raise Exception('tool provider {} not found'.format(provider_name))
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
return self.provider.get_credentials(obfuscated)
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:raises: ValidateFailedError
"""
return self.provider.credentials_validate(credentials)
...@@ -52,7 +52,7 @@ class WebReaderToolInput(BaseModel): ...@@ -52,7 +52,7 @@ class WebReaderToolInput(BaseModel):
class WebReaderTool(BaseTool): class WebReaderTool(BaseTool):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool.""" """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name: str = "read_page" name: str = "web_reader"
args_schema: Type[BaseModel] = WebReaderToolInput args_schema: Type[BaseModel] = WebReaderToolInput
description: str = "use this to read a website. " \ description: str = "use this to read a website. " \
"If you can answer the question based on the information provided, " \ "If you can answer the question based on the information provided, " \
......
"""add tool ptoviders """add tool providers
Revision ID: 46c503018f11 Revision ID: 7ce5a52e4eee
Revises: 2beac44e5f5f Revises: 2beac44e5f5f
Create Date: 2023-07-07 16:35:32.974075 Create Date: 2023-07-10 10:26:50.074515
""" """
from alembic import op from alembic import op
...@@ -10,7 +10,7 @@ import sqlalchemy as sa ...@@ -10,7 +10,7 @@ import sqlalchemy as sa
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '46c503018f11' revision = '7ce5a52e4eee'
down_revision = '2beac44e5f5f' down_revision = '2beac44e5f5f'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
...@@ -23,16 +23,22 @@ def upgrade(): ...@@ -23,16 +23,22 @@ def upgrade():
sa.Column('tenant_id', postgresql.UUID(), nullable=False), sa.Column('tenant_id', postgresql.UUID(), nullable=False),
sa.Column('tool_name', sa.String(length=40), nullable=False), sa.Column('tool_name', sa.String(length=40), nullable=False),
sa.Column('encrypted_config', sa.Text(), nullable=True), sa.Column('encrypted_config', sa.Text(), nullable=True),
sa.Column('is_valid', sa.Boolean(), server_default=sa.text('false'), nullable=False), sa.Column('is_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
) )
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('sensitive_word_avoidance')
op.drop_table('tool_providers') op.drop_table('tool_providers')
# ### end Alembic commands ### # ### end Alembic commands ###
...@@ -88,6 +88,7 @@ class AppModelConfig(db.Model): ...@@ -88,6 +88,7 @@ class AppModelConfig(db.Model):
user_input_form = db.Column(db.Text) user_input_form = db.Column(db.Text)
pre_prompt = db.Column(db.Text) pre_prompt = db.Column(db.Text)
agent_mode = db.Column(db.Text) agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text)
@property @property
def app(self): def app(self):
...@@ -116,6 +117,11 @@ class AppModelConfig(db.Model): ...@@ -116,6 +117,11 @@ class AppModelConfig(db.Model):
def more_like_this_dict(self) -> dict: def more_like_this_dict(self) -> dict:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@property
def sensitive_word_avoidance_dict(self) -> dict:
return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \
else {"enabled": False, "words": [], "canned_response": []}
@property @property
def user_input_form_list(self) -> dict: def user_input_form_list(self) -> dict:
return json.loads(self.user_input_form) if self.user_input_form else [] return json.loads(self.user_input_form) if self.user_input_form else []
...@@ -235,6 +241,9 @@ class Conversation(db.Model): ...@@ -235,6 +241,9 @@ class Conversation(db.Model):
if 'speech_to_text' in override_model_configs else {"enabled": False} if 'speech_to_text' in override_model_configs else {"enabled": False}
model_config['more_like_this'] = override_model_configs['more_like_this'] \ model_config['more_like_this'] = override_model_configs['more_like_this'] \
if 'more_like_this' in override_model_configs else {"enabled": False} if 'more_like_this' in override_model_configs else {"enabled": False}
model_config['sensitive_word_avoidance'] = override_model_configs['sensitive_word_avoidance'] \
if 'sensitive_word_avoidance' in override_model_configs \
else {"enabled": False, "words": [], "canned_response": []}
model_config['user_input_form'] = override_model_configs['user_input_form'] model_config['user_input_form'] = override_model_configs['user_input_form']
else: else:
model_config['configs'] = override_model_configs model_config['configs'] = override_model_configs
...@@ -251,6 +260,7 @@ class Conversation(db.Model): ...@@ -251,6 +260,7 @@ class Conversation(db.Model):
model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
model_config['speech_to_text'] = app_model_config.speech_to_text_dict model_config['speech_to_text'] = app_model_config.speech_to_text_dict
model_config['more_like_this'] = app_model_config.more_like_this_dict model_config['more_like_this'] = app_model_config.more_like_this_dict
model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
model_config['user_input_form'] = app_model_config.user_input_form_list model_config['user_input_form'] = app_model_config.user_input_form_list
model_config['model_id'] = self.model_id model_config['model_id'] = self.model_id
......
import json
from enum import Enum
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db from extensions.ext_database import db
class ToolProviderName(Enum):
SERPAPI = 'serpapi'
@staticmethod
def value_of(value):
for member in ToolProviderName:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ToolProvider(db.Model): class ToolProvider(db.Model):
__tablename__ = 'tool_providers' __tablename__ = 'tool_providers'
...@@ -24,3 +37,10 @@ class ToolProvider(db.Model): ...@@ -24,3 +37,10 @@ class ToolProvider(db.Model):
Returns True if the encrypted_config is not None, indicating that the token is set. Returns True if the encrypted_config is not None, indicating that the token is set.
""" """
return self.encrypted_config is not None return self.encrypted_config is not None
@property
def config(self):
"""
Returns the decrypted config.
"""
return json.loads(self.decrypt_config()) if self.encrypted_config is not None else None
...@@ -145,6 +145,33 @@ class AppModelConfigService: ...@@ -145,6 +145,33 @@ class AppModelConfigService:
if not isinstance(config["more_like_this"]["enabled"], bool): if not isinstance(config["more_like_this"]["enabled"], bool):
raise ValueError("enabled in more_like_this must be of boolean type") raise ValueError("enabled in more_like_this must be of boolean type")
# sensitive_word_avoidance
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
config["sensitive_word_avoidance"] = {
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
config["sensitive_word_avoidance"]["enabled"] = False
if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool):
raise ValueError("enabled in sensitive_word_avoidance must be of boolean type")
if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]:
config["sensitive_word_avoidance"]["words"] = ""
if not isinstance(config["sensitive_word_avoidance"]["words"], str):
raise ValueError("words in sensitive_word_avoidance must be of string type")
if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]:
config["sensitive_word_avoidance"]["canned_response"] = ""
if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str):
raise ValueError("canned_response in sensitive_word_avoidance must be of string type")
# model # model
if 'model' not in config: if 'model' not in config:
raise ValueError("model is required") raise ValueError("model is required")
...@@ -258,8 +285,8 @@ class AppModelConfigService: ...@@ -258,8 +285,8 @@ class AppModelConfigService:
for tool in config["agent_mode"]["tools"]: for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key not in ["sensitive-word-avoidance", "dataset"]: if key not in ["dataset"]:
raise ValueError("Keys in agent_mode.tools list can only be 'sensitive-word-avoidance' or 'dataset'") raise ValueError("Keys in agent_mode.tools list can only be 'dataset'")
tool_item = tool[key] tool_item = tool[key]
...@@ -269,19 +296,7 @@ class AppModelConfigService: ...@@ -269,19 +296,7 @@ class AppModelConfigService:
if not isinstance(tool_item["enabled"], bool): if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type") raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "sensitive-word-avoidance": if key == "dataset":
if "words" not in tool_item or not tool_item["words"]:
tool_item["words"] = ""
if not isinstance(tool_item["words"], str):
raise ValueError("words in sensitive-word-avoidance must be of string type")
if "canned_response" not in tool_item or not tool_item["canned_response"]:
tool_item["canned_response"] = ""
if not isinstance(tool_item["canned_response"], str):
raise ValueError("canned_response in sensitive-word-avoidance must be of string type")
elif key == "dataset":
if 'id' not in tool_item: if 'id' not in tool_item:
raise ValueError("id is required in dataset") raise ValueError("id is required in dataset")
...@@ -300,6 +315,7 @@ class AppModelConfigService: ...@@ -300,6 +315,7 @@ class AppModelConfigService:
"suggested_questions_after_answer": config["suggested_questions_after_answer"], "suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"], "speech_to_text": config["speech_to_text"],
"more_like_this": config["more_like_this"], "more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
"model": { "model": {
"provider": config["model"]["provider"], "provider": config["model"]["provider"],
"name": config["model"]["name"], "name": config["model"]["name"],
......
...@@ -140,6 +140,7 @@ class CompletionService: ...@@ -140,6 +140,7 @@ class CompletionService:
suggested_questions=json.dumps(model_config['suggested_questions']), suggested_questions=json.dumps(model_config['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']), suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
more_like_this=json.dumps(model_config['more_like_this']), more_like_this=json.dumps(model_config['more_like_this']),
sensitive_word_avoidance=json.dumps(model_config['sensitive_word_avoidance']),
model=json.dumps(model_config['model']), model=json.dumps(model_config['model']),
user_input_form=json.dumps(model_config['user_input_form']), user_input_form=json.dumps(model_config['user_input_form']),
pre_prompt=model_config['pre_prompt'], pre_prompt=model_config['pre_prompt'],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment