Unverified Commit 2dfb3e95 authored by takatost's avatar takatost Committed by GitHub

feat: optimize error record in agent (#869)

parent f207e180
...@@ -59,7 +59,11 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -59,7 +59,11 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
_, observation = intermediate_steps[-1] _, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation) return AgentFinish(return_values={"output": observation}, log=observation)
return super().plan(intermediate_steps, callbacks, **kwargs) try:
return super().plan(intermediate_steps, callbacks, **kwargs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
async def aplan( async def aplan(
self, self,
......
...@@ -50,9 +50,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio ...@@ -50,9 +50,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages() messages = prompt.to_messages()
predicted_message = self.llm.predict_messages( try:
messages, functions=self.functions, callbacks=None predicted_message = self.llm.predict_messages(
) messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {}) function_call = predicted_message.additional_kwargs.get("function_call", {})
......
...@@ -50,9 +50,13 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope ...@@ -50,9 +50,13 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages() messages = prompt.to_messages()
predicted_message = self.llm.predict_messages( try:
messages, functions=self.functions, callbacks=None predicted_message = self.llm.predict_messages(
) messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {}) function_call = predicted_message.additional_kwargs.get("function_call", {})
......
...@@ -94,7 +94,12 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): ...@@ -94,7 +94,12 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return AgentFinish(return_values={"output": rst}, log=rst) return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
try: try:
return self.output_parser.parse(full_output) return self.output_parser.parse(full_output)
......
...@@ -89,8 +89,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -89,8 +89,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
Action specifying what tool to use. Action specifying what tool to use.
""" """
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)]) prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = [] messages = []
if prompts: if prompts:
messages = prompts[0].to_messages() messages = prompts[0].to_messages()
...@@ -99,7 +99,11 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -99,7 +99,11 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if rest_tokens < 0: if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs) full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
try: try:
return self.output_parser.parse(full_output) return self.output_parser.parse(full_output)
......
...@@ -85,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -85,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
logging.exception(error) logging.debug("Agent on_llm_error: %s", error)
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None self._message_agent_thought = None
...@@ -164,7 +164,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -164,7 +164,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
"""Do nothing.""" """Do nothing."""
logging.exception(error) logging.debug("Agent on_tool_error: %s", error)
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None self._message_agent_thought = None
......
...@@ -68,4 +68,4 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): ...@@ -68,4 +68,4 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
"""Do nothing.""" """Do nothing."""
logging.exception(error) logging.debug("Dataset tool on_llm_error: %s", error)
...@@ -72,5 +72,5 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -72,5 +72,5 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
def on_chain_error( def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
logging.exception(error) logging.debug("Dataset tool on_chain_error: %s", error)
self.clear_chain_results() self.clear_chain_results()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment