Commit 7e3229fe authored by John Wang's avatar John Wang

feat: completed agent event pub

parent 7497b47e
......@@ -98,7 +98,7 @@ class AgentExecutor:
max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method,
verbose=True
callbacks=self.configuration.callbacks
)
output = agent_executor.run(query)
......
import json
import logging
import time
from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
......@@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
self.current_chain = None
@property
......@@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def clear_agent_loops(self) -> None:
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
@property
def always_verbose(self) -> bool:
......@@ -62,6 +66,11 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
completion_message = response.generations[0][0].message
if 'function_call' in completion_message.additional_kwargs:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
else:
self._current_loop.completion = response.generations[0][0].text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
......@@ -71,6 +80,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
def on_tool_start(
self,
......@@ -90,7 +100,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Run on agent action."""
tool = action.tool
tool_input = action.tool_input
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
completion = None
if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction):
thought = action.log.strip()
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
else:
action_name_position = action.log.index("Action:") if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
if self._current_loop and self._current_loop.status == 'llm_end':
......@@ -98,6 +114,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.thought = thought
self._current_loop.tool_name = tool
self._current_loop.tool_input = tool_input
if completion is not None:
self._current_loop.completion = completion
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
def on_tool_end(
self,
......@@ -120,10 +143,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop)
self._current_loop = None
self._message_agent_thought = None
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
......@@ -132,6 +158,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
......@@ -141,10 +168,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed = True
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self._current_loop.thought = '[DONE]'
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop)
self._current_loop = None
self._message_agent_thought = None
elif not self._current_loop and self._agent_loops:
self._agent_loops[-1].status = 'agent_finish'
......@@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self._current_chain_result = None
self._current_chain_message = None
self.conversation_message_task = conversation_message_task
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler(
llm_constant.agent_model_name,
conversation_message_task
)
self.agent_callback = None
def clear_chain_results(self) -> None:
self._current_chain_result = None
self._current_chain_message = None
self.agent_loop_gather_callback_handler.current_chain = None
if self.agent_callback:
self.agent_callback.current_chain = None
@property
def always_verbose(self) -> bool:
......@@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
if self.agent_callback:
self.agent_callback.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
......
......@@ -88,7 +88,7 @@ class Completion:
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
callbacks=[chain_callback]
chain_callback=chain_callback
)
# run agent executor
......
......@@ -52,7 +52,7 @@ class ConversationMessageTask:
message=self.message,
conversation=self.conversation,
chain_pub=False, # disabled currently
agent_thought_pub=False # disabled currently
agent_thought_pub=True
)
def init(self):
......@@ -207,7 +207,28 @@ class ConversationMessageTask:
self._pub_handler.pub_chain(message_chain)
def on_agent_end(self, message_chain: MessageChain, agent_model_name: str,
def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
message_agent_thought = MessageAgentThought(
message_id=self.message.id,
message_chain_id=message_chain.id,
position=agent_loop.position,
thought=agent_loop.thought,
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
message=agent_loop.prompt,
answer=agent_loop.completion,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_thought)
db.session.flush()
self._pub_handler.pub_agent_thought(message_agent_thought)
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
......@@ -222,34 +243,18 @@ class ConversationMessageTask:
agent_answer_unit_price
)
message_agent_loop = MessageAgentThought(
message_id=self.message.id,
message_chain_id=message_chain.id,
position=agent_loop.position,
thought=agent_loop.thought,
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
observation=agent_loop.tool_output,
tool_process_data='', # currently not support
message=agent_loop.prompt,
message_token=loop_message_tokens,
message_unit_price=agent_message_unit_price,
answer=agent_loop.completion,
answer_token=loop_answer_tokens,
answer_unit_price=agent_answer_unit_price,
latency=agent_loop.latency,
tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
total_price=loop_total_price,
currency=llm_constant.model_currency,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_loop)
message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = agent_message_unit_price
message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = agent_answer_unit_price
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = llm_constant.model_currency
db.session.flush()
self._pub_handler.pub_agent_thought(message_agent_loop)
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
dataset_query = DatasetQuery(
dataset_id=dataset_query_obj.dataset_id,
......@@ -346,16 +351,14 @@ class PubHandler:
content = {
'event': 'agent_thought',
'data': {
'id': message_agent_thought.id,
'task_id': self._task_id,
'message_id': self._message.id,
'chain_id': message_agent_thought.message_chain_id,
'agent_thought_id': message_agent_thought.id,
'position': message_agent_thought.position,
'thought': message_agent_thought.thought,
'tool': message_agent_thought.tool,
'tool_input': message_agent_thought.tool_input,
'observation': message_agent_thought.observation,
'answer': message_agent_thought.answer,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
......
......@@ -2,12 +2,15 @@ import math
from typing import Optional
from langchain import WikipediaAPIWrapper
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask
......@@ -31,7 +34,7 @@ class OrchestratorRuleParser:
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, callbacks: Callbacks = None) \
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
-> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict:
return None
......@@ -43,12 +46,20 @@ class OrchestratorRuleParser:
tool_configs = agent_mode_config.get('tools', [])
agent_model_name = agent_mode_config.get('model_name', 'gpt-4')
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_name=agent_model_name,
conversation_message_task=conversation_message_task
)
chain_callback.agent_callback = agent_callback
agent_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=agent_model_name,
temperature=0,
max_tokens=800,
callbacks=[DifyStdOutCallbackHandler()]
max_tokens=1000,
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
)
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
......@@ -66,7 +77,12 @@ class OrchestratorRuleParser:
callbacks=[DifyStdOutCallbackHandler()]
)
tools = self.to_tools(tool_configs, conversation_message_task, rest_tokens)
tools = self.to_tools(
tool_configs=tool_configs,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
)
if len(tools) == 0:
return None
......@@ -77,7 +93,7 @@ class OrchestratorRuleParser:
tools=tools,
summary_llm=summary_llm,
memory=memory,
callbacks=callbacks,
callbacks=[chain_callback, agent_callback],
max_iterations=6,
max_execution_time=None,
early_stopping_method="generate"
......@@ -112,13 +128,14 @@ class OrchestratorRuleParser:
return None
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
rest_tokens: int) -> list[BaseTool]:
rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param rest_tokens:
:param tool_configs: app agent tool configs
:param conversation_message_task:
:param callbacks:
:return:
"""
tools = []
......@@ -139,6 +156,7 @@ class OrchestratorRuleParser:
tool = self.to_wikipedia_tool()
if tool:
tool.callbacks = callbacks
tools.append(tool)
return tools
......
......@@ -468,16 +468,14 @@ class CompletionService:
def get_agent_thought_response_data(cls, data: dict):
response_data = {
'event': 'agent_thought',
'id': data.get('agent_thought_id'),
'id': data.get('id'),
'chain_id': data.get('chain_id'),
'task_id': data.get('task_id'),
'message_id': data.get('message_id'),
'position': data.get('position'),
'thought': data.get('thought'),
'tool': data.get('tool'), # todo use real dataset obj replace it
'tool': data.get('tool'),
'tool_input': data.get('tool_input'),
'observation': data.get('observation'),
'answer': data.get('answer') if not data.get('thought') else '',
'created_at': int(time.time())
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment