Unverified Commit f073dca2 authored by takatost's avatar takatost Committed by GitHub

feat: optimize db connection when llm invoking (#2774)

parent 8b1e35d7
...@@ -195,6 +195,10 @@ class AssistantApplicationRunner(AppRunner): ...@@ -195,6 +195,10 @@ class AssistantApplicationRunner(AppRunner):
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
db.session.refresh(conversation)
db.session.refresh(message)
db.session.close()
# start agent runner # start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner( assistant_cot_runner = AssistantCotApplicationRunner(
......
...@@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner): ...@@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner):
model=app_orchestration_config.model_config.model model=app_orchestration_config.model_config.model
) )
db.session.close()
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=app_orchestration_config.model_config.parameters,
......
...@@ -89,6 +89,10 @@ class GenerateTaskPipeline: ...@@ -89,6 +89,10 @@ class GenerateTaskPipeline:
Process generate task pipeline. Process generate task pipeline.
:return: :return:
""" """
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()
if stream: if stream:
return self._process_stream_response() return self._process_stream_response()
else: else:
...@@ -303,6 +307,7 @@ class GenerateTaskPipeline: ...@@ -303,6 +307,7 @@ class GenerateTaskPipeline:
.first() .first()
) )
db.session.refresh(agent_thought) db.session.refresh(agent_thought)
db.session.close()
if agent_thought: if agent_thought:
response = { response = {
...@@ -330,6 +335,8 @@ class GenerateTaskPipeline: ...@@ -330,6 +335,8 @@ class GenerateTaskPipeline:
.filter(MessageFile.id == event.message_file_id) .filter(MessageFile.id == event.message_file_id)
.first() .first()
) )
db.session.close()
# get extension # get extension
if '.' in message_file.url: if '.' in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}' extension = f'.{message_file.url.split(".")[-1]}'
...@@ -413,6 +420,7 @@ class GenerateTaskPipeline: ...@@ -413,6 +420,7 @@ class GenerateTaskPipeline:
usage = llm_result.usage usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first() self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens self._message.message_tokens = usage.prompt_tokens
......
...@@ -201,7 +201,7 @@ class ApplicationManager: ...@@ -201,7 +201,7 @@ class ApplicationManager:
logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally: finally:
db.session.remove() db.session.close()
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager, queue_manager: ApplicationQueueManager,
...@@ -233,8 +233,6 @@ class ApplicationManager: ...@@ -233,8 +233,6 @@ class ApplicationManager:
else: else:
logger.exception(e) logger.exception(e)
raise e raise e
finally:
db.session.remove()
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity: -> AppOrchestrationConfigEntity:
...@@ -651,6 +649,7 @@ class ApplicationManager: ...@@ -651,6 +649,7 @@ class ApplicationManager:
db.session.add(conversation) db.session.add(conversation)
db.session.commit() db.session.commit()
db.session.refresh(conversation)
else: else:
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
...@@ -689,6 +688,7 @@ class ApplicationManager: ...@@ -689,6 +688,7 @@ class ApplicationManager:
db.session.add(message) db.session.add(message)
db.session.commit() db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files: for file in application_generate_entity.files:
message_file = MessageFile( message_file = MessageFile(
......
...@@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.agent_thought_count = db.session.query(MessageAgentThought).filter( self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id, MessageAgentThought.message_id == self.message.id,
).count() ).count()
db.session.close()
# check if model supports stream tool call # check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
...@@ -341,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -341,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner):
created_by=self.user_id, created_by=self.user_id,
) )
db.session.add(message_file) db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append(( result.append((
message_file, message_file,
message.save_as message.save_as
)) ))
db.session.commit()
db.session.close()
return result return result
def create_agent_thought(self, message_id: str, message: str, def create_agent_thought(self, message_id: str, message: str,
...@@ -384,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -384,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner):
db.session.add(thought) db.session.add(thought)
db.session.commit() db.session.commit()
db.session.refresh(thought)
db.session.close()
self.agent_thought_count += 1 self.agent_thought_count += 1
...@@ -401,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -401,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner):
""" """
Save agent thought Save agent thought
""" """
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None: if thought is not None:
agent_thought.thought = thought agent_thought.thought = thought
...@@ -451,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -451,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels) agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit() db.session.commit()
db.session.close()
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
""" """
...@@ -523,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -523,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner):
""" """
convert tool variables to db variables convert tool variables to db variables
""" """
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables.updated_at = datetime.utcnow() db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit() db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
...@@ -581,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -581,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner):
if message.answer: if message.answer:
result.append(AssistantPromptMessage(content=message.answer)) result.append(AssistantPromptMessage(content=message.answer))
db.session.close()
return result return result
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment