Commit aa13b8db authored by John Wang's avatar John Wang

fix: agent parse error

parent afd8996b
......@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
......@@ -66,12 +66,17 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
completion_message = response.generations[0][0].message
completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message
if 'function_call' in completion_message.additional_kwargs:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
else:
self._current_loop.completion = response.generations[0][0].text
else:
self._current_loop.completion = completion_generation.text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_error(
......
......@@ -30,7 +30,6 @@ class OrchestratorRuleParser:
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
self.tenant_id = tenant_id
self.app_model_config = app_model_config
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
......@@ -71,7 +70,7 @@ class OrchestratorRuleParser:
summary_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=self.agent_summary_model_name,
model_name=agent_model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
......
......@@ -316,7 +316,7 @@ class AppModelConfigService:
if "strategy" not in config["agent_mode"] or not config["agent_mode"]["strategy"]:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in PlanningStrategy.__members__:
if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment