Commit df66cd22 authored by takatost's avatar takatost

fix prompt transform bugs

parent a44d3c3e
......@@ -20,7 +20,7 @@ from core.prompt.prompt_transform import PromptTransform
from core.prompt.simple_prompt_transform import ModelMode
class AdvancePromptTransform(PromptTransform):
class AdvancedPromptTransform(PromptTransform):
"""
Advanced Prompt Transform for Workflow LLM Node.
"""
......@@ -74,10 +74,10 @@ class AdvancePromptTransform(PromptTransform):
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix
self._set_histories_variable(
prompt_inputs = self._set_histories_variable(
memory=memory,
raw_prompt=raw_prompt,
role_prefix=role_prefix,
......@@ -104,7 +104,7 @@ class AdvancePromptTransform(PromptTransform):
def _get_chat_model_prompt_messages(self,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
query: Optional[str],
files: list[FileObj],
context: Optional[str],
memory: Optional[TokenBufferMemory],
......@@ -122,7 +122,7 @@ class AdvancePromptTransform(PromptTransform):
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
prompt = prompt_template.format(
prompt_inputs
......@@ -136,7 +136,7 @@ class AdvancePromptTransform(PromptTransform):
prompt_messages.append(AssistantPromptMessage(content=prompt))
if memory:
self._append_chat_histories(memory, prompt_messages, model_config)
prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config)
if files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
......@@ -157,7 +157,7 @@ class AdvancePromptTransform(PromptTransform):
last_message.content = prompt_message_contents
else:
prompt_message_contents = [TextPromptMessageContent(data=query)]
prompt_message_contents = [TextPromptMessageContent(data='')] # not for query
for file in files:
prompt_message_contents.append(file.prompt_message_content)
......@@ -165,26 +165,30 @@ class AdvancePromptTransform(PromptTransform):
return prompt_messages
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
if '#context#' in prompt_template.variable_keys:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
return prompt_inputs
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
if '#query#' in prompt_template.variable_keys:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''
return prompt_inputs
def _set_histories_variable(self, memory: TokenBufferMemory,
raw_prompt: str,
role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity,
prompt_template: PromptTemplateParser,
prompt_inputs: dict,
model_config: ModelConfigEntity) -> None:
model_config: ModelConfigEntity) -> dict:
if '#histories#' in prompt_template.variable_keys:
if memory:
inputs = {'#histories#': '', **prompt_inputs}
......@@ -205,3 +209,5 @@ class AdvancePromptTransform(PromptTransform):
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''
return prompt_inputs
......@@ -10,12 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
class PromptTransform:
def _append_chat_histories(self, memory: TokenBufferMemory,
prompt_messages: list[PromptMessage],
model_config: ModelConfigEntity) -> None:
model_config: ModelConfigEntity) -> list[PromptMessage]:
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories)
return prompt_messages
def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int:
rest_tokens = 2000
......
......@@ -177,7 +177,7 @@ class SimplePromptTransform(PromptTransform):
if prompt:
prompt_messages.append(SystemPromptMessage(content=prompt))
self._append_chat_histories(
prompt_messages = self._append_chat_histories(
memory=memory,
prompt_messages=prompt_messages,
model_config=model_config
......
from unittest.mock import MagicMock
import pytest
from core.entities.application_entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \
ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity
from core.file.file_obj import FileObj, FileType, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.prompt_template import PromptTemplateParser
from models.model import Conversation
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-3.5-turbo-instruct'
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt=prompt_template,
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(
user="Human",
assistant="Assistant"
)
)
)
inputs = {
"name": "John"
}
files = []
context = "I am superman."
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
history_prompt_messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Hello")
]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_completion_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
context=context,
memory=memory,
model_config=model_config_mock
)
assert len(prompt_messages) == 1
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({
"#context#": context,
"#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: "
f"{prompt.content}" for prompt in history_prompt_messages]),
**inputs,
})
def test__get_chat_model_prompt_messages(get_chat_model_args):
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
files = []
query = "Hi2."
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
history_prompt_messages = [
UserPromptMessage(content="Hi1."),
AssistantPromptMessage(content="Hello1!")
]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config_mock
)
assert len(prompt_messages) == 6
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
).format({**inputs, "#context#": context})
assert prompt_messages[5].content == query
def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
files = []
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=None,
files=files,
context=context,
memory=None,
model_config=model_config_mock
)
assert len(prompt_messages) == 3
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
).format({**inputs, "#context#": context})
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
files = [
FileObj(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
url="https://example.com/image1.jpg",
file_config={
"image": {
"detail": "high",
}
}
)
]
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=None,
files=files,
context=context,
memory=None,
model_config=model_config_mock
)
assert len(prompt_messages) == 4
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
).format({**inputs, "#context#": context})
assert isinstance(prompt_messages[3].content, list)
assert len(prompt_messages[3].content) == 2
assert prompt_messages[3].content[1].data == files[0].url
@pytest.fixture
def get_chat_model_args():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4'
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[
AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM),
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
)
)
inputs = {
"name": "John"
}
context = "I am superman."
return model_config_mock, prompt_template_entity, inputs, context
from unittest.mock import MagicMock
from core.entities.application_entities import ModelConfigEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
from core.prompt.simple_prompt_transform import SimplePromptTransform
from models.model import AppMode
from models.model import AppMode, Conversation
def test_get_common_chat_app_prompt_template_with_pcqm():
......@@ -141,7 +143,16 @@ def test__get_chat_model_prompt_messages():
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4'
memory_mock = MagicMock(spec=TokenBufferMemory)
history_prompt_messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Hello")
]
memory_mock.get_history_prompt_messages.return_value = history_prompt_messages
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {
"name": "John"
......@@ -154,7 +165,7 @@ def test__get_chat_model_prompt_messages():
query=query,
files=[],
context=context,
memory=None,
memory=memory_mock,
model_config=model_config_mock
)
......@@ -171,9 +182,11 @@ def test__get_chat_model_prompt_messages():
full_inputs = {**inputs, '#context#': context}
real_system_prompt = prompt_template['prompt_template'].format(full_inputs)
assert len(prompt_messages) == 2
assert len(prompt_messages) == 4
assert prompt_messages[0].content == real_system_prompt
assert prompt_messages[1].content == query
assert prompt_messages[1].content == history_prompt_messages[0].content
assert prompt_messages[2].content == history_prompt_messages[1].content
assert prompt_messages[3].content == query
def test__get_completion_model_prompt_messages():
......@@ -181,7 +194,19 @@ def test__get_completion_model_prompt_messages():
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-3.5-turbo-instruct'
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
history_prompt_messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Hello")
]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {
"name": "John"
......@@ -194,7 +219,7 @@ def test__get_completion_model_prompt_messages():
query=query,
files=[],
context=context,
memory=None,
memory=memory,
model_config=model_config_mock
)
......@@ -205,12 +230,17 @@ def test__get_completion_model_prompt_messages():
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
with_memory_prompt=True,
)
full_inputs = {**inputs, '#context#': context, '#query#': query}
prompt_rules = prompt_template['prompt_rules']
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
max_token_limit=2000,
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
)}
real_prompt = prompt_template['prompt_template'].format(full_inputs)
assert len(prompt_messages) == 1
assert stops == prompt_template['prompt_rules'].get('stops')
assert stops == prompt_rules.get('stops')
assert prompt_messages[0].content == real_prompt
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment