Unverified Commit ba3dc8ca authored by John Wang's avatar John Wang Committed by GitHub

feat: fix dataset retrieve agent llm not support error (#656)

parent ae7c0380
...@@ -73,7 +73,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -73,7 +73,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
), ),
**kwargs: Any, **kwargs: Any,
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:
llm.model_name = 'gpt-3.5-turbo'
return super().from_llm_and_tools( return super().from_llm_and_tools(
llm=llm, llm=llm,
tools=tools, tools=tools,
......
...@@ -31,6 +31,7 @@ class AgentConfiguration(BaseModel): ...@@ -31,6 +31,7 @@ class AgentConfiguration(BaseModel):
llm: BaseLanguageModel llm: BaseLanguageModel
tools: list[BaseTool] tools: list[BaseTool]
summary_llm: BaseLanguageModel summary_llm: BaseLanguageModel
dataset_llm: BaseLanguageModel
memory: Optional[BaseChatMemory] = None memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None callbacks: Callbacks = None
max_iterations: int = 6 max_iterations: int = 6
...@@ -84,7 +85,7 @@ class AgentExecutor: ...@@ -84,7 +85,7 @@ class AgentExecutor:
elif self.configuration.strategy == PlanningStrategy.ROUTER: elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools( agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.llm, llm=self.configuration.dataset_llm,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True verbose=True
......
...@@ -32,6 +32,7 @@ class OrchestratorRuleParser: ...@@ -32,6 +32,7 @@ class OrchestratorRuleParser:
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.app_model_config = app_model_config self.app_model_config = app_model_config
self.agent_summary_model_name = "gpt-3.5-turbo-16k" self.agent_summary_model_name = "gpt-3.5-turbo-16k"
self.dataset_retrieve_model_name = "gpt-3.5-turbo"
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \ rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
...@@ -89,11 +90,20 @@ class OrchestratorRuleParser: ...@@ -89,11 +90,20 @@ class OrchestratorRuleParser:
if len(tools) == 0: if len(tools) == 0:
return None return None
dataset_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=self.dataset_retrieve_model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
agent_configuration = AgentConfiguration( agent_configuration = AgentConfiguration(
strategy=planning_strategy, strategy=planning_strategy,
llm=agent_llm, llm=agent_llm,
tools=tools, tools=tools,
summary_llm=summary_llm, summary_llm=summary_llm,
dataset_llm=dataset_llm,
memory=memory, memory=memory,
callbacks=[chain_callback, agent_callback], callbacks=[chain_callback, agent_callback],
max_iterations=10, max_iterations=10,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment