Commit c6c81164 authored by John Wang's avatar John Wang

feat: optimize tool providers and tool parse

parent d7712cf7
......@@ -24,6 +24,7 @@ model_config_fields = {
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'pre_prompt': fields.String,
......@@ -148,6 +149,7 @@ class AppListApi(Resource):
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
......@@ -439,6 +441,7 @@ class AppCopy(Resource):
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this,
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
model=app_config.model,
user_input_form=app_config.user_input_form,
pre_prompt=app_config.pre_prompt,
......
......@@ -43,6 +43,7 @@ class ModelConfigResource(Resource):
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
......
import enum
from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
......@@ -20,55 +21,46 @@ class PlanningStrategy(str, enum.Enum):
MULTI_FUNCTION_CALL = 'multi_function_call'
class AgentExecutor:
def __init__(self, strategy: PlanningStrategy, llm: BaseLanguageModel, tools: list[BaseTool],
summary_llm: BaseLanguageModel, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
callbacks: Callbacks = None, max_iterations: int = 6, max_execution_time: Optional[float] = None,
early_stopping_method: str = "generate"):
self.strategy = strategy
self.llm = llm
self.tools = tools
self.summary_llm = summary_llm
self.memory = memory
self.callbacks = callbacks
self.agent = self._init_agent(strategy, llm, tools, memory, callbacks)
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
llm: BaseLanguageModel
tools: list[BaseTool]
summary_llm: BaseLanguageModel
memory: ReadOnlyConversationTokenDBBufferSharedMemory
callbacks: Callbacks = None
max_iterations: int = 6
max_execution_time: Optional[float] = None
early_stopping_method: str = "generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
self.max_iterations = max_iterations
self.max_execution_time = max_execution_time
self.early_stopping_method = early_stopping_method
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
# summary_llm: StreamableChatOpenAI = LLMBuilder.to_llm(
# tenant_id=tenant_id,
# model_name='gpt-3.5-turbo-16k',
# max_tokens=300
# )
class AgentExecutor:
def __init__(self, configuration: AgentConfiguration):
self.configuration = configuration
self.agent = self._init_agent()
def _init_agent(self, strategy: PlanningStrategy, llm: BaseLanguageModel, tools: list[BaseTool],
memory: ReadOnlyConversationTokenDBBufferSharedMemory, callbacks: Callbacks = None) \
-> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if strategy == PlanningStrategy.REACT:
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
summary_llm=self.summary_llm,
llm=self.configuration.llm,
tools=self.configuration.tools,
summary_llm=self.configuration.summary_llm,
verbose=True
)
elif strategy == PlanningStrategy.FUNCTION_CALL:
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent(
llm=llm,
tools=tools,
extra_prompt_messages=memory.buffer, # used for read chat histories memory
summary_llm=self.summary_llm,
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
verbose=True
)
elif strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent(
llm=llm,
tools=tools,
extra_prompt_messages=memory.buffer, # used for read chat histories memory
summary_llm=self.summary_llm,
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
verbose=True
)
......@@ -77,21 +69,15 @@ class AgentExecutor:
def should_use_agent(self, query: str) -> bool:
return self.agent.should_use_agent(query)
def run(self, query: str) -> str:
def get_chain(self) -> AgentExecutor:
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.tools,
memory=self.memory,
max_iterations=self.max_iterations,
max_execution_time=self.max_execution_time,
early_stopping_method=self.early_stopping_method,
tools=self.configuration.tools,
memory=self.configuration.memory,
max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method,
verbose=True
)
# run agent
result = agent_executor.run(
query,
callbacks=self.callbacks
)
return result
return agent_executor
from typing import Optional, List, cast
from typing import Optional, List, cast, Tuple
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
......@@ -6,44 +6,54 @@ from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.conversation_message_task import ConversationMessageTask
from core.orchestrator_rule_parser import OrchestratorRuleParser
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import AppModelConfig
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask):
def get_chains(cls, tenant_id: str, app_model_config: AppModelConfig, memory: Optional[BaseChatMemory],
rest_tokens: int, conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
chains = []
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=tenant_id,
app_model_config=app_model_config
)
# agent mode
tool_chains, chains_output_key = cls.get_agent_chains(
# parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain()
if sensitive_word_avoidance_chain:
chains.append(sensitive_word_avoidance_chain)
# parse agent chain
agent_chain = cls.get_agent_chain(
tenant_id=tenant_id,
agent_mode=agent_mode,
agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
chains += tool_chains
if chains_output_key:
final_output_key = chains_output_key
if agent_chain:
chains.append(agent_chain)
final_output_key = agent_chain.output_keys[0]
if len(chains) == 0:
return None
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
for chain in chains:
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler)
chain.callbacks.append(chain_callback)
# build main chain
overall_chain = SequentialChain(
......@@ -56,26 +66,20 @@ class MainChainBuilder:
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
def get_agent_chain(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask) -> Chain:
# agent mode
chains = []
chain = None
if agent_mode and agent_mode.get('enabled'):
tools = agent_mode.get('tools', [])
pre_fixed_chains = []
# agent_tools = []
datasets = []
for tool in tools:
tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0]
if tool_type == 'sensitive-word-avoidance':
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
if chain:
pre_fixed_chains.append(chain)
elif tool_type == "dataset":
if tool_type == "dataset":
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
......@@ -85,9 +89,6 @@ class MainChainBuilder:
if dataset:
datasets.append(dataset)
# add pre-fixed chains
chains += pre_fixed_chains
if len(datasets) > 0:
# tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
......@@ -97,14 +98,6 @@ class MainChainBuilder:
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)
final_output_key = cls.get_chains_output_key(chains)
chain = multi_dataset_router_chain
return chains, final_output_key
@classmethod
def get_chains_output_key(cls, chains: List[Chain]):
if len(chains) > 0:
return chains[-1].output_keys[0]
return None
return chain
......@@ -70,9 +70,9 @@ class Completion:
)
# build main chain include agent
main_chain = MainChainBuilder.to_langchain_components(
main_chain = MainChainBuilder.get_chains(
tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict,
app_model_config=app_model_config,
rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task
......
......@@ -69,6 +69,7 @@ class ConversationMessageTask:
"suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict,
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
"user_input_form": self.app_model_config.user_input_form_list,
}
......
from typing import Optional
from langchain.callbacks.manager import Callbacks
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import AppModelConfig
class OrchestratorRuleParser:
"""Parse the orchestrator rule to entities."""
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
self.tenant_id = tenant_id
self.app_model_config = app_model_config
def to_agent_arguments(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, callbacks: Callbacks = None) -> Optional[Chain]:
if not self.app_model_config.agent_mode_dict:
return None
agent_mode_config = self.app_model_config.agent_mode_dict
chain = None
if agent_mode_config and agent_mode_config.get('enabled'):
tool_configs = agent_mode_config.get('tools', [])
# use router chain if planning strategy is router or not set
if not agent_mode_config.get('strategy') or agent_mode_config.get('strategy') == 'router':
return self.to_router_chain(tool_configs, conversation_message_task, rest_tokens)
agent_model_name = agent_mode_config.get('model_name', 'gpt-4')
agent_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=agent_model_name,
temperature=0,
max_tokens=800,
callbacks=[DifyStdOutCallbackHandler()]
)
summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name="gpt-3.5-turbo-16k",
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
agent_configuration = AgentConfiguration(
strategy=PlanningStrategy(agent_mode_config.get('strategy')),
llm=agent_llm,
tools=self.to_tools(tool_configs, conversation_message_task),
summary_llm=summary_llm,
memory=memory,
callbacks=callbacks,
max_iterations=6,
max_execution_time=None,
early_stopping_method="generate"
)
agent_executor = AgentExecutor(agent_configuration)
chain = agent_executor.get_chain()
return chain
def to_router_chain(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int) -> Optional[Chain]:
"""
Convert tool configs to router chain if planning strategy is router
:param tool_configs:
:param conversation_message_task:
:param rest_tokens:
:return:
"""
chain = None
datasets = []
for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0]
if tool_type == "dataset":
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_val.get("id")
).first()
if dataset:
datasets.append(dataset)
if len(datasets) > 0:
# tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
tenant_id=self.tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chain = multi_dataset_router_chain
return chain
def to_sensitive_word_avoidance_chain(self, **kwargs) -> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
:param kwargs:
:return:
"""
if not self.app_model_config.sensitive_word_avoidance_dict:
return None
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callbacks=[DifyStdOutCallbackHandler()],
**kwargs
)
return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param tool_configs: app agent tool configs
:param conversation_message_task:
:return:
"""
tools = []
for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0]
if not tool_config.get("enabled") or tool_config.get("enabled") is not True:
continue
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool()
if tool:
tools.append(tool)
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config:
:param conversation_message_task:
:return:
"""
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
tool = DatasetRetrieverTool(
name=f"dataset_retriever",
description=description,
k=3,
dataset=dataset,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
return tool
def to_web_reader_tool(self) -> Optional[BaseTool]:
"""
A tool for reading web pages
:return:
"""
summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name="gpt-3.5-turbo-16k",
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
tool = WebReaderTool(
llm=summary_llm,
max_chunk_length=4000,
continue_reading=True,
callbacks=[DifyStdOutCallbackHandler()]
)
return tool
def to_google_search_tool(self) -> Optional[BaseTool]:
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
return None
tool = Tool(
name="google_search",
description="A tool for performing a Google search and extracting snippets and webpages "
"when you need to search for something you don't know or when your information is not up to date."
"Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
callbacks=[DifyStdOutCallbackHandler]
)
return tool
......@@ -11,7 +11,10 @@ from models.dataset import Dataset
class DatasetTool(BaseTool):
"""Tool for querying a Dataset."""
"""
Tool for querying a Dataset.
Only use for router chain.
"""
dataset: Dataset
k: int = 2
......
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class DatasetRetrieverTool(BaseTool):
"""Tool for querying a Dataset."""
# todo dataset id as tool argument
dataset: Dataset
k: int = 2
def _run(self, tool_input: str) -> str:
if self.dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=self.dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
tool_input,
search_type='similarity',
search_kwargs={
'k': self.k
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = await vector_index.asearch(
tool_input,
search_type='similarity',
search_kwargs={
'k': 10
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
import base64
from abc import ABC, abstractmethod
from typing import Optional
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.tool import ToolProvider, ToolProviderName
class BaseToolProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
@abstractmethod
def get_provider_name(self) -> ToolProviderName:
raise NotImplementedError
@abstractmethod
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_to_func_kwargs(self) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_validate(self, credentials: dict):
raise NotImplementedError
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
"""
query = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == self.tenant_id,
ToolProvider.provider_name == self.get_provider_name()
)
if must_enabled:
query = query.filter(ToolProvider.is_enabled == True)
return query.first()
def encrypt_token(self, token) -> str:
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
if obfuscated:
return self._obfuscated_token(token)
return token
def _obfuscated_token(self, token: str) -> str:
return token[:6] + '*' * (len(token) - 8) + token[-2:]
class ValidateFailedError(Exception):
description = "Provider Validate failed"
from typing import Optional
from core.llm.provider.errors import ValidateFailedError
from core.tool.provider.base import BaseToolProvider
from models.tool import ToolProviderName
class SerpAPIToolProvider(BaseToolProvider):
def get_provider_name(self) -> ToolProviderName:
"""
Returns the name of the provider.
:return:
"""
return ToolProviderName.SERPAPI
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for SerpAPI as a dictionary.
:param obfuscated: obfuscate credentials if True
:return:
"""
tool_provider = self.get_provider(must_enabled=True)
if not tool_provider:
return None
config = tool_provider.config
if not config:
return None
if config.get('api_key'):
config['api_key'] = self.decrypt_token(config.get('api_key'), obfuscated)
return config
def credentials_to_func_kwargs(self) -> Optional[dict]:
"""
Returns the credentials function kwargs as a dictionary.
:return:
"""
credentials = self.get_credentials()
if not credentials:
return None
return {
'serpapi_api_key': credentials.get('api_key')
}
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:return:
"""
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ValidateFailedError("SerpAPI api_key is required.")
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
class ToolProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self._init_provider(tenant_id, provider_name)
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
if provider_name == 'serpapi':
return SerpAPIToolProvider(tenant_id)
else:
raise Exception('tool provider {} not found'.format(provider_name))
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
return self.provider.get_credentials(obfuscated)
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:raises: ValidateFailedError
"""
return self.provider.credentials_validate(credentials)
......@@ -52,7 +52,7 @@ class WebReaderToolInput(BaseModel):
class WebReaderTool(BaseTool):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name: str = "read_page"
name: str = "web_reader"
args_schema: Type[BaseModel] = WebReaderToolInput
description: str = "use this to read a website. " \
"If you can answer the question based on the information provided, " \
......
"""add tool ptoviders
"""add tool providers
Revision ID: 46c503018f11
Revision ID: 7ce5a52e4eee
Revises: 2beac44e5f5f
Create Date: 2023-07-07 16:35:32.974075
Create Date: 2023-07-10 10:26:50.074515
"""
from alembic import op
......@@ -10,7 +10,7 @@ import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '46c503018f11'
revision = '7ce5a52e4eee'
down_revision = '2beac44e5f5f'
branch_labels = None
depends_on = None
......@@ -23,16 +23,22 @@ def upgrade():
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
sa.Column('tool_name', sa.String(length=40), nullable=False),
sa.Column('encrypted_config', sa.Text(), nullable=True),
sa.Column('is_valid', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('is_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('sensitive_word_avoidance')
op.drop_table('tool_providers')
# ### end Alembic commands ###
......@@ -88,6 +88,7 @@ class AppModelConfig(db.Model):
user_input_form = db.Column(db.Text)
pre_prompt = db.Column(db.Text)
agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text)
@property
def app(self):
......@@ -116,6 +117,11 @@ class AppModelConfig(db.Model):
def more_like_this_dict(self) -> dict:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@property
def sensitive_word_avoidance_dict(self) -> dict:
return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \
else {"enabled": False, "words": [], "canned_response": []}
@property
def user_input_form_list(self) -> dict:
return json.loads(self.user_input_form) if self.user_input_form else []
......@@ -235,6 +241,9 @@ class Conversation(db.Model):
if 'speech_to_text' in override_model_configs else {"enabled": False}
model_config['more_like_this'] = override_model_configs['more_like_this'] \
if 'more_like_this' in override_model_configs else {"enabled": False}
model_config['sensitive_word_avoidance'] = override_model_configs['sensitive_word_avoidance'] \
if 'sensitive_word_avoidance' in override_model_configs \
else {"enabled": False, "words": [], "canned_response": []}
model_config['user_input_form'] = override_model_configs['user_input_form']
else:
model_config['configs'] = override_model_configs
......@@ -251,6 +260,7 @@ class Conversation(db.Model):
model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
model_config['speech_to_text'] = app_model_config.speech_to_text_dict
model_config['more_like_this'] = app_model_config.more_like_this_dict
model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
model_config['user_input_form'] = app_model_config.user_input_form_list
model_config['model_id'] = self.model_id
......
import json
from enum import Enum
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db
class ToolProviderName(Enum):
SERPAPI = 'serpapi'
@staticmethod
def value_of(value):
for member in ToolProviderName:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ToolProvider(db.Model):
__tablename__ = 'tool_providers'
......@@ -24,3 +37,10 @@ class ToolProvider(db.Model):
Returns True if the encrypted_config is not None, indicating that the token is set.
"""
return self.encrypted_config is not None
@property
def config(self):
"""
Returns the decrypted config.
"""
return json.loads(self.decrypt_config()) if self.encrypted_config is not None else None
......@@ -145,6 +145,33 @@ class AppModelConfigService:
if not isinstance(config["more_like_this"]["enabled"], bool):
raise ValueError("enabled in more_like_this must be of boolean type")
# sensitive_word_avoidance
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
config["sensitive_word_avoidance"] = {
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
config["sensitive_word_avoidance"]["enabled"] = False
if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool):
raise ValueError("enabled in sensitive_word_avoidance must be of boolean type")
if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]:
config["sensitive_word_avoidance"]["words"] = ""
if not isinstance(config["sensitive_word_avoidance"]["words"], str):
raise ValueError("words in sensitive_word_avoidance must be of string type")
if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]:
config["sensitive_word_avoidance"]["canned_response"] = ""
if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str):
raise ValueError("canned_response in sensitive_word_avoidance must be of string type")
# model
if 'model' not in config:
raise ValueError("model is required")
......@@ -258,8 +285,8 @@ class AppModelConfigService:
for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0]
if key not in ["sensitive-word-avoidance", "dataset"]:
raise ValueError("Keys in agent_mode.tools list can only be 'sensitive-word-avoidance' or 'dataset'")
if key not in ["dataset"]:
raise ValueError("Keys in agent_mode.tools list can only be 'dataset'")
tool_item = tool[key]
......@@ -269,19 +296,7 @@ class AppModelConfigService:
if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "sensitive-word-avoidance":
if "words" not in tool_item or not tool_item["words"]:
tool_item["words"] = ""
if not isinstance(tool_item["words"], str):
raise ValueError("words in sensitive-word-avoidance must be of string type")
if "canned_response" not in tool_item or not tool_item["canned_response"]:
tool_item["canned_response"] = ""
if not isinstance(tool_item["canned_response"], str):
raise ValueError("canned_response in sensitive-word-avoidance must be of string type")
elif key == "dataset":
if key == "dataset":
if 'id' not in tool_item:
raise ValueError("id is required in dataset")
......@@ -300,6 +315,7 @@ class AppModelConfigService:
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"],
"more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
"model": {
"provider": config["model"]["provider"],
"name": config["model"]["name"],
......
......@@ -140,6 +140,7 @@ class CompletionService:
suggested_questions=json.dumps(model_config['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
more_like_this=json.dumps(model_config['more_like_this']),
sensitive_word_avoidance=json.dumps(model_config['sensitive_word_avoidance']),
model=json.dumps(model_config['model']),
user_input_form=json.dumps(model_config['user_input_form']),
pre_prompt=model_config['pre_prompt'],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment