Unverified Commit 42a5b3ec authored by Garfield Dai's avatar Garfield Dai Committed by GitHub

feat: advanced prompt backend (#1301)

Co-authored-by: 's avatartakatost <takatost@gmail.com>
parent 2d1cb076
...@@ -31,6 +31,7 @@ model_templates = { ...@@ -31,6 +31,7 @@ model_templates = {
'model': json.dumps({ 'model': json.dumps({
"provider": "openai", "provider": "openai",
"name": "gpt-3.5-turbo-instruct", "name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": { "completion_params": {
"max_tokens": 512, "max_tokens": 512,
"temperature": 1, "temperature": 1,
...@@ -81,6 +82,7 @@ model_templates = { ...@@ -81,6 +82,7 @@ model_templates = {
'model': json.dumps({ 'model': json.dumps({
"provider": "openai", "provider": "openai",
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": { "completion_params": {
"max_tokens": 512, "max_tokens": 512,
"temperature": 1, "temperature": 1,
...@@ -137,10 +139,11 @@ demo_model_templates = { ...@@ -137,10 +139,11 @@ demo_model_templates = {
}, },
opening_statement='', opening_statement='',
suggested_questions=None, suggested_questions=None,
pre_prompt="Please translate the following text into {{target_language}}:\n", pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({ model=json.dumps({
"provider": "openai", "provider": "openai",
"name": "gpt-3.5-turbo-instruct", "name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": { "completion_params": {
"max_tokens": 1000, "max_tokens": 1000,
"temperature": 0, "temperature": 0,
...@@ -169,6 +172,13 @@ demo_model_templates = { ...@@ -169,6 +172,13 @@ demo_model_templates = {
'Italian', 'Italian',
] ]
} }
},{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
} }
]) ])
) )
...@@ -200,6 +210,7 @@ demo_model_templates = { ...@@ -200,6 +210,7 @@ demo_model_templates = {
model=json.dumps({ model=json.dumps({
"provider": "openai", "provider": "openai",
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": { "completion_params": {
"max_tokens": 300, "max_tokens": 300,
"temperature": 0.8, "temperature": 0.8,
...@@ -255,10 +266,11 @@ demo_model_templates = { ...@@ -255,10 +266,11 @@ demo_model_templates = {
}, },
opening_statement='', opening_statement='',
suggested_questions=None, suggested_questions=None,
pre_prompt="请将以下文本翻译为{{target_language}}:\n", pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
model=json.dumps({ model=json.dumps({
"provider": "openai", "provider": "openai",
"name": "gpt-3.5-turbo-instruct", "name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": { "completion_params": {
"max_tokens": 1000, "max_tokens": 1000,
"temperature": 0, "temperature": 0,
...@@ -287,6 +299,13 @@ demo_model_templates = { ...@@ -287,6 +299,13 @@ demo_model_templates = {
"意大利语", "意大利语",
] ]
} }
},{
"paragraph": {
"label": "文本内容",
"variable": "query",
"required": True,
"default": ""
}
} }
]) ])
) )
...@@ -318,6 +337,7 @@ demo_model_templates = { ...@@ -318,6 +337,7 @@ demo_model_templates = {
model=json.dumps({ model=json.dumps({
"provider": "openai", "provider": "openai",
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": { "completion_params": {
"max_tokens": 300, "max_tokens": 300,
"temperature": 0.8, "temperature": 0.8,
......
...@@ -9,7 +9,7 @@ api = ExternalApi(bp) ...@@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import setup, version, apikey, admin from . import setup, version, apikey, admin
# Import app controllers # Import app controllers
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
# Import auth controllers # Import auth controllers
from .auth import login, oauth, data_source_oauth, activate from .auth import login, oauth, data_source_oauth, activate
......
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateList(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
parser.add_argument('model_name', type=str, required=True, location='args')
args = parser.parse_args()
service = AdvancedPromptTemplateService()
return service.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
\ No newline at end of file
...@@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE ...@@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
class IntroductionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('prompt_template', type=str, required=True, location='json')
args = parser.parse_args()
account = current_user
try:
answer = LLMGenerator.generate_introduction(
account.current_tenant_id,
args['prompt_template']
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
return {'introduction': answer}
class RuleGenerateApi(Resource): class RuleGenerateApi(Resource):
@setup_required @setup_required
@login_required @login_required
...@@ -72,5 +43,4 @@ class RuleGenerateApi(Resource): ...@@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
return rules return rules
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
api.add_resource(RuleGenerateApi, '/rule-generate') api.add_resource(RuleGenerateApi, '/rule-generate')
...@@ -329,7 +329,7 @@ class MessageApi(Resource): ...@@ -329,7 +329,7 @@ class MessageApi(Resource):
message_id = str(message_id) message_id = str(message_id)
# get app info # get app info
app_model = _get_app(app_id, 'chat') app_model = _get_app(app_id)
message = db.session.query(Message).filter( message = db.session.query(Message).filter(
Message.id == message_id, Message.id == message_id,
......
...@@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource): ...@@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource):
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
try: try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming) response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
return compact_response(response) return compact_response(response)
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
......
import json
import logging import logging
from typing import Optional, List, Union from typing import Optional, List, Union
...@@ -16,10 +15,8 @@ from core.model_providers.model_factory import ModelFactory ...@@ -16,10 +15,8 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT from models.model import App, AppModelConfig, Account, Conversation, EndUser
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
class Completion: class Completion:
...@@ -30,7 +27,7 @@ class Completion: ...@@ -30,7 +27,7 @@ class Completion:
""" """
errors: ProviderTokenNotInitError errors: ProviderTokenNotInitError
""" """
query = PromptBuilder.process_template(query) query = PromptTemplateParser.remove_template_variables(query)
memory = None memory = None
if conversation: if conversation:
...@@ -160,14 +157,28 @@ class Completion: ...@@ -160,14 +157,28 @@ class Completion:
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]): fake_response: Optional[str]):
# get llm prompt # get llm prompt
prompt_messages, stop_words = model_instance.get_prompt( if app_model_config.prompt_type == 'simple':
mode=mode, prompt_messages, stop_words = model_instance.get_prompt(
pre_prompt=app_model_config.pre_prompt, mode=mode,
inputs=inputs, pre_prompt=app_model_config.pre_prompt,
query=query, inputs=inputs,
context=agent_execute_result.output if agent_execute_result else None, query=query,
memory=memory context=agent_execute_result.output if agent_execute_result else None,
) memory=memory
)
else:
prompt_messages = model_instance.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
model_config = app_model_config.model_dict
completion_params = model_config.get("completion_params", {})
stop_words = completion_params.get("stop", [])
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
model_instance=model_instance, model_instance=model_instance,
...@@ -176,7 +187,7 @@ class Completion: ...@@ -176,7 +187,7 @@ class Completion:
response = model_instance.run( response = model_instance.run(
messages=prompt_messages, messages=prompt_messages,
stop=stop_words, stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response fake_response=fake_response
) )
...@@ -266,52 +277,3 @@ class Completion: ...@@ -266,52 +277,3 @@ class Completion:
model_kwargs = model_instance.get_model_kwargs() model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs) model_instance.set_model_kwargs(model_kwargs)
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model_config=app_model_config.model_dict,
streaming=streaming
)
# get llm prompt
old_prompt_messages, _ = final_model_instance.get_prompt(
mode='completion',
pre_prompt=pre_prompt,
inputs=message.inputs,
query=message.query,
context=None,
memory=None
)
original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
prompt_messages = [PromptMessage(content=prompt)]
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
inputs=message.inputs,
query=message.query,
is_override=True if message.override_model_configs else False,
streaming=streaming,
model_instance=final_model_instance
)
cls.recale_llm_max_tokens(
model_instance=final_model_instance,
prompt_messages=prompt_messages
)
final_model_instance.run(
messages=prompt_messages,
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
)
...@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory ...@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
...@@ -74,10 +74,10 @@ class ConversationMessageTask: ...@@ -74,10 +74,10 @@ class ConversationMessageTask:
if self.mode == 'chat': if self.mode == 'chat':
introduction = self.app_model_config.opening_statement introduction = self.app_model_config.opening_statement
if introduction: if introduction:
prompt_template = JinjaPromptTemplate.from_template(template=introduction) prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs} prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try: try:
introduction = prompt_template.format(**prompt_inputs) introduction = prompt_template.format(prompt_inputs)
except KeyError: except KeyError:
pass pass
...@@ -150,12 +150,12 @@ class ConversationMessageTask: ...@@ -150,12 +150,12 @@ class ConversationMessageTask:
message_tokens = llm_message.prompt_tokens message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN) message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT) answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price total_price = message_total_price + answer_total_price
...@@ -163,7 +163,7 @@ class ConversationMessageTask: ...@@ -163,7 +163,7 @@ class ConversationMessageTask:
self.message.message_tokens = message_tokens self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template( self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else '' llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price self.message.answer_unit_price = answer_unit_price
...@@ -226,15 +226,15 @@ class ConversationMessageTask: ...@@ -226,15 +226,15 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM, def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN) agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN) agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT) agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT) agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price loop_total_price = loop_message_total_price + loop_answer_total_price
......
...@@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs ...@@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
GENERATOR_QA_PROMPT
class LLMGenerator: class LLMGenerator:
...@@ -44,78 +43,19 @@ class LLMGenerator: ...@@ -44,78 +43,19 @@ class LLMGenerator:
return answer.strip() return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
max_context_token_length = max_context_token_length if max_context_token_length else 1500
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
if not message.answer:
continue
if len(message.query) > 2000:
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
else:
query = message.query
if len(message.answer) > 2000:
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
else:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod @classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
output_parser = SuggestedQuestionsAfterAnswerOutputParser() output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions() format_instructions = output_parser.get_format_instructions()
prompt = JinjaPromptTemplate( prompt_template = PromptTemplateParser(
template="{{histories}}\n{{format_instructions}}\nquestions:\n", template="{{histories}}\n{{format_instructions}}\nquestions:\n"
input_variables=["histories"],
partial_variables={"format_instructions": format_instructions}
) )
_input = prompt.format_prompt(histories=histories) prompt = prompt_template.format({
"histories": histories,
"format_instructions": format_instructions
})
try: try:
model_instance = ModelFactory.get_text_generation_model( model_instance = ModelFactory.get_text_generation_model(
...@@ -128,10 +68,10 @@ class LLMGenerator: ...@@ -128,10 +68,10 @@ class LLMGenerator:
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
return [] return []
prompts = [PromptMessage(content=_input.to_string())] prompt_messages = [PromptMessage(content=prompt)]
try: try:
output = model_instance.run(prompts) output = model_instance.run(prompt_messages)
questions = output_parser.parse(output.content) questions = output_parser.parse(output.content)
except LLMError: except LLMError:
questions = [] questions = []
...@@ -145,19 +85,21 @@ class LLMGenerator: ...@@ -145,19 +85,21 @@ class LLMGenerator:
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict: def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
output_parser = RuleConfigGeneratorOutputParser() output_parser = RuleConfigGeneratorOutputParser()
prompt = OutLinePromptTemplate( prompt_template = PromptTemplateParser(
template=output_parser.get_format_instructions(), template=output_parser.get_format_instructions()
input_variables=["audiences", "hoping_to_solve"],
partial_variables={
"variable": '{variable}',
"lanA": '{lanA}',
"lanB": '{lanB}',
"topic": '{topic}'
},
validate_template=False
) )
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) prompt = prompt_template.format(
inputs={
"audiences": audiences,
"hoping_to_solve": hoping_to_solve,
"variable": "{{variable}}",
"lanA": "{{lanA}}",
"lanB": "{{lanB}}",
"topic": "{{topic}}"
},
remove_template_variables=False
)
model_instance = ModelFactory.get_text_generation_model( model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id, tenant_id=tenant_id,
...@@ -167,10 +109,10 @@ class LLMGenerator: ...@@ -167,10 +109,10 @@ class LLMGenerator:
) )
) )
prompts = [PromptMessage(content=_input.to_string())] prompt_messages = [PromptMessage(content=prompt)]
try: try:
output = model_instance.run(prompts) output = model_instance.run(prompt_messages)
rule_config = output_parser.parse(output.content) rule_config = output_parser.parse(output.content)
except LLMError as e: except LLMError as e:
raise e raise e
......
...@@ -286,7 +286,7 @@ class IndexingRunner: ...@@ -286,7 +286,7 @@ class IndexingRunner:
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
...@@ -383,7 +383,7 @@ class IndexingRunner: ...@@ -383,7 +383,7 @@ class IndexingRunner:
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
......
...@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): ...@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
chat_messages: List[PromptMessage] = [] chat_messages: List[PromptMessage] = []
for message in messages: for message in messages:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN)) chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT)) chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages: if not chat_messages:
......
...@@ -13,13 +13,13 @@ class LLMRunResult(BaseModel): ...@@ -13,13 +13,13 @@ class LLMRunResult(BaseModel):
class MessageType(enum.Enum): class MessageType(enum.Enum):
HUMAN = 'human' USER = 'user'
ASSISTANT = 'assistant' ASSISTANT = 'assistant'
SYSTEM = 'system' SYSTEM = 'system'
class PromptMessage(BaseModel): class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN type: MessageType = MessageType.USER
content: str = '' content: str = ''
function_call: dict = None function_call: dict = None
...@@ -27,7 +27,7 @@ class PromptMessage(BaseModel): ...@@ -27,7 +27,7 @@ class PromptMessage(BaseModel):
def to_lc_messages(messages: list[PromptMessage]): def to_lc_messages(messages: list[PromptMessage]):
lc_messages = [] lc_messages = []
for message in messages: for message in messages:
if message.type == MessageType.HUMAN: if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content)) lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT: elif message.type == MessageType.ASSISTANT:
additional_kwargs = {} additional_kwargs = {}
...@@ -44,7 +44,7 @@ def to_prompt_messages(messages: list[BaseMessage]): ...@@ -44,7 +44,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = [] prompt_messages = []
for message in messages: for message in messages:
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_kwargs = { message_kwargs = {
'content': message.content, 'content': message.content,
...@@ -58,7 +58,7 @@ def to_prompt_messages(messages: list[BaseMessage]): ...@@ -58,7 +58,7 @@ def to_prompt_messages(messages: list[BaseMessage]):
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage): elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
return prompt_messages return prompt_messages
......
...@@ -18,7 +18,7 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp ...@@ -18,7 +18,7 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompt_template import PromptTemplateParser
from core.third_party.langchain.llms.fake import FakeLLM from core.third_party.langchain.llms.fake import FakeLLM
import logging import logging
...@@ -232,7 +232,7 @@ class BaseLLM(BaseProviderModel): ...@@ -232,7 +232,7 @@ class BaseLLM(BaseProviderModel):
:param message_type: :param message_type:
:return: :return:
""" """
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt'] unit_price = self.price_config['prompt']
else: else:
unit_price = self.price_config['completion'] unit_price = self.price_config['completion']
...@@ -250,7 +250,7 @@ class BaseLLM(BaseProviderModel): ...@@ -250,7 +250,7 @@ class BaseLLM(BaseProviderModel):
:param message_type: :param message_type:
:return: decimal.Decimal('0.0001') :return: decimal.Decimal('0.0001')
""" """
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt'] unit_price = self.price_config['prompt']
else: else:
unit_price = self.price_config['completion'] unit_price = self.price_config['completion']
...@@ -265,7 +265,7 @@ class BaseLLM(BaseProviderModel): ...@@ -265,7 +265,7 @@ class BaseLLM(BaseProviderModel):
:param message_type: :param message_type:
:return: decimal.Decimal('0.000001') :return: decimal.Decimal('0.000001')
""" """
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit'] price_unit = self.price_config['unit']
else: else:
price_unit = self.price_config['unit'] price_unit = self.price_config['unit']
...@@ -330,6 +330,85 @@ class BaseLLM(BaseProviderModel): ...@@ -330,6 +330,85 @@ class BaseLLM(BaseProviderModel):
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory) prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops return [PromptMessage(content=prompt)], stops
def get_advanced_prompt(self, app_mode: str,
app_model_config: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
model_mode = app_model_config.model_dict['mode']
conversation_histories_role = {}
raw_prompt_list = []
prompt_messages = []
if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
else:
raise Exception("app_mode or model_mode not support")
for prompt_item in raw_prompt_list:
prompt = prompt_item['text']
# set prompt template variables
prompt_template = PromptTemplateParser(template=prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if '#context#' in prompt:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''
if '#query#' in prompt:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''
if '#histories#' in prompt:
if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, 2000)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''
prompt = prompt_template.format(
prompt_inputs
)
prompt = re.sub(r'<\|.*?\|>', '', prompt)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, 2000)
prompt_messages.extend(histories)
if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
return prompt_messages
def prompt_file_name(self, mode: str) -> str: def prompt_file_name(self, mode: str) -> str:
if mode == 'completion': if mode == 'completion':
return 'common_completion' return 'common_completion'
...@@ -342,17 +421,17 @@ class BaseLLM(BaseProviderModel): ...@@ -342,17 +421,17 @@ class BaseLLM(BaseProviderModel):
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]: memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = '' context_prompt_content = ''
if context and 'context_prompt' in prompt_rules: if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt']) prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format( context_prompt_content = prompt_template.format(
context=context {'context': context}
) )
pre_prompt_content = '' pre_prompt_content = ''
if pre_prompt: if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt) prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs} prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format( pre_prompt_content = prompt_template.format(
**prompt_inputs prompt_inputs
) )
prompt = '' prompt = ''
...@@ -385,10 +464,8 @@ class BaseLLM(BaseProviderModel): ...@@ -385,10 +464,8 @@ class BaseLLM(BaseProviderModel):
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens) histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt']) prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format( histories_prompt_content = prompt_template.format({'histories': histories})
histories=histories
)
prompt = '' prompt = ''
for order in prompt_rules['system_prompt_orders']: for order in prompt_rules['system_prompt_orders']:
...@@ -399,10 +476,8 @@ class BaseLLM(BaseProviderModel): ...@@ -399,10 +476,8 @@ class BaseLLM(BaseProviderModel):
elif order == 'histories_prompt': elif order == 'histories_prompt':
prompt += histories_prompt_content prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt) prompt_template = PromptTemplateParser(template=query_prompt)
query_prompt_content = prompt_template.format( query_prompt_content = prompt_template.format({'query': query})
query=query
)
prompt += query_prompt_content prompt += query_prompt_content
...@@ -433,6 +508,16 @@ class BaseLLM(BaseProviderModel): ...@@ -433,6 +508,16 @@ class BaseLLM(BaseProviderModel):
external_context = memory.load_memory_variables({}) external_context = memory.load_memory_variables({})
return external_context[memory_key] return external_context[memory_key]
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> List[PromptMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory.return_messages = True
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])
def _get_prompt_from_messages(self, messages: List[PromptMessage], def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]: model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode: if not model_mode:
......
...@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage ...@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType from core.model_providers.models.llm.base import ModelType
...@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider): ...@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
{ {
'id': 'claude-instant-1', 'id': 'claude-instant-1',
'name': 'claude-instant-1', 'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'claude-2', 'id': 'claude-2',
'name': 'claude-2', 'name': 'claude-2',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider): ...@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -12,7 +12,7 @@ from core.helper import encrypter ...@@ -12,7 +12,7 @@ from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \ from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION AZURE_OPENAI_API_VERSION
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
} }
credentials = json.loads(provider_model.encrypted_config) credentials = json.loads(provider_model.encrypted_config)
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
if credentials['base_model_name'] in [ if credentials['base_model_name'] in [
'gpt-4', 'gpt-4',
'gpt-4-32k', 'gpt-4-32k',
...@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
return model_list return model_list
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name == 'text-davinci-003':
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION: if model_type == ModelType.TEXT_GENERATION:
models = [ models = [
{ {
'id': 'gpt-3.5-turbo', 'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo', 'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-3.5-turbo-16k', 'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k', 'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-4', 'id': 'gpt-4',
'name': 'gpt-4', 'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-4-32k', 'id': 'gpt-4-32k',
'name': 'gpt-4-32k', 'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider): ...@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{ {
'id': 'text-davinci-003', 'id': 'text-davinci-003',
'name': 'text-davinci-003', 'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
} }
] ]
......
...@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage ...@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.baichuan_model import BaichuanModel from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
...@@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider): ...@@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider):
Returns the name of a provider. Returns the name of a provider.
""" """
return 'baichuan' return 'baichuan'
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION: if model_type == ModelType.TEXT_GENERATION:
...@@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider): ...@@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider):
{ {
'id': 'baichuan2-53b', 'id': 'baichuan2-53b',
'name': 'Baichuan2-53B', 'name': 'Baichuan2-53B',
'mode': ModelMode.CHAT.value,
} }
] ]
else: else:
......
...@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC): ...@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
ProviderModel.is_valid == True ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all() ).order_by(ProviderModel.created_at.asc()).all()
return [{ provider_model_list = []
'id': provider_model.model_name, for provider_model in provider_models:
'name': provider_model.model_name provider_model_dict = {
} for provider_model in provider_models] 'id': provider_model.model_name,
'name': provider_model.model_name
}
if model_type == ModelType.TEXT_GENERATION:
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
provider_model_list.append(provider_model_dict)
return provider_model_list
@abstractmethod @abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
...@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC): ...@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def _get_text_generation_model_mode(self, model_name) -> str:
"""
get text generation model mode.
:param model_name:
:return:
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_model_class(self, model_type: ModelType) -> Type: def get_model_class(self, model_type: ModelType) -> Type:
""" """
......
...@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM ...@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.chatglm_model import ChatGLMModel from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType from models.provider import ProviderType
...@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider): ...@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
{ {
'id': 'chatglm2-6b', 'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B', 'name': 'ChatGLM2-6B',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'chatglm-6b', 'id': 'chatglm-6b',
'name': 'ChatGLM-6B', 'name': 'ChatGLM-6B',
'mode': ModelMode.COMPLETION.value,
} }
] ]
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -5,7 +5,7 @@ import requests ...@@ -5,7 +5,7 @@ import requests
from huggingface_hub import HfApi from huggingface_hub import HfApi
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider): ...@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage ...@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
from core.model_providers.models.llm.localai_model import LocalAIModel from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider): ...@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
if credentials['completion_type'] == 'chat_completion':
return ModelMode.CHAT.value
else:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage ...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.minimax_model import MinimaxModel from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
...@@ -29,10 +29,12 @@ class MinimaxProvider(BaseModelProvider): ...@@ -29,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
{ {
'id': 'abab5.5-chat', 'id': 'abab5.5-chat',
'name': 'abab5.5-chat', 'name': 'abab5.5-chat',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'abab5-chat', 'id': 'abab5-chat',
'name': 'abab5-chat', 'name': 'abab5-chat',
'mode': ModelMode.COMPLETION.value,
} }
] ]
elif model_type == ModelType.EMBEDDINGS: elif model_type == ModelType.EMBEDDINGS:
...@@ -45,6 +47,9 @@ class MinimaxProvider(BaseModelProvider): ...@@ -45,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature ...@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers from core.model_providers.providers.hosted import hosted_model_providers
...@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider): ...@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-3.5-turbo', 'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo', 'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider): ...@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-3.5-turbo-instruct', 'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct', 'name': 'GPT-3.5-Turbo-Instruct',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'gpt-3.5-turbo-16k', 'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k', 'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider): ...@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-4', 'id': 'gpt-4',
'name': 'gpt-4', 'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider): ...@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'gpt-4-32k', 'id': 'gpt-4-32k',
'name': 'gpt-4-32k', 'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [ 'features': [
ModelFeature.AGENT_THOUGHT.value ModelFeature.AGENT_THOUGHT.value
] ]
...@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider): ...@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
{ {
'id': 'text-davinci-003', 'id': 'text-davinci-003',
'name': 'text-davinci-003', 'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
} }
] ]
...@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider): ...@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name in COMPLETION_MODELS:
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -3,7 +3,7 @@ from typing import Type ...@@ -3,7 +3,7 @@ from typing import Type
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.openllm_model import OpenLLMModel from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider): ...@@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -6,7 +6,8 @@ import replicate ...@@ -6,7 +6,8 @@ import replicate
from replicate.exceptions import ReplicateError from replicate.exceptions import ReplicateError
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
ModelMode
from core.model_providers.models.llm.replicate_model import ReplicateModel from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider): ...@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage ...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.spark_model import SparkModel from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark from core.third_party.langchain.llms.spark import ChatSpark
...@@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider): ...@@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider):
{ {
'id': 'spark', 'id': 'spark',
'name': 'Spark V1.5', 'name': 'Spark V1.5',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'spark-v2', 'id': 'spark-v2',
'name': 'Spark V2.0', 'name': 'Spark V2.0',
'mode': ModelMode.CHAT.value,
} }
] ]
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -4,7 +4,7 @@ from typing import Type ...@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.tongyi_model import TongyiModel from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
...@@ -26,15 +26,20 @@ class TongyiProvider(BaseModelProvider): ...@@ -26,15 +26,20 @@ class TongyiProvider(BaseModelProvider):
{ {
'id': 'qwen-turbo', 'id': 'qwen-turbo',
'name': 'qwen-turbo', 'name': 'qwen-turbo',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'qwen-plus', 'id': 'qwen-plus',
'name': 'qwen-plus', 'name': 'qwen-plus',
'mode': ModelMode.COMPLETION.value,
} }
] ]
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -4,7 +4,7 @@ from typing import Type ...@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.wenxin_model import WenxinModel from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin from core.third_party.langchain.llms.wenxin import Wenxin
...@@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider): ...@@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider):
{ {
'id': 'ernie-bot', 'id': 'ernie-bot',
'name': 'ERNIE-Bot', 'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'ernie-bot-turbo', 'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo', 'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
}, },
{ {
'id': 'bloomz-7b', 'id': 'bloomz-7b',
'name': 'BLOOMZ-7B', 'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
} }
] ]
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -6,7 +6,7 @@ from langchain.embeddings import XinferenceEmbeddings ...@@ -6,7 +6,7 @@ from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.xinference_model import XinferenceModel from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
...@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider): ...@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage ...@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
...@@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider): ...@@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider):
{ {
'id': 'chatglm_pro', 'id': 'chatglm_pro',
'name': 'chatglm_pro', 'name': 'chatglm_pro',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'chatglm_std', 'id': 'chatglm_std',
'name': 'chatglm_std', 'name': 'chatglm_std',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'chatglm_lite', 'id': 'chatglm_lite',
'name': 'chatglm_lite', 'name': 'chatglm_lite',
'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'chatglm_lite_32k', 'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k', 'name': 'chatglm_lite_32k',
'mode': ModelMode.CHAT.value,
} }
] ]
elif model_type == ModelType.EMBEDDINGS: elif model_type == ModelType.EMBEDDINGS:
...@@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider): ...@@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider):
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
Returns the model class. Returns the model class.
......
import math
from typing import Optional from typing import Optional
from langchain import WikipediaAPIWrapper from langchain import WikipediaAPIWrapper
...@@ -50,6 +49,7 @@ class OrchestratorRuleParser: ...@@ -50,6 +49,7 @@ class OrchestratorRuleParser:
tool_configs = agent_mode_config.get('tools', []) tool_configs = agent_mode_config.get('tools', [])
agent_provider_name = model_dict.get('provider', 'openai') agent_provider_name = model_dict.get('provider', 'openai')
agent_model_name = model_dict.get('name', 'gpt-4') agent_model_name = model_dict.get('name', 'gpt-4')
dataset_configs = self.app_model_config.dataset_configs_dict
agent_model_instance = ModelFactory.get_text_generation_model( agent_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
...@@ -96,13 +96,14 @@ class OrchestratorRuleParser: ...@@ -96,13 +96,14 @@ class OrchestratorRuleParser:
summary_model_instance = None summary_model_instance = None
tools = self.to_tools( tools = self.to_tools(
agent_model_instance=agent_model_instance,
tool_configs=tool_configs, tool_configs=tool_configs,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
agent_model_instance=agent_model_instance,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens, rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
return_resource=return_resource, return_resource=return_resource,
retriever_from=retriever_from retriever_from=retriever_from,
dataset_configs=dataset_configs
) )
if len(tools) == 0: if len(tools) == 0:
...@@ -170,20 +171,12 @@ class OrchestratorRuleParser: ...@@ -170,20 +171,12 @@ class OrchestratorRuleParser:
return None return None
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
retriever_from: str = 'dev') -> list[BaseTool]:
""" """
Convert app agent tool configs to tools Convert app agent tool configs to tools
:param agent_model_instance:
:param rest_tokens:
:param tool_configs: app agent tool configs :param tool_configs: app agent tool configs
:param conversation_message_task:
:param callbacks: :param callbacks:
:param return_resource:
:param retriever_from:
:return: :return:
""" """
tools = [] tools = []
...@@ -195,15 +188,15 @@ class OrchestratorRuleParser: ...@@ -195,15 +188,15 @@ class OrchestratorRuleParser:
tool = None tool = None
if tool_type == "dataset": if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from) tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
elif tool_type == "web_reader": elif tool_type == "web_reader":
tool = self.to_web_reader_tool(agent_model_instance) tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
elif tool_type == "google_search": elif tool_type == "google_search":
tool = self.to_google_search_tool() tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
elif tool_type == "wikipedia": elif tool_type == "wikipedia":
tool = self.to_wikipedia_tool() tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
elif tool_type == "current_datetime": elif tool_type == "current_datetime":
tool = self.to_current_datetime_tool() tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
if tool: if tool:
if tool.callbacks is not None: if tool.callbacks is not None:
...@@ -215,12 +208,15 @@ class OrchestratorRuleParser: ...@@ -215,12 +208,15 @@ class OrchestratorRuleParser:
return tools return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \ dataset_configs: dict, rest_tokens: int,
return_resource: bool = False, retriever_from: str = 'dev',
**kwargs) \
-> Optional[BaseTool]: -> Optional[BaseTool]:
""" """
A dataset tool is a tool that can be used to retrieve information from a dataset A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens: :param rest_tokens:
:param tool_config: :param tool_config:
:param dataset_configs:
:param conversation_message_task: :param conversation_message_task:
:param return_resource: :param return_resource:
:param retriever_from: :param retriever_from:
...@@ -238,10 +234,20 @@ class OrchestratorRuleParser: ...@@ -238,10 +234,20 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None return None
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens) top_k = dataset_configs.get("top_k", 2)
# dynamically adjust top_k when the remaining token number is not enough to support top_k
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None
score_threshold_config = dataset_configs.get("score_threshold")
if score_threshold_config and score_threshold_config.get("enable"):
score_threshold = score_threshold_config.get("value")
tool = DatasetRetrieverTool.from_dataset( tool = DatasetRetrieverTool.from_dataset(
dataset=dataset, dataset=dataset,
k=k, top_k=top_k,
score_threshold=score_threshold,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)], callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
return_resource=return_resource, return_resource=return_resource,
...@@ -250,7 +256,7 @@ class OrchestratorRuleParser: ...@@ -250,7 +256,7 @@ class OrchestratorRuleParser:
return tool return tool
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]: def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
""" """
A tool for reading web pages A tool for reading web pages
...@@ -278,7 +284,7 @@ class OrchestratorRuleParser: ...@@ -278,7 +284,7 @@ class OrchestratorRuleParser:
return tool return tool
def to_google_search_tool(self) -> Optional[BaseTool]: def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id) tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs() func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs: if not func_kwargs:
...@@ -296,12 +302,12 @@ class OrchestratorRuleParser: ...@@ -296,12 +302,12 @@ class OrchestratorRuleParser:
return tool return tool
def to_current_datetime_tool(self) -> Optional[BaseTool]: def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool = DatetimeTool() tool = DatetimeTool()
return tool return tool
def to_wikipedia_tool(self) -> Optional[BaseTool]: def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
class WikipediaInput(BaseModel): class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.") query: str = Field(..., description="search query.")
...@@ -312,22 +318,18 @@ class OrchestratorRuleParser: ...@@ -312,22 +318,18 @@ class OrchestratorRuleParser:
) )
@classmethod @classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int: def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MAX_K = 10
if rest_tokens == -1: if rest_tokens == -1:
return DEFAULT_K return top_k
processing_rule = dataset.latest_process_rule processing_rule = dataset.latest_process_rule
if not processing_rule: if not processing_rule:
return DEFAULT_K return top_k
if processing_rule.mode == "custom": if processing_rule.mode == "custom":
rules = processing_rule.rules_dict rules = processing_rule.rules_dict
if not rules: if not rules:
return DEFAULT_K return top_k
segmentation = rules["segmentation"] segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"] segment_max_tokens = segmentation["max_tokens"]
...@@ -335,14 +337,7 @@ class OrchestratorRuleParser: ...@@ -335,14 +337,7 @@ class OrchestratorRuleParser:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'] segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens # when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K: if rest_tokens < segment_max_tokens * top_k:
return rest_tokens // segment_max_tokens return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT) return min(top_k, 10)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
return min(context_limit_tokens // segment_max_tokens, MAX_K)
CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n"
BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n"
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
},
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
}
}
}
CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "{{#pre_prompt#}}"
}]
}
}
COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "{{#pre_prompt#}}"
}]
}
}
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}"
}
}
}
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
},
"conversation_histories_role": {
"user_prefix": "用户",
"assistant_prefix": "助手"
}
}
}
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "{{#pre_prompt#}}"
}]
}
}
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "{{#pre_prompt#}}"
}]
}
}
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}"
}
}
}
import re from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate from core.prompt.prompt_template import PromptTemplateParser
from langchain.schema import BaseMessage
from core.prompt.prompt_template import JinjaPromptTemplate
class PromptBuilder: class PromptBuilder:
@classmethod
def parse_prompt(cls, prompt: str, inputs: dict) -> str:
prompt_template = PromptTemplateParser(prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt = prompt_template.format(prompt_inputs)
return prompt
@classmethod @classmethod
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content) return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
system_message = system_prompt_template.format(**prompt_inputs)
return system_message
@classmethod @classmethod
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content) return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
ai_message = ai_prompt_template.format(**prompt_inputs)
return ai_message
@classmethod @classmethod
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content) return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
human_message = human_prompt_template.format(**inputs)
return human_message
@classmethod
def process_template(cls, template: str):
processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
# processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
# processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
return processed_template
import re import re
from typing import Any
from jinja2 import Environment, meta REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}")
from langchain import PromptTemplate
from langchain.formatting import StrictFormatter
class JinjaPromptTemplate(PromptTemplate): class PromptTemplateParser:
template_format: str = "jinja2" """
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" Rules:
@classmethod 1. Template variables must be enclosed in `{{}}`.
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: 2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
"""Load a prompt template from a template.""" and can only start with letters and underscores.
env = Environment() 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
template = template.replace("{{}}", "{}") 4. In addition to the above, 3 types of special template variable Keys are accepted:
ast = env.parse(template) `{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
input_variables = meta.find_undeclared_variables(ast) """
if "partial_variables" in kwargs:
partial_variables = kwargs["partial_variables"]
input_variables = {
var for var in input_variables if var not in partial_variables
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
class OutLinePromptTemplate(PromptTemplate):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args: def __init__(self, template: str):
kwargs: Any arguments to be passed to the prompt template. self.template = template
self.variable_keys = self.extract()
Returns: def extract(self) -> list:
A formatted string. # Regular expression to match the template rules
return re.findall(REGEX, self.template)
Example: def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
.. code-block:: python if remove_template_variables:
return PromptTemplateParser.remove_template_variables(value)
return value
prompt.format(variable1="foo") return re.sub(REGEX, replacer, self.template)
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return OneLineFormatter().format(self.template, **kwargs)
@classmethod
class OneLineFormatter(StrictFormatter): def remove_template_variables(cls, text: str):
def parse(self, format_string): return re.sub(REGEX, r'{\1}', text)
last_end = 0
results = []
for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
field_name = match.group(1)
start, end = match.span()
literal_text = format_string[last_end:start]
last_end = end
results.append((literal_text, field_name, '', None))
remaining_literal_text = format_string[last_end:]
if remaining_literal_text:
results.append((remaining_literal_text, None, None, None))
return results
...@@ -61,36 +61,6 @@ User Input: yo, 你今天咋样? ...@@ -61,36 +61,6 @@ User Input: yo, 你今天咋样?
User Input: User Input:
""" """
CONVERSATION_SUMMARY_PROMPT = (
"Please generate a short summary of the following conversation.\n"
"If the following conversation communicating in English, you should only return an English summary.\n"
"If the following conversation communicating in Chinese, you should only return a Chinese summary.\n"
"[Conversation Start]\n"
"{context}\n"
"[Conversation End]\n\n"
"summary:"
)
INTRODUCTION_GENERATE_PROMPT = (
"I am designing a product for users to interact with an AI through dialogue. "
"The Prompt given to the AI before the conversation is:\n\n"
"```\n{prompt}\n```\n\n"
"Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
"Do not reveal the developer's motivation or deep logic behind the Prompt, "
"but focus on building a relationship with the user:\n"
)
MORE_LIKE_THIS_GENERATE_PROMPT = (
"-----\n"
"{original_completion}\n"
"-----\n\n"
"Please use the above content as a sample for generating the result, "
"and include key information points related to the original sample in the result. "
"Try to rephrase this information in different ways and predict according to the rules below.\n\n"
"-----\n"
"{prompt}\n"
)
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, " "Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n" "and keeping each question under 20 characters.\n"
...@@ -157,10 +127,10 @@ and fill in variables, with a welcome sentence, and keep TLDR. ...@@ -157,10 +127,10 @@ and fill in variables, with a welcome sentence, and keep TLDR.
``` ```
<< MY INTENDED AUDIENCES >> << MY INTENDED AUDIENCES >>
{audiences} {{audiences}}
<< HOPING TO SOLVE >> << HOPING TO SOLVE >>
{hoping_to_solve} {{hoping_to_solve}}
<< OUTPUT >> << OUTPUT >>
""" """
\ No newline at end of file
import json import json
from typing import Type from typing import Type, Optional
from flask import current_app from flask import current_app
from langchain.tools import BaseTool from langchain.tools import BaseTool
...@@ -28,7 +28,8 @@ class DatasetRetrieverTool(BaseTool): ...@@ -28,7 +28,8 @@ class DatasetRetrieverTool(BaseTool):
tenant_id: str tenant_id: str
dataset_id: str dataset_id: str
k: int = 3 top_k: int = 2
score_threshold: Optional[float] = None
conversation_message_task: ConversationMessageTask conversation_message_task: ConversationMessageTask
return_resource: bool return_resource: bool
retriever_from: str retriever_from: str
...@@ -66,7 +67,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -66,7 +67,7 @@ class DatasetRetrieverTool(BaseTool):
) )
) )
documents = kw_table_index.search(query, search_kwargs={'k': self.k}) documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
return str("\n".join([document.page_content for document in documents])) return str("\n".join([document.page_content for document in documents]))
else: else:
...@@ -80,20 +81,21 @@ class DatasetRetrieverTool(BaseTool): ...@@ -80,20 +81,21 @@ class DatasetRetrieverTool(BaseTool):
return '' return ''
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
return '' return ''
embeddings = CacheEmbedding(embedding_model)
embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex( vector_index = VectorIndex(
dataset=dataset, dataset=dataset,
config=current_app.config, config=current_app.config,
embeddings=embeddings embeddings=embeddings
) )
if self.k > 0: if self.top_k > 0:
documents = vector_index.search( documents = vector_index.search(
query, query,
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
search_kwargs={ search_kwargs={
'k': self.k, 'k': self.top_k,
'score_threshold': self.score_threshold,
'filter': { 'filter': {
'group_id': [dataset.id] 'group_id': [dataset.id]
} }
......
...@@ -4,5 +4,4 @@ from .clean_when_document_deleted import handle ...@@ -4,5 +4,4 @@ from .clean_when_document_deleted import handle
from .clean_when_dataset_deleted import handle from .clean_when_dataset_deleted import handle
from .update_app_dataset_join_when_app_model_config_updated import handle from .update_app_dataset_join_when_app_model_config_updated import handle
from .generate_conversation_name_when_first_message_created import handle from .generate_conversation_name_when_first_message_created import handle
from .generate_conversation_summary_when_few_message_created import handle
from .create_document_index import handle from .create_document_index import handle
from events.message_event import message_was_created
from tasks.generate_conversation_summary_task import generate_conversation_summary_task
@message_was_created.connect
def handle(sender, **kwargs):
message = sender
conversation = kwargs.get('conversation')
is_first_message = kwargs.get('is_first_message')
if not is_first_message and conversation.mode == 'chat' and not conversation.summary:
history_message_count = conversation.message_count
if history_message_count >= 5:
generate_conversation_summary_task.delay(conversation.id)
...@@ -28,6 +28,10 @@ model_config_fields = { ...@@ -28,6 +28,10 @@ model_config_fields = {
'dataset_query_variable': fields.String, 'dataset_query_variable': fields.String,
'pre_prompt': fields.String, 'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'), 'agent_mode': fields.Raw(attribute='agent_mode_dict'),
'prompt_type': fields.String,
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
} }
app_detail_fields = { app_detail_fields = {
......
...@@ -123,6 +123,7 @@ conversation_with_summary_fields = { ...@@ -123,6 +123,7 @@ conversation_with_summary_fields = {
'from_end_user_id': fields.String, 'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String, 'from_end_user_session_id': fields.String,
'from_account_id': fields.String, 'from_account_id': fields.String,
'name': fields.String,
'summary': fields.String(attribute='summary_or_query'), 'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField, 'read_at': TimestampField,
'created_at': TimestampField, 'created_at': TimestampField,
......
"""add advanced prompt templates
Revision ID: b3a09c049e8e
Revises: 2e9819ca5b28
Create Date: 2023-10-10 15:23:23.395420
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b3a09c049e8e'
down_revision = '2e9819ca5b28'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('dataset_configs')
batch_op.drop_column('completion_prompt_config')
batch_op.drop_column('chat_prompt_config')
batch_op.drop_column('prompt_type')
# ### end Alembic commands ###
...@@ -93,6 +93,10 @@ class AppModelConfig(db.Model): ...@@ -93,6 +93,10 @@ class AppModelConfig(db.Model):
agent_mode = db.Column(db.Text) agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text) sensitive_word_avoidance = db.Column(db.Text)
retriever_resource = db.Column(db.Text) retriever_resource = db.Column(db.Text)
prompt_type = db.Column(db.String(255), nullable=False, default='simple')
chat_prompt_config = db.Column(db.Text)
completion_prompt_config = db.Column(db.Text)
dataset_configs = db.Column(db.Text)
@property @property
def app(self): def app(self):
...@@ -139,6 +143,18 @@ class AppModelConfig(db.Model): ...@@ -139,6 +143,18 @@ class AppModelConfig(db.Model):
def agent_mode_dict(self) -> dict: def agent_mode_dict(self) -> dict:
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []} return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []}
@property
def chat_prompt_config_dict(self) -> dict:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
@property
def completion_prompt_config_dict(self) -> dict:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
@property
def dataset_configs_dict(self) -> dict:
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
"provider": "", "provider": "",
...@@ -155,7 +171,11 @@ class AppModelConfig(db.Model): ...@@ -155,7 +171,11 @@ class AppModelConfig(db.Model):
"user_input_form": self.user_input_form_list, "user_input_form": self.user_input_form_list,
"dataset_query_variable": self.dataset_query_variable, "dataset_query_variable": self.dataset_query_variable,
"pre_prompt": self.pre_prompt, "pre_prompt": self.pre_prompt,
"agent_mode": self.agent_mode_dict "agent_mode": self.agent_mode_dict,
"prompt_type": self.prompt_type,
"chat_prompt_config": self.chat_prompt_config_dict,
"completion_prompt_config": self.completion_prompt_config_dict,
"dataset_configs": self.dataset_configs_dict
} }
def from_model_config_dict(self, model_config: dict): def from_model_config_dict(self, model_config: dict):
...@@ -177,6 +197,13 @@ class AppModelConfig(db.Model): ...@@ -177,6 +197,13 @@ class AppModelConfig(db.Model):
self.agent_mode = json.dumps(model_config['agent_mode']) self.agent_mode = json.dumps(model_config['agent_mode'])
self.retriever_resource = json.dumps(model_config['retriever_resource']) \ self.retriever_resource = json.dumps(model_config['retriever_resource']) \
if model_config.get('retriever_resource') else None if model_config.get('retriever_resource') else None
self.prompt_type = model_config.get('prompt_type', 'simple')
self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \
if model_config.get('chat_prompt_config') else None
self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \
if model_config.get('completion_prompt_config') else None
self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
if model_config.get('dataset_configs') else None
return self return self
def copy(self): def copy(self):
...@@ -197,7 +224,11 @@ class AppModelConfig(db.Model): ...@@ -197,7 +224,11 @@ class AppModelConfig(db.Model):
dataset_query_variable=self.dataset_query_variable, dataset_query_variable=self.dataset_query_variable,
pre_prompt=self.pre_prompt, pre_prompt=self.pre_prompt,
agent_mode=self.agent_mode, agent_mode=self.agent_mode,
retriever_resource=self.retriever_resource retriever_resource=self.retriever_resource,
prompt_type=self.prompt_type,
chat_prompt_config=self.chat_prompt_config,
completion_prompt_config=self.completion_prompt_config,
dataset_configs=self.dataset_configs
) )
return new_app_model_config return new_app_model_config
......
import copy
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
class AdvancedPromptTemplateService:
def get_prompt(self, args: dict) -> dict:
app_mode = args['app_mode']
model_mode = args['model_mode']
model_name = args['model_name']
has_context = args['has_context']
if 'baichuan' in model_name:
return self.get_baichuan_prompt(app_mode, model_mode, has_context)
else:
return self.get_common_prompt(app_mode, model_mode, has_context)
def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
if app_mode == 'chat':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
elif app_mode == 'completion':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
if has_context == 'true':
prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
return prompt_template
def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
if has_context == 'true':
prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
return prompt_template
def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
if app_mode == 'chat':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
elif app_mode == 'completion':
if model_mode == 'completion':
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
elif model_mode == 'chat':
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
\ No newline at end of file
...@@ -3,7 +3,7 @@ import uuid ...@@ -3,7 +3,7 @@ import uuid
from core.agent.agent_executor import PlanningStrategy from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType, ModelMode
from models.account import Account from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
...@@ -34,40 +34,28 @@ class AppModelConfigService: ...@@ -34,40 +34,28 @@ class AppModelConfigService:
# max_tokens # max_tokens
if 'max_tokens' not in cp: if 'max_tokens' not in cp:
cp["max_tokens"] = 512 cp["max_tokens"] = 512
#
# if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
# llm_constant.max_context_token_length[model_name]:
# raise ValueError(
# "max_tokens must be an integer greater than 0 "
# "and not exceeding the maximum value of the corresponding model")
#
# temperature # temperature
if 'temperature' not in cp: if 'temperature' not in cp:
cp["temperature"] = 1 cp["temperature"] = 1
#
# if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2:
# raise ValueError("temperature must be a float between 0 and 2")
#
# top_p # top_p
if 'top_p' not in cp: if 'top_p' not in cp:
cp["top_p"] = 1 cp["top_p"] = 1
# if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2:
# raise ValueError("top_p must be a float between 0 and 2")
#
# presence_penalty # presence_penalty
if 'presence_penalty' not in cp: if 'presence_penalty' not in cp:
cp["presence_penalty"] = 0 cp["presence_penalty"] = 0
# if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2:
# raise ValueError("presence_penalty must be a float between -2 and 2")
#
# presence_penalty # presence_penalty
if 'frequency_penalty' not in cp: if 'frequency_penalty' not in cp:
cp["frequency_penalty"] = 0 cp["frequency_penalty"] = 0
# if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2: # stop
# raise ValueError("frequency_penalty must be a float between -2 and 2") if 'stop' not in cp:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")
# Filter out extra parameters # Filter out extra parameters
filtered_cp = { filtered_cp = {
...@@ -75,7 +63,8 @@ class AppModelConfigService: ...@@ -75,7 +63,8 @@ class AppModelConfigService:
"temperature": cp["temperature"], "temperature": cp["temperature"],
"top_p": cp["top_p"], "top_p": cp["top_p"],
"presence_penalty": cp["presence_penalty"], "presence_penalty": cp["presence_penalty"],
"frequency_penalty": cp["frequency_penalty"] "frequency_penalty": cp["frequency_penalty"],
"stop": cp["stop"]
} }
return filtered_cp return filtered_cp
...@@ -211,6 +200,10 @@ class AppModelConfigService: ...@@ -211,6 +200,10 @@ class AppModelConfigService:
model_ids = [m['id'] for m in model_list] model_ids = [m['id'] for m in model_list]
if config["model"]["name"] not in model_ids: if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list") raise ValueError("model.name must be in the specified model list")
# model.mode
if 'mode' not in config['model'] or not config['model']["mode"]:
config['model']["mode"] = ""
# model.completion_params # model.completion_params
if 'completion_params' not in config["model"]: if 'completion_params' not in config["model"]:
...@@ -339,6 +332,9 @@ class AppModelConfigService: ...@@ -339,6 +332,9 @@ class AppModelConfigService:
# dataset_query_variable # dataset_query_variable
AppModelConfigService.is_dataset_query_variable_valid(config, mode) AppModelConfigService.is_dataset_query_variable_valid(config, mode)
# advanced prompt validation
AppModelConfigService.is_advanced_prompt_valid(config, mode)
# Filter out extra parameters # Filter out extra parameters
filtered_config = { filtered_config = {
"opening_statement": config["opening_statement"], "opening_statement": config["opening_statement"],
...@@ -351,12 +347,17 @@ class AppModelConfigService: ...@@ -351,12 +347,17 @@ class AppModelConfigService:
"model": { "model": {
"provider": config["model"]["provider"], "provider": config["model"]["provider"],
"name": config["model"]["name"], "name": config["model"]["name"],
"mode": config['model']["mode"],
"completion_params": config["model"]["completion_params"] "completion_params": config["model"]["completion_params"]
}, },
"user_input_form": config["user_input_form"], "user_input_form": config["user_input_form"],
"dataset_query_variable": config.get('dataset_query_variable'), "dataset_query_variable": config.get('dataset_query_variable'),
"pre_prompt": config["pre_prompt"], "pre_prompt": config["pre_prompt"],
"agent_mode": config["agent_mode"] "agent_mode": config["agent_mode"],
"prompt_type": config["prompt_type"],
"chat_prompt_config": config["chat_prompt_config"],
"completion_prompt_config": config["completion_prompt_config"],
"dataset_configs": config["dataset_configs"]
} }
return filtered_config return filtered_config
...@@ -375,4 +376,51 @@ class AppModelConfigService: ...@@ -375,4 +376,51 @@ class AppModelConfigService:
if dataset_exists and not dataset_query_variable: if dataset_exists and not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist") raise ValueError("Dataset query variable is required when dataset is exist")
@staticmethod
def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
# prompt_type
if 'prompt_type' not in config or not config["prompt_type"]:
config["prompt_type"] = "simple"
if config['prompt_type'] not in ['simple', 'advanced']:
raise ValueError("prompt_type must be in ['simple', 'advanced']")
# chat_prompt_config
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
config["chat_prompt_config"] = {}
if not isinstance(config["chat_prompt_config"], dict):
raise ValueError("chat_prompt_config must be of object type")
# completion_prompt_config
if 'completion_prompt_config' not in config or not config["completion_prompt_config"]:
config["completion_prompt_config"] = {}
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
# dataset_configs
if 'dataset_configs' not in config or not config["dataset_configs"]:
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if config['prompt_type'] == 'advanced':
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
if not user_prefix:
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
...@@ -244,7 +244,8 @@ class CompletionService: ...@@ -244,7 +244,8 @@ class CompletionService:
@classmethod @classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
message_id: str, streaming: bool = True) -> Union[dict | Generator]: message_id: str, streaming: bool = True,
retriever_from: str = 'dev') -> Union[dict | Generator]:
if not user: if not user:
raise ValueError('user cannot be None') raise ValueError('user cannot be None')
...@@ -266,14 +267,11 @@ class CompletionService: ...@@ -266,14 +267,11 @@ class CompletionService:
raise MoreLikeThisDisabledError() raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config app_model_config = message.app_model_config
model_dict = app_model_config.model_dict
if message.override_model_configs: completion_params = model_dict.get('completion_params')
override_model_configs = json.loads(message.override_model_configs) completion_params['temperature'] = 0.9
pre_prompt = override_model_configs.get("pre_prompt", '') model_dict['completion_params'] = completion_params
elif app_model_config: app_model_config.model = json.dumps(model_dict)
pre_prompt = app_model_config.pre_prompt
else:
raise AppModelConfigBrokenError()
generate_task_id = str(uuid.uuid4()) generate_task_id = str(uuid.uuid4())
...@@ -282,58 +280,28 @@ class CompletionService: ...@@ -282,58 +280,28 @@ class CompletionService:
user = cls.get_real_user_instead_of_proxy_obj(user) user = cls.get_real_user_instead_of_proxy_obj(user)
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={ generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id, 'generate_task_id': generate_task_id,
'detached_app_model': app_model, 'detached_app_model': app_model,
'app_model_config': app_model_config, 'app_model_config': app_model_config,
'detached_message': message, 'query': message.query,
'pre_prompt': pre_prompt, 'inputs': message.inputs,
'detached_user': user, 'detached_user': user,
'streaming': streaming 'detached_conversation': None,
'streaming': streaming,
'is_model_config_override': True,
'retriever_from': retriever_from
}) })
generate_worker_thread.start() generate_worker_thread.start()
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id) # wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
return cls.compact_response(pubsub, streaming) return cls.compact_response(pubsub, streaming)
@classmethod
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
detached_user: Union[Account, EndUser], streaming: bool):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
app_model = db.session.merge(detached_app_model)
message = db.session.merge(detached_message)
try:
# run
Completion.generate_more_like_this(
task_id=generate_task_id,
app=app_model,
user=user,
message=message,
pre_prompt=pre_prompt,
app_model_config=app_model_config,
streaming=streaming
)
except ConversationTaskStoppedException:
pass
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
finally:
db.session.commit()
@classmethod @classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
if user_inputs is None: if user_inputs is None:
......
...@@ -482,6 +482,9 @@ class ProviderService: ...@@ -482,6 +482,9 @@ class ProviderService:
'features': [] 'features': []
} }
if 'mode' in model:
valid_model_dict['model_mode'] = model['mode']
if 'features' in model: if 'features' in model:
valid_model_dict['features'] = model['features'] valid_model_dict['features'] = model['features']
......
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.generator.llm_generator import LLMGenerator
from core.model_providers.error import LLMError, ProviderTokenNotInitError
from extensions.ext_database import db
from models.model import Conversation, Message
@shared_task(queue='generation')
def generate_conversation_summary_task(conversation_id: str):
"""
Async Generate conversation summary
:param conversation_id:
Usage: generate_conversation_summary_task.delay(conversation_id)
"""
logging.info(click.style('Start generate conversation summary: {}'.format(conversation_id), fg='green'))
start_at = time.perf_counter()
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
raise NotFound('Conversation not found')
try:
# get conversation messages count
history_message_count = conversation.message_count
if history_message_count >= 5 and not conversation.summary:
app_model = conversation.app
if not app_model:
return
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
.order_by(Message.created_at.asc()).all()
conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages)
db.session.add(conversation)
db.session.commit()
except (LLMError, ProviderTokenNotInitError):
conversation.summary = '[No Summary]'
db.session.commit()
pass
except Exception as e:
conversation.summary = '[No Summary]'
db.session.commit()
logging.exception(e)
end_at = time.perf_counter()
logging.info(
click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at),
fg='green'))
...@@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): ...@@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt): def test_get_num_tokens(mock_decrypt):
model = get_mock_model('claude-2') model = get_mock_model('claude-2')
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 6 assert rst == 6
......
...@@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker): ...@@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker) openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
rst = openai_model.get_num_tokens([ rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 22 assert rst == 22
......
...@@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt): ...@@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('baichuan2-53b') model = get_mock_model('baichuan2-53b')
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst > 0 assert rst > 0
...@@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker): ...@@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker):
model = get_mock_model('baichuan2-53b') model = get_mock_model('baichuan2-53b')
messages = [ messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
] ]
rst = model.run( rst = model.run(
messages, messages,
...@@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker): ...@@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
model = get_mock_model('baichuan2-53b', streaming=True) model = get_mock_model('baichuan2-53b', streaming=True)
messages = [ messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
] ]
rst = model.run( rst = model.run(
messages messages
......
...@@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock ...@@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock
mocker mocker
) )
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 5 assert rst == 5
...@@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke ...@@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
mocker mocker
) )
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 5 assert rst == 5
......
...@@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): ...@@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt): def test_get_num_tokens(mock_decrypt):
model = get_mock_model('abab5.5-chat') model = get_mock_model('abab5.5-chat')
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 5 assert rst == 5
......
...@@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt): ...@@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-3.5-turbo') openai_model = get_mock_openai_model('gpt-3.5-turbo')
rst = openai_model.get_num_tokens([ rst = openai_model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 22 assert rst == 22
......
...@@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): ...@@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker): def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('facebook/opt-125m', mocker) model = get_mock_model('facebook/opt-125m', mocker)
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 5 assert rst == 5
......
...@@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): ...@@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker): def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker) model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 7 assert rst == 7
......
...@@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): ...@@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt): def test_get_num_tokens(mock_decrypt):
model = get_mock_model('spark') model = get_mock_model('spark')
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 6 assert rst == 6
......
...@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): ...@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt): def test_get_num_tokens(mock_decrypt):
model = get_mock_model('qwen-turbo') model = get_mock_model('qwen-turbo')
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 5 assert rst == 5
......
...@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): ...@@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt): def test_get_num_tokens(mock_decrypt):
model = get_mock_model('ernie-bot') model = get_mock_model('ernie-bot')
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 5 assert rst == 5
......
...@@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): ...@@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):
def test_get_num_tokens(mock_decrypt, mocker): def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('llama-2-chat', mocker) model = get_mock_model('llama-2-chat', mocker)
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst == 5 assert rst == 5
......
...@@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt): ...@@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('chatglm_lite') model = get_mock_model('chatglm_lite')
rst = model.get_num_tokens([ rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') PromptMessage(type=MessageType.USER, content='Who is your manufacturer?')
]) ])
assert rst > 0 assert rst > 0
...@@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker): ...@@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker):
model = get_mock_model('chatglm_lite') model = get_mock_model('chatglm_lite')
messages = [ messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
] ]
rst = model.run( rst = model.run(
messages, messages,
...@@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker): ...@@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker):
model = get_mock_model('chatglm_lite', streaming=True) model = get_mock_model('chatglm_lite', streaming=True)
messages = [ messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?')
] ]
rst = model.run( rst = model.run(
messages messages
......
from typing import Type from typing import Type
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
...@@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider): ...@@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider):
return 'fake' return 'fake'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return [{'id': 'test_model', 'name': 'Test Model'}] return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}]
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
return OpenAIModel return OpenAIModel
......
...@@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker): ...@@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker):
provider = FakeModelProvider(provider=Provider()) provider = FakeModelProvider(provider=Provider())
result = provider.get_supported_model_list(ModelType.TEXT_GENERATION) result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
assert result == [{'id': 'test_model', 'name': 'test_model'}] assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}]
def test_check_quota_over_limit(mocker): def test_check_quota_over_limit(mocker):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment