Commit f11f66ed authored by John Wang's avatar John Wang

Merge branch 'feat/universal-chat' into deploy/dev

parents 27401bc2 e062b1ee
import re
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain import BasePromptTemplate
......@@ -7,6 +8,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
......@@ -121,6 +123,35 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
......
......@@ -47,7 +47,8 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
# tool_name = serialized.get('name')
input_dict = json.loads(input_str.replace("'", "\""))
dataset_id = input_dict.get('dataset_id')
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
query = input_dict.get('query')
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end(
self,
......
......@@ -2,7 +2,7 @@ from typing import List, Optional, Any, Dict
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
......@@ -44,3 +44,16 @@ class StreamableChatAnthropic(ChatAnthropic):
del params['presence_penalty']
return params
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
\ No newline at end of file
......@@ -158,7 +158,7 @@ class OrchestratorRuleParser:
tool = self.to_wikipedia_tool()
if tool:
tool.callbacks = callbacks
tool.callbacks.extend(callbacks)
tools.append(tool)
return tools
......@@ -186,7 +186,7 @@ class OrchestratorRuleParser:
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
k=k,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
)
return tool
......
......@@ -32,7 +32,7 @@ class DatasetRetrieverTool(BaseTool):
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
description = dataset.description.replace('\n', '').replace('\r', '')
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment