Commit d7712cf7 authored by John Wang's avatar John Wang

feat: optimize agents

parent 02a42a7f
from typing import List, Tuple, Any, Union from datetime import datetime
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain.agents import OpenAIFunctionsAgent import pytz
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
_format_intermediate_steps _format_intermediate_steps
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import AgentAction, AgentFinish from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
...@@ -12,6 +17,48 @@ from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunction ...@@ -12,6 +17,48 @@ from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunction
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin): class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 6
agent_decision = self.plan(
intermediate_steps=[],
callbacks=None,
input=query
)
self.llm.max_tokens = original_max_tokens
return isinstance(agent_decision, AgentAction)
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
...@@ -46,3 +93,15 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio ...@@ -46,3 +93,15 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
) )
agent_decision = _parse_ai_message(predicted_message) agent_decision = _parse_ai_message(predicted_message)
return agent_decision return agent_decision
@classmethod
def get_system_message(cls):
# get current time
current_time = datetime.now()
current_timezone = pytz.timezone('UTC')
current_time = current_timezone.localize(current_time)
return SystemMessage(content="You are a helpful AI assistant.\n"
"Current time: {}\n"
"Respond directly if appropriate.".format(
current_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")))
from typing import List, Tuple, Any, Union from datetime import datetime
from typing import List, Tuple, Any, Union, Sequence, Optional
import pytz
from langchain.agents import BaseMultiActionAgent
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
_parse_ai_message _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import AgentAction, AgentFinish from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
...@@ -11,6 +17,47 @@ from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunction ...@@ -11,6 +17,47 @@ from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunction
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin): class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseMultiActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 6
agent_decision = self.plan(
intermediate_steps=[],
callbacks=None,
input=query
)
self.llm.max_tokens = original_max_tokens
return isinstance(agent_decision, AgentAction)
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
...@@ -45,3 +92,14 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope ...@@ -45,3 +92,14 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
) )
agent_decision = _parse_ai_message(predicted_message) agent_decision = _parse_ai_message(predicted_message)
return agent_decision return agent_decision
def get_system_message(self):
# get current time
current_time = datetime.now()
current_timezone = pytz.timezone('UTC')
current_time = current_timezone.localize(current_time)
return SystemMessage(content="You are a helpful AI assistant.\n"
"Current time: {}\n"
"Respond directly if appropriate.".format(
current_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")))
...@@ -14,6 +14,17 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -14,6 +14,17 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_index: int = 0 moving_summary_index: int = 0
summary_llm: BaseLanguageModel summary_llm: BaseLanguageModel
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
......
import enum import enum
from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMemory
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
class PlanningStrategy(str, enum.Enum): class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT = 'react' REACT = 'react'
FUNCTION_CALL = 'function_call' FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
class AgentExecutor: class AgentExecutor:
def __init__(self, strategy: PlanningStrategy, model: BaseLanguageModel, tools: list[BaseTool], def __init__(self, strategy: PlanningStrategy, llm: BaseLanguageModel, tools: list[BaseTool],
memory: BaseMemory, callbacks: Callbacks = None, summary_llm: BaseLanguageModel, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_iterations: int = 6, early_stopping_method: str = "generate"): callbacks: Callbacks = None, max_iterations: int = 6, max_execution_time: Optional[float] = None,
early_stopping_method: str = "generate"):
self.strategy = strategy self.strategy = strategy
self.model = model self.llm = llm
self.tools = tools self.tools = tools
self.summary_llm = summary_llm
self.memory = memory self.memory = memory
self.callbacks = callbacks self.callbacks = callbacks
self.agent = self._init_agent(strategy, llm, tools, memory, callbacks)
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.max_execution_time = max_execution_time
self.early_stopping_method = early_stopping_method self.early_stopping_method = early_stopping_method
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
# summary_llm: StreamableChatOpenAI = LLMBuilder.to_llm(
# tenant_id=tenant_id,
# model_name='gpt-3.5-turbo-16k',
# max_tokens=300
# )
def _init_agent(self, strategy: PlanningStrategy, llm: BaseLanguageModel, tools: list[BaseTool],
memory: ReadOnlyConversationTokenDBBufferSharedMemory, callbacks: Callbacks = None) \
-> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
summary_llm=self.summary_llm,
verbose=True
)
elif strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent(
llm=llm,
tools=tools,
extra_prompt_messages=memory.buffer, # used for read chat histories memory
summary_llm=self.summary_llm,
verbose=True
)
elif strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent(
llm=llm,
tools=tools,
extra_prompt_messages=memory.buffer, # used for read chat histories memory
summary_llm=self.summary_llm,
verbose=True
)
return agent
def should_use_agent(self, query: str) -> bool: def should_use_agent(self, query: str) -> bool:
pass return self.agent.should_use_agent(query)
def run(self, query: str) -> str: def run(self, query: str) -> str:
pass agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.tools,
memory=self.memory,
max_iterations=self.max_iterations,
max_execution_time=self.max_execution_time,
early_stopping_method=self.early_stopping_method,
verbose=True
)
# run agent
result = agent_executor.run(
query,
callbacks=self.callbacks
)
return result
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment