Unverified Commit fe14130b authored by Garfield Dai's avatar Garfield Dai Committed by GitHub

refactor advanced prompt core. (#1350)

parent 52ebffa8
......@@ -16,6 +16,7 @@ from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.model import App, AppModelConfig, Account, Conversation, EndUser
......@@ -156,24 +157,28 @@ class Completion:
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
prompt_transform = PromptTransform()
# get llm prompt
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = model_instance.get_prompt(
prompt_messages, stop_words = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
memory=memory,
model_instance=model_instance
)
else:
prompt_messages = model_instance.get_advanced_prompt(
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
memory=memory,
model_instance=model_instance
)
model_config = app_model_config.model_dict
......@@ -238,14 +243,29 @@ class Completion:
if max_tokens is None:
max_tokens = 0
prompt_transform = PromptTransform()
prompt_messages = []
# get prompt without memory and context
prompt_messages, _ = model_instance.get_prompt(
if app_model_config.prompt_type == 'simple':
prompt_messages, _ = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=None,
memory=None
memory=None,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=None,
memory=None,
model_instance=model_instance
)
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
......
......@@ -37,12 +37,6 @@ class BaichuanModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
......
import json
import os
import re
import time
from abc import abstractmethod
from typing import List, Optional, Any, Union, Tuple
from typing import List, Optional, Any, Union
import decimal
import logging
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from langchain.schema import LLMResult, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
to_lc_messages
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
from core.third_party.langchain.llms.fake import FakeLLM
import logging
from extensions.ext_database import db
logger = logging.getLogger(__name__)
......@@ -320,206 +310,8 @@ class BaseLLM(BaseProviderModel):
def support_streaming(self):
return False
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def get_advanced_prompt(self, app_mode: str,
app_model_config: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
model_mode = app_model_config.model_dict['mode']
conversation_histories_role = {}
raw_prompt_list = []
prompt_messages = []
if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
else:
raise Exception("app_mode or model_mode not support")
for prompt_item in raw_prompt_list:
prompt = prompt_item['text']
# set prompt template variables
prompt_template = PromptTemplateParser(template=prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if '#context#' in prompt:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''
if '#query#' in prompt:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''
if '#histories#' in prompt:
if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, 2000)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''
prompt = prompt_template.format(
prompt_inputs
)
prompt = re.sub(r'<\|.*?\|>', '', prompt)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, 2000)
prompt_messages.extend(histories)
if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
return prompt_messages
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
{'context': context}
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += pre_prompt_content
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=prompt + query_prompt,
inputs={
'query': query
}
)
if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format({'histories': histories})
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = PromptTemplateParser(template=query_prompt)
query_prompt_content = prompt_template.format({'query': query})
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
return prompt, stops
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
'prompt/generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> List[PromptMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory.return_messages = True
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
if not model_mode:
model_mode = self.model_mode
......
......@@ -66,15 +66,6 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs
......
......@@ -49,15 +49,6 @@ class OpenLLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
......
......@@ -59,15 +59,6 @@ class XinferenceModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
......
import json
import os
import re
import enum
from typing import List, Optional, Tuple
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMessage
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
class AppMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
class PromptTransform:
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance)
return [PromptMessage(content=prompt)], stops
def get_advanced_prompt(self,
app_mode: str,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
model_mode = app_model_config.model_dict['mode']
app_mode_enum = AppMode(app_mode)
model_mode_enum = ModelMode(model_mode)
prompt_messages = []
if app_mode_enum == AppMode.CHAT:
if model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
elif model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
elif app_mode_enum == AppMode.COMPLETION:
if model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
elif model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)
return prompt_messages
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> List[PromptMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory.return_messages = True
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])
def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
# baichuan
if isinstance(model_instance, BaichuanModel):
return self._prompt_file_name_for_baichuan(mode)
baichuan_model_hosted_platforms = (HuggingfaceHubModel, OpenLLMModel, XinferenceModel)
if isinstance(model_instance, baichuan_model_hosted_platforms) and 'baichuan' in model_instance.name.lower():
return self._prompt_file_name_for_baichuan(mode)
# common
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'
def _prompt_file_name_for_baichuan(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
'generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
{'context': context}
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += pre_prompt_content
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=prompt + query_prompt,
inputs={
'query': query
}
)
rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format({'histories': histories})
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = PromptTemplateParser(template=query_prompt)
query_prompt_content = prompt_template.format({'query': query})
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
return prompt, stops
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
if '#context#' in prompt_template.variable_keys:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
if '#query#' in prompt_template.variable_keys:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''
def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None:
if '#histories#' in prompt_template.variable_keys:
if memory:
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=raw_prompt,
inputs={ '#histories#': '', **prompt_inputs }
)
rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''
def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None:
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories)
def _calculate_rest_token(self, prompt_messages: BaseMessage, model_instance: BaseLLM) -> int:
rest_tokens = 2000
if model_instance.model_rules.max_tokens.max:
curr_message_tokens = model_instance.get_num_tokens(to_prompt_messages(prompt_messages))
max_tokens = model_instance.model_kwargs.max_tokens
rest_tokens = model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str:
prompt = prompt_template.format(
prompt_inputs
)
prompt = re.sub(r'<\|.*?\|>', '', prompt)
return prompt
def _get_chat_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
prompt_messages = []
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
self._set_query_variable(query, prompt_template, prompt_inputs)
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance)
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
return prompt_messages
def _get_chat_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
prompt_messages = []
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item['text']
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
self._append_chat_histories(memory, prompt_messages, model_instance)
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
return prompt_messages
def _get_completion_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
prompt_messages = []
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
return prompt_messages
def _get_completion_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
prompt_messages = []
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item['text']
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
return prompt_messages
\ No newline at end of file
import copy
from core.model_providers.models.entity.model_params import ModelMode
from core.prompt.prompt_transform import AppMode
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
......@@ -13,7 +15,7 @@ class AdvancedPromptTemplateService:
model_name = args['model_name']
has_context = args['has_context']
if 'baichuan' in model_name:
if 'baichuan' in model_name.lower():
return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
else:
return cls.get_common_prompt(app_mode, model_mode, has_context)
......@@ -22,15 +24,15 @@ class AdvancedPromptTemplateService:
def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == 'chat':
if model_mode == 'completion':
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == 'completion':
if model_mode == 'completion':
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
@classmethod
......@@ -51,13 +53,13 @@ class AdvancedPromptTemplateService:
def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == 'chat':
if model_mode == 'completion':
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif app_mode == 'completion':
if model_mode == 'completion':
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
\ No newline at end of file
import re
import uuid
from core.prompt.prompt_transform import AppMode
from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType, ModelMode
......@@ -418,7 +419,7 @@ class AppModelConfigService:
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
......@@ -427,3 +428,10 @@ class AppModelConfigService:
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment