Unverified Commit fe14130b authored by Garfield Dai's avatar Garfield Dai Committed by GitHub

refactor advanced prompt core. (#1350)

parent 52ebffa8
...@@ -16,6 +16,7 @@ from core.model_providers.models.entity.message import PromptMessage ...@@ -16,6 +16,7 @@ from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.model import App, AppModelConfig, Account, Conversation, EndUser from models.model import App, AppModelConfig, Account, Conversation, EndUser
...@@ -156,24 +157,28 @@ class Completion: ...@@ -156,24 +157,28 @@ class Completion:
conversation_message_task: ConversationMessageTask, conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]): fake_response: Optional[str]):
prompt_transform = PromptTransform()
# get llm prompt # get llm prompt
if app_model_config.prompt_type == 'simple': if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = model_instance.get_prompt( prompt_messages, stop_words = prompt_transform.get_prompt(
mode=mode, mode=mode,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
inputs=inputs, inputs=inputs,
query=query, query=query,
context=agent_execute_result.output if agent_execute_result else None, context=agent_execute_result.output if agent_execute_result else None,
memory=memory memory=memory,
model_instance=model_instance
) )
else: else:
prompt_messages = model_instance.get_advanced_prompt( prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode, app_mode=mode,
app_model_config=app_model_config, app_model_config=app_model_config,
inputs=inputs, inputs=inputs,
query=query, query=query,
context=agent_execute_result.output if agent_execute_result else None, context=agent_execute_result.output if agent_execute_result else None,
memory=memory memory=memory,
model_instance=model_instance
) )
model_config = app_model_config.model_dict model_config = app_model_config.model_dict
...@@ -238,14 +243,29 @@ class Completion: ...@@ -238,14 +243,29 @@ class Completion:
if max_tokens is None: if max_tokens is None:
max_tokens = 0 max_tokens = 0
prompt_transform = PromptTransform()
prompt_messages = []
# get prompt without memory and context # get prompt without memory and context
prompt_messages, _ = model_instance.get_prompt( if app_model_config.prompt_type == 'simple':
prompt_messages, _ = prompt_transform.get_prompt(
mode=mode, mode=mode,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
inputs=inputs, inputs=inputs,
query=query, query=query,
context=None, context=None,
memory=None memory=None,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=None,
memory=None,
model_instance=model_instance
) )
prompt_tokens = model_instance.get_num_tokens(prompt_messages) prompt_tokens = model_instance.get_num_tokens(prompt_messages)
......
...@@ -37,12 +37,6 @@ class BaichuanModel(BaseLLM): ...@@ -37,12 +37,6 @@ class BaichuanModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks) return self._client.generate([prompts], stop, callbacks)
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def get_num_tokens(self, messages: List[PromptMessage]) -> int: def get_num_tokens(self, messages: List[PromptMessage]) -> int:
""" """
get num tokens of prompt messages. get num tokens of prompt messages.
......
This diff is collapsed.
...@@ -66,15 +66,6 @@ class HuggingfaceHubModel(BaseLLM): ...@@ -66,15 +66,6 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts) return self._client.get_num_tokens(prompts)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs self.client.model_kwargs = provider_model_kwargs
......
...@@ -49,15 +49,6 @@ class OpenLLMModel(BaseLLM): ...@@ -49,15 +49,6 @@ class OpenLLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass pass
......
...@@ -59,15 +59,6 @@ class XinferenceModel(BaseLLM): ...@@ -59,15 +59,6 @@ class XinferenceModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass pass
......
This diff is collapsed.
import copy import copy
from core.model_providers.models.entity.model_params import ModelMode
from core.prompt.prompt_transform import AppMode
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \ from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
...@@ -13,7 +15,7 @@ class AdvancedPromptTemplateService: ...@@ -13,7 +15,7 @@ class AdvancedPromptTemplateService:
model_name = args['model_name'] model_name = args['model_name']
has_context = args['has_context'] has_context = args['has_context']
if 'baichuan' in model_name: if 'baichuan' in model_name.lower():
return cls.get_baichuan_prompt(app_mode, model_mode, has_context) return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
else: else:
return cls.get_common_prompt(app_mode, model_mode, has_context) return cls.get_common_prompt(app_mode, model_mode, has_context)
...@@ -22,15 +24,15 @@ class AdvancedPromptTemplateService: ...@@ -22,15 +24,15 @@ class AdvancedPromptTemplateService:
def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
context_prompt = copy.deepcopy(CONTEXT) context_prompt = copy.deepcopy(CONTEXT)
if app_mode == 'chat': if app_mode == AppMode.CHAT.value:
if model_mode == 'completion': if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat': elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == 'completion': elif app_mode == AppMode.COMPLETION.value:
if model_mode == 'completion': if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat': elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
@classmethod @classmethod
...@@ -51,13 +53,13 @@ class AdvancedPromptTemplateService: ...@@ -51,13 +53,13 @@ class AdvancedPromptTemplateService:
def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == 'chat': if app_mode == AppMode.CHAT.value:
if model_mode == 'completion': if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat': elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif app_mode == 'completion': elif app_mode == AppMode.COMPLETION.value:
if model_mode == 'completion': if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat': elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
\ No newline at end of file
import re import re
import uuid import uuid
from core.prompt.prompt_transform import AppMode
from core.agent.agent_executor import PlanningStrategy from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType, ModelMode from core.model_providers.models.entity.model_params import ModelType, ModelMode
...@@ -418,7 +419,7 @@ class AppModelConfigService: ...@@ -418,7 +419,7 @@ class AppModelConfigService:
if config['model']["mode"] not in ['chat', 'completion']: if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced") raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value: if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
...@@ -427,3 +428,10 @@ class AppModelConfigService: ...@@ -427,3 +428,10 @@ class AppModelConfigService:
if not assistant_prefix: if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment