Unverified Commit 90150a6c authored by John Wang's avatar John Wang Committed by GitHub

Feat/optimize chat prompt (#158)

parent 7722a7c5
...@@ -39,7 +39,8 @@ class Completion: ...@@ -39,7 +39,8 @@ class Completion:
memory = cls.get_memory_from_conversation( memory = cls.get_memory_from_conversation(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
app_model_config=app_model_config, app_model_config=app_model_config,
conversation=conversation conversation=conversation,
return_messages=False
) )
inputs = conversation.inputs inputs = conversation.inputs
...@@ -119,7 +120,8 @@ class Completion: ...@@ -119,7 +120,8 @@ class Completion:
return response return response
@classmethod @classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str], def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Union[str | List[BaseMessage]]: Union[str | List[BaseMessage]]:
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
...@@ -161,11 +163,19 @@ And answer according to the language of the user's question. ...@@ -161,11 +163,19 @@ And answer according to the language of the user's question.
"query": query "query": query
} }
human_message_prompt = "{query}" human_message_prompt = ""
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if chain_output: if chain_output:
human_inputs['context'] = chain_output human_inputs['context'] = chain_output
human_message_instruction = """Use the following CONTEXT as your learned knowledge. human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT] [CONTEXT]
{context} {context}
[END CONTEXT] [END CONTEXT]
...@@ -176,39 +186,33 @@ When answer to user: ...@@ -176,39 +186,33 @@ When answer to user:
Avoid mentioning that you obtained the information from the context. Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question. And answer according to the language of the user's question.
""" """
if pre_prompt:
extra_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if extra_inputs:
human_inputs.update(extra_inputs)
human_message_instruction += pre_prompt + "\n"
human_message_prompt = human_message_instruction + "Q:{query}\nA:"
else:
if pre_prompt:
extra_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if extra_inputs:
human_inputs.update(extra_inputs)
human_message_prompt = pre_prompt + "\n" + human_message_prompt
# construct main prompt if pre_prompt:
human_message = PromptBuilder.to_human_message( human_message_prompt += pre_prompt
prompt_content=human_message_prompt,
inputs=human_inputs query_prompt = "\nHuman: {query}\nAI: "
)
if memory: if memory:
# append chat histories # append chat histories
tmp_messages = messages.copy() + [human_message] tmp_human_message = PromptBuilder.to_human_message(
curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages) prompt_content=human_message_prompt + query_prompt,
rest_tokens = llm_constant.max_context_token_length[ inputs=human_inputs
memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens )
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
- memory.llm.max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0) rest_tokens = max(rest_tokens, 0)
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens) history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
messages += history_messages human_message_prompt += "\n\n" + history_messages
human_message_prompt += query_prompt
# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)
messages.append(human_message) messages.append(human_message)
...@@ -216,7 +220,8 @@ And answer according to the language of the user's question. ...@@ -216,7 +220,8 @@ And answer according to the language of the user's question.
@classmethod @classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager: streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming: if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
...@@ -228,7 +233,7 @@ And answer according to the language of the user's question. ...@@ -228,7 +233,7 @@ And answer according to the language of the user's question.
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> \ max_token_limit: int) -> \
List[BaseMessage]: str:
"""Get memory messages.""" """Get memory messages."""
memory.max_token_limit = max_token_limit memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0] memory_key = memory.memory_variables[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment