Unverified Commit cc2d71c2 authored by takatost's avatar takatost Committed by GitHub

feat: optimize override app model config convert (#874)

parent cd116139
...@@ -124,12 +124,29 @@ class AppListApi(Resource): ...@@ -124,12 +124,29 @@ class AppListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
)
if default_model:
default_model_provider = default_model.provider_name
default_model_name = default_model.model_name
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
if args['model_config'] is not None: if args['model_config'] is not None:
# validate config # validate config
model_config_dict = args['model_config']
model_config_dict["model"]["provider"] = default_model_provider
model_config_dict["model"]["name"] = default_model_name
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
account=current_user, account=current_user,
config=args['model_config'] config=model_config_dict
) )
app = App( app = App(
...@@ -141,21 +158,8 @@ class AppListApi(Resource): ...@@ -141,21 +158,8 @@ class AppListApi(Resource):
status='normal' status='normal'
) )
app_model_config = AppModelConfig( app_model_config = AppModelConfig()
provider="", app_model_config = app_model_config.from_model_config_dict(model_configuration)
model_id="",
configs={},
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
agent_mode=json.dumps(model_configuration['agent_mode']),
)
else: else:
if 'mode' not in args or args['mode'] is None: if 'mode' not in args or args['mode'] is None:
abort(400, message="mode is required") abort(400, message="mode is required")
...@@ -165,20 +169,10 @@ class AppListApi(Resource): ...@@ -165,20 +169,10 @@ class AppListApi(Resource):
app = App(**model_config_template['app']) app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config']) app_model_config = AppModelConfig(**model_config_template['model_config'])
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
)
if default_model:
model_dict = app_model_config.model_dict model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.provider_name model_dict['provider'] = default_model_provider
model_dict['name'] = default_model.model_name model_dict['name'] = default_model_name
app_model_config.model = json.dumps(model_dict) app_model_config.model = json.dumps(model_dict)
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
app.name = args['name'] app.name = args['name']
app.mode = args['mode'] app.mode = args['mode']
...@@ -416,22 +410,9 @@ class AppCopy(Resource): ...@@ -416,22 +410,9 @@ class AppCopy(Resource):
@staticmethod @staticmethod
def create_app_model_config_copy(app_config, copy_app_id): def create_app_model_config_copy(app_config, copy_app_id):
copy_app_model_config = AppModelConfig( copy_app_model_config = app_config.copy()
app_id=copy_app_id, copy_app_model_config.app_id = copy_app_id
provider=app_config.provider,
model_id=app_config.model_id,
configs=app_config.configs,
opening_statement=app_config.opening_statement,
suggested_questions=app_config.suggested_questions,
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this,
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
model=app_config.model,
user_input_form=app_config.user_input_form,
pre_prompt=app_config.pre_prompt,
agent_mode=app_config.agent_mode
)
return copy_app_model_config return copy_app_model_config
@setup_required @setup_required
......
...@@ -35,20 +35,8 @@ class ModelConfigResource(Resource): ...@@ -35,20 +35,8 @@ class ModelConfigResource(Resource):
new_app_model_config = AppModelConfig( new_app_model_config = AppModelConfig(
app_id=app_model.id, app_id=app_model.id,
provider="",
model_id="",
configs={},
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
agent_mode=json.dumps(model_configuration['agent_mode']),
) )
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config) db.session.add(new_app_model_config)
db.session.flush() db.session.flush()
......
...@@ -112,7 +112,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): ...@@ -112,7 +112,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "") "I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2: if len(intermediate_steps) >= 2 and self.summary_llm:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation) should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps] for _, observation in should_summary_intermediate_steps]
......
...@@ -65,7 +65,8 @@ class AgentExecutor: ...@@ -65,7 +65,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client, llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(), output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client, summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
...@@ -74,7 +75,8 @@ class AgentExecutor: ...@@ -74,7 +75,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client, llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client, summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
...@@ -83,7 +85,8 @@ class AgentExecutor: ...@@ -83,7 +85,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client, llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client, summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.ROUTER: elif self.configuration.strategy == PlanningStrategy.ROUTER:
......
...@@ -60,17 +60,7 @@ class ConversationMessageTask: ...@@ -60,17 +60,7 @@ class ConversationMessageTask:
def init(self): def init(self):
override_model_configs = None override_model_configs = None
if self.is_override: if self.is_override:
override_model_configs = { override_model_configs = self.app_model_config.to_dict()
"model": self.app_model_config.model_dict,
"pre_prompt": self.app_model_config.pre_prompt,
"agent_mode": self.app_model_config.agent_mode_dict,
"opening_statement": self.app_model_config.opening_statement,
"suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict,
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
"user_input_form": self.app_model_config.user_input_form_list,
}
introduction = '' introduction = ''
system_instruction = '' system_instruction = ''
......
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
from langchain.schema import OutputParserException from langchain.schema import OutputParserException
from core.model_providers.error import LLMError from core.model_providers.error import LLMError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs from core.model_providers.models.entity.model_params import ModelKwargs
...@@ -108,6 +108,7 @@ class LLMGenerator: ...@@ -108,6 +108,7 @@ class LLMGenerator:
_input = prompt.format_prompt(histories=histories) _input = prompt.format_prompt(histories=histories)
try:
model_instance = ModelFactory.get_text_generation_model( model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id, tenant_id=tenant_id,
model_kwargs=ModelKwargs( model_kwargs=ModelKwargs(
...@@ -115,6 +116,8 @@ class LLMGenerator: ...@@ -115,6 +116,8 @@ class LLMGenerator:
temperature=0 temperature=0
) )
) )
except ProviderTokenNotInitError:
return []
prompts = [PromptMessage(content=_input.to_string())] prompts = [PromptMessage(content=_input.to_string())]
......
...@@ -14,6 +14,7 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa ...@@ -14,6 +14,7 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
...@@ -78,6 +79,7 @@ class OrchestratorRuleParser: ...@@ -78,6 +79,7 @@ class OrchestratorRuleParser:
elif planning_strategy == PlanningStrategy.ROUTER: elif planning_strategy == PlanningStrategy.ROUTER:
planning_strategy = PlanningStrategy.REACT_ROUTER planning_strategy = PlanningStrategy.REACT_ROUTER
try:
summary_model_instance = ModelFactory.get_text_generation_model( summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
model_kwargs=ModelKwargs( model_kwargs=ModelKwargs(
...@@ -85,6 +87,8 @@ class OrchestratorRuleParser: ...@@ -85,6 +87,8 @@ class OrchestratorRuleParser:
max_tokens=500 max_tokens=500
) )
) )
except ProviderTokenNotInitError as e:
summary_model_instance = None
tools = self.to_tools( tools = self.to_tools(
tool_configs=tool_configs, tool_configs=tool_configs,
......
import json import json
import re
from typing import Any from typing import Any
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser
from core.model_providers.error import LLMError
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
...@@ -12,5 +15,10 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): ...@@ -12,5 +15,10 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
def parse(self, text: str) -> Any: def parse(self, text: str) -> Any:
json_string = text.strip() json_string = text.strip()
json_obj = json.loads(json_string) action_match = re.search(r".*(\[\".+\"\]).*", json_string, re.DOTALL)
if action_match is not None:
json_obj = json.loads(action_match.group(1).strip(), strict=False)
else:
raise LLMError("Could not parse LLM output: {text}")
return json_obj return json_obj
...@@ -148,6 +148,46 @@ class AppModelConfig(db.Model): ...@@ -148,6 +148,46 @@ class AppModelConfig(db.Model):
"agent_mode": self.agent_mode_dict "agent_mode": self.agent_mode_dict
} }
def from_model_config_dict(self, model_config: dict):
self.provider = ""
self.model_id = ""
self.configs = {}
self.opening_statement = model_config['opening_statement']
self.suggested_questions = json.dumps(model_config['suggested_questions'])
self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer'])
self.speech_to_text = json.dumps(model_config['speech_to_text']) \
if model_config.get('speech_to_text') else None
self.more_like_this = json.dumps(model_config['more_like_this'])
self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance'])
self.model = json.dumps(model_config['model'])
self.user_input_form = json.dumps(model_config['user_input_form'])
self.pre_prompt = model_config['pre_prompt']
self.agent_mode = json.dumps(model_config['agent_mode'])
return self
def copy(self):
new_app_model_config = AppModelConfig(
id=self.id,
app_id=self.app_id,
provider="",
model_id="",
configs={},
opening_statement=self.opening_statement,
suggested_questions=self.suggested_questions,
suggested_questions_after_answer=self.suggested_questions_after_answer,
speech_to_text=self.speech_to_text,
more_like_this=self.more_like_this,
sensitive_word_avoidance=self.sensitive_word_avoidance,
model=self.model,
user_input_form=self.user_input_form,
pre_prompt=self.pre_prompt,
agent_mode=self.agent_mode
)
return new_app_model_config
class RecommendedApp(db.Model): class RecommendedApp(db.Model):
__tablename__ = 'recommended_apps' __tablename__ = 'recommended_apps'
__table_args__ = ( __table_args__ = (
...@@ -234,7 +274,8 @@ class Conversation(db.Model): ...@@ -234,7 +274,8 @@ class Conversation(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select',
passive_deletes="all")
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
...@@ -429,7 +470,7 @@ class Message(db.Model): ...@@ -429,7 +470,7 @@ class Message(db.Model):
@property @property
def agent_thoughts(self): def agent_thoughts(self):
return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id)\ return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \
.order_by(MessageAgentThought.position.asc()).all() .order_by(MessageAgentThought.position.asc()).all()
...@@ -557,7 +598,8 @@ class Site(db.Model): ...@@ -557,7 +598,8 @@ class Site(db.Model):
@property @property
def app_base_url(self): def app_base_url(self):
return (current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/')) return (
current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
class ApiToken(db.Model): class ApiToken(db.Model):
......
...@@ -63,26 +63,23 @@ class CompletionService: ...@@ -63,26 +63,23 @@ class CompletionService:
raise ConversationCompletedError() raise ConversationCompletedError()
if not conversation.override_model_configs: if not conversation.override_model_configs:
app_model_config = db.session.query(AppModelConfig).get(conversation.app_model_config_id) app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
if not app_model_config: if not app_model_config:
raise AppModelConfigBrokenError() raise AppModelConfigBrokenError()
else: else:
conversation_override_model_configs = json.loads(conversation.override_model_configs) conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig( app_model_config = AppModelConfig(
id=conversation.app_model_config_id, id=conversation.app_model_config_id,
app_id=app_model.id, app_id=app_model.id,
provider="",
model_id="",
configs="",
opening_statement=conversation_override_model_configs['opening_statement'],
suggested_questions=json.dumps(conversation_override_model_configs['suggested_questions']),
model=json.dumps(conversation_override_model_configs['model']),
user_input_form=json.dumps(conversation_override_model_configs['user_input_form']),
pre_prompt=conversation_override_model_configs['pre_prompt'],
agent_mode=json.dumps(conversation_override_model_configs['agent_mode']),
) )
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if is_model_config_override: if is_model_config_override:
# build new app model config # build new app model config
if 'model' not in args['model_config']: if 'model' not in args['model_config']:
...@@ -99,19 +96,8 @@ class CompletionService: ...@@ -99,19 +96,8 @@ class CompletionService:
app_model_config_model = app_model_config.model_dict app_model_config_model = app_model_config.model_dict
app_model_config_model['completion_params'] = completion_params app_model_config_model['completion_params'] = completion_params
app_model_config = AppModelConfig( app_model_config = app_model_config.copy()
id=app_model_config.id, app_model_config.model = json.dumps(app_model_config_model)
app_id=app_model.id,
provider="",
model_id="",
configs="",
opening_statement=app_model_config.opening_statement,
suggested_questions=app_model_config.suggested_questions,
model=json.dumps(app_model_config_model),
user_input_form=app_model_config.user_input_form,
pre_prompt=app_model_config.pre_prompt,
agent_mode=app_model_config.agent_mode,
)
else: else:
if app_model.app_model_config_id is None: if app_model.app_model_config_id is None:
raise AppModelConfigBrokenError() raise AppModelConfigBrokenError()
...@@ -135,20 +121,10 @@ class CompletionService: ...@@ -135,20 +121,10 @@ class CompletionService:
app_model_config = AppModelConfig( app_model_config = AppModelConfig(
id=app_model_config.id, id=app_model_config.id,
app_id=app_model.id, app_id=app_model.id,
provider="",
model_id="",
configs="",
opening_statement=model_config['opening_statement'],
suggested_questions=json.dumps(model_config['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
more_like_this=json.dumps(model_config['more_like_this']),
sensitive_word_avoidance=json.dumps(model_config['sensitive_word_avoidance']),
model=json.dumps(model_config['model']),
user_input_form=json.dumps(model_config['user_input_form']),
pre_prompt=model_config['pre_prompt'],
agent_mode=json.dumps(model_config['agent_mode']),
) )
app_model_config = app_model_config.from_model_config_dict(model_config)
# clean input by app_model_config form rules # clean input by app_model_config form rules
inputs = cls.get_cleaned_inputs(inputs, app_model_config) inputs = cls.get_cleaned_inputs(inputs, app_model_config)
......
import json
from typing import Optional, Union, List from typing import Optional, Union, List
from core.completion import Completion from core.completion import Completion
...@@ -5,8 +6,10 @@ from core.generator.llm_generator import LLMGenerator ...@@ -5,8 +6,10 @@ from core.generator.llm_generator import LLMGenerator
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import App, EndUser, Message, MessageFeedback from models.model import App, EndUser, Message, MessageFeedback, AppModelConfig
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError, LastMessageNotExistsError, \ from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError, LastMessageNotExistsError, \
SuggestedQuestionsAfterAnswerDisabledError SuggestedQuestionsAfterAnswerDisabledError
...@@ -172,12 +175,6 @@ class MessageService: ...@@ -172,12 +175,6 @@ class MessageService:
if not user: if not user:
raise ValueError('user cannot be None') raise ValueError('user cannot be None')
app_model_config = app_model.app_model_config
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
message = cls.get_message( message = cls.get_message(
app_model=app_model, app_model=app_model,
user=user, user=user,
...@@ -190,10 +187,38 @@ class MessageService: ...@@ -190,10 +187,38 @@ class MessageService:
user=user user=user
) )
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
raise ConversationCompletedError()
if not conversation.override_model_configs:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
if not app_model_config:
raise AppModelConfigBrokenError()
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
id=conversation.app_model_config_id,
app_id=app_model.id,
)
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
# get memory of conversation (read-only) # get memory of conversation (read-only)
memory = Completion.get_memory_from_conversation( memory = Completion.get_memory_from_conversation(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
app_model_config=app_model.app_model_config, app_model_config=app_model_config,
conversation=conversation, conversation=conversation,
max_token_limit=3000, max_token_limit=3000,
message_limit=3, message_limit=3,
...@@ -209,4 +234,3 @@ class MessageService: ...@@ -209,4 +234,3 @@ class MessageService:
) )
return questions return questions
...@@ -6,7 +6,7 @@ from celery import shared_task ...@@ -6,7 +6,7 @@ from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.generator.llm_generator import LLMGenerator from core.generator.llm_generator import LLMGenerator
from core.model_providers.error import LLMError from core.model_providers.error import LLMError, ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Conversation, Message from models.model import Conversation, Message
...@@ -40,10 +40,16 @@ def generate_conversation_summary_task(conversation_id: str): ...@@ -40,10 +40,16 @@ def generate_conversation_summary_task(conversation_id: str):
conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages) conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages)
db.session.add(conversation) db.session.add(conversation)
db.session.commit() db.session.commit()
except (LLMError, ProviderTokenNotInitError):
end_at = time.perf_counter() conversation.summary = '[No Summary]'
logging.info(click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at), fg='green')) db.session.commit()
except LLMError:
pass pass
except Exception as e: except Exception as e:
conversation.summary = '[No Summary]'
db.session.commit()
logging.exception(e) logging.exception(e)
end_at = time.perf_counter()
logging.info(
click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at),
fg='green'))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment