Commit b75cd251 authored by takatost's avatar takatost

optimize db connections

parent 7693ba87
import json
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse from flask_restful import Resource, inputs, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, BadRequest from werkzeug.exceptions import Forbidden, BadRequest
...@@ -6,6 +8,8 @@ from controllers.console import api ...@@ -6,6 +8,8 @@ from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.agent.entities import AgentToolEntity
from extensions.ext_database import db
from fields.app_fields import ( from fields.app_fields import (
app_detail_fields, app_detail_fields,
app_detail_fields_with_site, app_detail_fields_with_site,
...@@ -14,10 +18,8 @@ from fields.app_fields import ( ...@@ -14,10 +18,8 @@ from fields.app_fields import (
from libs.login import login_required from libs.login import login_required
from services.app_service import AppService from services.app_service import AppService
from models.model import App, AppModelConfig, AppMode from models.model import App, AppModelConfig, AppMode
from services.workflow_service import WorkflowService
from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.entities.application_entities import AgentToolEntity
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow'] ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow']
...@@ -108,36 +110,38 @@ class AppApi(Resource): ...@@ -108,36 +110,38 @@ class AppApi(Resource):
def get(self, app_model): def get(self, app_model):
"""Get app detail""" """Get app detail"""
# get original app model config # get original app model config
model_config: AppModelConfig = app_model.app_model_config if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
agent_mode = model_config.agent_mode_dict model_config: AppModelConfig = app_model.app_model_config
# decrypt agent tool parameters if it's secret-input agent_mode = model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []: # decrypt agent tool parameters if it's secret-input
agent_tool_entity = AgentToolEntity(**tool) for tool in agent_mode.get('tools') or []:
# get tool agent_tool_entity = AgentToolEntity(**tool)
tool_runtime = ToolManager.get_agent_tool_runtime( # get tool
tenant_id=current_user.current_tenant_id, tool_runtime = ToolManager.get_agent_tool_runtime(
agent_tool=agent_tool_entity, tenant_id=current_user.current_tenant_id,
agent_callback=None agent_tool=agent_tool_entity,
) agent_callback=None
manager = ToolParameterConfigurationManager( )
tenant_id=current_user.current_tenant_id, manager = ToolParameterConfigurationManager(
tool_runtime=tool_runtime, tenant_id=current_user.current_tenant_id,
provider_name=agent_tool_entity.provider_id, tool_runtime=tool_runtime,
provider_type=agent_tool_entity.provider_type, provider_name=agent_tool_entity.provider_id,
) provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters: # get decrypted parameters
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) if agent_tool_entity.tool_parameters:
masked_parameter = manager.mask_tool_parameters(parameters or {}) parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
else: masked_parameter = manager.mask_tool_parameters(parameters or {})
masked_parameter = {} else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter # override tool parameters
tool['tool_parameters'] = masked_parameter
# override agent mode
model_config.agent_mode = json.dumps(agent_mode) # override agent mode
model_config.agent_mode = json.dumps(agent_mode)
db.session.commit()
return app_model return app_model
......
...@@ -8,7 +8,7 @@ from controllers.console import api ...@@ -8,7 +8,7 @@ from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.entities.application_entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
...@@ -38,81 +38,82 @@ class ModelConfigResource(Resource): ...@@ -38,81 +38,82 @@ class ModelConfigResource(Resource):
) )
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
# get original app model config if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( # get original app model config
AppModelConfig.id == app.app_model_config_id original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
).first() AppModelConfig.id == app_model.app_model_config_id
agent_mode = original_app_model_config.agent_mode_dict ).first()
# decrypt agent tool parameters if it's secret-input agent_mode = original_app_model_config.agent_mode_dict
parameter_map = {} # decrypt agent tool parameters if it's secret-input
masked_parameter_map = {} parameter_map = {}
tool_map = {} masked_parameter_map = {}
for tool in agent_mode.get('tools') or []: tool_map = {}
agent_tool_entity = AgentToolEntity(**tool) for tool in agent_mode.get('tools') or []:
# get tool agent_tool_entity = AgentToolEntity(**tool)
tool_runtime = ToolManager.get_agent_tool_runtime( # get tool
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity, agent_tool=agent_tool_entity,
agent_callback=None agent_callback=None
) )
manager = ToolParameterConfigurationManager(
manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id,
tenant_id=current_user.current_tenant_id, tool_runtime=tool_runtime,
tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id,
provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type,
provider_type=agent_tool_entity.provider_type, )
)
manager.delete_tool_parameters_cache() # get decrypted parameters
if agent_tool_entity.tool_parameters:
# override parameters if it equals to masked parameters parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
if agent_tool_entity.tool_parameters: masked_parameter = manager.mask_tool_parameters(parameters or {})
if key not in masked_parameter_map: else:
continue parameters = {}
masked_parameter = {}
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key] key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
# encrypt parameters parameter_map[key] = parameters
if agent_tool_entity.tool_parameters: tool_map[key] = tool_runtime
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# encrypt agent tool parameters if it's secret-input
# update app model config agent_mode = new_app_model_config.agent_mode_dict
new_app_model_config.agent_mode = json.dumps(agent_mode) for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()
# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
db.session.add(new_app_model_config) db.session.add(new_app_model_config)
db.session.flush() db.session.flush()
......
...@@ -123,7 +123,8 @@ class DatasetConfigManager: ...@@ -123,7 +123,8 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get("datasets") need_manual_query_datasets = (config.get("dataset_configs")
and config["dataset_configs"].get("datasets", {}).get("datasets"))
if need_manual_query_datasets and app_mode == AppMode.COMPLETION: if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion # Only check when mode is completion
......
...@@ -153,8 +153,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ...@@ -153,8 +153,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)
db.session.close()
# chatbot app # chatbot app
runner = AdvancedChatAppRunner() runner = AdvancedChatAppRunner()
runner.run( runner.run(
......
...@@ -72,6 +72,8 @@ class AdvancedChatAppRunner(AppRunner): ...@@ -72,6 +72,8 @@ class AdvancedChatAppRunner(AppRunner):
): ):
return return
db.session.close()
# RUN WORKFLOW # RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow( workflow_engine_manager.run_workflow(
......
...@@ -193,4 +193,4 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): ...@@ -193,4 +193,4 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally: finally:
db.session.remove() db.session.close()
...@@ -201,8 +201,8 @@ class AgentChatAppRunner(AppRunner): ...@@ -201,8 +201,8 @@ class AgentChatAppRunner(AppRunner):
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
db.session.refresh(conversation) conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
db.session.refresh(message) message = db.session.query(Message).filter(Message.id == message.id).first()
db.session.close() db.session.close()
# start agent runner # start agent runner
......
...@@ -193,4 +193,4 @@ class ChatAppGenerator(MessageBasedAppGenerator): ...@@ -193,4 +193,4 @@ class ChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally: finally:
db.session.remove() db.session.close()
...@@ -182,7 +182,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ...@@ -182,7 +182,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally: finally:
db.session.remove() db.session.close()
def generate_more_like_this(self, app_model: App, def generate_more_like_this(self, app_model: App,
message_id: str, message_id: str,
......
...@@ -160,6 +160,8 @@ class CompletionAppRunner(AppRunner): ...@@ -160,6 +160,8 @@ class CompletionAppRunner(AppRunner):
model=application_generate_entity.model_config.model model=application_generate_entity.model_config.model
) )
db.session.close()
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_config.parameters, model_parameters=application_generate_entity.model_config.parameters,
......
...@@ -64,8 +64,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): ...@@ -64,8 +64,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
else: else:
logger.exception(e) logger.exception(e)
raise e raise e
finally:
db.session.remove()
def _get_conversation_by_user(self, app_model: App, conversation_id: str, def _get_conversation_by_user(self, app_model: App, conversation_id: str,
user: Union[Account, EndUser]) -> Conversation: user: Union[Account, EndUser]) -> Conversation:
......
...@@ -57,6 +57,8 @@ class WorkflowAppRunner: ...@@ -57,6 +57,8 @@ class WorkflowAppRunner:
): ):
return return
db.session.close()
# RUN WORKFLOW # RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager() workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow( workflow_engine_manager.run_workflow(
......
...@@ -5,8 +5,8 @@ import mimetypes ...@@ -5,8 +5,8 @@ import mimetypes
from os import listdir, path from os import listdir, path
from typing import Any, Union from typing import Any, Union
from core.agent.entities import AgentToolEntity
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.entities.application_entities import AgentToolEntity
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
......
...@@ -322,7 +322,7 @@ class AppModelConfig(db.Model): ...@@ -322,7 +322,7 @@ class AppModelConfig(db.Model):
} }
def from_model_config_dict(self, model_config: dict): def from_model_config_dict(self, model_config: dict):
self.opening_statement = model_config['opening_statement'] self.opening_statement = model_config.get('opening_statement')
self.suggested_questions = json.dumps(model_config['suggested_questions']) \ self.suggested_questions = json.dumps(model_config['suggested_questions']) \
if model_config.get('suggested_questions') else None if model_config.get('suggested_questions') else None
self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment