Unverified Commit cc2d71c2 authored by takatost's avatar takatost Committed by GitHub

feat: optimize override app model config convert (#874)

parent cd116139
......@@ -124,12 +124,29 @@ class AppListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
)
if default_model:
default_model_provider = default_model.provider_name
default_model_name = default_model.model_name
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
if args['model_config'] is not None:
# validate config
model_config_dict = args['model_config']
model_config_dict["model"]["provider"] = default_model_provider
model_config_dict["model"]["name"] = default_model_name
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=args['model_config']
config=model_config_dict
)
app = App(
......@@ -141,21 +158,8 @@ class AppListApi(Resource):
status='normal'
)
app_model_config = AppModelConfig(
provider="",
model_id="",
configs={},
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
agent_mode=json.dumps(model_configuration['agent_mode']),
)
app_model_config = AppModelConfig()
app_model_config = app_model_config.from_model_config_dict(model_configuration)
else:
if 'mode' not in args or args['mode'] is None:
abort(400, message="mode is required")
......@@ -165,20 +169,10 @@ class AppListApi(Resource):
app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config'])
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
)
if default_model:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.provider_name
model_dict['name'] = default_model.model_name
model_dict['provider'] = default_model_provider
model_dict['name'] = default_model_name
app_model_config.model = json.dumps(model_dict)
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
app.name = args['name']
app.mode = args['mode']
......@@ -416,22 +410,9 @@ class AppCopy(Resource):
@staticmethod
def create_app_model_config_copy(app_config, copy_app_id):
copy_app_model_config = AppModelConfig(
app_id=copy_app_id,
provider=app_config.provider,
model_id=app_config.model_id,
configs=app_config.configs,
opening_statement=app_config.opening_statement,
suggested_questions=app_config.suggested_questions,
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this,
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
model=app_config.model,
user_input_form=app_config.user_input_form,
pre_prompt=app_config.pre_prompt,
agent_mode=app_config.agent_mode
)
copy_app_model_config = app_config.copy()
copy_app_model_config.app_id = copy_app_id
return copy_app_model_config
@setup_required
......
......@@ -35,20 +35,8 @@ class ModelConfigResource(Resource):
new_app_model_config = AppModelConfig(
app_id=app_model.id,
provider="",
model_id="",
configs={},
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
agent_mode=json.dumps(model_configuration['agent_mode']),
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config)
db.session.flush()
......
......@@ -112,7 +112,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2:
if len(intermediate_steps) >= 2 and self.summary_llm:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
......
......@@ -65,7 +65,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client,
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
......@@ -74,7 +75,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client,
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
......@@ -83,7 +85,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client,
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
......
......@@ -60,17 +60,7 @@ class ConversationMessageTask:
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = {
"model": self.app_model_config.model_dict,
"pre_prompt": self.app_model_config.pre_prompt,
"agent_mode": self.app_model_config.agent_mode_dict,
"opening_statement": self.app_model_config.opening_statement,
"suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict,
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
"user_input_form": self.app_model_config.user_input_form_list,
}
override_model_configs = self.app_model_config.to_dict()
introduction = ''
system_instruction = ''
......
......@@ -2,7 +2,7 @@ import logging
from langchain.schema import OutputParserException
from core.model_providers.error import LLMError
from core.model_providers.error import LLMError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
......@@ -108,6 +108,7 @@ class LLMGenerator:
_input = prompt.format_prompt(histories=histories)
try:
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
......@@ -115,6 +116,8 @@ class LLMGenerator:
temperature=0
)
)
except ProviderTokenNotInitError:
return []
prompts = [PromptMessage(content=_input.to_string())]
......
......@@ -14,6 +14,7 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
......@@ -78,6 +79,7 @@ class OrchestratorRuleParser:
elif planning_strategy == PlanningStrategy.ROUTER:
planning_strategy = PlanningStrategy.REACT_ROUTER
try:
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_kwargs=ModelKwargs(
......@@ -85,6 +87,8 @@ class OrchestratorRuleParser:
max_tokens=500
)
)
except ProviderTokenNotInitError as e:
summary_model_instance = None
tools = self.to_tools(
tool_configs=tool_configs,
......
import json
import re
from typing import Any
from langchain.schema import BaseOutputParser
from core.model_providers.error import LLMError
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
......@@ -12,5 +15,10 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
def parse(self, text: str) -> Any:
json_string = text.strip()
json_obj = json.loads(json_string)
action_match = re.search(r".*(\[\".+\"\]).*", json_string, re.DOTALL)
if action_match is not None:
json_obj = json.loads(action_match.group(1).strip(), strict=False)
else:
raise LLMError("Could not parse LLM output: {text}")
return json_obj
......@@ -148,6 +148,46 @@ class AppModelConfig(db.Model):
"agent_mode": self.agent_mode_dict
}
def from_model_config_dict(self, model_config: dict):
self.provider = ""
self.model_id = ""
self.configs = {}
self.opening_statement = model_config['opening_statement']
self.suggested_questions = json.dumps(model_config['suggested_questions'])
self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer'])
self.speech_to_text = json.dumps(model_config['speech_to_text']) \
if model_config.get('speech_to_text') else None
self.more_like_this = json.dumps(model_config['more_like_this'])
self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance'])
self.model = json.dumps(model_config['model'])
self.user_input_form = json.dumps(model_config['user_input_form'])
self.pre_prompt = model_config['pre_prompt']
self.agent_mode = json.dumps(model_config['agent_mode'])
return self
def copy(self):
new_app_model_config = AppModelConfig(
id=self.id,
app_id=self.app_id,
provider="",
model_id="",
configs={},
opening_statement=self.opening_statement,
suggested_questions=self.suggested_questions,
suggested_questions_after_answer=self.suggested_questions_after_answer,
speech_to_text=self.speech_to_text,
more_like_this=self.more_like_this,
sensitive_word_avoidance=self.sensitive_word_avoidance,
model=self.model,
user_input_form=self.user_input_form,
pre_prompt=self.pre_prompt,
agent_mode=self.agent_mode
)
return new_app_model_config
class RecommendedApp(db.Model):
__tablename__ = 'recommended_apps'
__table_args__ = (
......@@ -234,7 +274,8 @@ class Conversation(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select',
passive_deletes="all")
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
......@@ -429,7 +470,7 @@ class Message(db.Model):
@property
def agent_thoughts(self):
return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id)\
return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \
.order_by(MessageAgentThought.position.asc()).all()
......@@ -557,7 +598,8 @@ class Site(db.Model):
@property
def app_base_url(self):
return (current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
return (
current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/'))
class ApiToken(db.Model):
......
......@@ -63,26 +63,23 @@ class CompletionService:
raise ConversationCompletedError()
if not conversation.override_model_configs:
app_model_config = db.session.query(AppModelConfig).get(conversation.app_model_config_id)
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
if not app_model_config:
raise AppModelConfigBrokenError()
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
id=conversation.app_model_config_id,
app_id=app_model.id,
provider="",
model_id="",
configs="",
opening_statement=conversation_override_model_configs['opening_statement'],
suggested_questions=json.dumps(conversation_override_model_configs['suggested_questions']),
model=json.dumps(conversation_override_model_configs['model']),
user_input_form=json.dumps(conversation_override_model_configs['user_input_form']),
pre_prompt=conversation_override_model_configs['pre_prompt'],
agent_mode=json.dumps(conversation_override_model_configs['agent_mode']),
)
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if is_model_config_override:
# build new app model config
if 'model' not in args['model_config']:
......@@ -99,19 +96,8 @@ class CompletionService:
app_model_config_model = app_model_config.model_dict
app_model_config_model['completion_params'] = completion_params
app_model_config = AppModelConfig(
id=app_model_config.id,
app_id=app_model.id,
provider="",
model_id="",
configs="",
opening_statement=app_model_config.opening_statement,
suggested_questions=app_model_config.suggested_questions,
model=json.dumps(app_model_config_model),
user_input_form=app_model_config.user_input_form,
pre_prompt=app_model_config.pre_prompt,
agent_mode=app_model_config.agent_mode,
)
app_model_config = app_model_config.copy()
app_model_config.model = json.dumps(app_model_config_model)
else:
if app_model.app_model_config_id is None:
raise AppModelConfigBrokenError()
......@@ -135,20 +121,10 @@ class CompletionService:
app_model_config = AppModelConfig(
id=app_model_config.id,
app_id=app_model.id,
provider="",
model_id="",
configs="",
opening_statement=model_config['opening_statement'],
suggested_questions=json.dumps(model_config['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
more_like_this=json.dumps(model_config['more_like_this']),
sensitive_word_avoidance=json.dumps(model_config['sensitive_word_avoidance']),
model=json.dumps(model_config['model']),
user_input_form=json.dumps(model_config['user_input_form']),
pre_prompt=model_config['pre_prompt'],
agent_mode=json.dumps(model_config['agent_mode']),
)
app_model_config = app_model_config.from_model_config_dict(model_config)
# clean input by app_model_config form rules
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
......
import json
from typing import Optional, Union, List
from core.completion import Completion
......@@ -5,8 +6,10 @@ from core.generator.llm_generator import LLMGenerator
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser, Message, MessageFeedback
from models.model import App, EndUser, Message, MessageFeedback, AppModelConfig
from services.conversation_service import ConversationService
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError, LastMessageNotExistsError, \
SuggestedQuestionsAfterAnswerDisabledError
......@@ -172,12 +175,6 @@ class MessageService:
if not user:
raise ValueError('user cannot be None')
app_model_config = app_model.app_model_config
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
message = cls.get_message(
app_model=app_model,
user=user,
......@@ -190,10 +187,38 @@ class MessageService:
user=user
)
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
raise ConversationCompletedError()
if not conversation.override_model_configs:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
if not app_model_config:
raise AppModelConfigBrokenError()
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
id=conversation.app_model_config_id,
app_id=app_model.id,
)
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if check_enabled and suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
# get memory of conversation (read-only)
memory = Completion.get_memory_from_conversation(
tenant_id=app_model.tenant_id,
app_model_config=app_model.app_model_config,
app_model_config=app_model_config,
conversation=conversation,
max_token_limit=3000,
message_limit=3,
......@@ -209,4 +234,3 @@ class MessageService:
)
return questions
......@@ -6,7 +6,7 @@ from celery import shared_task
from werkzeug.exceptions import NotFound
from core.generator.llm_generator import LLMGenerator
from core.model_providers.error import LLMError
from core.model_providers.error import LLMError, ProviderTokenNotInitError
from extensions.ext_database import db
from models.model import Conversation, Message
......@@ -40,10 +40,16 @@ def generate_conversation_summary_task(conversation_id: str):
conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages)
db.session.add(conversation)
db.session.commit()
end_at = time.perf_counter()
logging.info(click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at), fg='green'))
except LLMError:
except (LLMError, ProviderTokenNotInitError):
conversation.summary = '[No Summary]'
db.session.commit()
pass
except Exception as e:
conversation.summary = '[No Summary]'
db.session.commit()
logging.exception(e)
end_at = time.perf_counter()
logging.info(
click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at),
fg='green'))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment