Commit 5c7ea08b authored by takatost's avatar takatost

refactor apps

parent 5e389962
...@@ -37,7 +37,7 @@ class ChatMessageAudioApi(Resource): ...@@ -37,7 +37,7 @@ class ChatMessageAudioApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def post(self, app_model): def post(self, app_model):
file = request.files['file'] file = request.files['file']
......
...@@ -22,7 +22,7 @@ from controllers.console.app.wraps import get_app_model ...@@ -22,7 +22,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.app.app_queue_manager import AppQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
...@@ -103,7 +103,7 @@ class ChatMessageApi(Resource): ...@@ -103,7 +103,7 @@ class ChatMessageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
...@@ -168,7 +168,7 @@ class ChatMessageStopApi(Resource): ...@@ -168,7 +168,7 @@ class ChatMessageStopApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def post(self, app_model, task_id): def post(self, app_model, task_id):
account = flask_login.current_user account = flask_login.current_user
......
...@@ -112,7 +112,7 @@ class CompletionConversationDetailApi(Resource): ...@@ -112,7 +112,7 @@ class CompletionConversationDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def delete(self, app_model, conversation_id): def delete(self, app_model, conversation_id):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
...@@ -133,7 +133,7 @@ class ChatConversationApi(Resource): ...@@ -133,7 +133,7 @@ class ChatConversationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@marshal_with(conversation_with_summary_pagination_fields) @marshal_with(conversation_with_summary_pagination_fields)
def get(self, app_model): def get(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
...@@ -218,7 +218,7 @@ class ChatConversationDetailApi(Resource): ...@@ -218,7 +218,7 @@ class ChatConversationDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@marshal_with(conversation_detail_fields) @marshal_with(conversation_detail_fields)
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
...@@ -227,7 +227,7 @@ class ChatConversationDetailApi(Resource): ...@@ -227,7 +227,7 @@ class ChatConversationDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@account_initialization_required @account_initialization_required
def delete(self, app_model, conversation_id): def delete(self, app_model, conversation_id):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
......
...@@ -42,7 +42,7 @@ class ChatMessageListApi(Resource): ...@@ -42,7 +42,7 @@ class ChatMessageListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@account_initialization_required @account_initialization_required
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model): def get(self, app_model):
...@@ -194,7 +194,7 @@ class MessageSuggestedQuestionApi(Resource): ...@@ -194,7 +194,7 @@ class MessageSuggestedQuestionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def get(self, app_model, message_id): def get(self, app_model, message_id):
message_id = str(message_id) message_id = str(message_id)
......
...@@ -203,7 +203,7 @@ class AverageSessionInteractionStatistic(Resource): ...@@ -203,7 +203,7 @@ class AverageSessionInteractionStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.CHAT) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def get(self, app_model): def get(self, app_model):
account = current_user account = current_user
......
...@@ -22,7 +22,7 @@ from controllers.console.app.error import ( ...@@ -22,7 +22,7 @@ from controllers.console.app.error import (
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_queue_manager import AppQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
......
...@@ -24,7 +24,7 @@ from controllers.console.explore.error import ( ...@@ -24,7 +24,7 @@ from controllers.console.explore.error import (
NotCompletionAppError, NotCompletionAppError,
) )
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
......
...@@ -20,7 +20,7 @@ from controllers.service_api.app.error import ( ...@@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
) )
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.app_queue_manager import AppQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
......
...@@ -21,7 +21,7 @@ from controllers.web.error import ( ...@@ -21,7 +21,7 @@ from controllers.web.error import (
) )
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.app_queue_manager import AppQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
......
...@@ -21,7 +21,7 @@ from controllers.web.error import ( ...@@ -21,7 +21,7 @@ from controllers.web.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
......
...@@ -5,17 +5,15 @@ from datetime import datetime ...@@ -5,17 +5,15 @@ from datetime import datetime
from mimetypes import guess_extension from mimetypes import guess_extension
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_queue_manager import AppQueueManager from core.app.app_queue_manager import AppQueueManager
from core.app.base_app_runner import AppRunner from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_runner import AppRunner
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ( from core.app.entities.app_invoke_entities import (
AgentEntity, EasyUIBasedAppGenerateEntity,
AgentToolEntity, InvokeFrom, EasyUIBasedModelConfigEntity,
ApplicationGenerateEntity,
AppOrchestrationConfigEntity,
InvokeFrom,
ModelConfigEntity,
) )
from core.file.message_file_parser import FileTransferMethod from core.file.message_file_parser import FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
...@@ -50,9 +48,9 @@ logger = logging.getLogger(__name__) ...@@ -50,9 +48,9 @@ logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner): class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str, def __init__(self, tenant_id: str,
application_generate_entity: ApplicationGenerateEntity, application_generate_entity: EasyUIBasedAppGenerateEntity,
app_orchestration_config: AppOrchestrationConfigEntity, app_config: AgentChatAppConfig,
model_config: ModelConfigEntity, model_config: EasyUIBasedModelConfigEntity,
config: AgentEntity, config: AgentEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message: Message, message: Message,
...@@ -66,7 +64,7 @@ class BaseAgentRunner(AppRunner): ...@@ -66,7 +64,7 @@ class BaseAgentRunner(AppRunner):
""" """
Agent runner Agent runner
:param tenant_id: tenant id :param tenant_id: tenant id
:param app_orchestration_config: app orchestration config :param app_config: app generate entity
:param model_config: model config :param model_config: model config
:param config: dataset config :param config: dataset config
:param queue_manager: queue manager :param queue_manager: queue manager
...@@ -78,7 +76,7 @@ class BaseAgentRunner(AppRunner): ...@@ -78,7 +76,7 @@ class BaseAgentRunner(AppRunner):
""" """
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.app_orchestration_config = app_orchestration_config self.app_config = app_config
self.model_config = model_config self.model_config = model_config
self.config = config self.config = config
self.queue_manager = queue_manager self.queue_manager = queue_manager
...@@ -97,16 +95,16 @@ class BaseAgentRunner(AppRunner): ...@@ -97,16 +95,16 @@ class BaseAgentRunner(AppRunner):
# init dataset tools # init dataset tools
hit_callback = DatasetIndexToolCallbackHandler( hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=queue_manager, queue_manager=queue_manager,
app_id=self.application_generate_entity.app_id, app_id=self.app_config.app_id,
message_id=message.id, message_id=message.id,
user_id=user_id, user_id=user_id,
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
) )
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
tenant_id=tenant_id, tenant_id=tenant_id,
dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [], dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None, retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_orchestration_config.show_retrieve_source, return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback hit_callback=hit_callback
) )
...@@ -124,14 +122,15 @@ class BaseAgentRunner(AppRunner): ...@@ -124,14 +122,15 @@ class BaseAgentRunner(AppRunner):
else: else:
self.stream_tool_call = False self.stream_tool_call = False
def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: def _repack_app_generate_entity(self, app_generate_entity: EasyUIBasedAppGenerateEntity) \
-> EasyUIBasedAppGenerateEntity:
""" """
Repack app orchestration config Repack app generate entity
""" """
if app_orchestration_config.prompt_template.simple_prompt_template is None: if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
app_orchestration_config.prompt_template.simple_prompt_template = '' app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
return app_orchestration_config return app_generate_entity
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
""" """
...@@ -351,7 +350,7 @@ class BaseAgentRunner(AppRunner): ...@@ -351,7 +350,7 @@ class BaseAgentRunner(AppRunner):
)) ))
db.session.close() db.session.close()
return result return result
def create_agent_thought(self, message_id: str, message: str, def create_agent_thought(self, message_id: str, message: str,
...@@ -462,7 +461,7 @@ class BaseAgentRunner(AppRunner): ...@@ -462,7 +461,7 @@ class BaseAgentRunner(AppRunner):
db.session.commit() db.session.commit()
db.session.close() db.session.close()
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
""" """
Transform tool message into agent thought Transform tool message into agent thought
......
...@@ -5,7 +5,7 @@ from typing import Literal, Union ...@@ -5,7 +5,7 @@ from typing import Literal, Union
from core.agent.base_agent_runner import BaseAgentRunner from core.agent.base_agent_runner import BaseAgentRunner
from core.app.app_queue_manager import PublishFrom from core.app.app_queue_manager import PublishFrom
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -27,7 +27,7 @@ from core.tools.errors import ( ...@@ -27,7 +27,7 @@ from core.tools.errors import (
from models.model import Conversation, Message from models.model import Conversation, Message
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): class CotAgentRunner(BaseAgentRunner):
_is_first_iteration = True _is_first_iteration = True
_ignore_observation_providers = ['wenxin'] _ignore_observation_providers = ['wenxin']
...@@ -39,30 +39,33 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -39,30 +39,33 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
""" """
Run Cot agent application Run Cot agent application
""" """
app_orchestration_config = self.app_orchestration_config app_generate_entity = self.application_generate_entity
self._repack_app_orchestration_config(app_orchestration_config) self._repack_app_generate_entity(app_generate_entity)
agent_scratchpad: list[AgentScratchpadUnit] = [] agent_scratchpad: list[AgentScratchpadUnit] = []
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
if 'Observation' not in app_orchestration_config.model_config.stop: # check model mode
if app_orchestration_config.model_config.provider not in self._ignore_observation_providers: if 'Observation' not in app_generate_entity.model_config.stop:
app_orchestration_config.model_config.stop.append('Observation') if app_generate_entity.model_config.provider not in self._ignore_observation_providers:
app_generate_entity.model_config.stop.append('Observation')
app_config = self.app_config
# override inputs # override inputs
inputs = inputs or {} inputs = inputs or {}
instruction = self.app_orchestration_config.prompt_template.simple_prompt_template instruction = app_config.prompt_template.simple_prompt_template
instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
prompt_messages = self.history_prompt_messages prompt_messages = self.history_prompt_messages
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
prompt_messages_tools: list[PromptMessageTool] = [] prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {} tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: for tool in app_config.agent.tools if app_config.agent else []:
try: try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception: except Exception:
...@@ -122,11 +125,11 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -122,11 +125,11 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# update prompt messages # update prompt messages
prompt_messages = self._organize_cot_prompt_messages( prompt_messages = self._organize_cot_prompt_messages(
mode=app_orchestration_config.model_config.mode, mode=app_generate_entity.model_config.mode,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=prompt_messages_tools, tools=prompt_messages_tools,
agent_scratchpad=agent_scratchpad, agent_scratchpad=agent_scratchpad,
agent_prompt_message=app_orchestration_config.agent.prompt, agent_prompt_message=app_config.agent.prompt,
instruction=instruction, instruction=instruction,
input=query input=query
) )
...@@ -136,9 +139,9 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -136,9 +139,9 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# invoke model # invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=app_generate_entity.model_config.parameters,
tools=[], tools=[],
stop=app_orchestration_config.model_config.stop, stop=app_generate_entity.model_config.stop,
stream=True, stream=True,
user=self.user_id, user=self.user_id,
callbacks=[], callbacks=[],
...@@ -550,7 +553,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -550,7 +553,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
""" """
convert agent scratchpad list to str convert agent scratchpad list to str
""" """
next_iteration = self.app_orchestration_config.agent.prompt.next_iteration next_iteration = self.app_config.agent.prompt.next_iteration
result = '' result = ''
for scratchpad in agent_scratchpad: for scratchpad in agent_scratchpad:
......
from enum import Enum
from typing import Literal, Any, Union, Optional
from pydantic import BaseModel
class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
class AgentScratchpadUnit(BaseModel):
"""
Agent First Prompt Entity.
"""
class Action(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
agent_response: Optional[str] = None
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
class AgentEntity(BaseModel):
"""
Agent Entity.
"""
class Strategy(Enum):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
provider: str
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] = None
max_iteration: int = 5
...@@ -34,9 +34,11 @@ class FunctionCallAgentRunner(BaseAgentRunner): ...@@ -34,9 +34,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
""" """
Run FunctionCall agent application Run FunctionCall agent application
""" """
app_orchestration_config = self.app_orchestration_config app_generate_entity = self.application_generate_entity
prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or '' app_config = self.app_config
prompt_template = app_config.prompt_template.simple_prompt_template or ''
prompt_messages = self.history_prompt_messages prompt_messages = self.history_prompt_messages
prompt_messages = self.organize_prompt_messages( prompt_messages = self.organize_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
...@@ -47,7 +49,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): ...@@ -47,7 +49,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
prompt_messages_tools: list[PromptMessageTool] = [] prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {} tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: for tool in app_config.agent.tools if app_config.agent else []:
try: try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception: except Exception:
...@@ -67,7 +69,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): ...@@ -67,7 +69,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_instances[dataset_tool.identity.name] = dataset_tool tool_instances[dataset_tool.identity.name] = dataset_tool
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = True function_call_state = True
...@@ -110,9 +112,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): ...@@ -110,9 +112,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# invoke model # invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=app_generate_entity.model_config.parameters,
tools=prompt_messages_tools, tools=prompt_messages_tools,
stop=app_orchestration_config.model_config.stop, stop=app_generate_entity.model_config.stop,
stream=self.stream_tool_call, stream=self.stream_tool_call,
user=self.user_id, user=self.user_id,
callbacks=[], callbacks=[],
......
from typing import Union, Optional
from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import \
SuggestedQuestionsAfterAnswerConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import AppModelConfig
class BaseAppConfigManager:
@classmethod
def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: Union[AppModelConfig, dict],
config_dict: Optional[dict] = None) -> dict:
"""
Convert app model config to config dict
:param config_from: app model config from
:param app_model_config: app model config
:param config_dict: app model config dict
:return:
"""
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
return config_dict
@classmethod
def convert_features(cls, config_dict: dict) -> AppAdditionalFeatures:
"""
Convert app config to app model config
:param config_dict: app config
"""
config_dict = config_dict.copy()
additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
config=config_dict
)
additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict
)
additional_features.opening_statement, additional_features.suggested_questions = \
OpeningStatementConfigManager.convert(
config=config_dict
)
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
config=config_dict
)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
config=config_dict
)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
config=config_dict
)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
config=config_dict
)
return additional_features
import logging from typing import Optional
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
from core.moderation.factory import ModerationFactory from core.moderation.factory import ModerationFactory
logger = logging.getLogger(__name__)
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
if not sensitive_word_avoidance_dict:
return None
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
)
else:
return None
class ModerationValidator:
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
-> tuple[dict, list[str]]: -> tuple[dict, list[str]]:
......
from typing import Optional
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
class AgentConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[AgentEntity]:
"""
Convert model config to model config
:param config: model config args
"""
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode'] \
and config['agent_mode']['enabled']:
agent_dict = config.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == 'cot' or agent_strategy == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config['model']['provider'] == 'openai':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
continue
agent_tool_properties = {
'provider_type': tool['provider_type'],
'provider_id': tool['provider_id'],
'tool_name': tool['tool_name'],
'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
if 'strategy' in config['agent_mode'] and \
config['agent_mode']['strategy'] not in ['react_router', 'router']:
agent_prompt = agent_dict.get('prompt', None) or {}
# check model mode
model_mode = config.get('model', {}).get('mode', 'completion')
if model_mode == 'completion':
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['completion'][
'agent_scratchpad']),
)
else:
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
)
return AgentEntity(
provider=config['model']['provider'],
model=config['model']['name'],
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get('max_iteration', 5)
)
return None
import uuid from typing import Optional
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode from models.model import AppMode
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
class DatasetValidator: class DatasetConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[DatasetEntity]:
"""
Convert model config to model config
:param config: model config args
"""
dataset_ids = []
if 'datasets' in config.get('dataset_configs', {}):
datasets = config.get('dataset_configs', {}).get('datasets', {
'strategy': 'router',
'datasets': []
})
for dataset in datasets.get('datasets', []):
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != 'dataset':
continue
dataset = dataset['dataset']
if 'enabled' not in dataset or not dataset['enabled']:
continue
dataset_id = dataset.get('id', None)
if dataset_id:
dataset_ids.append(dataset_id)
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode'] \
and config['agent_mode']['enabled']:
agent_dict = config.get('agent_mode', {})
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) == 1:
# old standard
key = list(tool.keys())[0]
if key != 'dataset':
continue
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item['id']
dataset_ids.append(dataset_id)
if len(dataset_ids) == 0:
return None
# dataset configs
dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'})
query_variable = config.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single':
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
)
)
)
else:
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get('top_k'),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model')
)
)
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
""" """
......
from typing import cast
from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.provider_manager import ProviderManager
class EasyUIBasedModelConfigEntityConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig,
skip_check: bool = False) \
-> EasyUIBasedModelConfigEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
:param skip_check: skip check
:raises ProviderTokenNotInitError: provider token not init error
:return: app orchestration config entity
"""
model_config = app_config.model
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id,
provider=model_config.provider,
model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
model_name = model_config.model
model_type_instance = provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_config.model
)
if model_credentials is None:
if not skip_check:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else:
model_credentials = {}
if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model,
model_type=ModelType.LLM
)
if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model_config.parameters
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = model_config.mode
if not model_mode:
mode_enum = model_type_instance.get_model_mode(
model=model_config.model,
credentials=model_credentials
)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(
model_config.model,
model_credentials
)
if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return EasyUIBasedModelConfigEntity(
provider=model_config.provider,
model=model_config.model,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
from core.app.app_config.entities import ModelConfigEntity
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelType, ModelPropertyKey
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
class ModelValidator: class ModelConfigManager:
@classmethod
def convert(cls, config: dict) -> ModelConfigEntity:
"""
Convert model config to model config
:param config: model config args
"""
# model config
model_config = config.get('model')
if not model_config:
raise ValueError("model is required")
completion_params = model_config.get('completion_params')
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = model_config.get('mode')
return ModelConfigEntity(
provider=config['model']['provider'],
model=config['model']['name'],
mode=model_mode,
parameters=completion_params,
stop=stop,
)
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
""" """
......
from core.app.app_config.entities import PromptTemplateEntity, \
from core.entities.application_entities import PromptTemplateEntity AdvancedChatPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from models.model import AppMode from models.model import AppMode
class PromptValidator: class PromptTemplateConfigManager:
@classmethod
def convert(cls, config: dict) -> PromptTemplateEntity:
if not config.get("prompt_type"):
raise ValueError("prompt_type is required")
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = config.get("pre_prompt", "")
return PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else:
advanced_chat_prompt_template = None
chat_prompt_config = config.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'],
}
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
**completion_prompt_template_params
)
return PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
)
@classmethod @classmethod
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
""" """
...@@ -83,4 +134,4 @@ class PromptValidator: ...@@ -83,4 +134,4 @@ class PromptValidator:
if not isinstance(config["post_prompt"], str): if not isinstance(config["post_prompt"], str):
raise ValueError("post_prompt must be of string type") raise ValueError("post_prompt must be of string type")
return config return config
\ No newline at end of file
import re import re
from typing import Tuple
from core.app.app_config.entities import VariableEntity, ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
class BasicVariablesConfigManager:
@classmethod
def convert(cls, config: dict) -> Tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
"""
Convert model config to model config
:param config: model config args
"""
external_data_variables = []
variables = []
# old external_data_tools
external_data_tools = config.get('external_data_tools', [])
for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
)
)
# variables and external_data_tools
for variable in config.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
external_data_variables.append(
ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
)
)
elif typ in [
VariableEntity.Type.TEXT_INPUT.value,
VariableEntity.Type.PARAGRAPH.value,
VariableEntity.Type.NUMBER.value,
]:
variables.append(
VariableEntity(
type=VariableEntity.Type.value_of(typ),
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
max_length=variable[typ].get('max_length'),
default=variable[typ].get('default'),
)
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
)
)
return variables, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for user input form
:param tenant_id: workspace id
:param config: app model config args
"""
related_config_keys = []
config, current_related_config_keys = cls.validate_variables_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
config, current_related_config_keys = cls.validate_external_data_tools_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
return config, related_config_keys
class UserInputFormValidator:
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
""" """
Validate and set defaults for user input form Validate and set defaults for user input form
...@@ -59,3 +147,38 @@ class UserInputFormValidator: ...@@ -59,3 +147,38 @@ class UserInputFormValidator:
raise ValueError("default value in user_input_form must be in the options list") raise ValueError("default value in user_input_form must be in the options list")
return config, ["user_input_form"] return config, ["user_input_form"]
@classmethod
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for external data fetch feature
:param tenant_id: workspace id
:param config: app model config args
"""
if not config.get("external_data_tools"):
config["external_data_tools"] = []
if not isinstance(config["external_data_tools"], list):
raise ValueError("external_data_tools must be of list type")
for tool in config["external_data_tools"]:
if "enabled" not in tool or not tool["enabled"]:
tool["enabled"] = False
if not tool["enabled"]:
continue
if "type" not in tool or not tool["type"]:
raise ValueError("external_data_tools[].type is required")
typ = tool["type"]
config = tool["config"]
ExternalDataToolFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=config
)
return config, ["external_data_tools"]
\ No newline at end of file
from enum import Enum from enum import Enum
from typing import Any, Literal, Optional, Union from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileObj
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import AIModelEntity from models.model import AppMode
class ModelConfigEntity(BaseModel): class ModelConfigEntity(BaseModel):
...@@ -15,10 +13,7 @@ class ModelConfigEntity(BaseModel): ...@@ -15,10 +13,7 @@ class ModelConfigEntity(BaseModel):
""" """
provider: str provider: str
model: str model: str
model_schema: Optional[AIModelEntity] = None mode: Optional[str] = None
mode: str
provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {}
parameters: dict[str, Any] = {} parameters: dict[str, Any] = {}
stop: list[str] = [] stop: list[str] = []
...@@ -194,149 +189,53 @@ class FileUploadEntity(BaseModel): ...@@ -194,149 +189,53 @@ class FileUploadEntity(BaseModel):
image_config: Optional[dict[str, Any]] = None image_config: Optional[dict[str, Any]] = None
class AgentToolEntity(BaseModel): class AppAdditionalFeatures(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
class AgentScratchpadUnit(BaseModel):
"""
Agent First Prompt Entity.
"""
class Action(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
agent_response: Optional[str] = None
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
class AgentEntity(BaseModel):
"""
Agent Entity.
"""
class Strategy(Enum):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
provider: str
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] = None
max_iteration: int = 5
class AppOrchestrationConfigEntity(BaseModel):
"""
App Orchestration Config Entity.
"""
model_config: ModelConfigEntity
prompt_template: PromptTemplateEntity
variables: list[VariableEntity] = []
external_data_variables: list[ExternalDataVariableEntity] = []
agent: Optional[AgentEntity] = None
# features
dataset: Optional[DatasetEntity] = None
file_upload: Optional[FileUploadEntity] = None file_upload: Optional[FileUploadEntity] = None
opening_statement: Optional[str] = None opening_statement: Optional[str] = None
suggested_questions: list[str] = []
suggested_questions_after_answer: bool = False suggested_questions_after_answer: bool = False
show_retrieve_source: bool = False show_retrieve_source: bool = False
more_like_this: bool = False more_like_this: bool = False
speech_to_text: bool = False speech_to_text: bool = False
text_to_speech: Optional[TextToSpeechEntity] = None text_to_speech: Optional[TextToSpeechEntity] = None
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
class InvokeFrom(Enum): class AppConfig(BaseModel):
""" """
Invoke From. Application Config Entity.
""" """
SERVICE_API = 'service-api' tenant_id: str
WEB_APP = 'web-app' app_id: str
EXPLORE = 'explore' app_mode: AppMode
DEBUGGER = 'debugger' additional_features: AppAdditionalFeatures
variables: list[VariableEntity] = []
@classmethod sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
def value_of(cls, value: str) -> 'InvokeFrom':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid invoke from value {value}')
def to_source(self) -> str:
"""
Get source of invoke from.
:return: source
"""
if self == InvokeFrom.WEB_APP:
return 'web_app'
elif self == InvokeFrom.DEBUGGER:
return 'dev'
elif self == InvokeFrom.EXPLORE:
return 'explore_app'
elif self == InvokeFrom.SERVICE_API:
return 'api'
return 'dev' class EasyUIBasedAppModelConfigFrom(Enum):
"""
App Model Config From.
"""
ARGS = 'args'
APP_LATEST_CONFIG = 'app-latest-config'
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
class ApplicationGenerateEntity(BaseModel): class EasyUIBasedAppConfig(AppConfig):
""" """
Application Generate Entity. Easy UI Based App Config Entity.
""" """
task_id: str app_model_config_from: EasyUIBasedAppModelConfigFrom
tenant_id: str
app_id: str
app_model_config_id: str app_model_config_id: str
# for save
app_model_config_dict: dict app_model_config_dict: dict
app_model_config_override: bool model: ModelConfigEntity
prompt_template: PromptTemplateEntity
# Converted from app_model_config to Entity object, or directly covered by external input dataset: Optional[DatasetEntity] = None
app_orchestration_config_entity: AppOrchestrationConfigEntity external_data_variables: list[ExternalDataVariableEntity] = []
conversation_id: Optional[str] = None
inputs: dict[str, str] class WorkflowUIBasedAppConfig(AppConfig):
query: Optional[str] = None """
files: list[FileObj] = [] Workflow UI Based App Config Entity.
user_id: str """
# extras workflow_id: str
stream: bool
invoke_from: InvokeFrom
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
from typing import Optional
from core.app.app_config.entities import FileUploadEntity
class FileUploadConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[FileUploadEntity]:
"""
Convert model config to model config
:param config: model config args
"""
file_upload_dict = config.get('file_upload')
if file_upload_dict:
if 'image' in file_upload_dict and file_upload_dict['image']:
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
return FileUploadEntity(
image_config={
'number_limits': file_upload_dict['image']['number_limits'],
'detail': file_upload_dict['image']['detail'],
'transfer_methods': file_upload_dict['image']['transfer_methods']
}
)
return None
class FileUploadValidator:
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
""" """
......
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
:param config: model config args
"""
more_like_this = False
more_like_this_dict = config.get('more_like_this')
if more_like_this_dict:
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
more_like_this = True
return more_like_this
class MoreLikeThisValidator:
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
""" """
......
from typing import Tuple
class OpeningStatementValidator: class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict) -> Tuple[str, list]:
"""
Convert model config to model config
:param config: model config args
"""
# opening statement
opening_statement = config.get('opening_statement')
# suggested questions
suggested_questions_list = config.get('suggested_questions')
return opening_statement, suggested_questions_list
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
""" """
......
class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get('retriever_resource')
if retriever_resource_dict:
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
show_retrieve_source = True
return show_retrieve_source
class RetrieverResourceValidator:
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
""" """
......
class SpeechToTextConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
:param config: model config args
"""
speech_to_text = False
speech_to_text_dict = config.get('speech_to_text')
if speech_to_text_dict:
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
speech_to_text = True
return speech_to_text
class SpeechToTextValidator:
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
""" """
......
class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
:param config: model config args
"""
suggested_questions_after_answer = False
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
if suggested_questions_after_answer_dict:
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
suggested_questions_after_answer = True
return suggested_questions_after_answer
class SuggestedQuestionsValidator:
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
""" """
...@@ -16,7 +29,8 @@ class SuggestedQuestionsValidator: ...@@ -16,7 +29,8 @@ class SuggestedQuestionsValidator:
if not isinstance(config["suggested_questions_after_answer"], dict): if not isinstance(config["suggested_questions_after_answer"], dict):
raise ValueError("suggested_questions_after_answer must be of dict type") raise ValueError("suggested_questions_after_answer must be of dict type")
if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]: if "enabled" not in config["suggested_questions_after_answer"] or not \
config["suggested_questions_after_answer"]["enabled"]:
config["suggested_questions_after_answer"]["enabled"] = False config["suggested_questions_after_answer"]["enabled"] = False
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
......
from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechValidator: class TextToSpeechConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
:param config: model config args
"""
text_to_speech = False
text_to_speech_dict = config.get('text_to_speech')
if text_to_speech_dict:
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
text_to_speech = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get('language'),
)
return text_to_speech
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
""" """
......
from core.app.app_config.entities import VariableEntity
from models.workflow import Workflow
class WorkflowVariablesConfigManager:
@classmethod
def convert(cls, workflow: Workflow) -> list[VariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
# find start node
user_input_form = workflow.user_input_form()
# variables
for variable in user_input_form:
variables.append(VariableEntity(**variable))
return variables
This diff is collapsed.
This diff is collapsed.
...@@ -6,8 +6,8 @@ from typing import Any ...@@ -6,8 +6,8 @@ from typing import Any
from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeMeta
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities.queue_entities import ( from core.app.entities.queue_entities import (
AnnotationReplyEvent, AnnotationReplyEvent,
AppQueueEvent, AppQueueEvent,
QueueAgentMessageEvent, QueueAgentMessageEvent,
......
from core.app.validators.file_upload import FileUploadValidator from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.validators.moderation import ModerationValidator from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.validators.opening_statement import OpeningStatementValidator from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.validators.retriever_resource import RetrieverResourceValidator from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.validators.speech_to_text import SpeechToTextValidator from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.validators.suggested_questions import SuggestedQuestionsValidator from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.validators.text_to_speech import TextToSpeechValidator from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import \
SuggestedQuestionsAfterAnswerConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.model import AppMode, App
from models.workflow import Workflow
class AdvancedChatAppConfigValidator: class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
"""
Advanced Chatbot App Config Entity.
"""
pass
class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def config_convert(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=AppMode.value_of(app_model.mode),
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict)
)
return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
""" """
...@@ -20,31 +53,32 @@ class AdvancedChatAppConfigValidator: ...@@ -20,31 +53,32 @@ class AdvancedChatAppConfigValidator:
related_config_keys = [] related_config_keys = []
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# opening_statement # opening_statement
config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# text_to_speech # text_to_speech
config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# return retriever resource # return retriever resource
config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, tenant_id=tenant_id,
config=config, config=config,
only_structure_validate=only_structure_validate only_structure_validate=only_structure_validate
...@@ -57,3 +91,4 @@ class AdvancedChatAppConfigValidator: ...@@ -57,3 +91,4 @@ class AdvancedChatAppConfigValidator:
filtered_config = {key: config.get(key) for key in related_config_keys} filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config return filtered_config
import uuid import uuid
from typing import Optional
from core.app.validators.dataset_retrieval import DatasetValidator
from core.app.validators.external_data_fetch import ExternalDataFetchValidator from core.agent.entities import AgentEntity
from core.app.validators.file_upload import FileUploadValidator from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.validators.model_validator import ModelValidator from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
from core.app.validators.moderation import ModerationValidator from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.app.validators.opening_statement import OpeningStatementValidator from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.validators.prompt import PromptValidator from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
from core.app.validators.retriever_resource import RetrieverResourceValidator from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
from core.app.validators.speech_to_text import SpeechToTextValidator from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.validators.suggested_questions import SuggestedQuestionsValidator from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, DatasetEntity
from core.app.validators.text_to_speech import TextToSpeechValidator from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.validators.user_input_form import UserInputFormValidator from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import \
SuggestedQuestionsAfterAnswerConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode from models.model import AppMode, App, AppModelConfig
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
class AgentChatAppConfigValidator: class AgentChatAppConfig(EasyUIBasedAppConfig):
"""
Agent Chatbot App Config Entity.
"""
agent: Optional[AgentEntity] = None
class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod
def config_convert(cls, app_model: App,
config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: AppModelConfig,
config_dict: Optional[dict] = None) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
:param app_model: app model
:param config_from: app model config from
:param app_model_config: app model config
:param config_dict: app model config dict
:return:
"""
config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict)
app_config = AgentChatAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=AppMode.value_of(app_model.mode),
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
agent=AgentConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict)
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
config=config_dict
)
return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict: def config_validate(cls, tenant_id: str, config: dict) -> dict:
""" """
...@@ -32,23 +90,19 @@ class AgentChatAppConfigValidator: ...@@ -32,23 +90,19 @@ class AgentChatAppConfigValidator:
related_config_keys = [] related_config_keys = []
# model # model
config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config) config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# user_input_form # user_input_form
config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config) config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# external data tools validation
config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# prompt # prompt
config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config) config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# agent_mode # agent_mode
...@@ -56,27 +110,29 @@ class AgentChatAppConfigValidator: ...@@ -56,27 +110,29 @@ class AgentChatAppConfigValidator:
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# opening_statement # opening_statement
config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config) config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config) config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config) config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# text_to_speech # text_to_speech
config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# return retriever resource # return retriever resource
config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config) config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config) config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys)) related_config_keys = list(set(related_config_keys))
...@@ -143,7 +199,7 @@ class AgentChatAppConfigValidator: ...@@ -143,7 +199,7 @@ class AgentChatAppConfigValidator:
except ValueError: except ValueError:
raise ValueError("id in dataset must be of UUID type") raise ValueError("id in dataset must be of UUID type")
if not DatasetValidator.is_dataset_exists(tenant_id, tool_item["id"]): if not DatasetConfigManager.is_dataset_exists(tenant_id, tool_item["id"]):
raise ValueError("Dataset ID does not exist, please check your permission.") raise ValueError("Dataset ID does not exist, please check your permission.")
else: else:
# latest style, use key-value pair # latest style, use key-value pair
......
...@@ -2,10 +2,12 @@ import logging ...@@ -2,10 +2,12 @@ import logging
from typing import cast from typing import cast
from core.agent.cot_agent_runner import CotAgentRunner from core.agent.cot_agent_runner import CotAgentRunner
from core.agent.entities import AgentEntity
from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.base_app_runner import AppRunner from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, EasyUIBasedModelConfigEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
...@@ -24,7 +26,7 @@ class AgentChatAppRunner(AppRunner): ...@@ -24,7 +26,7 @@ class AgentChatAppRunner(AppRunner):
""" """
Agent Application Runner Agent Application Runner
""" """
def run(self, application_generate_entity: ApplicationGenerateEntity, def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
...@@ -36,12 +38,13 @@ class AgentChatAppRunner(AppRunner): ...@@ -36,12 +38,13 @@ class AgentChatAppRunner(AppRunner):
:param message: message :param message: message
:return: :return:
""" """
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs inputs = application_generate_entity.inputs
query = application_generate_entity.query query = application_generate_entity.query
files = application_generate_entity.files files = application_generate_entity.files
...@@ -53,8 +56,8 @@ class AgentChatAppRunner(AppRunner): ...@@ -53,8 +56,8 @@ class AgentChatAppRunner(AppRunner):
# Not Include: memory, external data, dataset context # Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens( self.get_pre_calculate_rest_tokens(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query
...@@ -64,22 +67,22 @@ class AgentChatAppRunner(AppRunner): ...@@ -64,22 +67,22 @@ class AgentChatAppRunner(AppRunner):
if application_generate_entity.conversation_id: if application_generate_entity.conversation_id:
# get memory of conversation (read-only) # get memory of conversation (read-only)
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model model=application_generate_entity.model_config.model
) )
memory = TokenBufferMemory( memory = TokenBufferMemory(
conversation=conversation, conversation=conversation,
model_instance=model_instance model_instance=model_instance
) )
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
# memory(optional) # memory(optional)
prompt_messages, _ = self.organize_prompt_messages( prompt_messages, _ = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
...@@ -91,15 +94,15 @@ class AgentChatAppRunner(AppRunner): ...@@ -91,15 +94,15 @@ class AgentChatAppRunner(AppRunner):
# process sensitive_word_avoidance # process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs( _, inputs, query = self.moderation_for_inputs(
app_id=app_record.id, app_id=app_record.id,
tenant_id=application_generate_entity.tenant_id, tenant_id=app_config.tenant_id,
app_orchestration_config_entity=app_orchestration_config, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
) )
except ModerationException as e: except ModerationException as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream stream=application_generate_entity.stream
...@@ -123,7 +126,7 @@ class AgentChatAppRunner(AppRunner): ...@@ -123,7 +126,7 @@ class AgentChatAppRunner(AppRunner):
) )
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=annotation_reply.content, text=annotation_reply.content,
stream=application_generate_entity.stream stream=application_generate_entity.stream
...@@ -131,7 +134,7 @@ class AgentChatAppRunner(AppRunner): ...@@ -131,7 +134,7 @@ class AgentChatAppRunner(AppRunner):
return return
# fill in variable inputs from external data tools if exists # fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables external_data_tools = app_config.external_data_variables
if external_data_tools: if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools( inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
...@@ -146,8 +149,8 @@ class AgentChatAppRunner(AppRunner): ...@@ -146,8 +149,8 @@ class AgentChatAppRunner(AppRunner):
# memory(optional), external data, dataset context(optional) # memory(optional), external data, dataset context(optional)
prompt_messages, _ = self.organize_prompt_messages( prompt_messages, _ = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
...@@ -164,25 +167,25 @@ class AgentChatAppRunner(AppRunner): ...@@ -164,25 +167,25 @@ class AgentChatAppRunner(AppRunner):
if hosting_moderation_result: if hosting_moderation_result:
return return
agent_entity = app_orchestration_config.agent agent_entity = app_config.agent
# load tool variables # load tool variables
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
tenant_id=application_generate_entity.tenant_id) tenant_id=app_config.tenant_id)
# convert db variables to tool variables # convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
# init model instance # init model instance
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model model=application_generate_entity.model_config.model
) )
prompt_message, _ = self.organize_prompt_messages( prompt_message, _ = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
...@@ -203,10 +206,10 @@ class AgentChatAppRunner(AppRunner): ...@@ -203,10 +206,10 @@ class AgentChatAppRunner(AppRunner):
# start agent runner # start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = CotAgentRunner( assistant_cot_runner = CotAgentRunner(
tenant_id=application_generate_entity.tenant_id, tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
app_orchestration_config=app_orchestration_config, app_config=app_config,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
config=agent_entity, config=agent_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
message=message, message=message,
...@@ -225,10 +228,10 @@ class AgentChatAppRunner(AppRunner): ...@@ -225,10 +228,10 @@ class AgentChatAppRunner(AppRunner):
) )
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
assistant_fc_runner = FunctionCallAgentRunner( assistant_fc_runner = FunctionCallAgentRunner(
tenant_id=application_generate_entity.tenant_id, tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
app_orchestration_config=app_orchestration_config, app_config=app_config,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
config=agent_entity, config=agent_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
message=message, message=message,
...@@ -289,7 +292,7 @@ class AgentChatAppRunner(AppRunner): ...@@ -289,7 +292,7 @@ class AgentChatAppRunner(AppRunner):
'pool': db_variables.variables 'pool': db_variables.variables
}) })
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, def _get_usage_of_all_agent_thoughts(self, model_config: EasyUIBasedModelConfigEntity,
message: Message) -> LLMUsage: message: Message) -> LLMUsage:
""" """
Get usage of all agent thoughts Get usage of all agent thoughts
......
...@@ -2,16 +2,13 @@ import time ...@@ -2,16 +2,13 @@ import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.app.app_config.entities import PromptTemplateEntity, ExternalDataVariableEntity
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.entities.application_entities import ( from core.app.entities.app_invoke_entities import (
ApplicationGenerateEntity, EasyUIBasedAppGenerateEntity,
AppOrchestrationConfigEntity, InvokeFrom, EasyUIBasedModelConfigEntity,
ExternalDataVariableEntity,
InvokeFrom,
ModelConfigEntity,
PromptTemplateEntity,
) )
from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
...@@ -29,7 +26,7 @@ from models.model import App, AppMode, Message, MessageAnnotation ...@@ -29,7 +26,7 @@ from models.model import App, AppMode, Message, MessageAnnotation
class AppRunner: class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App, def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigEntity, model_config: EasyUIBasedModelConfigEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileObj], files: list[FileObj],
...@@ -85,7 +82,7 @@ class AppRunner: ...@@ -85,7 +82,7 @@ class AppRunner:
return rest_tokens return rest_tokens
def recalc_llm_max_tokens(self, model_config: ModelConfigEntity, def recale_llm_max_tokens(self, model_config: EasyUIBasedModelConfigEntity,
prompt_messages: list[PromptMessage]): prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
...@@ -121,7 +118,7 @@ class AppRunner: ...@@ -121,7 +118,7 @@ class AppRunner:
model_config.parameters[parameter_rule.name] = max_tokens model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(self, app_record: App, def organize_prompt_messages(self, app_record: App,
model_config: ModelConfigEntity, model_config: EasyUIBasedModelConfigEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileObj], files: list[FileObj],
...@@ -170,7 +167,7 @@ class AppRunner: ...@@ -170,7 +167,7 @@ class AppRunner:
return prompt_messages, stop return prompt_messages, stop
def direct_output(self, queue_manager: AppQueueManager, def direct_output(self, queue_manager: AppQueueManager,
app_orchestration_config: AppOrchestrationConfigEntity, app_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list, prompt_messages: list,
text: str, text: str,
stream: bool, stream: bool,
...@@ -178,7 +175,7 @@ class AppRunner: ...@@ -178,7 +175,7 @@ class AppRunner:
""" """
Direct output Direct output
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param app_orchestration_config: app orchestration config :param app_generate_entity: app generate entity
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:param text: text :param text: text
:param stream: stream :param stream: stream
...@@ -189,7 +186,7 @@ class AppRunner: ...@@ -189,7 +186,7 @@ class AppRunner:
index = 0 index = 0
for token in text: for token in text:
queue_manager.publish_chunk_message(LLMResultChunk( queue_manager.publish_chunk_message(LLMResultChunk(
model=app_orchestration_config.model_config.model, model=app_generate_entity.model_config.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=index, index=index,
...@@ -201,7 +198,7 @@ class AppRunner: ...@@ -201,7 +198,7 @@ class AppRunner:
queue_manager.publish_message_end( queue_manager.publish_message_end(
llm_result=LLMResult( llm_result=LLMResult(
model=app_orchestration_config.model_config.model, model=app_generate_entity.model_config.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text), message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage() usage=usage if usage else LLMUsage.empty_usage()
...@@ -294,14 +291,14 @@ class AppRunner: ...@@ -294,14 +291,14 @@ class AppRunner:
def moderation_for_inputs(self, app_id: str, def moderation_for_inputs(self, app_id: str,
tenant_id: str, tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity, app_generate_entity: EasyUIBasedAppGenerateEntity,
inputs: dict, inputs: dict,
query: str) -> tuple[bool, dict, str]: query: str) -> tuple[bool, dict, str]:
""" """
Process sensitive_word_avoidance. Process sensitive_word_avoidance.
:param app_id: app id :param app_id: app id
:param tenant_id: tenant id :param tenant_id: tenant id
:param app_orchestration_config_entity: app orchestration config entity :param app_generate_entity: app generate entity
:param inputs: inputs :param inputs: inputs
:param query: query :param query: query
:return: :return:
...@@ -310,12 +307,12 @@ class AppRunner: ...@@ -310,12 +307,12 @@ class AppRunner:
return moderation_feature.check( return moderation_feature.check(
app_id=app_id, app_id=app_id,
tenant_id=tenant_id, tenant_id=tenant_id,
app_orchestration_config_entity=app_orchestration_config_entity, app_config=app_generate_entity.app_config,
inputs=inputs, inputs=inputs,
query=query, query=query,
) )
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage]) -> bool: prompt_messages: list[PromptMessage]) -> bool:
""" """
...@@ -334,7 +331,7 @@ class AppRunner: ...@@ -334,7 +331,7 @@ class AppRunner:
if moderation_result: if moderation_result:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_orchestration_config=application_generate_entity.app_orchestration_config_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text="I apologize for any confusion, " \ text="I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest.", "but I'm an AI assistant to be helpful, harmless, and honest.",
......
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import \
SuggestedQuestionsAfterAnswerConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import AppMode, App, AppModelConfig
class ChatAppConfig(EasyUIBasedAppConfig):
"""
Chatbot App Config Entity.
"""
pass
class ChatAppConfigManager(BaseAppConfigManager):
@classmethod
def config_convert(cls, app_model: App,
config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: AppModelConfig,
config_dict: Optional[dict] = None) -> ChatAppConfig:
"""
Convert app model config to chat app config
:param app_model: app model
:param config_from: app model config from
:param app_model_config: app model config
:param config_dict: app model config dict
:return:
"""
config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict)
app_config = ChatAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=AppMode.value_of(app_model.mode),
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict)
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
config=config_dict
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
"""
Validate for chat app model config
:param tenant_id: tenant id
:param config: app model config args
"""
app_mode = AppMode.CHAT
related_config_keys = []
# model
config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# user_input_form
config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# prompt
config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
import logging import logging
from typing import cast
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.base_app_runner import AppRunner from core.app.apps.chat.app_config_manager import ChatAppConfig
from core.app.apps.base_app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ( from core.app.entities.app_invoke_entities import (
ApplicationGenerateEntity, EasyUIBasedAppGenerateEntity,
) )
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
...@@ -21,7 +23,7 @@ class ChatAppRunner(AppRunner): ...@@ -21,7 +23,7 @@ class ChatAppRunner(AppRunner):
Chat Application Runner Chat Application Runner
""" """
def run(self, application_generate_entity: ApplicationGenerateEntity, def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
...@@ -33,12 +35,13 @@ class ChatAppRunner(AppRunner): ...@@ -33,12 +35,13 @@ class ChatAppRunner(AppRunner):
:param message: message :param message: message
:return: :return:
""" """
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs inputs = application_generate_entity.inputs
query = application_generate_entity.query query = application_generate_entity.query
files = application_generate_entity.files files = application_generate_entity.files
...@@ -50,8 +53,8 @@ class ChatAppRunner(AppRunner): ...@@ -50,8 +53,8 @@ class ChatAppRunner(AppRunner):
# Not Include: memory, external data, dataset context # Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens( self.get_pre_calculate_rest_tokens(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query
...@@ -61,8 +64,8 @@ class ChatAppRunner(AppRunner): ...@@ -61,8 +64,8 @@ class ChatAppRunner(AppRunner):
if application_generate_entity.conversation_id: if application_generate_entity.conversation_id:
# get memory of conversation (read-only) # get memory of conversation (read-only)
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model model=application_generate_entity.model_config.model
) )
memory = TokenBufferMemory( memory = TokenBufferMemory(
...@@ -75,8 +78,8 @@ class ChatAppRunner(AppRunner): ...@@ -75,8 +78,8 @@ class ChatAppRunner(AppRunner):
# memory(optional) # memory(optional)
prompt_messages, stop = self.organize_prompt_messages( prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
...@@ -88,15 +91,15 @@ class ChatAppRunner(AppRunner): ...@@ -88,15 +91,15 @@ class ChatAppRunner(AppRunner):
# process sensitive_word_avoidance # process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs( _, inputs, query = self.moderation_for_inputs(
app_id=app_record.id, app_id=app_record.id,
tenant_id=application_generate_entity.tenant_id, tenant_id=app_config.tenant_id,
app_orchestration_config_entity=app_orchestration_config, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
) )
except ModerationException as e: except ModerationException as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream stream=application_generate_entity.stream
...@@ -120,7 +123,7 @@ class ChatAppRunner(AppRunner): ...@@ -120,7 +123,7 @@ class ChatAppRunner(AppRunner):
) )
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=annotation_reply.content, text=annotation_reply.content,
stream=application_generate_entity.stream stream=application_generate_entity.stream
...@@ -128,7 +131,7 @@ class ChatAppRunner(AppRunner): ...@@ -128,7 +131,7 @@ class ChatAppRunner(AppRunner):
return return
# fill in variable inputs from external data tools if exists # fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables external_data_tools = app_config.external_data_variables
if external_data_tools: if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools( inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
...@@ -140,7 +143,7 @@ class ChatAppRunner(AppRunner): ...@@ -140,7 +143,7 @@ class ChatAppRunner(AppRunner):
# get context from datasets # get context from datasets
context = None context = None
if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler( hit_callback = DatasetIndexToolCallbackHandler(
queue_manager, queue_manager,
app_record.id, app_record.id,
...@@ -152,11 +155,11 @@ class ChatAppRunner(AppRunner): ...@@ -152,11 +155,11 @@ class ChatAppRunner(AppRunner):
dataset_retrieval = DatasetRetrieval() dataset_retrieval = DatasetRetrieval()
context = dataset_retrieval.retrieve( context = dataset_retrieval.retrieve(
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
config=app_orchestration_config.dataset, config=app_config.dataset,
query=query, query=query,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_orchestration_config.show_retrieve_source, show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback, hit_callback=hit_callback,
memory=memory memory=memory
) )
...@@ -166,8 +169,8 @@ class ChatAppRunner(AppRunner): ...@@ -166,8 +169,8 @@ class ChatAppRunner(AppRunner):
# memory(optional), external data, dataset context(optional) # memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.organize_prompt_messages( prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
...@@ -186,22 +189,22 @@ class ChatAppRunner(AppRunner): ...@@ -186,22 +189,22 @@ class ChatAppRunner(AppRunner):
return return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens( self.recale_llm_max_tokens(
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_messages=prompt_messages prompt_messages=prompt_messages
) )
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model model=application_generate_entity.model_config.model
) )
db.session.close() db.session.close()
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=application_generate_entity.model_config.parameters,
stop=stop, stop=stop,
stream=application_generate_entity.stream, stream=application_generate_entity.stream,
user=application_generate_entity.user_id, user=application_generate_entity.user_id,
......
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import AppMode, App, AppModelConfig
class CompletionAppConfig(EasyUIBasedAppConfig):
"""
Completion App Config Entity.
"""
pass
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def config_convert(cls, app_model: App,
config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: AppModelConfig,
config_dict: Optional[dict] = None) -> CompletionAppConfig:
"""
Convert app model config to completion app config
:param app_model: app model
:param config_from: app model config from
:param app_model_config: app model config
:param config_dict: app model config dict
:return:
"""
config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict)
app_config = CompletionAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=AppMode.value_of(app_model.mode),
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict)
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
config=config_dict
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
"""
Validate for completion app model config
:param tenant_id: tenant id
:param config: app model config args
"""
app_mode = AppMode.COMPLETION
related_config_keys = []
# model
config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# user_input_form
config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# prompt
config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# more_like_this
config, current_related_config_keys = MoreLikeThisConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
import logging import logging
from typing import cast
from core.app.app_queue_manager import AppQueueManager from core.app.app_queue_manager import AppQueueManager
from core.app.base_app_runner import AppRunner from core.app.apps.completion.app_config_manager import CompletionAppConfig
from core.app.apps.base_app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ( from core.app.entities.app_invoke_entities import (
ApplicationGenerateEntity, EasyUIBasedAppGenerateEntity,
) )
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
...@@ -20,7 +22,7 @@ class CompletionAppRunner(AppRunner): ...@@ -20,7 +22,7 @@ class CompletionAppRunner(AppRunner):
Completion Application Runner Completion Application Runner
""" """
def run(self, application_generate_entity: ApplicationGenerateEntity, def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message: Message) -> None: message: Message) -> None:
""" """
...@@ -30,12 +32,13 @@ class CompletionAppRunner(AppRunner): ...@@ -30,12 +32,13 @@ class CompletionAppRunner(AppRunner):
:param message: message :param message: message
:return: :return:
""" """
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs inputs = application_generate_entity.inputs
query = application_generate_entity.query query = application_generate_entity.query
files = application_generate_entity.files files = application_generate_entity.files
...@@ -47,8 +50,8 @@ class CompletionAppRunner(AppRunner): ...@@ -47,8 +50,8 @@ class CompletionAppRunner(AppRunner):
# Not Include: memory, external data, dataset context # Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens( self.get_pre_calculate_rest_tokens(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query
...@@ -58,8 +61,8 @@ class CompletionAppRunner(AppRunner): ...@@ -58,8 +61,8 @@ class CompletionAppRunner(AppRunner):
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
prompt_messages, stop = self.organize_prompt_messages( prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query query=query
...@@ -70,15 +73,15 @@ class CompletionAppRunner(AppRunner): ...@@ -70,15 +73,15 @@ class CompletionAppRunner(AppRunner):
# process sensitive_word_avoidance # process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs( _, inputs, query = self.moderation_for_inputs(
app_id=app_record.id, app_id=app_record.id,
tenant_id=application_generate_entity.tenant_id, tenant_id=app_config.tenant_id,
app_orchestration_config_entity=app_orchestration_config, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
) )
except ModerationException as e: except ModerationException as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream stream=application_generate_entity.stream
...@@ -86,7 +89,7 @@ class CompletionAppRunner(AppRunner): ...@@ -86,7 +89,7 @@ class CompletionAppRunner(AppRunner):
return return
# fill in variable inputs from external data tools if exists # fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables external_data_tools = app_config.external_data_variables
if external_data_tools: if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools( inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
...@@ -98,7 +101,7 @@ class CompletionAppRunner(AppRunner): ...@@ -98,7 +101,7 @@ class CompletionAppRunner(AppRunner):
# get context from datasets # get context from datasets
context = None context = None
if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler( hit_callback = DatasetIndexToolCallbackHandler(
queue_manager, queue_manager,
app_record.id, app_record.id,
...@@ -107,18 +110,18 @@ class CompletionAppRunner(AppRunner): ...@@ -107,18 +110,18 @@ class CompletionAppRunner(AppRunner):
application_generate_entity.invoke_from application_generate_entity.invoke_from
) )
dataset_config = app_orchestration_config.dataset dataset_config = app_config.dataset
if dataset_config and dataset_config.retrieve_config.query_variable: if dataset_config and dataset_config.retrieve_config.query_variable:
query = inputs.get(dataset_config.retrieve_config.query_variable, "") query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval() dataset_retrieval = DatasetRetrieval()
context = dataset_retrieval.retrieve( context = dataset_retrieval.retrieve(
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
config=dataset_config, config=dataset_config,
query=query, query=query,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_orchestration_config.show_retrieve_source, show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback hit_callback=hit_callback
) )
...@@ -127,8 +130,8 @@ class CompletionAppRunner(AppRunner): ...@@ -127,8 +130,8 @@ class CompletionAppRunner(AppRunner):
# memory(optional), external data, dataset context(optional) # memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.organize_prompt_messages( prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
...@@ -147,19 +150,19 @@ class CompletionAppRunner(AppRunner): ...@@ -147,19 +150,19 @@ class CompletionAppRunner(AppRunner):
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens( self.recale_llm_max_tokens(
model_config=app_orchestration_config.model_config, model_config=application_generate_entity.model_config,
prompt_messages=prompt_messages prompt_messages=prompt_messages
) )
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model model=application_generate_entity.model_config.model
) )
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=application_generate_entity.model_config.parameters,
stop=stop, stop=stop,
stream=application_generate_entity.stream, stream=application_generate_entity.stream,
user=application_generate_entity.user_id, user=application_generate_entity.user_id,
......
from core.app.validators.file_upload import FileUploadValidator from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.validators.moderation import ModerationValidator from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.validators.text_to_speech import TextToSpeechValidator from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.model import AppMode, App
from models.workflow import Workflow
class WorkflowAppConfigValidator: class WorkflowAppConfig(WorkflowUIBasedAppConfig):
"""
Workflow App Config Entity.
"""
pass
class WorkflowAppConfigManager(BaseAppConfigManager):
@classmethod
def config_convert(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig:
features_dict = workflow.features_dict
app_config = WorkflowAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=AppMode.value_of(app_model.mode),
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict)
)
return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
""" """
...@@ -16,15 +48,15 @@ class WorkflowAppConfigValidator: ...@@ -16,15 +48,15 @@ class WorkflowAppConfigValidator:
related_config_keys = [] related_config_keys = []
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config) config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# text_to_speech # text_to_speech
config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config) config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = ModerationValidator.validate_and_set_defaults( config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, tenant_id=tenant_id,
config=config, config=config,
only_structure_validate=only_structure_validate only_structure_validate=only_structure_validate
......
from core.app.validators.dataset_retrieval import DatasetValidator
from core.app.validators.external_data_fetch import ExternalDataFetchValidator
from core.app.validators.file_upload import FileUploadValidator
from core.app.validators.model_validator import ModelValidator
from core.app.validators.moderation import ModerationValidator
from core.app.validators.opening_statement import OpeningStatementValidator
from core.app.validators.prompt import PromptValidator
from core.app.validators.retriever_resource import RetrieverResourceValidator
from core.app.validators.speech_to_text import SpeechToTextValidator
from core.app.validators.suggested_questions import SuggestedQuestionsValidator
from core.app.validators.text_to_speech import TextToSpeechValidator
from core.app.validators.user_input_form import UserInputFormValidator
from models.model import AppMode
class ChatAppConfigValidator:
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
"""
Validate for chat app model config
:param tenant_id: tenant id
:param config: app model config args
"""
app_mode = AppMode.CHAT
related_config_keys = []
# model
config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# user_input_form
config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# external data tools validation
config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# file upload validation
config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# prompt
config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrieverResourceValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
from core.app.validators.dataset_retrieval import DatasetValidator
from core.app.validators.external_data_fetch import ExternalDataFetchValidator
from core.app.validators.file_upload import FileUploadValidator
from core.app.validators.model_validator import ModelValidator
from core.app.validators.moderation import ModerationValidator
from core.app.validators.more_like_this import MoreLikeThisValidator
from core.app.validators.prompt import PromptValidator
from core.app.validators.text_to_speech import TextToSpeechValidator
from core.app.validators.user_input_form import UserInputFormValidator
from models.model import AppMode
class CompletionAppConfigValidator:
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
"""
Validate for completion app model config
:param tenant_id: tenant id
:param config: app model config args
"""
app_mode = AppMode.COMPLETION
related_config_keys = []
# model
config, current_related_config_keys = ModelValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# user_input_form
config, current_related_config_keys = UserInputFormValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# external data tools validation
config, current_related_config_keys = ExternalDataFetchValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# file upload validation
config, current_related_config_keys = FileUploadValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# prompt
config, current_related_config_keys = PromptValidator.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetValidator.validate_and_set_defaults(tenant_id, app_mode, config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# more_like_this
config, current_related_config_keys = MoreLikeThisValidator.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = ModerationValidator.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileObj
from core.model_runtime.entities.model_entities import AIModelEntity
class InvokeFrom(Enum):
"""
Invoke From.
"""
SERVICE_API = 'service-api'
WEB_APP = 'web-app'
EXPLORE = 'explore'
DEBUGGER = 'debugger'
@classmethod
def value_of(cls, value: str) -> 'InvokeFrom':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid invoke from value {value}')
def to_source(self) -> str:
"""
Get source of invoke from.
:return: source
"""
if self == InvokeFrom.WEB_APP:
return 'web_app'
elif self == InvokeFrom.DEBUGGER:
return 'dev'
elif self == InvokeFrom.EXPLORE:
return 'explore_app'
elif self == InvokeFrom.SERVICE_API:
return 'api'
return 'dev'
class EasyUIBasedModelConfigEntity(BaseModel):
"""
Model Config Entity.
"""
provider: str
model: str
model_schema: AIModelEntity
mode: str
provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {}
parameters: dict[str, Any] = {}
stop: list[str] = []
class EasyUIBasedAppGenerateEntity(BaseModel):
"""
EasyUI Based Application Generate Entity.
"""
task_id: str
# app config
app_config: EasyUIBasedAppConfig
model_config: EasyUIBasedModelConfigEntity
conversation_id: Optional[str] = None
inputs: dict[str, str]
query: Optional[str] = None
files: list[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
class WorkflowUIBasedAppGenerateEntity(BaseModel):
"""
Workflow UI Based Application Generate Entity.
"""
task_id: str
# app config
app_config: WorkflowUIBasedAppConfig
inputs: dict[str, str]
files: list[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom
# extra parameters
extras: dict[str, Any] = {}
class AdvancedChatAppGenerateEntity(WorkflowUIBasedAppGenerateEntity):
conversation_id: Optional[str] = None
query: str
import logging import logging
from typing import Optional from typing import Optional
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
......
import logging import logging
from core.entities.application_entities import ApplicationGenerateEntity from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity
from core.helper import moderation from core.helper import moderation
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
...@@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) ...@@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
class HostingModerationFeature: class HostingModerationFeature:
def check(self, application_generate_entity: ApplicationGenerateEntity, def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list[PromptMessage]) -> bool: prompt_messages: list[PromptMessage]) -> bool:
""" """
Check hosting moderation Check hosting moderation
...@@ -16,8 +16,7 @@ class HostingModerationFeature: ...@@ -16,8 +16,7 @@ class HostingModerationFeature:
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: :return:
""" """
app_orchestration_config = application_generate_entity.app_orchestration_config_entity model_config = application_generate_entity.model_config
model_config = app_orchestration_config.model_config
text = "" text = ""
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
......
...@@ -7,8 +7,8 @@ from typing import Optional, Union, cast ...@@ -7,8 +7,8 @@ from typing import Optional, Union, cast
from pydantic import BaseModel from pydantic import BaseModel
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, InvokeFrom
from core.entities.queue_entities import ( from core.app.entities.queue_entities import (
AnnotationReplyEvent, AnnotationReplyEvent,
QueueAgentMessageEvent, QueueAgentMessageEvent,
QueueAgentThoughtEvent, QueueAgentThoughtEvent,
...@@ -58,7 +58,7 @@ class GenerateTaskPipeline: ...@@ -58,7 +58,7 @@ class GenerateTaskPipeline:
GenerateTaskPipeline is a class that generate stream output and state management for Application. GenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
def __init__(self, application_generate_entity: ApplicationGenerateEntity, def __init__(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
...@@ -75,7 +75,7 @@ class GenerateTaskPipeline: ...@@ -75,7 +75,7 @@ class GenerateTaskPipeline:
self._message = message self._message = message
self._task_state = TaskState( self._task_state = TaskState(
llm_result=LLMResult( llm_result=LLMResult(
model=self._application_generate_entity.app_orchestration_config_entity.model_config.model, model=self._application_generate_entity.model_config.model,
prompt_messages=[], prompt_messages=[],
message=AssistantPromptMessage(content=""), message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage() usage=LLMUsage.empty_usage()
...@@ -127,7 +127,7 @@ class GenerateTaskPipeline: ...@@ -127,7 +127,7 @@ class GenerateTaskPipeline:
if isinstance(event, QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result self._task_state.llm_result = event.llm_result
else: else:
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config model_config = self._application_generate_entity.model_config
model = model_config.model model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
...@@ -210,7 +210,7 @@ class GenerateTaskPipeline: ...@@ -210,7 +210,7 @@ class GenerateTaskPipeline:
if isinstance(event, QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result self._task_state.llm_result = event.llm_result
else: else:
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config model_config = self._application_generate_entity.model_config
model = model_config.model model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
...@@ -569,7 +569,7 @@ class GenerateTaskPipeline: ...@@ -569,7 +569,7 @@ class GenerateTaskPipeline:
:return: :return:
""" """
prompts = [] prompts = []
if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat': if self._application_generate_entity.model_config.mode == 'chat':
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER: if prompt_message.role == PromptMessageRole.USER:
role = 'user' role = 'user'
...@@ -638,13 +638,13 @@ class GenerateTaskPipeline: ...@@ -638,13 +638,13 @@ class GenerateTaskPipeline:
Init output moderation. Init output moderation.
:return: :return:
""" """
app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity app_config = self._application_generate_entity.app_config
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance sensitive_word_avoidance = app_config.sensitive_word_avoidance
if sensitive_word_avoidance: if sensitive_word_avoidance:
return OutputModeration( return OutputModeration(
tenant_id=self._application_generate_entity.tenant_id, tenant_id=app_config.tenant_id,
app_id=self._application_generate_entity.app_id, app_id=app_config.app_id,
rule=ModerationRule( rule=ModerationRule(
type=sensitive_word_avoidance.type, type=sensitive_word_avoidance.type,
config=sensitive_word_avoidance.config config=sensitive_word_avoidance.config
......
from core.external_data_tool.factory import ExternalDataToolFactory
class ExternalDataFetchValidator:
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for external data fetch feature
:param tenant_id: workspace id
:param config: app model config args
"""
if not config.get("external_data_tools"):
config["external_data_tools"] = []
if not isinstance(config["external_data_tools"], list):
raise ValueError("external_data_tools must be of list type")
for tool in config["external_data_tools"]:
if "enabled" not in tool or not tool["enabled"]:
tool["enabled"] = False
if not tool["enabled"]:
continue
if "type" not in tool or not tool["type"]:
raise ValueError("external_data_tools[].type is required")
typ = tool["type"]
config = tool["config"]
ExternalDataToolFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=config
)
return config, ["external_data_tools"]
from pydantic import BaseModel
class AgentLoop(BaseModel):
position: int = 1
thought: str = None
tool_name: str = None
tool_input: str = None
tool_output: str = None
prompt: str = None
prompt_tokens: int = 0
completion: str = None
completion_tokens: int = 0
latency: float = None
status: str = 'llm_started'
completed: bool = False
started_at: float = None
completed_at: float = None
\ No newline at end of file
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DatasetQuery, DocumentSegment from models.dataset import DatasetQuery, DocumentSegment
......
...@@ -5,7 +5,7 @@ from typing import Optional ...@@ -5,7 +5,7 @@ from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from core.entities.application_entities import ExternalDataVariableEntity from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory from core.external_data_tool.factory import ExternalDataToolFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -3,6 +3,7 @@ from typing import Optional ...@@ -3,6 +3,7 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.app.app_config.entities import FileUploadEntity
from core.file.upload_file_parser import UploadFileParser from core.file.upload_file_parser import UploadFileParser
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from extensions.ext_database import db from extensions.ext_database import db
...@@ -50,7 +51,7 @@ class FileObj(BaseModel): ...@@ -50,7 +51,7 @@ class FileObj(BaseModel):
transfer_method: FileTransferMethod transfer_method: FileTransferMethod
url: Optional[str] url: Optional[str]
upload_file_id: Optional[str] upload_file_id: Optional[str]
file_config: dict file_upload_entity: FileUploadEntity
@property @property
def data(self) -> Optional[str]: def data(self) -> Optional[str]:
...@@ -63,7 +64,7 @@ class FileObj(BaseModel): ...@@ -63,7 +64,7 @@ class FileObj(BaseModel):
@property @property
def prompt_message_content(self) -> ImagePromptMessageContent: def prompt_message_content(self) -> ImagePromptMessageContent:
if self.type == FileType.IMAGE: if self.type == FileType.IMAGE:
image_config = self.file_config.get('image') image_config = self.file_upload_entity.image_config
return ImagePromptMessageContent( return ImagePromptMessageContent(
data=self.data, data=self.data,
......
from typing import Optional, Union from typing import Union
import requests import requests
from core.app.app_config.entities import FileUploadEntity
from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import AppModelConfig, EndUser, MessageFile, UploadFile from models.model import EndUser, MessageFile, UploadFile
from services.file_service import IMAGE_EXTENSIONS from services.file_service import IMAGE_EXTENSIONS
...@@ -15,18 +16,16 @@ class MessageFileParser: ...@@ -15,18 +16,16 @@ class MessageFileParser:
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.app_id = app_id self.app_id = app_id
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig, def validate_and_transform_files_arg(self, files: list[dict], file_upload_entity: FileUploadEntity,
user: Union[Account, EndUser]) -> list[FileObj]: user: Union[Account, EndUser]) -> list[FileObj]:
""" """
validate and transform files arg validate and transform files arg
:param files: :param files:
:param app_model_config: :param file_upload_entity:
:param user: :param user:
:return: :return:
""" """
file_upload_config = app_model_config.file_upload_dict
for file in files: for file in files:
if not isinstance(file, dict): if not isinstance(file, dict):
raise ValueError('Invalid file format, must be dict') raise ValueError('Invalid file format, must be dict')
...@@ -45,17 +44,17 @@ class MessageFileParser: ...@@ -45,17 +44,17 @@ class MessageFileParser:
raise ValueError('Missing file upload_file_id') raise ValueError('Missing file upload_file_id')
# transform files to file objs # transform files to file objs
type_file_objs = self._to_file_objs(files, file_upload_config) type_file_objs = self._to_file_objs(files, file_upload_entity)
# validate files # validate files
new_files = [] new_files = []
for file_type, file_objs in type_file_objs.items(): for file_type, file_objs in type_file_objs.items():
if file_type == FileType.IMAGE: if file_type == FileType.IMAGE:
# parse and validate files # parse and validate files
image_config = file_upload_config.get('image') image_config = file_upload_entity.image_config
# check if image file feature is enabled # check if image file feature is enabled
if not image_config['enabled']: if not image_config:
continue continue
# Validate number of files # Validate number of files
...@@ -96,27 +95,27 @@ class MessageFileParser: ...@@ -96,27 +95,27 @@ class MessageFileParser:
# return all file objs # return all file objs
return new_files return new_files
def transform_message_files(self, files: list[MessageFile], file_upload_config: Optional[dict]) -> list[FileObj]: def transform_message_files(self, files: list[MessageFile], file_upload_entity: FileUploadEntity) -> list[FileObj]:
""" """
transform message files transform message files
:param files: :param files:
:param file_upload_config: :param file_upload_entity:
:return: :return:
""" """
# transform files to file objs # transform files to file objs
type_file_objs = self._to_file_objs(files, file_upload_config) type_file_objs = self._to_file_objs(files, file_upload_entity)
# return all file objs # return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: list[Union[dict, MessageFile]], def _to_file_objs(self, files: list[Union[dict, MessageFile]],
file_upload_config: dict) -> dict[FileType, list[FileObj]]: file_upload_entity: FileUploadEntity) -> dict[FileType, list[FileObj]]:
""" """
transform files to file objs transform files to file objs
:param files: :param files:
:param file_upload_config: :param file_upload_entity:
:return: :return:
""" """
type_file_objs: dict[FileType, list[FileObj]] = { type_file_objs: dict[FileType, list[FileObj]] = {
...@@ -133,7 +132,7 @@ class MessageFileParser: ...@@ -133,7 +132,7 @@ class MessageFileParser:
if file.belongs_to == FileBelongsTo.ASSISTANT.value: if file.belongs_to == FileBelongsTo.ASSISTANT.value:
continue continue
file_obj = self._to_file_obj(file, file_upload_config) file_obj = self._to_file_obj(file, file_upload_entity)
if file_obj.type not in type_file_objs: if file_obj.type not in type_file_objs:
continue continue
...@@ -141,7 +140,7 @@ class MessageFileParser: ...@@ -141,7 +140,7 @@ class MessageFileParser:
return type_file_objs return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj: def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_entity: FileUploadEntity) -> FileObj:
""" """
transform file to file obj transform file to file obj
...@@ -156,7 +155,7 @@ class MessageFileParser: ...@@ -156,7 +155,7 @@ class MessageFileParser:
transfer_method=transfer_method, transfer_method=transfer_method,
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
file_config=file_upload_config file_upload_entity=file_upload_entity
) )
else: else:
return FileObj( return FileObj(
...@@ -166,7 +165,7 @@ class MessageFileParser: ...@@ -166,7 +165,7 @@ class MessageFileParser:
transfer_method=FileTransferMethod.value_of(file.transfer_method), transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url, url=file.url,
upload_file_id=file.upload_file_id or None, upload_file_id=file.upload_file_id or None,
file_config=file_upload_config file_upload_entity=file_upload_entity
) )
def _check_image_remote_url(self, url): def _check_image_remote_url(self, url):
......
import logging import logging
import random import random
from core.entities.application_entities import ModelConfigEntity from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
from extensions.ext_hosting_provider import hosting_configuration from extensions.ext_hosting_provider import hosting_configuration
...@@ -10,7 +10,7 @@ from models.provider import ProviderType ...@@ -10,7 +10,7 @@ from models.provider import ProviderType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_moderation(model_config: ModelConfigEntity, text: str) -> bool: def check_moderation(model_config: EasyUIBasedModelConfigEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config moderation_config = hosting_configuration.moderation_config
if (moderation_config and moderation_config.enabled is True if (moderation_config and moderation_config.enabled is True
and 'openai' in hosting_configuration.provider_map and 'openai' in hosting_configuration.provider_map
......
from core.app.app_config.entities import FileUploadEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file.message_file_parser import MessageFileParser from core.file.message_file_parser import MessageFileParser
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
...@@ -43,12 +45,18 @@ class TokenBufferMemory: ...@@ -43,12 +45,18 @@ class TokenBufferMemory:
for message in messages: for message in messages:
files = message.message_files files = message.message_files
if files: if files:
file_objs = message_file_parser.transform_message_files( if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
files, file_upload_entity = FileUploadConfigManager.convert(message.app_model_config.to_dict())
message.app_model_config.file_upload_dict else:
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] file_upload_entity = FileUploadConfigManager.convert(message.workflow_run.workflow.features_dict)
else message.workflow_run.workflow.features_dict.get('file_upload', {})
) if file_upload_entity:
file_objs = message_file_parser.transform_message_files(
files,
file_upload_entity
)
else:
file_objs = []
if not file_objs: if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query)) prompt_messages.append(UserPromptMessage(content=message.query))
......
import logging import logging
from core.entities.application_entities import AppOrchestrationConfigEntity from core.app.app_config.entities import AppConfig
from core.moderation.base import ModerationAction, ModerationException from core.moderation.base import ModerationAction, ModerationException
from core.moderation.factory import ModerationFactory from core.moderation.factory import ModerationFactory
...@@ -10,22 +10,22 @@ logger = logging.getLogger(__name__) ...@@ -10,22 +10,22 @@ logger = logging.getLogger(__name__)
class InputModeration: class InputModeration:
def check(self, app_id: str, def check(self, app_id: str,
tenant_id: str, tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity, app_config: AppConfig,
inputs: dict, inputs: dict,
query: str) -> tuple[bool, dict, str]: query: str) -> tuple[bool, dict, str]:
""" """
Process sensitive_word_avoidance. Process sensitive_word_avoidance.
:param app_id: app id :param app_id: app id
:param tenant_id: tenant id :param tenant_id: tenant id
:param app_orchestration_config_entity: app orchestration config entity :param app_config: app config
:param inputs: inputs :param inputs: inputs
:param query: query :param query: query
:return: :return:
""" """
if not app_orchestration_config_entity.sensitive_word_avoidance: if not app_config.sensitive_word_avoidance:
return False, inputs, query return False, inputs, query
sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance sensitive_word_avoidance_config = app_config.sensitive_word_avoidance
moderation_type = sensitive_word_avoidance_config.type moderation_type = sensitive_word_avoidance_config.type
moderation_factory = ModerationFactory( moderation_factory = ModerationFactory(
......
from typing import Optional from typing import Optional
from core.entities.application_entities import ( from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity
AdvancedCompletionPromptTemplateEntity, from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
ModelConfigEntity,
PromptTemplateEntity,
)
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
...@@ -31,7 +28,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -31,7 +28,7 @@ class AdvancedPromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> list[PromptMessage]: model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]:
prompt_messages = [] prompt_messages = []
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
...@@ -65,7 +62,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -65,7 +62,7 @@ class AdvancedPromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> list[PromptMessage]: model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]:
""" """
Get completion model prompt messages. Get completion model prompt messages.
""" """
...@@ -113,7 +110,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -113,7 +110,7 @@ class AdvancedPromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> list[PromptMessage]: model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]:
""" """
Get chat model prompt messages. Get chat model prompt messages.
""" """
...@@ -202,7 +199,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -202,7 +199,7 @@ class AdvancedPromptTransform(PromptTransform):
role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity,
prompt_template: PromptTemplateParser, prompt_template: PromptTemplateParser,
prompt_inputs: dict, prompt_inputs: dict,
model_config: ModelConfigEntity) -> dict: model_config: EasyUIBasedModelConfigEntity) -> dict:
if '#histories#' in prompt_template.variable_keys: if '#histories#' in prompt_template.variable_keys:
if memory: if memory:
inputs = {'#histories#': '', **prompt_inputs} inputs = {'#histories#': '', **prompt_inputs}
......
from typing import Optional, cast from typing import Optional, cast
from core.entities.application_entities import ModelConfigEntity from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
...@@ -10,14 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -10,14 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
class PromptTransform: class PromptTransform:
def _append_chat_histories(self, memory: TokenBufferMemory, def _append_chat_histories(self, memory: TokenBufferMemory,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_config: ModelConfigEntity) -> list[PromptMessage]: model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config) rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, rest_tokens) histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories) prompt_messages.extend(histories)
return prompt_messages return prompt_messages
def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int: def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: EasyUIBasedModelConfigEntity) -> int:
rest_tokens = 2000 rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
......
...@@ -3,10 +3,8 @@ import json ...@@ -3,10 +3,8 @@ import json
import os import os
from typing import Optional from typing import Optional
from core.entities.application_entities import ( from core.app.app_config.entities import PromptTemplateEntity
ModelConfigEntity, from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
PromptTemplateEntity,
)
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
...@@ -54,7 +52,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -54,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> \ model_config: EasyUIBasedModelConfigEntity) -> \
tuple[list[PromptMessage], Optional[list[str]]]: tuple[list[PromptMessage], Optional[list[str]]]:
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
...@@ -83,7 +81,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -83,7 +81,7 @@ class SimplePromptTransform(PromptTransform):
return prompt_messages, stops return prompt_messages, stops
def get_prompt_str_and_rules(self, app_mode: AppMode, def get_prompt_str_and_rules(self, app_mode: AppMode,
model_config: ModelConfigEntity, model_config: EasyUIBasedModelConfigEntity,
pre_prompt: str, pre_prompt: str,
inputs: dict, inputs: dict,
query: Optional[str] = None, query: Optional[str] = None,
...@@ -164,7 +162,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -164,7 +162,7 @@ class SimplePromptTransform(PromptTransform):
context: Optional[str], context: Optional[str],
files: list[FileObj], files: list[FileObj],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \ model_config: EasyUIBasedModelConfigEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]: -> tuple[list[PromptMessage], Optional[list[str]]]:
prompt_messages = [] prompt_messages = []
...@@ -202,7 +200,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -202,7 +200,7 @@ class SimplePromptTransform(PromptTransform):
context: Optional[str], context: Optional[str],
files: list[FileObj], files: list[FileObj],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \ model_config: EasyUIBasedModelConfigEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]: -> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt # get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules( prompt, prompt_rules = self.get_prompt_str_and_rules(
......
import logging
from typing import Optional
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class AgentLLMCallback(Callback):
def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
self.agent_callback = agent_callback
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_before_invoke(
prompt_messages=prompt_messages
)
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
pass
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_after_invoke(
result=result
)
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_error(
error=ex
)
...@@ -5,19 +5,17 @@ from langchain.callbacks.manager import CallbackManagerForChainRun ...@@ -5,19 +5,17 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from core.entities.application_entities import ModelConfigEntity from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.rag.retrieval.agent.fake_llm import FakeLLM from core.rag.retrieval.agent.fake_llm import FakeLLM
class LLMChain(LCLLMChain): class LLMChain(LCLLMChain):
model_config: ModelConfigEntity model_config: EasyUIBasedModelConfigEntity
"""The language model instance to use.""" """The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="") llm: BaseLanguageModel = FakeLLM(response="")
parameters: dict[str, Any] = {} parameters: dict[str, Any] = {}
agent_llm_callback: Optional[AgentLLMCallback] = None
def generate( def generate(
self, self,
...@@ -38,7 +36,6 @@ class LLMChain(LCLLMChain): ...@@ -38,7 +36,6 @@ class LLMChain(LCLLMChain):
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stream=False, stream=False,
stop=stop, stop=stop,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
model_parameters=self.parameters model_parameters=self.parameters
) )
......
...@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage ...@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import root_validator from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessageTool
...@@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
""" """
An Multi Dataset Retrieve Agent driven by Router. An Multi Dataset Retrieve Agent driven by Router.
""" """
model_config: ModelConfigEntity model_config: EasyUIBasedModelConfigEntity
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
...@@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_config: ModelConfigEntity, model_config: EasyUIBasedModelConfigEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
......
...@@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy ...@@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.entities.application_entities import ModelConfigEntity from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.rag.retrieval.agent.llm_chain import LLMChain from core.rag.retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
...@@ -206,7 +206,7 @@ Thought: {agent_scratchpad} ...@@ -206,7 +206,7 @@ Thought: {agent_scratchpad}
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_config: ModelConfigEntity, model_config: EasyUIBasedModelConfigEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
......
...@@ -7,13 +7,12 @@ from langchain.callbacks.manager import Callbacks ...@@ -7,13 +7,12 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages from core.entities.message_entities import prompt_messages_to_lc_messages
from core.helper import moderation from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.rag.retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
...@@ -23,15 +22,14 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr ...@@ -23,15 +22,14 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr
class AgentConfiguration(BaseModel): class AgentConfiguration(BaseModel):
strategy: PlanningStrategy strategy: PlanningStrategy
model_config: ModelConfigEntity model_config: EasyUIBasedModelConfigEntity
tools: list[BaseTool] tools: list[BaseTool]
summary_model_config: Optional[ModelConfigEntity] = None summary_model_config: Optional[EasyUIBasedModelConfigEntity] = None
memory: Optional[TokenBufferMemory] = None memory: Optional[TokenBufferMemory] = None
callbacks: Callbacks = None callbacks: Callbacks = None
max_iterations: int = 6 max_iterations: int = 6
max_execution_time: Optional[float] = None max_execution_time: Optional[float] = None
early_stopping_method: str = "generate" early_stopping_method: str = "generate"
agent_llm_callback: Optional[AgentLLMCallback] = None
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config: class Config:
......
...@@ -2,9 +2,10 @@ from typing import Optional, cast ...@@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom, EasyUIBasedModelConfigEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
...@@ -17,7 +18,7 @@ from models.dataset import Dataset ...@@ -17,7 +18,7 @@ from models.dataset import Dataset
class DatasetRetrieval: class DatasetRetrieval:
def retrieve(self, tenant_id: str, def retrieve(self, tenant_id: str,
model_config: ModelConfigEntity, model_config: EasyUIBasedModelConfigEntity,
config: DatasetEntity, config: DatasetEntity,
query: str, query: str,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
......
...@@ -2,8 +2,9 @@ from typing import Any ...@@ -2,8 +2,9 @@ from typing import Any
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
......
from core.entities.application_entities import ApplicationGenerateEntity from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
...@@ -8,9 +8,9 @@ from models.provider import Provider, ProviderType ...@@ -8,9 +8,9 @@ from models.provider import Provider, ProviderType
@message_was_created.connect @message_was_created.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
message = sender message = sender
application_generate_entity: ApplicationGenerateEntity = kwargs.get('application_generate_entity') application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity')
model_config = application_generate_entity.app_orchestration_config_entity.model_config model_config = application_generate_entity.model_config
provider_model_bundle = model_config.provider_model_bundle provider_model_bundle = model_config.provider_model_bundle
provider_configuration = provider_model_bundle.configuration provider_configuration = provider_model_bundle.configuration
...@@ -43,7 +43,7 @@ def handle(sender, **kwargs): ...@@ -43,7 +43,7 @@ def handle(sender, **kwargs):
if used_quota is not None: if used_quota is not None:
db.session.query(Provider).filter( db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.tenant_id, Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == model_config.provider, Provider.provider_name == model_config.provider,
Provider.provider_type == ProviderType.SYSTEM.value, Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value, Provider.quota_type == system_configuration.current_quota_type.value,
......
from datetime import datetime from datetime import datetime
from core.entities.application_entities import ApplicationGenerateEntity from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.provider import Provider from models.provider import Provider
...@@ -9,10 +9,10 @@ from models.provider import Provider ...@@ -9,10 +9,10 @@ from models.provider import Provider
@message_was_created.connect @message_was_created.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
message = sender message = sender
application_generate_entity: ApplicationGenerateEntity = kwargs.get('application_generate_entity') application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity')
db.session.query(Provider).filter( db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.tenant_id, Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == application_generate_entity.app_orchestration_config_entity.model_config.provider Provider.provider_name == application_generate_entity.model_config.provider
).update({'last_used': datetime.utcnow()}) ).update({'last_used': datetime.utcnow()})
db.session.commit() db.session.commit()
...@@ -105,6 +105,18 @@ class App(db.Model): ...@@ -105,6 +105,18 @@ class App(db.Model):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
return tenant return tenant
@property
def is_agent(self) -> bool:
app_model_config = self.app_model_config
if not app_model_config:
return False
if not app_model_config.agent_mode:
return False
if self.app_model_config.agent_mode_dict.get('enabled', False) \
and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']:
return True
return False
@property @property
def deleted_tools(self) -> list: def deleted_tools(self) -> list:
# get agent mode tools # get agent mode tools
......
...@@ -129,7 +129,7 @@ class Workflow(db.Model): ...@@ -129,7 +129,7 @@ class Workflow(db.Model):
def features_dict(self): def features_dict(self):
return self.features if not self.features else json.loads(self.features) return self.features if not self.features else json.loads(self.features)
def user_input_form(self): def user_input_form(self) -> list:
# get start node from graph # get start node from graph
if not self.graph: if not self.graph:
return [] return []
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment