Commit e062b1ee authored by John Wang's avatar John Wang

fix: dataset query parse error

parent 693b7531
import re
from typing import List, Tuple, Any, Union, Sequence, Optional from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain import BasePromptTemplate from langchain import BasePromptTemplate
...@@ -7,6 +8,7 @@ from langchain.base_language import BaseLanguageModel ...@@ -7,6 +8,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin from langchain.memory.summary import SummarizerMixin
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
...@@ -121,6 +123,35 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -121,6 +123,35 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return self.get_full_inputs([intermediate_steps[-1]], **kwargs) return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
......
...@@ -47,7 +47,8 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): ...@@ -47,7 +47,8 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
# tool_name = serialized.get('name') # tool_name = serialized.get('name')
input_dict = json.loads(input_str.replace("'", "\"")) input_dict = json.loads(input_str.replace("'", "\""))
dataset_id = input_dict.get('dataset_id') dataset_id = input_dict.get('dataset_id')
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str)) query = input_dict.get('query')
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end( def on_tool_end(
self, self,
......
...@@ -2,7 +2,7 @@ from typing import List, Optional, Any, Dict ...@@ -2,7 +2,7 @@ from typing import List, Optional, Any, Dict
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
from pydantic import root_validator from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
...@@ -44,3 +44,16 @@ class StreamableChatAnthropic(ChatAnthropic): ...@@ -44,3 +44,16 @@ class StreamableChatAnthropic(ChatAnthropic):
del params['presence_penalty'] del params['presence_penalty']
return params return params
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
\ No newline at end of file
...@@ -158,7 +158,7 @@ class OrchestratorRuleParser: ...@@ -158,7 +158,7 @@ class OrchestratorRuleParser:
tool = self.to_wikipedia_tool() tool = self.to_wikipedia_tool()
if tool: if tool:
tool.callbacks = callbacks tool.callbacks.extend(callbacks)
tools.append(tool) tools.append(tool)
return tools return tools
...@@ -186,7 +186,7 @@ class OrchestratorRuleParser: ...@@ -186,7 +186,7 @@ class OrchestratorRuleParser:
tool = DatasetRetrieverTool.from_dataset( tool = DatasetRetrieverTool.from_dataset(
dataset=dataset, dataset=dataset,
k=k, k=k,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()] callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
) )
return tool return tool
......
...@@ -32,7 +32,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -32,7 +32,7 @@ class DatasetRetrieverTool(BaseTool):
@classmethod @classmethod
def from_dataset(cls, dataset: Dataset, **kwargs): def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description description = dataset.description.replace('\n', '').replace('\r', '')
if not description: if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name description = 'useful for when you want to answer queries about the ' + dataset.name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment