Commit 86d65f44 authored by takatost's avatar takatost

restore completion app

parent 10e6fe9b
...@@ -78,7 +78,7 @@ class AppListApi(Resource): ...@@ -78,7 +78,7 @@ class AppListApi(Resource):
"""Create app""" """Create app"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json') parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('mode', type=str, choices=[mode.value for mode in AppMode], location='json') parser.add_argument('mode', type=str, choices=['chat', 'agent', 'workflow'], location='json')
parser.add_argument('icon', type=str, location='json') parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json') parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('model_config', type=dict, location='json') parser.add_argument('model_config', type=dict, location='json')
......
...@@ -37,7 +37,7 @@ class CompletionMessageApi(Resource): ...@@ -37,7 +37,7 @@ class CompletionMessageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
...@@ -90,7 +90,7 @@ class CompletionMessageStopApi(Resource): ...@@ -90,7 +90,7 @@ class CompletionMessageStopApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model, task_id): def post(self, app_model, task_id):
account = flask_login.current_user account = flask_login.current_user
......
...@@ -29,7 +29,7 @@ class CompletionConversationApi(Resource): ...@@ -29,7 +29,7 @@ class CompletionConversationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields) @marshal_with(conversation_pagination_fields)
def get(self, app_model): def get(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
...@@ -102,7 +102,7 @@ class CompletionConversationDetailApi(Resource): ...@@ -102,7 +102,7 @@ class CompletionConversationDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields) @marshal_with(conversation_message_detail_fields)
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
......
...@@ -330,7 +330,7 @@ class AverageResponseTimeStatistic(Resource): ...@@ -330,7 +330,7 @@ class AverageResponseTimeStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW) @get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
......
...@@ -12,6 +12,7 @@ from werkzeug.exceptions import InternalServerError, NotFound ...@@ -12,6 +12,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ( from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError, ProviderNotInitializeError,
...@@ -23,10 +24,13 @@ from controllers.console.explore.error import ( ...@@ -23,10 +24,13 @@ from controllers.console.explore.error import (
NotCompletionAppError, NotCompletionAppError,
) )
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService from services.message_service import MessageService
...@@ -72,6 +76,48 @@ class MessageFeedbackApi(InstalledAppResource): ...@@ -72,6 +76,48 @@ class MessageFeedbackApi(InstalledAppResource):
return {'result': 'success'} return {'result': 'success'}
class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id):
app_model = installed_app.app
if app_model.mode != 'completion':
raise NotCompletionAppError()
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
def compact_response(response: Union[dict, Generator]) -> Response: def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict): if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json') return Response(response=json.dumps(response), status=200, mimetype='application/json')
...@@ -120,4 +166,5 @@ class MessageSuggestedQuestionApi(InstalledAppResource): ...@@ -120,4 +166,5 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages') api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback') api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question') api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
...@@ -11,6 +11,7 @@ from werkzeug.exceptions import InternalServerError, NotFound ...@@ -11,6 +11,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.web import api from controllers.web import api
from controllers.web.error import ( from controllers.web.error import (
AppMoreLikeThisDisabledError,
AppSuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError,
CompletionRequestError, CompletionRequestError,
NotChatAppError, NotChatAppError,
...@@ -20,11 +21,14 @@ from controllers.web.error import ( ...@@ -20,11 +21,14 @@ from controllers.web.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields from fields.message_fields import agent_thought_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService from services.message_service import MessageService
...@@ -109,6 +113,48 @@ class MessageFeedbackApi(WebApiResource): ...@@ -109,6 +113,48 @@ class MessageFeedbackApi(WebApiResource):
return {'result': 'success'} return {'result': 'success'}
class MessageMoreLikeThisApi(WebApiResource):
def get(self, app_model, end_user, message_id):
if app_model.mode != 'completion':
raise NotCompletionAppError()
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(
app_model=app_model,
user=end_user,
message_id=message_id,
invoke_from=InvokeFrom.WEB_APP,
streaming=streaming
)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
def compact_response(response: Union[dict, Generator]) -> Response: def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict): if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json') return Response(response=json.dumps(response), status=200, mimetype='application/json')
...@@ -156,4 +202,5 @@ class MessageSuggestedQuestionApi(WebApiResource): ...@@ -156,4 +202,5 @@ class MessageSuggestedQuestionApi(WebApiResource):
api.add_resource(MessageListApi, '/messages') api.add_resource(MessageListApi, '/messages')
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks') api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions') api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')
...@@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage, ...@@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform
from models.model import App, Message, MessageAnnotation from models.model import App, Message, MessageAnnotation, AppMode
class AppRunner: class AppRunner:
...@@ -140,11 +141,11 @@ class AppRunner: ...@@ -140,11 +141,11 @@ class AppRunner:
:param memory: memory :param memory: memory
:return: :return:
""" """
prompt_transform = SimplePromptTransform()
# get prompt without memory and context # get prompt without memory and context
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
prompt_transform = SimplePromptTransform()
prompt_messages, stop = prompt_transform.get_prompt( prompt_messages, stop = prompt_transform.get_prompt(
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity, prompt_template_entity=prompt_template_entity,
inputs=inputs, inputs=inputs,
query=query if query else '', query=query if query else '',
...@@ -154,7 +155,17 @@ class AppRunner: ...@@ -154,7 +155,17 @@ class AppRunner:
model_config=model_config model_config=model_config
) )
else: else:
raise NotImplementedError("Advanced prompt is not supported yet.") prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query if query else '',
files=files,
context=context,
memory=memory,
model_config=model_config
)
stop = model_config.stop
return prompt_messages, stop return prompt_messages, stop
......
...@@ -11,10 +11,9 @@ class PromptTransform: ...@@ -11,10 +11,9 @@ class PromptTransform:
def _append_chat_histories(self, memory: TokenBufferMemory, def _append_chat_histories(self, memory: TokenBufferMemory,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_config: ModelConfigEntity) -> list[PromptMessage]: model_config: ModelConfigEntity) -> list[PromptMessage]:
if memory: rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
histories = self._get_history_messages_list_from_memory(memory, rest_tokens) prompt_messages.extend(histories)
prompt_messages.extend(histories)
return prompt_messages return prompt_messages
......
...@@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform):
""" """
def get_prompt(self, def get_prompt(self,
app_mode: AppMode,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict, inputs: dict,
query: str, query: str,
...@@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform):
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages( prompt_messages, stops = self._get_chat_model_prompt_messages(
app_mode=app_mode,
pre_prompt=prompt_template_entity.simple_prompt_template, pre_prompt=prompt_template_entity.simple_prompt_template,
inputs=inputs, inputs=inputs,
query=query, query=query,
...@@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform):
) )
else: else:
prompt_messages, stops = self._get_completion_model_prompt_messages( prompt_messages, stops = self._get_completion_model_prompt_messages(
app_mode=app_mode,
pre_prompt=prompt_template_entity.simple_prompt_template, pre_prompt=prompt_template_entity.simple_prompt_template,
inputs=inputs, inputs=inputs,
query=query, query=query,
...@@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform): ...@@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform):
"prompt_rules": prompt_rules "prompt_rules": prompt_rules
} }
def _get_chat_model_prompt_messages(self, pre_prompt: str, def _get_chat_model_prompt_messages(self, app_mode: AppMode,
pre_prompt: str,
inputs: dict, inputs: dict,
query: str, query: str,
context: Optional[str], context: Optional[str],
...@@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
# get prompt # get prompt
prompt, _ = self.get_prompt_str_and_rules( prompt, _ = self.get_prompt_str_and_rules(
app_mode=AppMode.CHAT, app_mode=app_mode,
model_config=model_config, model_config=model_config,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
inputs=inputs, inputs=inputs,
...@@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform): ...@@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform):
) )
if prompt: if prompt:
prompt_messages.append(SystemPromptMessage(content=prompt)) if query:
prompt_messages.append(SystemPromptMessage(content=prompt))
else:
prompt_messages.append(UserPromptMessage(content=prompt))
prompt_messages = self._append_chat_histories( if memory:
memory=memory, prompt_messages = self._append_chat_histories(
prompt_messages=prompt_messages, memory=memory,
model_config=model_config prompt_messages=prompt_messages,
) model_config=model_config
)
prompt_messages.append(self.get_last_user_message(query, files)) if query:
prompt_messages.append(self.get_last_user_message(query, files))
return prompt_messages, None return prompt_messages, None
def _get_completion_model_prompt_messages(self, pre_prompt: str, def _get_completion_model_prompt_messages(self, app_mode: AppMode,
pre_prompt: str,
inputs: dict, inputs: dict,
query: str, query: str,
context: Optional[str], context: Optional[str],
...@@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform):
-> tuple[list[PromptMessage], Optional[list[str]]]: -> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt # get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules( prompt, prompt_rules = self.get_prompt_str_and_rules(
app_mode=AppMode.CHAT, app_mode=app_mode,
model_config=model_config, model_config=model_config,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
inputs=inputs, inputs=inputs,
...@@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform):
# get prompt # get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules( prompt, prompt_rules = self.get_prompt_str_and_rules(
app_mode=AppMode.CHAT, app_mode=app_mode,
model_config=model_config, model_config=model_config,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
inputs=inputs, inputs=inputs,
...@@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform): ...@@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform):
is_baichuan = True is_baichuan = True
if is_baichuan: if is_baichuan:
if app_mode == AppMode.WORKFLOW: if app_mode == AppMode.COMPLETION:
return 'baichuan_completion' return 'baichuan_completion'
else: else:
return 'baichuan_chat' return 'baichuan_chat'
# common # common
if app_mode == AppMode.WORKFLOW: if app_mode == AppMode.COMPLETION:
return 'common_completion' return 'common_completion'
else: else:
return 'common_chat' return 'common_chat'
...@@ -316,6 +316,9 @@ class AppModelConfigService: ...@@ -316,6 +316,9 @@ class AppModelConfigService:
if "tool_parameters" not in tool: if "tool_parameters" not in tool:
raise ValueError("tool_parameters is required in agent_mode.tools") raise ValueError("tool_parameters is required in agent_mode.tools")
# dataset_query_variable
cls.is_dataset_query_variable_valid(config, app_mode)
# advanced prompt validation # advanced prompt validation
cls.is_advanced_prompt_valid(config, app_mode) cls.is_advanced_prompt_valid(config, app_mode)
...@@ -441,6 +444,21 @@ class AppModelConfigService: ...@@ -441,6 +444,21 @@ class AppModelConfigService:
config=config config=config
) )
@classmethod
def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None:
# Only check when mode is completion
if mode != 'completion':
return
agent_mode = config.get("agent_mode", {})
tools = agent_mode.get("tools", [])
dataset_exists = "dataset" in str(tools)
dataset_query_variable = config.get("dataset_query_variable")
if dataset_exists and not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist")
@classmethod @classmethod
def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None: def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
# prompt_type # prompt_type
......
...@@ -8,10 +8,12 @@ from core.application_manager import ApplicationManager ...@@ -8,10 +8,12 @@ from core.application_manager import ApplicationManager
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.file.message_file_parser import MessageFileParser from core.file.message_file_parser import MessageFileParser
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Account, App, AppModelConfig, Conversation, EndUser from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
from services.errors.message import MessageNotExistsError
class CompletionService: class CompletionService:
...@@ -155,6 +157,62 @@ class CompletionService: ...@@ -155,6 +157,62 @@ class CompletionService:
} }
) )
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
-> Union[dict, Generator]:
if not user:
raise ValueError('user cannot be None')
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
model_dict = app_model_config.model_dict
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.transform_message_files(
message.files, app_model_config
)
application_manager = ApplicationManager()
return application_manager.generate(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_model_config_id=app_model_config.id,
app_model_config_dict=app_model_config.to_dict(),
app_model_config_override=True,
user=user,
invoke_from=invoke_from,
inputs=message.inputs,
query=message.query,
files=file_objs,
conversation=None,
stream=streaming,
extras={
"auto_generate_conversation_name": False
}
)
@classmethod @classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
if user_inputs is None: if user_inputs is None:
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
__all__ = [ __all__ = [
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
'completion', 'audio', 'file' 'app', 'completion', 'audio', 'file'
] ]
from . import * from . import *
class MoreLikeThisDisabledError(Exception):
pass
...@@ -160,6 +160,7 @@ def test__get_chat_model_prompt_messages(): ...@@ -160,6 +160,7 @@ def test__get_chat_model_prompt_messages():
context = "yes or no." context = "yes or no."
query = "How are you?" query = "How are you?"
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
inputs=inputs, inputs=inputs,
query=query, query=query,
...@@ -214,6 +215,7 @@ def test__get_completion_model_prompt_messages(): ...@@ -214,6 +215,7 @@ def test__get_completion_model_prompt_messages():
context = "yes or no." context = "yes or no."
query = "How are you?" query = "How are you?"
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
inputs=inputs, inputs=inputs,
query=query, query=query,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment