Commit 7e3229fe authored by John Wang's avatar John Wang

feat: completed agent event pub

parent 7497b47e
...@@ -98,7 +98,7 @@ class AgentExecutor: ...@@ -98,7 +98,7 @@ class AgentExecutor:
max_iterations=self.configuration.max_iterations, max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time, max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method, early_stopping_method=self.configuration.early_stopping_method,
verbose=True callbacks=self.configuration.callbacks
) )
output = agent_executor.run(query) output = agent_executor.run(query)
......
import json
import logging import logging
import time import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
...@@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self.conversation_message_task = conversation_message_task self.conversation_message_task = conversation_message_task
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
self.current_chain = None self.current_chain = None
@property @property
...@@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def clear_agent_loops(self) -> None: def clear_agent_loops(self) -> None:
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
@property @property
def always_verbose(self) -> bool: def always_verbose(self) -> bool:
...@@ -62,7 +66,12 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -62,7 +66,12 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if self._current_loop and self._current_loop.status == 'llm_started': if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end' self._current_loop.status = 'llm_end'
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
self._current_loop.completion = response.generations[0][0].text completion_message = response.generations[0][0].message
if 'function_call' in completion_message.additional_kwargs:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
else:
self._current_loop.completion = response.generations[0][0].text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_error( def on_llm_error(
...@@ -71,6 +80,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -71,6 +80,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error) logging.error(error)
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
def on_tool_start( def on_tool_start(
self, self,
...@@ -90,14 +100,27 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -90,14 +100,27 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Run on agent action.""" """Run on agent action."""
tool = action.tool tool = action.tool
tool_input = action.tool_input tool_input = action.tool_input
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 completion = None
thought = action.log[:action_name_position].strip() if action.log else '' if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction):
thought = action.log.strip()
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
else:
action_name_position = action.log.index("Action:") if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
if self._current_loop and self._current_loop.status == 'llm_end': if self._current_loop and self._current_loop.status == 'llm_end':
self._current_loop.status = 'agent_action' self._current_loop.status = 'agent_action'
self._current_loop.thought = thought self._current_loop.thought = thought
self._current_loop.tool_name = tool self._current_loop.tool_name = tool
self._current_loop.tool_input = tool_input self._current_loop.tool_input = tool_input
if completion is not None:
self._current_loop.completion = completion
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
def on_tool_end( def on_tool_end(
self, self,
...@@ -120,10 +143,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -120,10 +143,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed_at = time.perf_counter() self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop) self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
def on_tool_error( def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
...@@ -132,6 +158,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -132,6 +158,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error) logging.error(error)
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end.""" """Run on agent end."""
...@@ -141,10 +168,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -141,10 +168,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed = True self._current_loop.completed = True
self._current_loop.completed_at = time.perf_counter() self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self._current_loop.thought = '[DONE]'
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop) self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
elif not self._current_loop and self._agent_loops: elif not self._current_loop and self._agent_loops:
self._agent_loops[-1].status = 'agent_finish' self._agent_loops[-1].status = 'agent_finish'
...@@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self._current_chain_result = None self._current_chain_result = None
self._current_chain_message = None self._current_chain_message = None
self.conversation_message_task = conversation_message_task self.conversation_message_task = conversation_message_task
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler( self.agent_callback = None
llm_constant.agent_model_name,
conversation_message_task
)
def clear_chain_results(self) -> None: def clear_chain_results(self) -> None:
self._current_chain_result = None self._current_chain_result = None
self._current_chain_message = None self._current_chain_message = None
self.agent_loop_gather_callback_handler.current_chain = None if self.agent_callback:
self.agent_callback.current_chain = None
@property @property
def always_verbose(self) -> bool: def always_verbose(self) -> bool:
...@@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
started_at=time.perf_counter() started_at=time.perf_counter()
) )
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message if self.agent_callback:
self.agent_callback.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
......
...@@ -88,7 +88,7 @@ class Completion: ...@@ -88,7 +88,7 @@ class Completion:
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
memory=memory, memory=memory,
rest_tokens=rest_tokens_for_context_and_memory, rest_tokens=rest_tokens_for_context_and_memory,
callbacks=[chain_callback] chain_callback=chain_callback
) )
# run agent executor # run agent executor
......
...@@ -52,7 +52,7 @@ class ConversationMessageTask: ...@@ -52,7 +52,7 @@ class ConversationMessageTask:
message=self.message, message=self.message,
conversation=self.conversation, conversation=self.conversation,
chain_pub=False, # disabled currently chain_pub=False, # disabled currently
agent_thought_pub=False # disabled currently agent_thought_pub=True
) )
def init(self): def init(self):
...@@ -207,7 +207,28 @@ class ConversationMessageTask: ...@@ -207,7 +207,28 @@ class ConversationMessageTask:
self._pub_handler.pub_chain(message_chain) self._pub_handler.pub_chain(message_chain)
def on_agent_end(self, message_chain: MessageChain, agent_model_name: str, def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
message_agent_thought = MessageAgentThought(
message_id=self.message.id,
message_chain_id=message_chain.id,
position=agent_loop.position,
thought=agent_loop.thought,
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
message=agent_loop.prompt,
answer=agent_loop.completion,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_thought)
db.session.flush()
self._pub_handler.pub_agent_thought(message_agent_thought)
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt'] agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion'] agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
...@@ -222,34 +243,18 @@ class ConversationMessageTask: ...@@ -222,34 +243,18 @@ class ConversationMessageTask:
agent_answer_unit_price agent_answer_unit_price
) )
message_agent_loop = MessageAgentThought( message_agent_thought.observation = agent_loop.tool_output
message_id=self.message.id, message_agent_thought.tool_process_data = '' # currently not support
message_chain_id=message_chain.id, message_agent_thought.message_token = loop_message_tokens
position=agent_loop.position, message_agent_thought.message_unit_price = agent_message_unit_price
thought=agent_loop.thought, message_agent_thought.answer_token = loop_answer_tokens
tool=agent_loop.tool_name, message_agent_thought.answer_unit_price = agent_answer_unit_price
tool_input=agent_loop.tool_input, message_agent_thought.latency = agent_loop.latency
observation=agent_loop.tool_output, message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
tool_process_data='', # currently not support message_agent_thought.total_price = loop_total_price
message=agent_loop.prompt, message_agent_thought.currency = llm_constant.model_currency
message_token=loop_message_tokens,
message_unit_price=agent_message_unit_price,
answer=agent_loop.completion,
answer_token=loop_answer_tokens,
answer_unit_price=agent_answer_unit_price,
latency=agent_loop.latency,
tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
total_price=loop_total_price,
currency=llm_constant.model_currency,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_loop)
db.session.flush() db.session.flush()
self._pub_handler.pub_agent_thought(message_agent_loop)
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
dataset_query = DatasetQuery( dataset_query = DatasetQuery(
dataset_id=dataset_query_obj.dataset_id, dataset_id=dataset_query_obj.dataset_id,
...@@ -346,16 +351,14 @@ class PubHandler: ...@@ -346,16 +351,14 @@ class PubHandler:
content = { content = {
'event': 'agent_thought', 'event': 'agent_thought',
'data': { 'data': {
'id': message_agent_thought.id,
'task_id': self._task_id, 'task_id': self._task_id,
'message_id': self._message.id, 'message_id': self._message.id,
'chain_id': message_agent_thought.message_chain_id, 'chain_id': message_agent_thought.message_chain_id,
'agent_thought_id': message_agent_thought.id,
'position': message_agent_thought.position, 'position': message_agent_thought.position,
'thought': message_agent_thought.thought, 'thought': message_agent_thought.thought,
'tool': message_agent_thought.tool, 'tool': message_agent_thought.tool,
'tool_input': message_agent_thought.tool_input, 'tool_input': message_agent_thought.tool_input,
'observation': message_agent_thought.observation,
'answer': message_agent_thought.answer,
'mode': self._conversation.mode, 'mode': self._conversation.mode,
'conversation_id': self._conversation.id 'conversation_id': self._conversation.id
} }
......
...@@ -2,12 +2,15 @@ import math ...@@ -2,12 +2,15 @@ import math
from typing import Optional from typing import Optional
from langchain import WikipediaAPIWrapper from langchain import WikipediaAPIWrapper
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
...@@ -31,7 +34,7 @@ class OrchestratorRuleParser: ...@@ -31,7 +34,7 @@ class OrchestratorRuleParser:
self.agent_summary_model_name = "gpt-3.5-turbo-16k" self.agent_summary_model_name = "gpt-3.5-turbo-16k"
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, callbacks: Callbacks = None) \ rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
-> Optional[AgentExecutor]: -> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict: if not self.app_model_config.agent_mode_dict:
return None return None
...@@ -43,12 +46,20 @@ class OrchestratorRuleParser: ...@@ -43,12 +46,20 @@ class OrchestratorRuleParser:
tool_configs = agent_mode_config.get('tools', []) tool_configs = agent_mode_config.get('tools', [])
agent_model_name = agent_mode_config.get('model_name', 'gpt-4') agent_model_name = agent_mode_config.get('model_name', 'gpt-4')
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_name=agent_model_name,
conversation_message_task=conversation_message_task
)
chain_callback.agent_callback = agent_callback
agent_llm = LLMBuilder.to_llm( agent_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
model_name=agent_model_name, model_name=agent_model_name,
temperature=0, temperature=0,
max_tokens=800, max_tokens=1000,
callbacks=[DifyStdOutCallbackHandler()] callbacks=[agent_callback, DifyStdOutCallbackHandler()]
) )
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router')) planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
...@@ -66,7 +77,12 @@ class OrchestratorRuleParser: ...@@ -66,7 +77,12 @@ class OrchestratorRuleParser:
callbacks=[DifyStdOutCallbackHandler()] callbacks=[DifyStdOutCallbackHandler()]
) )
tools = self.to_tools(tool_configs, conversation_message_task, rest_tokens) tools = self.to_tools(
tool_configs=tool_configs,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
)
if len(tools) == 0: if len(tools) == 0:
return None return None
...@@ -77,7 +93,7 @@ class OrchestratorRuleParser: ...@@ -77,7 +93,7 @@ class OrchestratorRuleParser:
tools=tools, tools=tools,
summary_llm=summary_llm, summary_llm=summary_llm,
memory=memory, memory=memory,
callbacks=callbacks, callbacks=[chain_callback, agent_callback],
max_iterations=6, max_iterations=6,
max_execution_time=None, max_execution_time=None,
early_stopping_method="generate" early_stopping_method="generate"
...@@ -112,13 +128,14 @@ class OrchestratorRuleParser: ...@@ -112,13 +128,14 @@ class OrchestratorRuleParser:
return None return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask, def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int) -> list[BaseTool]: rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
""" """
Convert app agent tool configs to tools Convert app agent tool configs to tools
:param rest_tokens: :param rest_tokens:
:param tool_configs: app agent tool configs :param tool_configs: app agent tool configs
:param conversation_message_task: :param conversation_message_task:
:param callbacks:
:return: :return:
""" """
tools = [] tools = []
...@@ -139,6 +156,7 @@ class OrchestratorRuleParser: ...@@ -139,6 +156,7 @@ class OrchestratorRuleParser:
tool = self.to_wikipedia_tool() tool = self.to_wikipedia_tool()
if tool: if tool:
tool.callbacks = callbacks
tools.append(tool) tools.append(tool)
return tools return tools
......
...@@ -468,16 +468,14 @@ class CompletionService: ...@@ -468,16 +468,14 @@ class CompletionService:
def get_agent_thought_response_data(cls, data: dict): def get_agent_thought_response_data(cls, data: dict):
response_data = { response_data = {
'event': 'agent_thought', 'event': 'agent_thought',
'id': data.get('agent_thought_id'), 'id': data.get('id'),
'chain_id': data.get('chain_id'), 'chain_id': data.get('chain_id'),
'task_id': data.get('task_id'), 'task_id': data.get('task_id'),
'message_id': data.get('message_id'), 'message_id': data.get('message_id'),
'position': data.get('position'), 'position': data.get('position'),
'thought': data.get('thought'), 'thought': data.get('thought'),
'tool': data.get('tool'), # todo use real dataset obj replace it 'tool': data.get('tool'),
'tool_input': data.get('tool_input'), 'tool_input': data.get('tool_input'),
'observation': data.get('observation'),
'answer': data.get('answer') if not data.get('thought') else '',
'created_at': int(time.time()) 'created_at': int(time.time())
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment