Commit 8186803b authored by takatost's avatar takatost

refactor app

parent baf2b7f3
...@@ -21,7 +21,7 @@ from controllers.console.app.error import ( ...@@ -21,7 +21,7 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
...@@ -94,7 +94,7 @@ class CompletionMessageStopApi(Resource): ...@@ -94,7 +94,7 @@ class CompletionMessageStopApi(Resource):
def post(self, app_model, task_id): def post(self, app_model, task_id):
account = flask_login.current_user account = flask_login.current_user
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
...@@ -172,7 +172,7 @@ class ChatMessageStopApi(Resource): ...@@ -172,7 +172,7 @@ class ChatMessageStopApi(Resource):
def post(self, app_model, task_id): def post(self, app_model, task_id):
account = flask_login.current_user account = flask_login.current_user
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
...@@ -11,7 +11,7 @@ from controllers.console.app.error import ( ...@@ -11,7 +11,7 @@ from controllers.console.app.error import (
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required from libs.login import login_required
......
...@@ -21,7 +21,7 @@ from controllers.console.app.error import ( ...@@ -21,7 +21,7 @@ from controllers.console.app.error import (
) )
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
...@@ -90,7 +90,7 @@ class CompletionStopApi(InstalledAppResource): ...@@ -90,7 +90,7 @@ class CompletionStopApi(InstalledAppResource):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
...@@ -154,7 +154,7 @@ class ChatStopApi(InstalledAppResource): ...@@ -154,7 +154,7 @@ class ChatStopApi(InstalledAppResource):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
...@@ -19,7 +19,7 @@ from controllers.service_api.app.error import ( ...@@ -19,7 +19,7 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
...@@ -85,7 +85,7 @@ class CompletionStopApi(Resource): ...@@ -85,7 +85,7 @@ class CompletionStopApi(Resource):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
...@@ -147,7 +147,7 @@ class ChatStopApi(Resource): ...@@ -147,7 +147,7 @@ class ChatStopApi(Resource):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
...@@ -20,7 +20,7 @@ from controllers.web.error import ( ...@@ -20,7 +20,7 @@ from controllers.web.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
...@@ -84,7 +84,7 @@ class CompletionStopApi(WebApiResource): ...@@ -84,7 +84,7 @@ class CompletionStopApi(WebApiResource):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
...@@ -144,7 +144,7 @@ class ChatStopApi(WebApiResource): ...@@ -144,7 +144,7 @@ class ChatStopApi(WebApiResource):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
...@@ -5,8 +5,8 @@ from datetime import datetime ...@@ -5,8 +5,8 @@ from datetime import datetime
from mimetypes import guess_extension from mimetypes import guess_extension
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.app_runner.app_runner import AppRunner from core.app.base_app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ( from core.entities.application_entities import (
...@@ -48,13 +48,13 @@ from models.tools import ToolConversationVariables ...@@ -48,13 +48,13 @@ from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseAssistantApplicationRunner(AppRunner): class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str, def __init__(self, tenant_id: str,
application_generate_entity: ApplicationGenerateEntity, application_generate_entity: ApplicationGenerateEntity,
app_orchestration_config: AppOrchestrationConfigEntity, app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity, model_config: ModelConfigEntity,
config: AgentEntity, config: AgentEntity,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
message: Message, message: Message,
user_id: str, user_id: str,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
......
...@@ -3,9 +3,9 @@ import re ...@@ -3,9 +3,9 @@ import re
from collections.abc import Generator from collections.abc import Generator
from typing import Literal, Union from typing import Literal, Union
from core.application_queue_manager import PublishFrom from core.app.app_queue_manager import PublishFrom
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
from core.features.assistant_base_runner import BaseAssistantApplicationRunner from core.agent.base_agent_runner import BaseAgentRunner
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -27,7 +27,7 @@ from core.tools.errors import ( ...@@ -27,7 +27,7 @@ from core.tools.errors import (
from models.model import Conversation, Message from models.model import Conversation, Message
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): class CotAgentRunner(BaseAgentRunner):
def run(self, conversation: Conversation, def run(self, conversation: Conversation,
message: Message, message: Message,
query: str, query: str,
......
...@@ -3,8 +3,8 @@ import logging ...@@ -3,8 +3,8 @@ import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Union from typing import Any, Union
from core.application_queue_manager import PublishFrom from core.app.app_queue_manager import PublishFrom
from core.features.assistant_base_runner import BaseAssistantApplicationRunner from core.agent.base_agent_runner import BaseAgentRunner
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -26,7 +26,7 @@ from models.model import Conversation, Message, MessageAgentThought ...@@ -26,7 +26,7 @@ from models.model import Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, conversation: Conversation, def run(self, conversation: Conversation,
message: Message, message: Message,
query: str, query: str,
......
from core.apps.config_validators.file_upload import FileUploadValidator from core.app.validators.file_upload import FileUploadValidator
from core.apps.config_validators.moderation import ModerationValidator from core.app.validators.moderation import ModerationValidator
from core.apps.config_validators.opening_statement import OpeningStatementValidator from core.app.validators.opening_statement import OpeningStatementValidator
from core.apps.config_validators.retriever_resource import RetrieverResourceValidator from core.app.validators.retriever_resource import RetrieverResourceValidator
from core.apps.config_validators.speech_to_text import SpeechToTextValidator from core.app.validators.speech_to_text import SpeechToTextValidator
from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator from core.app.validators.suggested_questions import SuggestedQuestionsValidator
from core.apps.config_validators.text_to_speech import TextToSpeechValidator from core.app.validators.text_to_speech import TextToSpeechValidator
class AdvancedChatAppConfigValidator: class AdvancedChatAppConfigValidator:
......
import logging import logging
from typing import cast from typing import cast
from core.app_runner.app_runner import AppRunner from core.app.base_app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity
from core.features.assistant_cot_runner import AssistantCotApplicationRunner from core.agent.cot_agent_runner import CotAgentRunner
from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
...@@ -19,12 +19,13 @@ from models.tools import ToolConversationVariables ...@@ -19,12 +19,13 @@ from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AssistantApplicationRunner(AppRunner):
class AgentChatAppRunner(AppRunner):
""" """
Assistant Application Runner Agent Application Runner
""" """
def run(self, application_generate_entity: ApplicationGenerateEntity, def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
""" """
...@@ -197,7 +198,7 @@ class AssistantApplicationRunner(AppRunner): ...@@ -197,7 +198,7 @@ class AssistantApplicationRunner(AppRunner):
# start agent runner # start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner( assistant_cot_runner = CotAgentRunner(
tenant_id=application_generate_entity.tenant_id, tenant_id=application_generate_entity.tenant_id,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
app_orchestration_config=app_orchestration_config, app_orchestration_config=app_orchestration_config,
...@@ -219,7 +220,7 @@ class AssistantApplicationRunner(AppRunner): ...@@ -219,7 +220,7 @@ class AssistantApplicationRunner(AppRunner):
inputs=inputs, inputs=inputs,
) )
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
assistant_fc_runner = AssistantFunctionCallApplicationRunner( assistant_fc_runner = FunctionCallAgentRunner(
tenant_id=application_generate_entity.tenant_id, tenant_id=application_generate_entity.tenant_id,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
app_orchestration_config=app_orchestration_config, app_orchestration_config=app_orchestration_config,
......
import uuid import uuid
from core.apps.config_validators.dataset import DatasetValidator
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.app.validators.dataset_retrieval import DatasetValidator
from core.app.validators.external_data_fetch import ExternalDataFetchValidator
from core.app.validators.file_upload import FileUploadValidator
from core.app.validators.model_validator import ModelValidator
from core.app.validators.moderation import ModerationValidator
from core.app.validators.opening_statement import OpeningStatementValidator
from core.app.validators.prompt import PromptValidator
from core.app.validators.retriever_resource import RetrieverResourceValidator
from core.app.validators.speech_to_text import SpeechToTextValidator
from core.app.validators.suggested_questions import SuggestedQuestionsValidator
from core.app.validators.text_to_speech import TextToSpeechValidator
from core.app.validators.user_input_form import UserInputFormValidator
from models.model import AppMode
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
class AgentValidator: class AgentChatAppConfigValidator:
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
"""
Validate for agent chat app model config
:param tenant_id: tenant id
:param config: app model config args
"""
app_mode = AppMode.AGENT_CHAT
related_config_keys = []
# model
config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# user_input_form
config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# external data tools validation
config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# file upload validation
config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# prompt
config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys)
# agent_mode
config, current_related_config_keys = cls.validate_agent_mode_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
""" """
Validate and set defaults for agent feature Validate agent_mode and set defaults for agent feature
:param tenant_id: tenant ID :param tenant_id: tenant ID
:param config: app model config args :param config: app model config args
...@@ -33,7 +113,8 @@ class AgentValidator: ...@@ -33,7 +113,8 @@ class AgentValidator:
if not config["agent_mode"].get("strategy"): if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: if config["agent_mode"]["strategy"] not in [member.value for member in
list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list") raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"): if not config["agent_mode"].get("tools"):
......
import json
import logging
import threading
import uuid
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from pydantic import ValidationError
from core.app.app_orchestration_config_converter import AppOrchestrationConfigConverter
from core.app.agent_chat.app_runner import AgentChatAppRunner
from core.app.chat.app_runner import ChatAppRunner
from core.app.generate_task_pipeline import GenerateTaskPipeline
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.entities.application_entities import (
ApplicationGenerateEntity,
InvokeFrom,
)
from core.file.file_obj import FileObj
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message, MessageFile
logger = logging.getLogger(__name__)
class AppManager:
"""
This class is responsible for managing application
"""
def generate(self, tenant_id: str,
app_id: str,
app_model_config_id: str,
app_model_config_dict: dict,
app_model_config_override: bool,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
inputs: dict[str, str],
query: Optional[str] = None,
files: Optional[list[FileObj]] = None,
conversation: Optional[Conversation] = None,
stream: bool = False,
extras: Optional[dict[str, Any]] = None) \
-> Union[dict, Generator]:
"""
Generate App response.
:param tenant_id: workspace ID
:param app_id: app ID
:param app_model_config_id: app model config id
:param app_model_config_dict: app model config dict
:param app_model_config_override: app model config override
:param user: account or end user
:param invoke_from: invoke from source
:param inputs: inputs
:param query: query
:param files: file obj list
:param conversation: conversation
:param stream: is stream
:param extras: extras
"""
# init task id
task_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = ApplicationGenerateEntity(
task_id=task_id,
tenant_id=tenant_id,
app_id=app_id,
app_model_config_id=app_model_config_id,
app_model_config_dict=app_model_config_dict,
app_orchestration_config_entity=AppOrchestrationConfigConverter.convert_from_app_model_config_dict(
tenant_id=tenant_id,
app_model_config_dict=app_model_config_dict
),
app_model_config_override=app_model_config_override,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else inputs,
query=query.replace('\x00', '') if query else None,
files=files if files else [],
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
if not stream and application_generate_entity.app_orchestration_config_entity.agent:
raise ValueError("Agent app is not supported in blocking mode.")
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = AppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: ApplicationGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if application_generate_entity.app_orchestration_config_entity.agent:
# agent app
runner = AgentChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
else:
# basic app
runner = ChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
stream: bool = False) -> Union[dict, Generator]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = GenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
try:
return generate_task_pipeline.process(stream=stream)
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise ConversationTaskStoppedException()
else:
logger.exception(e)
raise e
finally:
db.session.remove()
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
-> tuple[Conversation, Message]:
"""
Initialize generate records
:param application_generate_entity: application generate entity
:return:
"""
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_schema = model_type_instance.get_model_schema(
model=app_orchestration_config_entity.model_config.model,
credentials=app_orchestration_config_entity.model_config.credentials
)
app_record = (db.session.query(App)
.filter(App.id == application_generate_entity.app_id).first())
app_mode = app_record.mode
# get from source
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
from_source = 'api'
end_user_id = application_generate_entity.user_id
else:
from_source = 'console'
account_id = application_generate_entity.user_id
override_model_configs = None
if application_generate_entity.app_model_config_override:
override_model_configs = application_generate_entity.app_model_config_dict
introduction = ''
if app_mode == 'chat':
# get conversation introduction
introduction = self._get_conversation_introduction(application_generate_entity)
if not application_generate_entity.conversation_id:
conversation = Conversation(
app_id=app_record.id,
app_model_config_id=application_generate_entity.app_model_config_id,
model_provider=app_orchestration_config_entity.model_config.provider,
model_id=app_orchestration_config_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_mode,
name='New conversation',
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status='normal',
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
)
db.session.add(conversation)
db.session.commit()
else:
conversation = (
db.session.query(Conversation)
.filter(
Conversation.id == application_generate_entity.conversation_id,
Conversation.app_id == app_record.id
).first()
)
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
message = Message(
app_id=app_record.id,
model_provider=app_orchestration_config_entity.model_config.provider,
model_id=app_orchestration_config_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query or "",
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency=currency,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
agent_based=app_orchestration_config_entity.agent is not None
)
db.session.add(message)
db.session.commit()
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
belongs_to='user',
url=file.url,
upload_file_id=file.upload_file_id,
created_by_role=('account' if account_id else 'end_user'),
created_by=account_id or end_user_id,
)
db.session.add(message_file)
db.session.commit()
return conversation, message
def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
"""
Get conversation introduction
:param application_generate_entity: application generate entity
:return: conversation introduction
"""
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
introduction = app_orchestration_config_entity.opening_statement
if introduction:
try:
inputs = application_generate_entity.inputs
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
return introduction
def _get_conversation(self, conversation_id: str) -> Conversation:
"""
Get conversation by conversation id
:param conversation_id: conversation id
:return: conversation
"""
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
return conversation
def _get_message(self, message_id: str) -> Message:
"""
Get message by message id
:param message_id: message id
:return: message
"""
message = (
db.session.query(Message)
.filter(Message.id == message_id)
.first()
)
return message
import json from typing import cast
import logging
import threading from core.entities.application_entities import AppOrchestrationConfigEntity, SensitiveWordAvoidanceEntity, \
import uuid TextToSpeechEntity, DatasetRetrieveConfigEntity, DatasetEntity, AgentPromptEntity, AgentEntity, AgentToolEntity, \
from collections.abc import Generator ExternalDataVariableEntity, VariableEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, \
from typing import Any, Optional, Union, cast AdvancedChatPromptTemplateEntity, ModelConfigEntity, FileUploadEntity
from flask import Flask, current_app
from pydantic import ValidationError
from core.app_runner.assistant_app_runner import AssistantApplicationRunner
from core.app_runner.basic_app_runner import BasicApplicationRunner
from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
from core.entities.application_entities import (
AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity,
AgentEntity,
AgentPromptEntity,
AgentToolEntity,
ApplicationGenerateEntity,
AppOrchestrationConfigEntity,
DatasetEntity,
DatasetRetrieveConfigEntity,
ExternalDataVariableEntity,
FileUploadEntity,
InvokeFrom,
ModelConfigEntity,
PromptTemplateEntity,
SensitiveWordAvoidanceEntity,
TextToSpeechEntity,
VariableEntity,
)
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ProviderTokenNotInitError, ModelCurrentlyNotSupportError, QuotaExceededError
from core.file.file_obj import FileObj
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_template import PromptTemplateParser
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message, MessageFile
logger = logging.getLogger(__name__)
class ApplicationManager:
"""
This class is responsible for managing application
"""
def generate(self, tenant_id: str,
app_id: str,
app_model_config_id: str,
app_model_config_dict: dict,
app_model_config_override: bool,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
inputs: dict[str, str],
query: Optional[str] = None,
files: Optional[list[FileObj]] = None,
conversation: Optional[Conversation] = None,
stream: bool = False,
extras: Optional[dict[str, Any]] = None) \
-> Union[dict, Generator]:
"""
Generate App response.
:param tenant_id: workspace ID
:param app_id: app ID
:param app_model_config_id: app model config id
:param app_model_config_dict: app model config dict
:param app_model_config_override: app model config override
:param user: account or end user
:param invoke_from: invoke from source
:param inputs: inputs
:param query: query
:param files: file obj list
:param conversation: conversation
:param stream: is stream
:param extras: extras
"""
# init task id
task_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = ApplicationGenerateEntity(
task_id=task_id,
tenant_id=tenant_id,
app_id=app_id,
app_model_config_id=app_model_config_id,
app_model_config_dict=app_model_config_dict,
app_orchestration_config_entity=self.convert_from_app_model_config_dict(
tenant_id=tenant_id,
app_model_config_dict=app_model_config_dict
),
app_model_config_override=app_model_config_override,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else inputs,
query=query.replace('\x00', '') if query else None,
files=files if files else [],
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
if not stream and application_generate_entity.app_orchestration_config_entity.agent:
raise ValueError("Agent app is not supported in blocking mode.")
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = ApplicationQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if application_generate_entity.app_orchestration_config_entity.agent:
# agent app
runner = AssistantApplicationRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
else:
# basic app
runner = BasicApplicationRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message,
stream: bool = False) -> Union[dict, Generator]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = GenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
try:
return generate_task_pipeline.process(stream=stream)
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise ConversationTaskStoppedException()
else:
logger.exception(e)
raise e
finally:
db.session.remove()
def convert_from_app_model_config_dict(self, tenant_id: str, class AppOrchestrationConfigConverter:
@classmethod
def convert_from_app_model_config_dict(cls, tenant_id: str,
app_model_config_dict: dict, app_model_config_dict: dict,
skip_check: bool = False) \ skip_check: bool = False) \
-> AppOrchestrationConfigEntity: -> AppOrchestrationConfigEntity:
...@@ -396,7 +174,7 @@ class ApplicationManager: ...@@ -396,7 +174,7 @@ class ApplicationManager:
) )
properties['variables'] = [] properties['variables'] = []
# variables and external_data_tools # variables and external_data_tools
for variable in copy_app_model_config_dict.get('user_input_form', []): for variable in copy_app_model_config_dict.get('user_input_form', []):
typ = list(variable.keys())[0] typ = list(variable.keys())[0]
...@@ -446,7 +224,7 @@ class ApplicationManager: ...@@ -446,7 +224,7 @@ class ApplicationManager:
show_retrieve_source = True show_retrieve_source = True
properties['show_retrieve_source'] = show_retrieve_source properties['show_retrieve_source'] = show_retrieve_source
dataset_ids = [] dataset_ids = []
if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}): if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', { datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
...@@ -454,26 +232,23 @@ class ApplicationManager: ...@@ -454,26 +232,23 @@ class ApplicationManager:
'datasets': [] 'datasets': []
}) })
for dataset in datasets.get('datasets', []): for dataset in datasets.get('datasets', []):
keys = list(dataset.keys()) keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != 'dataset': if len(keys) == 0 or keys[0] != 'dataset':
continue continue
dataset = dataset['dataset'] dataset = dataset['dataset']
if 'enabled' not in dataset or not dataset['enabled']: if 'enabled' not in dataset or not dataset['enabled']:
continue continue
dataset_id = dataset.get('id', None) dataset_id = dataset.get('id', None)
if dataset_id: if dataset_id:
dataset_ids.append(dataset_id) dataset_ids.append(dataset_id)
else:
datasets = {'strategy': 'router', 'datasets': []}
if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
and 'enabled' in copy_app_model_config_dict['agent_mode'] \ and 'enabled' in copy_app_model_config_dict['agent_mode'] \
and copy_app_model_config_dict['agent_mode']['enabled']: and copy_app_model_config_dict['agent_mode']['enabled']:
agent_dict = copy_app_model_config_dict.get('agent_mode', {}) agent_dict = copy_app_model_config_dict.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot') agent_strategy = agent_dict.get('strategy', 'cot')
...@@ -517,7 +292,7 @@ class ApplicationManager: ...@@ -517,7 +292,7 @@ class ApplicationManager:
dataset_id = tool_item['id'] dataset_id = tool_item['id']
dataset_ids.append(dataset_id) dataset_ids.append(dataset_id)
if 'strategy' in copy_app_model_config_dict['agent_mode'] and \ if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']: copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
agent_prompt = agent_dict.get('prompt', None) or {} agent_prompt = agent_dict.get('prompt', None) or {}
...@@ -525,13 +300,18 @@ class ApplicationManager: ...@@ -525,13 +300,18 @@ class ApplicationManager:
model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion') model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
if model_mode == 'completion': if model_mode == 'completion':
agent_prompt_entity = AgentPromptEntity( agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), first_prompt=agent_prompt.get('first_prompt',
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']), REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['completion'][
'agent_scratchpad']),
) )
else: else:
agent_prompt_entity = AgentPromptEntity( agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), first_prompt=agent_prompt.get('first_prompt',
next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
) )
properties['agent'] = AgentEntity( properties['agent'] = AgentEntity(
...@@ -553,7 +333,7 @@ class ApplicationManager: ...@@ -553,7 +333,7 @@ class ApplicationManager:
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity( retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable, query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model'] dataset_configs['retrieval_model']
) )
) )
...@@ -626,167 +406,3 @@ class ApplicationManager: ...@@ -626,167 +406,3 @@ class ApplicationManager:
) )
return AppOrchestrationConfigEntity(**properties) return AppOrchestrationConfigEntity(**properties)
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
-> tuple[Conversation, Message]:
"""
Initialize generate records
:param application_generate_entity: application generate entity
:return:
"""
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_schema = model_type_instance.get_model_schema(
model=app_orchestration_config_entity.model_config.model,
credentials=app_orchestration_config_entity.model_config.credentials
)
app_record = (db.session.query(App)
.filter(App.id == application_generate_entity.app_id).first())
app_mode = app_record.mode
# get from source
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
from_source = 'api'
end_user_id = application_generate_entity.user_id
else:
from_source = 'console'
account_id = application_generate_entity.user_id
override_model_configs = None
if application_generate_entity.app_model_config_override:
override_model_configs = application_generate_entity.app_model_config_dict
introduction = ''
if app_mode == 'chat':
# get conversation introduction
introduction = self._get_conversation_introduction(application_generate_entity)
if not application_generate_entity.conversation_id:
conversation = Conversation(
app_id=app_record.id,
app_model_config_id=application_generate_entity.app_model_config_id,
model_provider=app_orchestration_config_entity.model_config.provider,
model_id=app_orchestration_config_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_mode,
name='New conversation',
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status='normal',
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
)
db.session.add(conversation)
db.session.commit()
else:
conversation = (
db.session.query(Conversation)
.filter(
Conversation.id == application_generate_entity.conversation_id,
Conversation.app_id == app_record.id
).first()
)
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
message = Message(
app_id=app_record.id,
model_provider=app_orchestration_config_entity.model_config.provider,
model_id=app_orchestration_config_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query or "",
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency=currency,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
agent_based=app_orchestration_config_entity.agent is not None
)
db.session.add(message)
db.session.commit()
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
belongs_to='user',
url=file.url,
upload_file_id=file.upload_file_id,
created_by_role=('account' if account_id else 'end_user'),
created_by=account_id or end_user_id,
)
db.session.add(message_file)
db.session.commit()
return conversation, message
def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
"""
Get conversation introduction
:param application_generate_entity: application generate entity
:return: conversation introduction
"""
app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
introduction = app_orchestration_config_entity.opening_statement
if introduction:
try:
inputs = application_generate_entity.inputs
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
return introduction
def _get_conversation(self, conversation_id: str) -> Conversation:
"""
Get conversation by conversation id
:param conversation_id: conversation id
:return: conversation
"""
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
return conversation
def _get_message(self, message_id: str) -> Message:
"""
Get message by message id
:param message_id: message id
:return: message
"""
message = (
db.session.query(Message)
.filter(Message.id == message_id)
.first()
)
return message
...@@ -32,7 +32,7 @@ class PublishFrom(Enum): ...@@ -32,7 +32,7 @@ class PublishFrom(Enum):
TASK_PIPELINE = 2 TASK_PIPELINE = 2
class ApplicationQueueManager: class AppQueueManager:
def __init__(self, task_id: str, def __init__(self, task_id: str,
user_id: str, user_id: str,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
...@@ -50,7 +50,7 @@ class ApplicationQueueManager: ...@@ -50,7 +50,7 @@ class ApplicationQueueManager:
self._message_id = str(message_id) self._message_id = str(message_id)
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
q = queue.Queue() q = queue.Queue()
...@@ -239,7 +239,7 @@ class ApplicationQueueManager: ...@@ -239,7 +239,7 @@ class ApplicationQueueManager:
Check if task is stopped Check if task is stopped
:return: :return:
""" """
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id) stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id)
result = redis_client.get(stopped_cache_key) result = redis_client.get(stopped_cache_key)
if result is not None: if result is not None:
return True return True
......
...@@ -2,7 +2,7 @@ import time ...@@ -2,7 +2,7 @@ import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.entities.application_entities import ( from core.entities.application_entities import (
ApplicationGenerateEntity, ApplicationGenerateEntity,
AppOrchestrationConfigEntity, AppOrchestrationConfigEntity,
...@@ -11,10 +11,10 @@ from core.entities.application_entities import ( ...@@ -11,10 +11,10 @@ from core.entities.application_entities import (
ModelConfigEntity, ModelConfigEntity,
PromptTemplateEntity, PromptTemplateEntity,
) )
from core.features.annotation_reply import AnnotationReplyFeature from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.features.external_data_fetch import ExternalDataFetchFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.features.hosting_moderation import HostingModerationFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.features.moderation import ModerationFeature from core.moderation.input_moderation import InputModeration
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
...@@ -169,7 +169,7 @@ class AppRunner: ...@@ -169,7 +169,7 @@ class AppRunner:
return prompt_messages, stop return prompt_messages, stop
def direct_output(self, queue_manager: ApplicationQueueManager, def direct_output(self, queue_manager: AppQueueManager,
app_orchestration_config: AppOrchestrationConfigEntity, app_orchestration_config: AppOrchestrationConfigEntity,
prompt_messages: list, prompt_messages: list,
text: str, text: str,
...@@ -210,7 +210,7 @@ class AppRunner: ...@@ -210,7 +210,7 @@ class AppRunner:
) )
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
agent: bool = False) -> None: agent: bool = False) -> None:
""" """
...@@ -234,7 +234,7 @@ class AppRunner: ...@@ -234,7 +234,7 @@ class AppRunner:
) )
def _handle_invoke_result_direct(self, invoke_result: LLMResult, def _handle_invoke_result_direct(self, invoke_result: LLMResult,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
agent: bool) -> None: agent: bool) -> None:
""" """
Handle invoke result direct Handle invoke result direct
...@@ -248,7 +248,7 @@ class AppRunner: ...@@ -248,7 +248,7 @@ class AppRunner:
) )
def _handle_invoke_result_stream(self, invoke_result: Generator, def _handle_invoke_result_stream(self, invoke_result: Generator,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
agent: bool) -> None: agent: bool) -> None:
""" """
Handle invoke result Handle invoke result
...@@ -306,7 +306,7 @@ class AppRunner: ...@@ -306,7 +306,7 @@ class AppRunner:
:param query: query :param query: query
:return: :return:
""" """
moderation_feature = ModerationFeature() moderation_feature = InputModeration()
return moderation_feature.check( return moderation_feature.check(
app_id=app_id, app_id=app_id,
tenant_id=tenant_id, tenant_id=tenant_id,
...@@ -316,7 +316,7 @@ class AppRunner: ...@@ -316,7 +316,7 @@ class AppRunner:
) )
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage]) -> bool: prompt_messages: list[PromptMessage]) -> bool:
""" """
Check hosting moderation Check hosting moderation
...@@ -358,7 +358,7 @@ class AppRunner: ...@@ -358,7 +358,7 @@ class AppRunner:
:param query: the query :param query: the query
:return: the filled inputs :return: the filled inputs
""" """
external_data_fetch_feature = ExternalDataFetchFeature() external_data_fetch_feature = ExternalDataFetch()
return external_data_fetch_feature.fetch( return external_data_fetch_feature.fetch(
tenant_id=tenant_id, tenant_id=tenant_id,
app_id=app_id, app_id=app_id,
...@@ -388,4 +388,4 @@ class AppRunner: ...@@ -388,4 +388,4 @@ class AppRunner:
query=query, query=query,
user_id=user_id, user_id=user_id,
invoke_from=invoke_from invoke_from=invoke_from
) )
\ No newline at end of file
import logging import logging
from typing import Optional from typing import Optional
from core.app_runner.app_runner import AppRunner from core.app.base_app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ( from core.entities.application_entities import (
ApplicationGenerateEntity, ApplicationGenerateEntity,
...@@ -10,7 +10,7 @@ from core.entities.application_entities import ( ...@@ -10,7 +10,7 @@ from core.entities.application_entities import (
InvokeFrom, InvokeFrom,
ModelConfigEntity, ModelConfigEntity,
) )
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
...@@ -20,13 +20,13 @@ from models.model import App, AppMode, Conversation, Message ...@@ -20,13 +20,13 @@ from models.model import App, AppMode, Conversation, Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BasicApplicationRunner(AppRunner): class ChatAppRunner(AppRunner):
""" """
Basic Application Runner Chat Application Runner
""" """
def run(self, application_generate_entity: ApplicationGenerateEntity, def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
""" """
...@@ -213,7 +213,7 @@ class BasicApplicationRunner(AppRunner): ...@@ -213,7 +213,7 @@ class BasicApplicationRunner(AppRunner):
def retrieve_dataset_context(self, tenant_id: str, def retrieve_dataset_context(self, tenant_id: str,
app_record: App, app_record: App,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
model_config: ModelConfigEntity, model_config: ModelConfigEntity,
dataset_config: DatasetEntity, dataset_config: DatasetEntity,
show_retrieve_source: bool, show_retrieve_source: bool,
...@@ -252,7 +252,7 @@ class BasicApplicationRunner(AppRunner): ...@@ -252,7 +252,7 @@ class BasicApplicationRunner(AppRunner):
and dataset_config.retrieve_config.query_variable): and dataset_config.retrieve_config.query_variable):
query = inputs.get(dataset_config.retrieve_config.query_variable, "") query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrievalFeature() dataset_retrieval = DatasetRetrieval()
return dataset_retrieval.retrieve( return dataset_retrieval.retrieve(
tenant_id=tenant_id, tenant_id=tenant_id,
model_config=model_config, model_config=model_config,
......
from core.apps.config_validators.dataset import DatasetValidator from core.app.validators.dataset_retrieval import DatasetValidator
from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator from core.app.validators.external_data_fetch import ExternalDataFetchValidator
from core.apps.config_validators.file_upload import FileUploadValidator from core.app.validators.file_upload import FileUploadValidator
from core.apps.config_validators.model import ModelValidator from core.app.validators.model_validator import ModelValidator
from core.apps.config_validators.moderation import ModerationValidator from core.app.validators.moderation import ModerationValidator
from core.apps.config_validators.opening_statement import OpeningStatementValidator from core.app.validators.opening_statement import OpeningStatementValidator
from core.apps.config_validators.prompt import PromptValidator from core.app.validators.prompt import PromptValidator
from core.apps.config_validators.retriever_resource import RetrieverResourceValidator from core.app.validators.retriever_resource import RetrieverResourceValidator
from core.apps.config_validators.speech_to_text import SpeechToTextValidator from core.app.validators.speech_to_text import SpeechToTextValidator
from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator from core.app.validators.suggested_questions import SuggestedQuestionsValidator
from core.apps.config_validators.text_to_speech import TextToSpeechValidator from core.app.validators.text_to_speech import TextToSpeechValidator
from core.apps.config_validators.user_input_form import UserInputFormValidator from core.app.validators.user_input_form import UserInputFormValidator
from models.model import AppMode from models.model import AppMode
...@@ -35,7 +35,7 @@ class ChatAppConfigValidator: ...@@ -35,7 +35,7 @@ class ChatAppConfigValidator:
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# external data tools validation # external data tools validation
config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# file upload validation # file upload validation
......
import logging
from typing import Optional
from core.app.base_app_runner import AppRunner
from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import (
ApplicationGenerateEntity,
DatasetEntity,
InvokeFrom,
ModelConfigEntity,
)
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
from extensions.ext_database import db
from models.model import App, AppMode, Conversation, Message
logger = logging.getLogger(__name__)
class CompletionAppRunner(AppRunner):
"""
Completion Application Runner
"""
def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError("App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
memory=memory
)
# moderation
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=application_generate_entity.tenant_id,
app_orchestration_config_entity=app_orchestration_config,
inputs=inputs,
query=query,
)
except ModerationException as e:
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
)
return
if query:
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from
)
if annotation_reply:
queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id,
pub_from=PublishFrom.APPLICATION_MANAGER
)
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get context from datasets
context = None
if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids:
context = self.retrieve_dataset_context(
tenant_id=app_record.tenant_id,
app_record=app_record,
queue_manager=queue_manager,
model_config=app_orchestration_config.model_config,
show_retrieve_source=app_orchestration_config.show_retrieve_source,
dataset_config=app_orchestration_config.dataset,
message=message,
inputs=inputs,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
memory=memory
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context,
memory=memory
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
)
def retrieve_dataset_context(self, tenant_id: str,
app_record: App,
queue_manager: AppQueueManager,
model_config: ModelConfigEntity,
dataset_config: DatasetEntity,
show_retrieve_source: bool,
message: Message,
inputs: dict,
query: str,
user_id: str,
invoke_from: InvokeFrom,
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
"""
Retrieve dataset context
:param tenant_id: tenant id
:param app_record: app record
:param queue_manager: queue manager
:param model_config: model config
:param dataset_config: dataset config
:param show_retrieve_source: show retrieve source
:param message: message
:param inputs: inputs
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:param memory: memory
:return:
"""
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
user_id,
invoke_from
)
# TODO
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
and dataset_config.retrieve_config.query_variable):
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval()
return dataset_retrieval.retrieve(
tenant_id=tenant_id,
model_config=model_config,
config=dataset_config,
query=query,
invoke_from=invoke_from,
show_retrieve_source=show_retrieve_source,
hit_callback=hit_callback,
memory=memory
)
\ No newline at end of file
from core.apps.config_validators.dataset import DatasetValidator from core.app.validators.dataset_retrieval import DatasetValidator
from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator from core.app.validators.external_data_fetch import ExternalDataFetchValidator
from core.apps.config_validators.file_upload import FileUploadValidator from core.app.validators.file_upload import FileUploadValidator
from core.apps.config_validators.model import ModelValidator from core.app.validators.model_validator import ModelValidator
from core.apps.config_validators.moderation import ModerationValidator from core.app.validators.moderation import ModerationValidator
from core.apps.config_validators.more_like_this import MoreLikeThisValidator from core.app.validators.more_like_this import MoreLikeThisValidator
from core.apps.config_validators.prompt import PromptValidator from core.app.validators.prompt import PromptValidator
from core.apps.config_validators.text_to_speech import TextToSpeechValidator from core.app.validators.text_to_speech import TextToSpeechValidator
from core.apps.config_validators.user_input_form import UserInputFormValidator from core.app.validators.user_input_form import UserInputFormValidator
from models.model import AppMode from models.model import AppMode
...@@ -32,7 +32,7 @@ class CompletionAppConfigValidator: ...@@ -32,7 +32,7 @@ class CompletionAppConfigValidator:
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# external data tools validation # external data tools validation
config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config) config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# file upload validation # file upload validation
......
...@@ -6,8 +6,8 @@ from typing import Optional, Union, cast ...@@ -6,8 +6,8 @@ from typing import Optional, Union, cast
from pydantic import BaseModel from pydantic import BaseModel
from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler from core.moderation.output_moderation import ModerationRule, OutputModeration
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
from core.entities.queue_entities import ( from core.entities.queue_entities import (
AnnotationReplyEvent, AnnotationReplyEvent,
...@@ -35,7 +35,7 @@ from core.model_runtime.entities.message_entities import ( ...@@ -35,7 +35,7 @@ from core.model_runtime.entities.message_entities import (
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.prompt_template import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
...@@ -59,7 +59,7 @@ class GenerateTaskPipeline: ...@@ -59,7 +59,7 @@ class GenerateTaskPipeline:
""" """
def __init__(self, application_generate_entity: ApplicationGenerateEntity, def __init__(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
""" """
...@@ -625,7 +625,7 @@ class GenerateTaskPipeline: ...@@ -625,7 +625,7 @@ class GenerateTaskPipeline:
return prompts return prompts
def _init_output_moderation(self) -> Optional[OutputModerationHandler]: def _init_output_moderation(self) -> Optional[OutputModeration]:
""" """
Init output moderation. Init output moderation.
:return: :return:
...@@ -634,7 +634,7 @@ class GenerateTaskPipeline: ...@@ -634,7 +634,7 @@ class GenerateTaskPipeline:
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
if sensitive_word_avoidance: if sensitive_word_avoidance:
return OutputModerationHandler( return OutputModeration(
tenant_id=self._application_generate_entity.tenant_id, tenant_id=self._application_generate_entity.tenant_id,
app_id=self._application_generate_entity.app_id, app_id=self._application_generate_entity.app_id,
rule=ModerationRule( rule=ModerationRule(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from core.external_data_tool.factory import ExternalDataToolFactory from core.external_data_tool.factory import ExternalDataToolFactory
class ExternalDataToolsValidator: class ExternalDataFetchValidator:
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
""" """
......
from core.apps.config_validators.file_upload import FileUploadValidator from core.app.validators.file_upload import FileUploadValidator
from core.apps.config_validators.moderation import ModerationValidator from core.app.validators.moderation import ModerationValidator
from core.apps.config_validators.text_to_speech import TextToSpeechValidator from core.app.validators.text_to_speech import TextToSpeechValidator
class WorkflowAppConfigValidator: class WorkflowAppConfigValidator:
......
from core.apps.config_validators.agent import AgentValidator
from core.apps.config_validators.external_data_tools import ExternalDataToolsValidator
from core.apps.config_validators.file_upload import FileUploadValidator
from core.apps.config_validators.model import ModelValidator
from core.apps.config_validators.moderation import ModerationValidator
from core.apps.config_validators.opening_statement import OpeningStatementValidator
from core.apps.config_validators.prompt import PromptValidator
from core.apps.config_validators.retriever_resource import RetrieverResourceValidator
from core.apps.config_validators.speech_to_text import SpeechToTextValidator
from core.apps.config_validators.suggested_questions import SuggestedQuestionsValidator
from core.apps.config_validators.text_to_speech import TextToSpeechValidator
from core.apps.config_validators.user_input_form import UserInputFormValidator
from models.model import AppMode
class AgentChatAppConfigValidator:
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
"""
Validate for agent chat app model config
:param tenant_id: tenant id
:param config: app model config args
"""
app_mode = AppMode.AGENT_CHAT
related_config_keys = []
# model
config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# user_input_form
config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# external data tools validation
config, current_related_config_keys = ExternalDataToolsValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# file upload validation
config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# prompt
config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys)
# agent_mode
config, current_related_config_keys = AgentValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
...@@ -7,7 +7,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen ...@@ -7,7 +7,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
...@@ -22,7 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -22,7 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
raise_error: bool = True raise_error: bool = True
def __init__(self, model_config: ModelConfigEntity, def __init__(self, model_config: ModelConfigEntity,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
message: Message, message: Message,
message_chain: MessageChain) -> None: message_chain: MessageChain) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
......
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
...@@ -10,7 +10,7 @@ from models.model import DatasetRetrieverResource ...@@ -10,7 +10,7 @@ from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool.""" """Callback handler for dataset tool."""
def __init__(self, queue_manager: ApplicationQueueManager, def __init__(self, queue_manager: AppQueueManager,
app_id: str, app_id: str,
message_id: str, message_id: str,
user_id: str, user_id: str,
......
...@@ -11,7 +11,7 @@ from core.external_data_tool.factory import ExternalDataToolFactory ...@@ -11,7 +11,7 @@ from core.external_data_tool.factory import ExternalDataToolFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExternalDataFetchFeature: class ExternalDataFetch:
def fetch(self, tenant_id: str, def fetch(self, tenant_id: str,
app_id: str, app_id: str,
external_data_tools: list[ExternalDataVariableEntity], external_data_tools: list[ExternalDataVariableEntity],
......
...@@ -13,7 +13,7 @@ from sqlalchemy.orm.exc import ObjectDeletedError ...@@ -13,7 +13,7 @@ from sqlalchemy.orm.exc import ObjectDeletedError
from core.docstore.dataset_docstore import DatasetDocumentStore from core.docstore.dataset_docstore import DatasetDocumentStore
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from core.generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
......
...@@ -7,10 +7,10 @@ from core.model_manager import ModelManager ...@@ -7,10 +7,10 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
class LLMGenerator: class LLMGenerator:
......
...@@ -2,7 +2,7 @@ from typing import Any ...@@ -2,7 +2,7 @@ from typing import Any
from langchain.schema import BaseOutputParser, OutputParserException from langchain.schema import BaseOutputParser, OutputParserException
from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE
from libs.json_in_md_parser import parse_and_check_json_markdown from libs.json_in_md_parser import parse_and_check_json_markdown
......
...@@ -5,7 +5,7 @@ from typing import Any ...@@ -5,7 +5,7 @@ from typing import Any
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
......
...@@ -7,7 +7,7 @@ from core.moderation.factory import ModerationFactory ...@@ -7,7 +7,7 @@ from core.moderation.factory import ModerationFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ModerationFeature: class InputModeration:
def check(self, app_id: str, def check(self, app_id: str,
tenant_id: str, tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity, app_orchestration_config_entity: AppOrchestrationConfigEntity,
......
...@@ -6,7 +6,7 @@ from typing import Any, Optional ...@@ -6,7 +6,7 @@ from typing import Any, Optional
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import BaseModel from pydantic import BaseModel
from core.application_queue_manager import PublishFrom from core.app.app_queue_manager import PublishFrom
from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.factory import ModerationFactory from core.moderation.factory import ModerationFactory
...@@ -18,7 +18,7 @@ class ModerationRule(BaseModel): ...@@ -18,7 +18,7 @@ class ModerationRule(BaseModel):
config: dict[str, Any] config: dict[str, Any]
class OutputModerationHandler(BaseModel): class OutputModeration(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300 DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str tenant_id: str
......
...@@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import ( ...@@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.prompt.prompt_template import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform from core.prompt.prompt_transform import PromptTransform
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
......
...@@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import ( ...@@ -15,7 +15,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.prompt.prompt_template import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform from core.prompt.prompt_transform import PromptTransform
from models.model import AppMode from models.model import AppMode
...@@ -275,7 +275,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -275,7 +275,7 @@ class SimplePromptTransform(PromptTransform):
return prompt_file_contents[prompt_file_name] return prompt_file_contents[prompt_file_name]
# Get the absolute path of the subdirectory # Get the absolute path of the subdirectory
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates')
json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
# Open the JSON file and read its content # Open the JSON file and read its content
......
...@@ -10,7 +10,7 @@ from flask import Flask, current_app ...@@ -10,7 +10,7 @@ from flask import Flask, current_app
from flask_login import current_user from flask_login import current_user
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from core.generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
......
...@@ -7,8 +7,8 @@ from langchain.schema.language_model import BaseLanguageModel ...@@ -7,8 +7,8 @@ from langchain.schema.language_model import BaseLanguageModel
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM from core.rag.retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
......
...@@ -12,7 +12,7 @@ from pydantic import root_validator ...@@ -12,7 +12,7 @@ from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM from core.rag.retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessageTool
......
...@@ -13,7 +13,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException ...@@ -13,7 +13,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.features.dataset_retrieval.agent.llm_chain import LLMChain from core.rag.retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
......
...@@ -10,10 +10,10 @@ from pydantic import BaseModel, Extra ...@@ -10,10 +10,10 @@ from pydantic import BaseModel, Extra
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages from core.entities.message_entities import prompt_messages_to_lc_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.helper import moderation from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
......
...@@ -5,7 +5,7 @@ from langchain.tools import BaseTool ...@@ -5,7 +5,7 @@ from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
...@@ -15,7 +15,7 @@ from extensions.ext_database import db ...@@ -15,7 +15,7 @@ from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
class DatasetRetrievalFeature: class DatasetRetrieval:
def retrieve(self, tenant_id: str, def retrieve(self, tenant_id: str,
model_config: ModelConfigEntity, model_config: ModelConfigEntity,
config: DatasetEntity, config: DatasetEntity,
......
...@@ -4,7 +4,7 @@ from langchain.tools import BaseTool ...@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
...@@ -30,7 +30,7 @@ class DatasetRetrieverTool(Tool): ...@@ -30,7 +30,7 @@ class DatasetRetrieverTool(Tool):
if retrieve_config is None: if retrieve_config is None:
return [] return []
feature = DatasetRetrievalFeature() feature = DatasetRetrieval()
# save original retrieve strategy, and set retrieve strategy to SINGLE # save original retrieve strategy, and set retrieve strategy to SINGLE
# Agent only support SINGLE mode # Agent only support SINGLE mode
......
from core.generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
......
...@@ -310,22 +310,28 @@ class AppModelConfig(db.Model): ...@@ -310,22 +310,28 @@ class AppModelConfig(db.Model):
def from_model_config_dict(self, model_config: dict): def from_model_config_dict(self, model_config: dict):
self.opening_statement = model_config['opening_statement'] self.opening_statement = model_config['opening_statement']
self.suggested_questions = json.dumps(model_config['suggested_questions']) self.suggested_questions = json.dumps(model_config['suggested_questions']) \
self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) if model_config.get('suggested_questions') else None
self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \
if model_config.get('suggested_questions_after_answer') else None
self.speech_to_text = json.dumps(model_config['speech_to_text']) \ self.speech_to_text = json.dumps(model_config['speech_to_text']) \
if model_config.get('speech_to_text') else None if model_config.get('speech_to_text') else None
self.text_to_speech = json.dumps(model_config['text_to_speech']) \ self.text_to_speech = json.dumps(model_config['text_to_speech']) \
if model_config.get('text_to_speech') else None if model_config.get('text_to_speech') else None
self.more_like_this = json.dumps(model_config['more_like_this']) self.more_like_this = json.dumps(model_config['more_like_this']) \
if model_config.get('more_like_this') else None
self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \
if model_config.get('sensitive_word_avoidance') else None if model_config.get('sensitive_word_avoidance') else None
self.external_data_tools = json.dumps(model_config['external_data_tools']) \ self.external_data_tools = json.dumps(model_config['external_data_tools']) \
if model_config.get('external_data_tools') else None if model_config.get('external_data_tools') else None
self.model = json.dumps(model_config['model']) self.model = json.dumps(model_config['model']) \
self.user_input_form = json.dumps(model_config['user_input_form']) if model_config.get('model') else None
self.user_input_form = json.dumps(model_config['user_input_form']) \
if model_config.get('user_input_form') else None
self.dataset_query_variable = model_config.get('dataset_query_variable') self.dataset_query_variable = model_config.get('dataset_query_variable')
self.pre_prompt = model_config['pre_prompt'] self.pre_prompt = model_config['pre_prompt']
self.agent_mode = json.dumps(model_config['agent_mode']) self.agent_mode = json.dumps(model_config['agent_mode']) \
if model_config.get('agent_mode') else None
self.retriever_resource = json.dumps(model_config['retriever_resource']) \ self.retriever_resource = json.dumps(model_config['retriever_resource']) \
if model_config.get('retriever_resource') else None if model_config.get('retriever_resource') else None
self.prompt_type = model_config.get('prompt_type', 'simple') self.prompt_type = model_config.get('prompt_type', 'simple')
......
import copy import copy
from core.prompt.advanced_prompt_templates import ( from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
......
from core.apps.app_config_validators.advanced_chat_app import AdvancedChatAppConfigValidator from core.app.advanced_chat.config_validator import AdvancedChatAppConfigValidator
from core.apps.app_config_validators.agent_chat_app import AgentChatAppConfigValidator from core.app.agent_chat.config_validator import AgentChatAppConfigValidator
from core.apps.app_config_validators.chat_app import ChatAppConfigValidator from core.app.chat.config_validator import ChatAppConfigValidator
from core.apps.app_config_validators.completion_app import CompletionAppConfigValidator from core.app.completion.config_validator import CompletionAppConfigValidator
from core.apps.app_config_validators.workflow_app import WorkflowAppConfigValidator from core.app.workflow.config_validator import WorkflowAppConfigValidator
from models.model import AppMode from models.model import AppMode
......
...@@ -4,8 +4,8 @@ from typing import Any, Union ...@@ -4,8 +4,8 @@ from typing import Any, Union
from sqlalchemy import and_ from sqlalchemy import and_
from core.application_manager import ApplicationManager from core.app.app_manager import AppManager
from core.apps.config_validators.model import ModelValidator from core.app.validators.model_validator import ModelValidator
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.file.message_file_parser import MessageFileParser from core.file.message_file_parser import MessageFileParser
from extensions.ext_database import db from extensions.ext_database import db
...@@ -137,7 +137,7 @@ class CompletionService: ...@@ -137,7 +137,7 @@ class CompletionService:
user user
) )
application_manager = ApplicationManager() application_manager = AppManager()
return application_manager.generate( return application_manager.generate(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
app_id=app_model.id, app_id=app_model.id,
...@@ -193,7 +193,7 @@ class CompletionService: ...@@ -193,7 +193,7 @@ class CompletionService:
message.files, app_model_config message.files, app_model_config
) )
application_manager = ApplicationManager() application_manager = AppManager()
return application_manager.generate( return application_manager.generate(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
app_id=app_model.id, app_id=app_model.id,
......
from typing import Optional, Union from typing import Optional, Union
from core.generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account from models.account import Account
......
import json import json
from typing import Optional, Union from typing import Optional, Union
from core.generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
......
import json import json
from typing import Optional from typing import Optional
from core.application_manager import ApplicationManager from core.app.app_manager import AppManager
from core.entities.application_entities import ( from core.entities.application_entities import (
DatasetEntity, DatasetEntity,
DatasetRetrieveConfigEntity, DatasetRetrieveConfigEntity,
...@@ -111,7 +111,7 @@ class WorkflowConverter: ...@@ -111,7 +111,7 @@ class WorkflowConverter:
new_app_mode = self._get_new_app_mode(app_model) new_app_mode = self._get_new_app_mode(app_model)
# convert app model config # convert app model config
application_manager = ApplicationManager() application_manager = AppManager()
app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
app_model_config_dict=app_model_config.to_dict(), app_model_config_dict=app_model_config.to_dict(),
......
...@@ -8,7 +8,7 @@ from core.file.file_obj import FileObj, FileType, FileTransferMethod ...@@ -8,7 +8,7 @@ from core.file.file_obj import FileObj, FileType, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.prompt_template import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import Conversation from models.model import Conversation
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment