Unverified Commit 42a5b3ec authored by Garfield Dai's avatar Garfield Dai Committed by GitHub

feat: advanced prompt backend (#1301)

Co-authored-by: 's avatartakatost <takatost@gmail.com>
parent 2d1cb076
......@@ -31,6 +31,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
......@@ -81,6 +82,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
......@@ -137,10 +139,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
pre_prompt="Please translate the following text into {{target_language}}:\n",
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
......@@ -169,6 +172,13 @@ demo_model_templates = {
'Italian',
]
}
},{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
......@@ -200,6 +210,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
......@@ -255,10 +266,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
......@@ -287,6 +299,13 @@ demo_model_templates = {
"意大利语",
]
}
},{
"paragraph": {
"label": "文本内容",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
......@@ -318,6 +337,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
......
......@@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import setup, version, apikey, admin
# Import app controllers
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate
......
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateList(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
parser.add_argument('model_name', type=str, required=True, location='args')
args = parser.parse_args()
service = AdvancedPromptTemplateService()
return service.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
\ No newline at end of file
......@@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
class IntroductionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('prompt_template', type=str, required=True, location='json')
args = parser.parse_args()
account = current_user
try:
answer = LLMGenerator.generate_introduction(
account.current_tenant_id,
args['prompt_template']
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
return {'introduction': answer}
class RuleGenerateApi(Resource):
@setup_required
@login_required
......@@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
return rules
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
api.add_resource(RuleGenerateApi, '/rule-generate')
......@@ -329,7 +329,7 @@ class MessageApi(Resource):
message_id = str(message_id)
# get app info
app_model = _get_app(app_id, 'chat')
app_model = _get_app(app_id)
message = db.session.query(Message).filter(
Message.id == message_id,
......
......@@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource):
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
......
import json
import logging
from typing import Optional, List, Union
......@@ -16,10 +15,8 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
from core.prompt.prompt_template import PromptTemplateParser
from models.model import App, AppModelConfig, Account, Conversation, EndUser
class Completion:
......@@ -30,7 +27,7 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
query = PromptBuilder.process_template(query)
query = PromptTemplateParser.remove_template_variables(query)
memory = None
if conversation:
......@@ -160,14 +157,28 @@ class Completion:
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
else:
prompt_messages = model_instance.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
model_config = app_model_config.model_dict
completion_params = model_config.get("completion_params", {})
stop_words = completion_params.get("stop", [])
cls.recale_llm_max_tokens(
model_instance=model_instance,
......@@ -176,7 +187,7 @@ class Completion:
response = model_instance.run(
messages=prompt_messages,
stop=stop_words,
stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
......@@ -266,52 +277,3 @@ class Completion:
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model_config=app_model_config.model_dict,
streaming=streaming
)
# get llm prompt
old_prompt_messages, _ = final_model_instance.get_prompt(
mode='completion',
pre_prompt=pre_prompt,
inputs=message.inputs,
query=message.query,
context=None,
memory=None
)
original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
prompt_messages = [PromptMessage(content=prompt)]
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
inputs=message.inputs,
query=message.query,
is_override=True if message.override_model_configs else False,
streaming=streaming,
model_instance=final_model_instance
)
cls.recale_llm_max_tokens(
model_instance=final_model_instance,
prompt_messages=prompt_messages
)
final_model_instance.run(
messages=prompt_messages,
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
)
......@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
......@@ -74,10 +74,10 @@ class ConversationMessageTask:
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try:
introduction = prompt_template.format(**prompt_inputs)
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
......@@ -150,12 +150,12 @@ class ConversationMessageTask:
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
......@@ -163,7 +163,7 @@ class ConversationMessageTask:
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(
self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
......@@ -226,15 +226,15 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
......
......@@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
class LLMGenerator:
......@@ -44,78 +43,19 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
max_context_token_length = max_context_token_length if max_context_token_length else 1500
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
if not message.answer:
continue
if len(message.query) > 2000:
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
else:
query = message.query
if len(message.answer) > 2000:
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
else:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
prompt = JinjaPromptTemplate(
template="{{histories}}\n{{format_instructions}}\nquestions:\n",
input_variables=["histories"],
partial_variables={"format_instructions": format_instructions}
prompt_template = PromptTemplateParser(
template="{{histories}}\n{{format_instructions}}\nquestions:\n"
)
_input = prompt.format_prompt(histories=histories)
prompt = prompt_template.format({
"histories": histories,
"format_instructions": format_instructions
})
try:
model_instance = ModelFactory.get_text_generation_model(
......@@ -128,10 +68,10 @@ class LLMGenerator:
except ProviderTokenNotInitError:
return []
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
questions = output_parser.parse(output.content)
except LLMError:
questions = []
......@@ -145,19 +85,21 @@ class LLMGenerator:
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
output_parser = RuleConfigGeneratorOutputParser()
prompt = OutLinePromptTemplate(
template=output_parser.get_format_instructions(),
input_variables=["audiences", "hoping_to_solve"],
partial_variables={
"variable": '{variable}',
"lanA": '{lanA}',
"lanB": '{lanB}',
"topic": '{topic}'
},
validate_template=False
prompt_template = PromptTemplateParser(
template=output_parser.get_format_instructions()
)
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
prompt = prompt_template.format(
inputs={
"audiences": audiences,
"hoping_to_solve": hoping_to_solve,
"variable": "{{variable}}",
"lanA": "{{lanA}}",
"lanB": "{{lanB}}",
"topic": "{{topic}}"
},
remove_template_variables=False
)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
......@@ -167,10 +109,10 @@ class LLMGenerator:
)
)
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
rule_config = output_parser.parse(output.content)
except LLMError as e:
raise e
......
......@@ -286,7 +286,7 @@ class IndexingRunner:
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
......@@ -383,7 +383,7 @@ class IndexingRunner:
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
......
......@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:
......
......@@ -13,13 +13,13 @@ class LLMRunResult(BaseModel):
class MessageType(enum.Enum):
HUMAN = 'human'
USER = 'user'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
type: MessageType = MessageType.USER
content: str = ''
function_call: dict = None
......@@ -27,7 +27,7 @@ class PromptMessage(BaseModel):
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
additional_kwargs = {}
......@@ -44,7 +44,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage):
message_kwargs = {
'content': message.content,
......@@ -58,7 +58,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
return prompt_messages
......
......@@ -18,7 +18,7 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from core.third_party.langchain.llms.fake import FakeLLM
import logging
......@@ -232,7 +232,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return:
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
......@@ -250,7 +250,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
......@@ -265,7 +265,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
......@@ -330,6 +330,85 @@ class BaseLLM(BaseProviderModel):
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def get_advanced_prompt(self, app_mode: str,
app_model_config: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
model_mode = app_model_config.model_dict['mode']
conversation_histories_role = {}
raw_prompt_list = []
prompt_messages = []
if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
else:
raise Exception("app_mode or model_mode not support")
for prompt_item in raw_prompt_list:
prompt = prompt_item['text']
# set prompt template variables
prompt_template = PromptTemplateParser(template=prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if '#context#' in prompt:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''
if '#query#' in prompt:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''
if '#histories#' in prompt:
if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, 2000)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''
prompt = prompt_template.format(
prompt_inputs
)
prompt = re.sub(r'<\|.*?\|>', '', prompt)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, 2000)
prompt_messages.extend(histories)
if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
return prompt_messages
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
......@@ -342,17 +421,17 @@ class BaseLLM(BaseProviderModel):
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
{'context': context}
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
prompt_inputs
)
prompt = ''
......@@ -385,10 +464,8 @@ class BaseLLM(BaseProviderModel):
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format({'histories': histories})
prompt = ''
for order in prompt_rules['system_prompt_orders']:
......@@ -399,10 +476,8 @@ class BaseLLM(BaseProviderModel):
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt_template = PromptTemplateParser(template=query_prompt)
query_prompt_content = prompt_template.format({'query': query})
prompt += query_prompt_content
......@@ -433,6 +508,16 @@ class BaseLLM(BaseProviderModel):
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> List[PromptMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory.return_messages = True
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode:
......
......@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType
......@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
},
{
'id': 'claude-2',
'name': 'claude-2',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
......@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -12,7 +12,7 @@ from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
......@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
}
credentials = json.loads(provider_model.encrypted_config)
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
if credentials['base_model_name'] in [
'gpt-4',
'gpt-4-32k',
......@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
return model_list
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name == 'text-davinci-003':
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
......@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
......@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
......@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
......@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]
......
......@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
......@@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider):
Returns the name of a provider.
"""
return 'baichuan'
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
......@@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider):
{
'id': 'baichuan2-53b',
'name': 'Baichuan2-53B',
'mode': ModelMode.CHAT.value,
}
]
else:
......
......@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
return [{
'id': provider_model.model_name,
'name': provider_model.model_name
} for provider_model in provider_models]
provider_model_list = []
for provider_model in provider_models:
provider_model_dict = {
'id': provider_model.model_name,
'name': provider_model.model_name
}
if model_type == ModelType.TEXT_GENERATION:
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
provider_model_list.append(provider_model_dict)
return provider_model_list
@abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
......@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
"""
raise NotImplementedError
@abstractmethod
def _get_text_generation_model_mode(self, model_name) -> str:
"""
get text generation model mode.
:param model_name:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_class(self, model_type: ModelType) -> Type:
"""
......
......@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
......@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -5,7 +5,7 @@ import requests
from huggingface_hub import HfApi
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
......@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
......@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
if credentials['completion_type'] == 'chat_completion':
return ModelMode.CHAT.value
else:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
......@@ -29,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
{
'id': 'abab5.5-chat',
'name': 'abab5.5-chat',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'abab5-chat',
'name': 'abab5-chat',
'mode': ModelMode.COMPLETION.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
......@@ -45,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
......@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
......@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
......@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
......@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
......@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]
......@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name in COMPLETION_MODELS:
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -3,7 +3,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
......@@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -6,7 +6,8 @@ import replicate
from replicate.exceptions import ReplicateError
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
ModelMode
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
......@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark
......@@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider):
{
'id': 'spark',
'name': 'Spark V1.5',
'mode': ModelMode.CHAT.value,
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
......@@ -26,15 +26,20 @@ class TongyiProvider(BaseModelProvider):
{
'id': 'qwen-turbo',
'name': 'qwen-turbo',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'qwen-plus',
'name': 'qwen-plus',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin
......@@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider):
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -6,7 +6,7 @@ from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
......@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
......@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
......@@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider):
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
'mode': ModelMode.CHAT.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
......@@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
......
import math
from typing import Optional
from langchain import WikipediaAPIWrapper
......@@ -50,6 +49,7 @@ class OrchestratorRuleParser:
tool_configs = agent_mode_config.get('tools', [])
agent_provider_name = model_dict.get('provider', 'openai')
agent_model_name = model_dict.get('name', 'gpt-4')
dataset_configs = self.app_model_config.dataset_configs_dict
agent_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
......@@ -96,13 +96,14 @@ class OrchestratorRuleParser:
summary_model_instance = None
tools = self.to_tools(
agent_model_instance=agent_model_instance,
tool_configs=tool_configs,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
agent_model_instance=agent_model_instance,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
return_resource=return_resource,
retriever_from=retriever_from
retriever_from=retriever_from,
dataset_configs=dataset_configs
)
if len(tools) == 0:
......@@ -170,20 +171,12 @@ class OrchestratorRuleParser:
return None
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
retriever_from: str = 'dev') -> list[BaseTool]:
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param agent_model_instance:
:param rest_tokens:
:param tool_configs: app agent tool configs
:param conversation_message_task:
:param callbacks:
:param return_resource:
:param retriever_from:
:return:
"""
tools = []
......@@ -195,15 +188,15 @@ class OrchestratorRuleParser:
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool(agent_model_instance)
tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
elif tool_type == "google_search":
tool = self.to_google_search_tool()
tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
elif tool_type == "wikipedia":
tool = self.to_wikipedia_tool()
tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
elif tool_type == "current_datetime":
tool = self.to_current_datetime_tool()
tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
if tool:
if tool.callbacks is not None:
......@@ -215,12 +208,15 @@ class OrchestratorRuleParser:
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
dataset_configs: dict, rest_tokens: int,
return_resource: bool = False, retriever_from: str = 'dev',
**kwargs) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param dataset_configs:
:param conversation_message_task:
:param return_resource:
:param retriever_from:
......@@ -238,10 +234,20 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
top_k = dataset_configs.get("top_k", 2)
# dynamically adjust top_k when the remaining token number is not enough to support top_k
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None
score_threshold_config = dataset_configs.get("score_threshold")
if score_threshold_config and score_threshold_config.get("enable"):
score_threshold = score_threshold_config.get("value")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
k=k,
top_k=top_k,
score_threshold=score_threshold,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
......@@ -250,7 +256,7 @@ class OrchestratorRuleParser:
return tool
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
"""
A tool for reading web pages
......@@ -278,7 +284,7 @@ class OrchestratorRuleParser:
return tool
def to_google_search_tool(self) -> Optional[BaseTool]:
def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
......@@ -296,12 +302,12 @@ class OrchestratorRuleParser:
return tool
def to_current_datetime_tool(self) -> Optional[BaseTool]:
def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool = DatetimeTool()
return tool
def to_wikipedia_tool(self) -> Optional[BaseTool]:
def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
......@@ -312,22 +318,18 @@ class OrchestratorRuleParser:
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MAX_K = 10
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
if rest_tokens == -1:
return DEFAULT_K
return top_k
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
return top_k
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
return top_k
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
......@@ -335,14 +337,7 @@ class OrchestratorRuleParser:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
if rest_tokens < segment_max_tokens * top_k:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
return min(context_limit_tokens // segment_max_tokens, MAX_K)
return min(top_k, 10)
CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n"
BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n"
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
},
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
}
}
}
CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "{{#pre_prompt#}}"
}]
}
}
COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "{{#pre_prompt#}}"
}]
}
}
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}"
}
}
}
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
},
"conversation_histories_role": {
"user_prefix": "用户",
"assistant_prefix": "助手"
}
}
}
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "{{#pre_prompt#}}"
}]
}
}
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "{{#pre_prompt#}}"
}]
}
}
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}"
}
}
}
import re
from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
from langchain.schema import BaseMessage
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
class PromptBuilder:
@classmethod
def parse_prompt(cls, prompt: str, inputs: dict) -> str:
prompt_template = PromptTemplateParser(prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt = prompt_template.format(prompt_inputs)
return prompt
@classmethod
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
system_message = system_prompt_template.format(**prompt_inputs)
return system_message
return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
ai_message = ai_prompt_template.format(**prompt_inputs)
return ai_message
return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
human_message = human_prompt_template.format(**inputs)
return human_message
@classmethod
def process_template(cls, template: str):
processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
# processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
# processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
return processed_template
return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))
import re
from typing import Any
from jinja2 import Environment, meta
from langchain import PromptTemplate
from langchain.formatting import StrictFormatter
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}")
class JinjaPromptTemplate(PromptTemplate):
template_format: str = "jinja2"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
class PromptTemplateParser:
"""
Rules:
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
env = Environment()
template = template.replace("{{}}", "{}")
ast = env.parse(template)
input_variables = meta.find_undeclared_variables(ast)
if "partial_variables" in kwargs:
partial_variables = kwargs["partial_variables"]
input_variables = {
var for var in input_variables if var not in partial_variables
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
class OutLinePromptTemplate(PromptTemplate):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
and can only start with letters and underscores.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
4. In addition to the above, 3 types of special template variable Keys are accepted:
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
"""
Args:
kwargs: Any arguments to be passed to the prompt template.
def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()
Returns:
A formatted string.
def extract(self) -> list:
# Regular expression to match the template rules
return re.findall(REGEX, self.template)
Example:
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
.. code-block:: python
if remove_template_variables:
return PromptTemplateParser.remove_template_variables(value)
return value
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return OneLineFormatter().format(self.template, **kwargs)
return re.sub(REGEX, replacer, self.template)
class OneLineFormatter(StrictFormatter):
def parse(self, format_string):
last_end = 0
results = []
for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
field_name = match.group(1)
start, end = match.span()
literal_text = format_string[last_end:start]
last_end = end
results.append((literal_text, field_name, '', None))
remaining_literal_text = format_string[last_end:]
if remaining_literal_text:
results.append((remaining_literal_text, None, None, None))
return results
@classmethod
def remove_template_variables(cls, text: str):
return re.sub(REGEX, r'{\1}', text)
......@@ -61,36 +61,6 @@ User Input: yo, 你今天咋样?
User Input:
"""
CONVERSATION_SUMMARY_PROMPT = (
"Please generate a short summary of the following conversation.\n"
"If the following conversation communicating in English, you should only return an English summary.\n"
"If the following conversation communicating in Chinese, you should only return a Chinese summary.\n"
"[Conversation Start]\n"
"{context}\n"
"[Conversation End]\n\n"
"summary:"
)
INTRODUCTION_GENERATE_PROMPT = (
"I am designing a product for users to interact with an AI through dialogue. "
"The Prompt given to the AI before the conversation is:\n\n"
"```\n{prompt}\n```\n\n"
"Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
"Do not reveal the developer's motivation or deep logic behind the Prompt, "
"but focus on building a relationship with the user:\n"
)
MORE_LIKE_THIS_GENERATE_PROMPT = (
"-----\n"
"{original_completion}\n"
"-----\n\n"
"Please use the above content as a sample for generating the result, "
"and include key information points related to the original sample in the result. "
"Try to rephrase this information in different ways and predict according to the rules below.\n\n"
"-----\n"
"{prompt}\n"
)
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
......@@ -157,10 +127,10 @@ and fill in variables, with a welcome sentence, and keep TLDR.
```
<< MY INTENDED AUDIENCES >>
{audiences}
{{audiences}}
<< HOPING TO SOLVE >>
{hoping_to_solve}
{{hoping_to_solve}}
<< OUTPUT >>
"""
\ No newline at end of file
import json
from typing import Type
from typing import Type, Optional
from flask import current_app
from langchain.tools import BaseTool
......@@ -28,7 +28,8 @@ class DatasetRetrieverTool(BaseTool):
tenant_id: str
dataset_id: str
k: int = 3
top_k: int = 2
score_threshold: Optional[float] = None
conversation_message_task: ConversationMessageTask
return_resource: bool
retriever_from: str
......@@ -66,7 +67,7 @@ class DatasetRetrieverTool(BaseTool):
)
)
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
return str("\n".join([document.page_content for document in documents]))
else:
......@@ -80,20 +81,21 @@ class DatasetRetrieverTool(BaseTool):
return ''
except ProviderTokenNotInitError:
return ''
embeddings = CacheEmbedding(embedding_model)
embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
if self.k > 0:
if self.top_k > 0:
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': self.k,
'k': self.top_k,
'score_threshold': self.score_threshold,
'filter': {
'group_id': [dataset.id]
}
......
......@@ -4,5 +4,4 @@ from .clean_when_document_deleted import handle
from .clean_when_dataset_deleted import handle
from .update_app_dataset_join_when_app_model_config_updated import handle
from .generate_conversation_name_when_first_message_created import handle
from .generate_conversation_summary_when_few_message_created import handle
from .create_document_index import handle
from events.message_event import message_was_created
from tasks.generate_conversation_summary_task import generate_conversation_summary_task
@message_was_created.connect
def handle(sender, **kwargs):
message = sender
conversation = kwargs.get('conversation')
is_first_message = kwargs.get('is_first_message')
if not is_first_message and conversation.mode == 'chat' and not conversation.summary:
history_message_count = conversation.message_count
if history_message_count >= 5:
generate_conversation_summary_task.delay(conversation.id)
......@@ -28,6 +28,10 @@ model_config_fields = {
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
'prompt_type': fields.String,
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
}
app_detail_fields = {
......
......@@ -123,6 +123,7 @@ conversation_with_summary_fields = {
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
'name': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
......
"""add advanced prompt templates
Revision ID: b3a09c049e8e
Revises: 2e9819ca5b28
Create Date: 2023-10-10 15:23:23.395420
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b3a09c049e8e'
down_revision = '2e9819ca5b28'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('dataset_configs')
batch_op.drop_column('completion_prompt_config')
batch_op.drop_column('chat_prompt_config')
batch_op.drop_column('prompt_type')
# ### end Alembic commands ###
......@@ -93,6 +93,10 @@ class AppModelConfig(db.Model):
agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text)
retriever_resource = db.Column(db.Text)
prompt_type = db.Column(db.String(255), nullable=False, default='simple')
chat_prompt_config = db.Column(db.Text)
completion_prompt_config = db.Column(db.Text)
dataset_configs = db.Column(db.Text)
@property
def app(self):
......@@ -139,6 +143,18 @@ class AppModelConfig(db.Model):
def agent_mode_dict(self) -> dict:
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []}
@property
def chat_prompt_config_dict(self) -> dict:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
@property
def completion_prompt_config_dict(self) -> dict:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
@property
def dataset_configs_dict(self) -> dict:
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
def to_dict(self) -> dict:
return {
"provider": "",
......@@ -155,7 +171,11 @@ class AppModelConfig(db.Model):
"user_input_form": self.user_input_form_list,
"dataset_query_variable": self.dataset_query_variable,
"pre_prompt": self.pre_prompt,
"agent_mode": self.agent_mode_dict
"agent_mode": self.agent_mode_dict,
"prompt_type": self.prompt_type,
"chat_prompt_config": self.chat_prompt_config_dict,
"completion_prompt_config": self.completion_prompt_config_dict,
"dataset_configs": self.dataset_configs_dict
}
def from_model_config_dict(self, model_config: dict):
......@@ -177,6 +197,13 @@ class AppModelConfig(db.Model):
self.agent_mode = json.dumps(model_config['agent_mode'])
self.retriever_resource = json.dumps(model_config['retriever_resource']) \
if model_config.get('retriever_resource') else None
self.prompt_type = model_config.get('prompt_type', 'simple')
self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \
if model_config.get('chat_prompt_config') else None
self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \
if model_config.get('completion_prompt_config') else None
self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
if model_config.get('dataset_configs') else None
return self
def copy(self):
......@@ -197,7 +224,11 @@ class AppModelConfig(db.Model):
dataset_query_variable=self.dataset_query_variable,
pre_prompt=self.pre_prompt,
agent_mode=self.agent_mode,
retriever_resource=self.retriever_resource
retriever_resource=self.retriever_resource,
prompt_type=self.prompt_type,
chat_prompt_config=self.chat_prompt_config,
completion_prompt_config=self.completion_prompt_config,
dataset_configs=self.dataset_configs
)
return new_app_model_config
......
import copy
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
class AdvancedPromptTemplateService:
def get_prompt(self, args: dict) -> dict:
app_mode = args['app_mode']
model_mode = args['model_mode']
model_name = args['model_name']
has_context = args['has_context']
if 'baichuan' in model_name:
return self.get_baichuan_prompt(app_mode, model_mode, has_context)
else:
return self.get_common_prompt(app_mode, model_mode, has_context)
def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
if app_mode == 'chat':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
elif app_mode == 'completion':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
if has_context == 'true':
prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
return prompt_template
def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
if has_context == 'true':
prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
return prompt_template
def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
if app_mode == 'chat':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
elif app_mode == 'completion':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
\ No newline at end of file
......@@ -3,7 +3,7 @@ import uuid
from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.models.entity.model_params import ModelType, ModelMode
from models.account import Account
from services.dataset_service import DatasetService
......@@ -34,40 +34,28 @@ class AppModelConfigService:
# max_tokens
if 'max_tokens' not in cp:
cp["max_tokens"] = 512
#
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
# llm_constant.max_context_token_length[model_name]:
# raise ValueError(
# "max_tokens must be an integer greater than 0 "
# "and not exceeding the maximum value of the corresponding model")
#
# temperature
if 'temperature' not in cp:
cp["temperature"] = 1
#
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
# raise ValueError("temperature must be a float between 0 and 2")
#
# top_p
if 'top_p' not in cp:
cp["top_p"] = 1
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
# raise ValueError("top_p must be a float between 0 and 2")
#
# presence_penalty
if 'presence_penalty' not in cp:
cp["presence_penalty"] = 0
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
# raise ValueError("presence_penalty must be a float between -2 and 2")
#
# presence_penalty
if 'frequency_penalty' not in cp:
cp["frequency_penalty"] = 0
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2:
# raise ValueError("frequency_penalty must be a float between -2 and 2")
# stop
if 'stop' not in cp:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")
# Filter out extra parameters
filtered_cp = {
......@@ -75,7 +63,8 @@ class AppModelConfigService:
"temperature": cp["temperature"],
"top_p": cp["top_p"],
"presence_penalty": cp["presence_penalty"],
"frequency_penalty": cp["frequency_penalty"]
"frequency_penalty": cp["frequency_penalty"],
"stop": cp["stop"]
}
return filtered_cp
......@@ -211,6 +200,10 @@ class AppModelConfigService:
model_ids = [m['id'] for m in model_list]
if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list")
# model.mode
if 'mode' not in config['model'] or not config['model']["mode"]:
config['model']["mode"] = ""
# model.completion_params
if 'completion_params' not in config["model"]:
......@@ -339,6 +332,9 @@ class AppModelConfigService:
# dataset_query_variable
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
# advanced prompt validation
AppModelConfigService.is_advanced_prompt_valid(config, mode)
# Filter out extra parameters
filtered_config = {
"opening_statement": config["opening_statement"],
......@@ -351,12 +347,17 @@ class AppModelConfigService:
"model": {
"provider": config["model"]["provider"],
"name": config["model"]["name"],
"mode": config['model']["mode"],
"completion_params": config["model"]["completion_params"]
},
"user_input_form": config["user_input_form"],
"dataset_query_variable": config.get('dataset_query_variable'),
"pre_prompt": config["pre_prompt"],
"agent_mode": config["agent_mode"]
"agent_mode": config["agent_mode"],
"prompt_type": config["prompt_type"],
"chat_prompt_config": config["chat_prompt_config"],
"completion_prompt_config": config["completion_prompt_config"],
"dataset_configs": config["dataset_configs"]
}
return filtered_config
......@@ -375,4 +376,51 @@ class AppModelConfigService:
if dataset_exists and not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist")
@staticmethod
def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
# prompt_type
if 'prompt_type' not in config or not config["prompt_type"]:
config["prompt_type"] = "simple"
if config['prompt_type'] not in ['simple', 'advanced']:
raise ValueError("prompt_type must be in ['simple', 'advanced']")
# chat_prompt_config
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
config["chat_prompt_config"] = {}
if not isinstance(config["chat_prompt_config"], dict):
raise ValueError("chat_prompt_config must be of object type")
# completion_prompt_config
if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
config["completion_prompt_config"] = {}
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
# dataset_configs
if 'dataset_configs' not in config or not config["dataset_configs"]:
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if config['prompt_type'] == 'advanced':
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
if not user_prefix:
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
......@@ -244,7 +244,8 @@ class CompletionService:
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
message_id: str, streaming: bool = True) -> Union[dict | Generator]:
message_id: str, streaming: bool = True,
retriever_from: str = 'dev') -> Union[dict | Generator]:
if not user:
raise ValueError('user cannot be None')
......@@ -266,14 +267,11 @@ class CompletionService:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
if message.override_model_configs:
override_model_configs = json.loads(message.override_model_configs)
pre_prompt = override_model_configs.get("pre_prompt", '')
elif app_model_config:
pre_prompt = app_model_config.pre_prompt
else:
raise AppModelConfigBrokenError()
model_dict = app_model_config.model_dict
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
generate_task_id = str(uuid.uuid4())
......@@ -282,58 +280,28 @@ class CompletionService:
user = cls.get_real_user_instead_of_proxy_obj(user)
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'detached_app_model': app_model,
'app_model_config': app_model_config,
'detached_message': message,
'pre_prompt': pre_prompt,
'query': message.query,
'inputs': message.inputs,
'detached_user': user,
'streaming': streaming
'detached_conversation': None,
'streaming': streaming,
'is_model_config_override': True,
'retriever_from': retriever_from
})
generate_worker_thread.start()
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
# wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
return cls.compact_response(pubsub, streaming)
@classmethod
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
detached_user: Union[Account, EndUser], streaming: bool):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
app_model = db.session.merge(detached_app_model)
message = db.session.merge(detached_message)
try:
# run
Completion.generate_more_like_this(
task_id=generate_task_id,
app=app_model,
user=user,
message=message,
pre_prompt=pre_prompt,
app_model_config=app_model_config,
streaming=streaming
)
except ConversationTaskStoppedException:
pass
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
finally:
db.session.commit()
@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
if user_inputs is None:
......
......@@ -482,6 +482,9 @@ class ProviderService:
'features': []
}
if 'mode' in model:
valid_model_dict['model_mode'] = model['mode']
if 'features' in model:
valid_model_dict['features'] = model['features']
......
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.generator.llm_generator import LLMGenerator
from core.model_providers.error import LLMError, ProviderTokenNotInitError
from extensions.ext_database import db
from models.model import Conversation, Message
@shared_task(queue='generation')
def generate_conversation_summary_task(conversation_id: str):
"""
Async Generate conversation summary
:param conversation_id:
Usage: generate_conversation_summary_task.delay(conversation_id)
"""
logging.info(click.style('Start generate conversation summary: {}'.format(conversation_id), fg='green'))
start_at = time.perf_counter()
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
raise NotFound('Conversation not found')
try:
# get conversation messages count
history_message_count = conversation.message_count
if history_message_count >= 5 and not conversation.summary:
app_model = conversation.app
if not app_model:
return
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
.order_by(Message.created_at.asc()).all()
conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages)
db.session.add(conversation)
db.session.commit()
except (LLMError, ProviderTokenNotInitError):
conversation.summary = '[No Summary]'
db.session.commit()
pass
except Exception as e:
conversation.summary = '[No Summary]'
db.session.commit()
logging.exception(e)
end_at = time.perf_counter()
logging.info(
click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at),
fg='green'))
......@@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('claude-2')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 6
......
......@@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 22
......
......@@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('baichuan2-53b')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst > 0
......@@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker):
model = get_mock_model('baichuan2-53b')
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
......@@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
model = get_mock_model('baichuan2-53b', streaming=True)
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages
......
......@@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
......@@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
mocker
)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
......
......@@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('abab5.5-chat')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
......
......@@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo')
rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 22
......
......@@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('facebook/opt-125m', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
......
......@@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 7
......
......@@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('spark')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 6
......
......@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('qwen-turbo')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
......
......@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt):
model = get_mock_model('ernie-bot')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
......
......@@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('llama-2-chat', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst == 5
......
......@@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('chatglm_lite')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
])
assert rst > 0
......@@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker):
model = get_mock_model('chatglm_lite')
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
......@@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
model = get_mock_model('chatglm_lite', streaming=True)
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages
......
from typing import Type
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.base import BaseModelProvider
......@@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider):
return 'fake'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [{'id': 'test_model', 'name': 'Test Model'}]
return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}]
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
return OpenAIModel
......
......@@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker):
provider = FakeModelProvider(provider=Provider())
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
assert result == [{'id': 'test_model', 'name': 'test_model'}]
assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}]
def test_check_quota_over_limit(mocker):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment