Commit b75cd251 authored by takatost's avatar takatost

optimize db connections

parent 7693ba87
import json
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, BadRequest
......@@ -6,6 +8,8 @@ from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.agent.entities import AgentToolEntity
from extensions.ext_database import db
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
......@@ -14,10 +18,8 @@ from fields.app_fields import (
from libs.login import login_required
from services.app_service import AppService
from models.model import App, AppModelConfig, AppMode
from services.workflow_service import WorkflowService
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.tool_manager import ToolManager
from core.entities.application_entities import AgentToolEntity
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow']
......@@ -108,36 +110,38 @@ class AppApi(Resource):
def get(self, app_model):
"""Get app detail"""
# get original app model config
model_config: AppModelConfig = app_model.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
model_config: AppModelConfig = app_model.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
db.session.commit()
return app_model
......
......@@ -8,7 +8,7 @@ from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.entities.application_entities import AgentToolEntity
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
......@@ -38,81 +38,82 @@ class ModelConfigResource(Resource):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app_model.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()
# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()
# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
db.session.add(new_app_model_config)
db.session.flush()
......
......@@ -123,7 +123,8 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get("datasets")
need_manual_query_datasets = (config.get("dataset_configs")
and config["dataset_configs"].get("datasets", {}).get("datasets"))
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
......
......@@ -153,8 +153,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
db.session.close()
# chatbot app
runner = AdvancedChatAppRunner()
runner.run(
......
......@@ -72,6 +72,8 @@ class AdvancedChatAppRunner(AppRunner):
):
return
db.session.close()
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
......
......@@ -193,4 +193,4 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
......@@ -201,8 +201,8 @@ class AgentChatAppRunner(AppRunner):
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
db.session.refresh(conversation)
db.session.refresh(message)
conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
message = db.session.query(Message).filter(Message.id == message.id).first()
db.session.close()
# start agent runner
......
......@@ -193,4 +193,4 @@ class ChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
......@@ -182,7 +182,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
def generate_more_like_this(self, app_model: App,
message_id: str,
......
......@@ -160,6 +160,8 @@ class CompletionAppRunner(AppRunner):
model=application_generate_entity.model_config.model
)
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_config.parameters,
......
......@@ -64,8 +64,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
else:
logger.exception(e)
raise e
finally:
db.session.remove()
def _get_conversation_by_user(self, app_model: App, conversation_id: str,
user: Union[Account, EndUser]) -> Conversation:
......
......@@ -57,6 +57,8 @@ class WorkflowAppRunner:
):
return
db.session.close()
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
......
......@@ -5,8 +5,8 @@ import mimetypes
from os import listdir, path
from typing import Any, Union
from core.agent.entities import AgentToolEntity
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.entities.application_entities import AgentToolEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.provider_manager import ProviderManager
from core.tools.entities.common_entities import I18nObject
......
......@@ -322,7 +322,7 @@ class AppModelConfig(db.Model):
}
def from_model_config_dict(self, model_config: dict):
self.opening_statement = model_config['opening_statement']
self.opening_statement = model_config.get('opening_statement')
self.suggested_questions = json.dumps(model_config['suggested_questions']) \
if model_config.get('suggested_questions') else None
self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment