Commit a44d3c3e authored by takatost's avatar takatost

fix bugs and add unit tests

parent 297b33aa
......@@ -133,7 +133,7 @@ class ModelPropertyKey(Enum):
DEFAULT_VOICE = "default_voice"
VOICES = "voices"
WORD_LIMIT = "word_limit"
AUDOI_TYPE = "audio_type"
AUDIO_TYPE = "audio_type"
MAX_WORKERS = "max_workers"
......
......@@ -94,8 +94,8 @@ class TTSModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.AUDOI_TYPE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.AUDOI_TYPE]
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""
......
......@@ -45,6 +45,7 @@ class SimplePromptTransform(PromptTransform):
"""
Simple Prompt Transform for Chatbot App Basic Mode.
"""
def get_prompt(self,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
......@@ -154,12 +155,12 @@ class SimplePromptTransform(PromptTransform):
}
def _get_chat_model_prompt_messages(self, pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
files: list[FileObj],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \
inputs: dict,
query: str,
context: Optional[str],
files: list[FileObj],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
prompt_messages = []
......@@ -169,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
model_config=model_config,
pre_prompt=pre_prompt,
inputs=inputs,
query=query,
query=None,
context=context
)
......@@ -187,12 +188,12 @@ class SimplePromptTransform(PromptTransform):
return prompt_messages, None
def _get_completion_model_prompt_messages(self, pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
files: list[FileObj],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \
inputs: dict,
query: str,
context: Optional[str],
files: list[FileObj],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules(
......@@ -259,7 +260,7 @@ class SimplePromptTransform(PromptTransform):
provider=provider,
model=model
)
# Check if the prompt file is already loaded
if prompt_file_name in prompt_file_contents:
return prompt_file_contents[prompt_file_name]
......@@ -267,14 +268,16 @@ class SimplePromptTransform(PromptTransform):
# Get the absolute path of the subdirectory
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
# Open the JSON file and read its content
with open(json_file_path, encoding='utf-8') as json_file:
content = json.load(json_file)
# Store the content of the prompt file
prompt_file_contents[prompt_file_name] = content
return content
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
# baichuan
is_baichuan = False
......
......@@ -5,7 +5,6 @@ from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db
from models.account import Account
from models.model import AppMode
class WorkflowType(Enum):
......@@ -29,13 +28,14 @@ class WorkflowType(Enum):
raise ValueError(f'invalid workflow type value {value}')
@classmethod
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> 'WorkflowType':
def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType':
"""
Get workflow type from app mode.
:param app_mode: app mode
:return: workflow type
"""
from models.model import AppMode
app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode)
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
......
.env.test
\ No newline at end of file
import os
# Getting the absolute path of the current file's directory
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
# Getting the absolute path of the project's root directory
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
from unittest.mock import MagicMock
from core.entities.application_entities import ModelConfigEntity
from core.entities.provider_configuration import ProviderModelBundle
from core.model_runtime.entities.message_entities import UserPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, AIModelEntity, ParameterRule
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform
def test__calculate_rest_token():
model_schema_mock = MagicMock(spec=AIModelEntity)
parameter_rule_mock = MagicMock(spec=ParameterRule)
parameter_rule_mock.name = 'max_tokens'
model_schema_mock.parameter_rules = [
parameter_rule_mock
]
model_schema_mock.model_properties = {
ModelPropertyKey.CONTEXT_SIZE: 62
}
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens.return_value = 6
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
provider_model_bundle_mock.model_type_instance = large_language_model_mock
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.model = 'gpt-4'
model_config_mock.credentials = {}
model_config_mock.parameters = {
'max_tokens': 50
}
model_config_mock.model_schema = model_schema_mock
model_config_mock.provider_model_bundle = provider_model_bundle_mock
prompt_transform = PromptTransform()
prompt_messages = [UserPromptMessage(content="Hello, how are you?")]
rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
# Validate based on the mock configuration and expected logic
expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
- model_config_mock.parameters['max_tokens']
- large_language_model_mock.get_num_tokens.return_value)
assert rest_tokens == expected_rest_tokens
assert rest_tokens == 6
from unittest.mock import MagicMock
from core.entities.application_entities import ModelConfigEntity
from core.prompt.simple_prompt_transform import SimplePromptTransform
from models.model import AppMode
def test_get_common_chat_app_prompt_template_with_pcqm():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['histories_prompt']
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
def test_get_baichuan_chat_app_prompt_template_with_pcqm():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="baichuan",
model="Baichuan2-53B",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['histories_prompt']
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
def test_get_common_completion_app_prompt_template_with_pcq():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.WORKFLOW,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
def test_get_baichuan_completion_app_prompt_template_with_pcq():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.WORKFLOW,
provider="baichuan",
model="Baichuan2-53B",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
print(prompt_template['prompt_template'].template)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
def test_get_common_chat_app_prompt_template_with_q():
prompt_transform = SimplePromptTransform()
pre_prompt = ""
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=False,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == prompt_rules['query_prompt']
assert prompt_template['special_variable_keys'] == ['#query#']
def test_get_common_chat_app_prompt_template_with_cq():
prompt_transform = SimplePromptTransform()
pre_prompt = ""
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
def test_get_common_chat_app_prompt_template_with_p():
prompt_transform = SimplePromptTransform()
pre_prompt = "you are {{name}}"
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=False,
query_in_prompt=False,
with_memory_prompt=False,
)
assert prompt_template['prompt_template'].template == pre_prompt + '\n'
assert prompt_template['custom_variable_keys'] == ['name']
assert prompt_template['special_variable_keys'] == []
def test__get_chat_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4'
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {
"name": "John"
}
context = "yes or no."
query = "How are you?"
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
pre_prompt=pre_prompt,
inputs=inputs,
query=query,
files=[],
context=context,
memory=None,
model_config=model_config_mock
)
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider=model_config_mock.provider,
model=model_config_mock.model,
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=False,
with_memory_prompt=False,
)
full_inputs = {**inputs, '#context#': context}
real_system_prompt = prompt_template['prompt_template'].format(full_inputs)
assert len(prompt_messages) == 2
assert prompt_messages[0].content == real_system_prompt
assert prompt_messages[1].content == query
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-3.5-turbo-instruct'
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {
"name": "John"
}
context = "yes or no."
query = "How are you?"
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
pre_prompt=pre_prompt,
inputs=inputs,
query=query,
files=[],
context=context,
memory=None,
model_config=model_config_mock
)
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider=model_config_mock.provider,
model=model_config_mock.model,
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
full_inputs = {**inputs, '#context#': context, '#query#': query}
real_prompt = prompt_template['prompt_template'].format(full_inputs)
assert len(prompt_messages) == 1
assert stops == prompt_template['prompt_rules'].get('stops')
assert prompt_messages[0].content == real_prompt
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment