Commit 9f95d972 authored by takatost's avatar takatost

fix bugs

parent 0feabefd
...@@ -39,6 +39,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -39,6 +39,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = self._get_completion_model_prompt_messages( prompt_messages = self._get_completion_model_prompt_messages(
prompt_template_entity=prompt_template_entity, prompt_template_entity=prompt_template_entity,
inputs=inputs, inputs=inputs,
query=query,
files=files, files=files,
context=context, context=context,
memory=memory, memory=memory,
...@@ -60,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -60,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform):
def _get_completion_model_prompt_messages(self, def _get_completion_model_prompt_messages(self,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict, inputs: dict,
query: Optional[str],
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
...@@ -86,6 +88,9 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -86,6 +88,9 @@ class AdvancedPromptTransform(PromptTransform):
model_config=model_config model_config=model_config
) )
if query:
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
prompt = prompt_template.format( prompt = prompt_template.format(
prompt_inputs prompt_inputs
) )
...@@ -147,21 +152,30 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -147,21 +152,30 @@ class AdvancedPromptTransform(PromptTransform):
else: else:
prompt_messages.append(UserPromptMessage(content=query)) prompt_messages.append(UserPromptMessage(content=query))
elif files: elif files:
# get last message if not query:
last_message = prompt_messages[-1] if prompt_messages else None # get last message
if last_message and last_message.role == PromptMessageRole.USER: last_message = prompt_messages[-1] if prompt_messages else None
# get last user message content and add files if last_message and last_message.role == PromptMessageRole.USER:
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] # get last user message content and add files
for file in files: prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
prompt_message_contents.append(file.prompt_message_content) for file in files:
prompt_message_contents.append(file.prompt_message_content)
last_message.content = prompt_message_contents
last_message.content = prompt_message_contents
else:
prompt_message_contents = [TextPromptMessageContent(data='')] # not for query
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:
prompt_message_contents = [TextPromptMessageContent(data='')] # not for query prompt_message_contents = [TextPromptMessageContent(data=query)]
for file in files: for file in files:
prompt_message_contents.append(file.prompt_message_content) prompt_message_contents.append(file.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
elif query:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages return prompt_messages
...@@ -210,4 +224,4 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -210,4 +224,4 @@ class AdvancedPromptTransform(PromptTransform):
else: else:
prompt_inputs['#histories#'] = '' prompt_inputs['#histories#'] = ''
return prompt_inputs return prompt_inputs
...@@ -50,6 +50,7 @@ def test__get_completion_model_prompt_messages(): ...@@ -50,6 +50,7 @@ def test__get_completion_model_prompt_messages():
prompt_messages = prompt_transform._get_completion_model_prompt_messages( prompt_messages = prompt_transform._get_completion_model_prompt_messages(
prompt_template_entity=prompt_template_entity, prompt_template_entity=prompt_template_entity,
inputs=inputs, inputs=inputs,
query=None,
files=files, files=files,
context=context, context=context,
memory=memory, memory=memory,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment