Commit 937061cf authored by John Wang's avatar John Wang

fix: should use agent

parent c429005c
......@@ -38,7 +38,6 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
......@@ -49,15 +48,18 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 6
agent_decision = self.plan(
intermediate_steps=[],
callbacks=None,
input=query
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return isinstance(agent_decision, AgentAction)
return True if function_call else False
def plan(
self,
......
......@@ -48,15 +48,18 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 6
agent_decision = self.plan(
intermediate_steps=[],
callbacks=None,
input=query
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return isinstance(agent_decision, AgentAction)
return True if function_call else False
def plan(
self,
......@@ -93,7 +96,8 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
def get_system_message(self):
@classmethod
def get_system_message(cls):
# get current time
current_time = datetime.now()
current_timezone = pytz.timezone('UTC')
......
......@@ -5,7 +5,7 @@ from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent, AgentE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel
from pydantic import BaseModel, Extra
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
......@@ -26,13 +26,19 @@ class AgentConfiguration(BaseModel):
llm: BaseLanguageModel
tools: list[BaseTool]
summary_llm: BaseLanguageModel
memory: ReadOnlyConversationTokenDBBufferSharedMemory
memory: ReadOnlyConversationTokenDBBufferSharedMemory = None
callbacks: Callbacks = None
max_iterations: int = 6
max_execution_time: Optional[float] = None
early_stopping_method: str = "generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
class AgentExecutor:
def __init__(self, configuration: AgentConfiguration):
......@@ -48,18 +54,18 @@ class AgentExecutor:
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent(
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer, # used for read chat histories memory
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent(
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer, # used for read chat histories memory
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
verbose=True
)
......@@ -71,7 +77,7 @@ class AgentExecutor:
def should_use_agent(self, query: str) -> bool:
return self.agent.should_use_agent(query)
def get_chain(self) -> AgentExecutor:
def run(self, query: str) -> str:
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
......@@ -82,4 +88,4 @@ class AgentExecutor:
verbose=True
)
return agent_executor
return agent_executor.run(query)
from typing import Optional, List, cast, Tuple
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.conversation_message_task import ConversationMessageTask
from core.orchestrator_rule_parser import OrchestratorRuleParser
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import AppModelConfig
class MainChainBuilder:
@classmethod
def get_chains(cls, tenant_id: str, app_model_config: AppModelConfig, memory: Optional[BaseChatMemory],
rest_tokens: int, conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
chains = []
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=tenant_id,
app_model_config=app_model_config
)
# parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain()
if sensitive_word_avoidance_chain:
chains.append(sensitive_word_avoidance_chain)
# parse agent chain
agent_chain = cls.get_agent_chain(
tenant_id=tenant_id,
agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
if agent_chain:
chains.append(agent_chain)
final_output_key = agent_chain.output_keys[0]
if len(chains) == 0:
return None
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
for chain in chains:
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback)
# build main chain
overall_chain = SequentialChain(
chains=chains,
input_variables=[first_input_key],
output_variables=[final_output_key],
memory=memory, # only for use the memory prompt input key
)
return overall_chain
@classmethod
def get_agent_chain(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask) -> Chain:
# agent mode
chain = None
if agent_mode and agent_mode.get('enabled'):
tools = agent_mode.get('tools', [])
datasets = []
for tool in tools:
tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0]
if tool_type == "dataset":
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == tool_config.get("id")
).first()
if dataset:
datasets.append(dataset)
if len(datasets) > 0:
# tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chain = multi_dataset_router_chain
return chain
# from typing import Optional, List, cast, Tuple
#
# from langchain.chains import SequentialChain
# from langchain.chains.base import Chain
# from langchain.memory.chat_memory import BaseChatMemory
#
# from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
# from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
# from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
# from core.conversation_message_task import ConversationMessageTask
# from core.orchestrator_rule_parser import OrchestratorRuleParser
# from extensions.ext_database import db
# from models.dataset import Dataset
# from models.model import AppModelConfig
#
#
# class MainChainBuilder:
# @classmethod
# def get_chains(cls, tenant_id: str, app_model_config: AppModelConfig, memory: Optional[BaseChatMemory],
# rest_tokens: int, conversation_message_task: ConversationMessageTask):
# first_input_key = "input"
# final_output_key = "output"
#
# chains = []
#
# # init orchestrator rule parser
# orchestrator_rule_parser = OrchestratorRuleParser(
# tenant_id=tenant_id,
# app_model_config=app_model_config
# )
#
# # parse sensitive_word_avoidance_chain
# sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain()
# if sensitive_word_avoidance_chain:
# chains.append(sensitive_word_avoidance_chain)
#
# # parse agent chain
# agent_executor = orchestrator_rule_parser.to_agent_executor(
# conversation_message_task=conversation_message_task,
# memory=memory,
# rest_tokens=rest_tokens,
# callbacks=[DifyStdOutCallbackHandler()]
# )
#
# if agent_executor:
# if isinstance(agent_executor, MultiDatasetRouterChain):
# chains.append(agent_executor)
# final_output_key = agent_executor.output_keys[0]
# chains.append(agent_chain)
# final_output_key = agent_chain.output_keys[0]
#
# if len(chains) == 0:
# return None
#
# chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
# for chain in chains:
# chain = cast(Chain, chain)
# chain.callbacks.append(chain_callback)
#
# # build main chain
# overall_chain = SequentialChain(
# chains=chains,
# input_variables=[first_input_key],
# output_variables=[final_output_key],
# memory=memory, # only for use the memory prompt input key
# )
#
# return overall_chain
import logging
import time
from typing import Optional, List, Union, Tuple
from langchain.base_language import BaseLanguageModel
......@@ -8,6 +9,8 @@ from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
......@@ -15,13 +18,13 @@ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCa
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
ReadOnlyConversationTokenDBStringBufferSharedMemory
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
......@@ -69,28 +72,52 @@ class Completion:
streaming=streaming
)
# build main chain include agent
main_chain = MainChainBuilder.get_chains(
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=app.tenant_id,
app_model_config=app_model_config,
rest_tokens=rest_tokens_for_context_and_memory,
app_model_config=app_model_config
)
# parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task
rest_tokens=rest_tokens_for_context_and_memory,
callbacks=[chain_callback]
)
chain_output = ''
if main_chain:
chain_output = main_chain.run(query)
# run agent executor
executor_output = ''
is_agent_output = False
if agent_executor:
if isinstance(agent_executor, MultiDatasetRouterChain):
executor_output = agent_executor.run(query)
else:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
executor_output = agent_executor.run(query)
is_agent_output = True
# run the final llm
try:
# if is_agent_output and not app_model_config.pre_prompt:
# # todo streaming flush the agent result to user, not call final llm
# pass
cls.run_final_llm(
tenant_id=app.tenant_id,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
chain_output=chain_output,
chain_output=executor_output,
conversation_message_task=conversation_message_task,
memory=memory,
streaming=streaming
......@@ -137,6 +164,38 @@ class Completion:
return response
# @classmethod
# def simulate_output(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
# conversation_message_task: ConversationMessageTask,
# executor_output: str, memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
# streaming: bool):
# final_llm = LLMBuilder.to_llm_from_model(
# tenant_id=tenant_id,
# model=app_model_config.model_dict,
# streaming=streaming
# )
#
# # get llm prompt
# prompt, stop_words = cls.get_main_llm_prompt(
# mode=mode,
# llm=final_llm,
# pre_prompt=app_model_config.pre_prompt,
# query=query,
# inputs=inputs,
# chain_output=executor_output,
# memory=memory
# )
#
# final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
#
# llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
# llm_callback_handler
# for token in executor_output:
# if token:
# sys.stdout.write(token)
# sys.stdout.flush()
# time.sleep(0.01)
@classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
......
from typing import Optional
from typing import Optional, Union
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
......@@ -23,13 +24,15 @@ from models.model import AppModelConfig
class OrchestratorRuleParser:
"""Parse the orchestrator rule to entities."""
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
self.tenant_id = tenant_id
self.app_model_config = app_model_config
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
def to_agent_chain(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, callbacks: Callbacks = None) -> Optional[Chain]:
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, callbacks: Callbacks = None) \
-> Optional[Union[AgentExecutor | MultiDatasetRouterChain]]:
if not self.app_model_config.agent_mode_dict:
return None
......@@ -61,6 +64,11 @@ class OrchestratorRuleParser:
callbacks=[DifyStdOutCallbackHandler()]
)
tools = self.to_tools(tool_configs, conversation_message_task)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=PlanningStrategy(agent_mode_config.get('strategy')),
llm=agent_llm,
......@@ -73,8 +81,7 @@ class OrchestratorRuleParser:
early_stopping_method="generate"
)
agent_executor = AgentExecutor(agent_configuration)
chain = agent_executor.get_chain()
return AgentExecutor(agent_configuration)
return chain
......@@ -116,7 +123,8 @@ class OrchestratorRuleParser:
return chain
def to_sensitive_word_avoidance_chain(self, **kwargs) -> Optional[SensitiveWordAvoidanceChain]:
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
......@@ -133,7 +141,7 @@ class OrchestratorRuleParser:
sensitive_words=sensitive_words.split(","),
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callbacks=[DifyStdOutCallbackHandler()],
callbacks=callbacks,
**kwargs
)
......@@ -151,14 +159,19 @@ class OrchestratorRuleParser:
for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0]
if not tool_config.get("enabled") or tool_config.get("enabled") is not True:
if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
continue
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task)
tool = None
# tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool()
elif tool_type == "google_search":
tool = self.to_google_search_tool()
elif tool_type == "wikipedia":
tool = self.to_wikipedia_tool()
if tool:
tools.append(tool)
......@@ -226,7 +239,13 @@ class OrchestratorRuleParser:
"is not up to date."
"Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
callbacks=[DifyStdOutCallbackHandler]
callbacks=[DifyStdOutCallbackHandler()]
)
return tool
def to_wikipedia_tool(self) -> Optional[BaseTool]:
return WikipediaQueryRun(
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
callbacks=[DifyStdOutCallbackHandler()]
)
......@@ -16,7 +16,7 @@ from models.dataset import Dataset
class DatasetRetrieverToolInput(BaseModel):
dataset_id: str = Field(..., description="ID of dateset to be queried. MUST be UUID format.")
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
......
......@@ -97,7 +97,7 @@ class WebReaderTool(BaseTool):
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING read_page TOOL, OTHERWISE USE " \
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return page_contents
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment