Unverified Commit 90150a6c authored by John Wang's avatar John Wang Committed by GitHub

Feat/optimize chat prompt (#158)

parent 7722a7c5
......@@ -39,7 +39,8 @@ class Completion:
memory = cls.get_memory_from_conversation(
tenant_id=app.tenant_id,
app_model_config=app_model_config,
conversation=conversation
conversation=conversation,
return_messages=False
)
inputs = conversation.inputs
......@@ -119,7 +120,8 @@ class Completion:
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Union[str | List[BaseMessage]]:
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
......@@ -161,11 +163,19 @@ And answer according to the language of the user's question.
"query": query
}
human_message_prompt = "{query}"
human_message_prompt = ""
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if chain_output:
human_inputs['context'] = chain_output
human_message_instruction = """Use the following CONTEXT as your learned knowledge.
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT]
{context}
[END CONTEXT]
......@@ -176,39 +186,33 @@ When answer to user:
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
"""
if pre_prompt:
extra_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if extra_inputs:
human_inputs.update(extra_inputs)
human_message_instruction += pre_prompt + "\n"
human_message_prompt = human_message_instruction + "Q:{query}\nA:"
else:
if pre_prompt:
extra_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if extra_inputs:
human_inputs.update(extra_inputs)
human_message_prompt = pre_prompt + "\n" + human_message_prompt
# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\nHuman: {query}\nAI: "
if memory:
# append chat histories
tmp_messages = messages.copy() + [human_message]
curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
rest_tokens = llm_constant.max_context_token_length[
memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt + query_prompt,
inputs=human_inputs
)
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
- memory.llm.max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
messages += history_messages
human_message_prompt += "\n\n" + history_messages
human_message_prompt += query_prompt
# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)
messages.append(human_message)
......@@ -216,7 +220,8 @@ And answer according to the language of the user's question.
@classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
......@@ -228,7 +233,7 @@ And answer according to the language of the user's question.
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> \
List[BaseMessage]:
str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment