Commit 6921ee5d authored by John Wang's avatar John Wang

fix: dataset query when single dataset

parent 77146b50
from typing import Tuple, List, Any, Union, Sequence, Optional from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
...@@ -42,7 +42,9 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -42,7 +42,9 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
if len(self.tools) == 0: if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='') return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1: elif len(self.tools) == 1:
rst = next(iter(self.tools)).run(kwargs['input']) tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst) return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps: if intermediate_steps:
......
...@@ -103,6 +103,7 @@ class Completion: ...@@ -103,6 +103,7 @@ class Completion:
# # todo streaming flush the agent result to user, not call final llm # # todo streaming flush the agent result to user, not call final llm
# pass # pass
# todo or use fake llm
cls.run_final_llm( cls.run_final_llm(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
mode=app.mode, mode=app.mode,
......
...@@ -27,6 +27,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -27,6 +27,7 @@ class DatasetRetrieverTool(BaseTool):
description: str = "use this to retrieve a dataset. " description: str = "use this to retrieve a dataset. "
tenant_id: str tenant_id: str
dataset_id: str
k: int = 3 k: int = 3
@classmethod @classmethod
...@@ -38,6 +39,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -38,6 +39,7 @@ class DatasetRetrieverTool(BaseTool):
description += '\nID of dataset MUST be ' + dataset.id description += '\nID of dataset MUST be ' + dataset.id
return cls( return cls(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description, description=description,
**kwargs **kwargs
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment