Commit 2b98c0b4 authored by takatost's avatar takatost

refactor app generate

parent f672f698
...@@ -59,8 +59,7 @@ class CompletionMessageApi(Resource): ...@@ -59,8 +59,7 @@ class CompletionMessageApi(Resource):
user=account, user=account,
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming, streaming=streaming
is_model_config_override=True
) )
return compact_response(response) return compact_response(response)
...@@ -126,8 +125,7 @@ class ChatMessageApi(Resource): ...@@ -126,8 +125,7 @@ class ChatMessageApi(Resource):
user=account, user=account,
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming, streaming=streaming
is_model_config_override=True
) )
return compact_response(response) return compact_response(response)
......
...@@ -10,9 +10,8 @@ from core.app.app_queue_manager import AppQueueManager ...@@ -10,9 +10,8 @@ from core.app.app_queue_manager import AppQueueManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
EasyUIBasedAppGenerateEntity, ModelConfigWithCredentialsEntity,
EasyUIBasedModelConfigEntity, InvokeFrom, AgentChatAppGenerateEntity,
InvokeFrom,
) )
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
...@@ -49,9 +48,9 @@ logger = logging.getLogger(__name__) ...@@ -49,9 +48,9 @@ logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner): class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str, def __init__(self, tenant_id: str,
application_generate_entity: EasyUIBasedAppGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
app_config: AgentChatAppConfig, app_config: AgentChatAppConfig,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity, config: AgentEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message: Message, message: Message,
...@@ -122,8 +121,8 @@ class BaseAgentRunner(AppRunner): ...@@ -122,8 +121,8 @@ class BaseAgentRunner(AppRunner):
else: else:
self.stream_tool_call = False self.stream_tool_call = False
def _repack_app_generate_entity(self, app_generate_entity: EasyUIBasedAppGenerateEntity) \ def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> EasyUIBasedAppGenerateEntity: -> AgentChatAppGenerateEntity:
""" """
Repack app generate entity Repack app generate entity
""" """
......
from typing import cast from typing import cast
from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
...@@ -9,11 +9,11 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -9,11 +9,11 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
class EasyUIBasedModelConfigEntityConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, def convert(cls, app_config: EasyUIBasedAppConfig,
skip_check: bool = False) \ skip_check: bool = False) \
-> EasyUIBasedModelConfigEntity: -> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
...@@ -91,7 +91,7 @@ class EasyUIBasedModelConfigEntityConverter: ...@@ -91,7 +91,7 @@ class EasyUIBasedModelConfigEntityConverter:
if not skip_check and not model_schema: if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")
return EasyUIBasedModelConfigEntity( return ModelConfigWithCredentialsEntity(
provider=model_config.provider, provider=model_config.provider,
model=model_config.model, model=model_config.model,
model_schema=model_schema, model_schema=model_schema,
......
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.app_config.entities import WorkflowUIBasedAppConfig
...@@ -10,7 +12,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ...@@ -10,7 +12,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
) )
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.model import App, AppMode from models.model import App, AppMode, Conversation
from models.workflow import Workflow from models.workflow import Workflow
...@@ -23,7 +25,9 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): ...@@ -23,7 +25,9 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
class AdvancedChatAppConfigManager(BaseAppConfigManager): class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: def get_app_config(cls, app_model: App,
workflow: Workflow,
conversation: Optional[Conversation] = None) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict features_dict = workflow.features_dict
app_config = AdvancedChatAppConfig( app_config = AdvancedChatAppConfig(
......
...@@ -19,7 +19,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ...@@ -19,7 +19,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
) )
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from models.model import App, AppMode, AppModelConfig from models.model import App, AppMode, AppModelConfig, Conversation
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
...@@ -33,19 +33,30 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): ...@@ -33,19 +33,30 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
class AgentChatAppConfigManager(BaseAppConfigManager): class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, def get_app_config(cls, app_model: App,
config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: AppModelConfig, app_model_config: AppModelConfig,
config_dict: Optional[dict] = None) -> AgentChatAppConfig: conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
""" """
Convert app model config to agent chat app config Convert app model config to agent chat app config
:param app_model: app model :param app_model: app model
:param config_from: app model config from
:param app_model_config: app model config :param app_model_config: app model config
:param config_dict: app model config dict :param conversation: conversation
:param override_config_dict: app model config dict
:return: :return:
""" """
config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) if override_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
elif conversation:
config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
app_config = AgentChatAppConfig( app_config = AgentChatAppConfig(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
......
import logging
import threading
import uuid
from typing import Union, Any, Generator
from flask import current_app, Flask
from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, AgentChatAppGenerateEntity
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator]:
"""
Generate App response.
:param app_model: App
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
query = args['query']
if not isinstance(query, str):
raise ValueError('query must be a string')
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {
"auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True
}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
# validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
)
# parse files
files = args['files'] if 'files' in args and args['files'] else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_upload_entity,
user
)
else:
file_objs = []
# convert to app config
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict
)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_config=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = AppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AgentChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
...@@ -7,7 +7,8 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner ...@@ -7,7 +7,8 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, \
AgentChatAppGenerateEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
...@@ -26,7 +27,7 @@ class AgentChatAppRunner(AppRunner): ...@@ -26,7 +27,7 @@ class AgentChatAppRunner(AppRunner):
""" """
Agent Application Runner Agent Application Runner
""" """
def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def run(self, application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
...@@ -288,7 +289,7 @@ class AgentChatAppRunner(AppRunner): ...@@ -288,7 +289,7 @@ class AgentChatAppRunner(AppRunner):
'pool': db_variables.variables 'pool': db_variables.variables
}) })
def _get_usage_of_all_agent_thoughts(self, model_config: EasyUIBasedModelConfigEntity, def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
message: Message) -> LLMUsage: message: Message) -> LLMUsage:
""" """
Get usage of all agent thoughts Get usage of all agent thoughts
......
from core.app.app_config.entities import VariableEntity, AppConfig
class BaseAppGenerator:
def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"{variable} is required in input form")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value:
if not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
...@@ -5,9 +5,8 @@ from typing import Optional, Union, cast ...@@ -5,9 +5,8 @@ from typing import Optional, Union, cast
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
EasyUIBasedAppGenerateEntity, ModelConfigWithCredentialsEntity,
EasyUIBasedModelConfigEntity, InvokeFrom, AppGenerateEntity, EasyUIBasedAppGenerateEntity,
InvokeFrom,
) )
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
...@@ -27,7 +26,7 @@ from models.model import App, AppMode, Message, MessageAnnotation ...@@ -27,7 +26,7 @@ from models.model import App, AppMode, Message, MessageAnnotation
class AppRunner: class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App, def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileObj], files: list[FileObj],
...@@ -83,7 +82,7 @@ class AppRunner: ...@@ -83,7 +82,7 @@ class AppRunner:
return rest_tokens return rest_tokens
def recale_llm_max_tokens(self, model_config: EasyUIBasedModelConfigEntity, def recale_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage]): prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
...@@ -119,7 +118,7 @@ class AppRunner: ...@@ -119,7 +118,7 @@ class AppRunner:
model_config.parameters[parameter_rule.name] = max_tokens model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(self, app_record: App, def organize_prompt_messages(self, app_record: App,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileObj], files: list[FileObj],
...@@ -292,7 +291,7 @@ class AppRunner: ...@@ -292,7 +291,7 @@ class AppRunner:
def moderation_for_inputs(self, app_id: str, def moderation_for_inputs(self, app_id: str,
tenant_id: str, tenant_id: str,
app_generate_entity: EasyUIBasedAppGenerateEntity, app_generate_entity: AppGenerateEntity,
inputs: dict, inputs: dict,
query: str) -> tuple[bool, dict, str]: query: str) -> tuple[bool, dict, str]:
""" """
......
...@@ -15,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ...@@ -15,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
SuggestedQuestionsAfterAnswerConfigManager, SuggestedQuestionsAfterAnswerConfigManager,
) )
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig from models.model import App, AppMode, AppModelConfig, Conversation
class ChatAppConfig(EasyUIBasedAppConfig): class ChatAppConfig(EasyUIBasedAppConfig):
...@@ -27,19 +27,30 @@ class ChatAppConfig(EasyUIBasedAppConfig): ...@@ -27,19 +27,30 @@ class ChatAppConfig(EasyUIBasedAppConfig):
class ChatAppConfigManager(BaseAppConfigManager): class ChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, def get_app_config(cls, app_model: App,
config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: AppModelConfig, app_model_config: AppModelConfig,
config_dict: Optional[dict] = None) -> ChatAppConfig: conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> ChatAppConfig:
""" """
Convert app model config to chat app config Convert app model config to chat app config
:param app_model: app model :param app_model: app model
:param config_from: app model config from
:param app_model_config: app model config :param app_model_config: app model config
:param config_dict: app model config dict :param conversation: conversation
:param override_config_dict: app model config dict
:return: :return:
""" """
config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) if override_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
elif conversation:
config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
app_config = ChatAppConfig( app_config = ChatAppConfig(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
......
import logging
import threading
import uuid
from typing import Union, Any, Generator
from flask import current_app, Flask
from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
logger = logging.getLogger(__name__)
class ChatAppGenerator(MessageBasedAppGenerator):
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator]:
"""
Generate App response.
:param app_model: App
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
query = args['query']
if not isinstance(query, str):
raise ValueError('query must be a string')
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {
"auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True
}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
)
# parse files
files = args['files'] if 'files' in args and args['files'] else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_upload_entity,
user
)
else:
file_objs = []
# convert to app config
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict
)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_config=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = AppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = ChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
...@@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom ...@@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.apps.chat.app_config_manager import ChatAppConfig
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
EasyUIBasedAppGenerateEntity, ChatAppGenerateEntity,
) )
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
...@@ -23,7 +23,7 @@ class ChatAppRunner(AppRunner): ...@@ -23,7 +23,7 @@ class ChatAppRunner(AppRunner):
Chat Application Runner Chat Application Runner
""" """
def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def run(self, application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
......
...@@ -10,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod ...@@ -10,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig from models.model import App, AppMode, AppModelConfig, Conversation
class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfig(EasyUIBasedAppConfig):
...@@ -22,19 +22,26 @@ class CompletionAppConfig(EasyUIBasedAppConfig): ...@@ -22,19 +22,26 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
class CompletionAppConfigManager(BaseAppConfigManager): class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, def get_app_config(cls, app_model: App,
config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: AppModelConfig, app_model_config: AppModelConfig,
config_dict: Optional[dict] = None) -> CompletionAppConfig: override_config_dict: Optional[dict] = None) -> CompletionAppConfig:
""" """
Convert app model config to completion app config Convert app model config to completion app config
:param app_model: app model :param app_model: app model
:param config_from: app model config from
:param app_model_config: app model config :param app_model_config: app model config
:param config_dict: app model config dict :param override_config_dict: app model config dict
:return: :return:
""" """
config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) if override_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
app_config = CompletionAppConfig( app_config = CompletionAppConfig(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
......
import json
import logging
import threading
import uuid
from typing import Union, Any, Generator
from flask import current_app, Flask
from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, CompletionAppGenerateEntity
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser, Message
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator]:
"""
Generate App response.
:param app_model: App
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
query = args['query']
if not isinstance(query, str):
raise ValueError('query must be a string')
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {}
# get conversation
conversation = None
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
# validate config
override_model_config_dict = CompletionAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
)
# parse files
files = args['files'] if 'files' in args and args['files'] else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_upload_entity,
user
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_config=ModelConfigConverter.convert(app_config),
inputs=self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = AppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get message
message = self._get_message(message_id)
# chatbot app
runner = CompletionAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
def generate_more_like_this(self, app_model: App,
message_id: str,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator]:
"""
Generate App response.
:param app_model: App
:param message_id: message ID
:param user: account or end user
:param invoke_from: invoke from source
:param stream: is stream
"""
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict['model']
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
override_model_config_dict['model'] = model_dict
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.validate_and_transform_files_arg(
message.files,
file_upload_entity,
user
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_config=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras={}
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = AppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
...@@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager ...@@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.apps.completion.app_config_manager import CompletionAppConfig
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
EasyUIBasedAppGenerateEntity, CompletionAppGenerateEntity,
) )
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
...@@ -22,7 +22,7 @@ class CompletionAppRunner(AppRunner): ...@@ -22,7 +22,7 @@ class CompletionAppRunner(AppRunner):
Completion Application Runner Completion Application Runner
""" """
def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def run(self, application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message: Message) -> None: message: Message) -> None:
""" """
......
import json import json
import logging import logging
import threading from typing import Union, Generator, Optional
import uuid
from collections.abc import Generator from sqlalchemy import and_
from typing import Any, Optional, Union, cast
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from flask import Flask, current_app from core.app.app_queue_manager import ConversationTaskStoppedException, AppQueueManager
from pydantic import ValidationError from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity, AppGenerateEntity, \
from core.app.app_config.easy_ui_based_app.model_config.converter import EasyUIBasedModelConfigEntityConverter CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom, VariableEntity
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.entities.app_invoke_entities import (
EasyUIBasedAppGenerateEntity,
InvokeFrom,
)
from core.app.generate_task_pipeline import GenerateTaskPipeline from core.app.generate_task_pipeline import GenerateTaskPipeline
from core.file.file_obj import FileObj
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from models.model import Conversation, Message, AppMode, MessageFile, App, EndUser, AppModelConfig
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EasyUIBasedAppManager: class MessageBasedAppGenerator(BaseAppGenerator):
def generate(self, app_model: App,
app_model_config: AppModelConfig,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
inputs: dict[str, str],
app_model_config_dict: Optional[dict] = None,
query: Optional[str] = None,
files: Optional[list[FileObj]] = None,
conversation: Optional[Conversation] = None,
stream: bool = False,
extras: Optional[dict[str, Any]] = None) \
-> Union[dict, Generator]:
"""
Generate App response.
:param app_model: App
:param app_model_config: app model config
:param user: account or end user
:param invoke_from: invoke from source
:param inputs: inputs
:param app_model_config_dict: app model config dict
:param query: query
:param files: file obj list
:param conversation: conversation
:param stream: is stream
:param extras: extras
"""
# init task id
task_id = str(uuid.uuid4())
# convert to app config
app_config = self.convert_to_app_config(
app_model=app_model,
app_model_config=app_model_config,
app_model_config_dict=app_model_config_dict,
conversation=conversation
)
# init application generate entity
application_generate_entity = EasyUIBasedAppGenerateEntity(
task_id=task_id,
app_config=app_config,
model_config=EasyUIBasedModelConfigEntityConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query.replace('\x00', '') if query else None,
files=files if files else [],
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
if not stream and application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT:
raise ValueError("Agent app is not supported in blocking mode.")
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = AppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def convert_to_app_config(self, app_model: App,
app_model_config: AppModelConfig,
app_model_config_dict: Optional[dict] = None,
conversation: Optional[Conversation] = None) -> EasyUIBasedAppConfig:
if app_model_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
elif conversation:
config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
app_mode = AppMode.value_of(app_model.mode)
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_config = AgentChatAppConfigManager.config_convert(
app_model=app_model,
config_from=config_from,
app_model_config=app_model_config,
config_dict=app_model_config_dict
)
elif app_mode == AppMode.CHAT:
app_config = ChatAppConfigManager.config_convert(
app_model=app_model,
config_from=config_from,
app_model_config=app_model_config,
config_dict=app_model_config_dict
)
elif app_mode == AppMode.COMPLETION:
app_config = CompletionAppConfigManager.config_convert(
app_model=app_model,
config_from=config_from,
app_model_config=app_model_config,
config_dict=app_model_config_dict
)
else:
raise ValueError("Invalid app mode")
return app_config
def _get_cleaned_inputs(self, user_inputs: dict, app_config: EasyUIBasedAppConfig):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"{variable} is required in input form")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value: def _handle_response(self, application_generate_entity: Union[
if not isinstance(value, str): ChatAppGenerateEntity,
raise ValueError(f"{variable} in input form must be a string") CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
if variable_config.type == VariableEntity.Type.SELECT: ],
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
def _generate_worker(self, flask_app: Flask,
application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if application_generate_entity.app_config.app_mode == AppMode.AGENT_CHAT:
# agent app
runner = AgentChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
elif application_generate_entity.app_config.app_mode == AppMode.CHAT:
# chatbot app
runner = ChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
elif application_generate_entity.app_config.app_mode == AppMode.COMPLETION:
# completion app
runner = CompletionAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message
)
else:
raise ValueError("Invalid app mode")
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
def _handle_response(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
...@@ -303,27 +59,66 @@ class EasyUIBasedAppManager: ...@@ -303,27 +59,66 @@ class EasyUIBasedAppManager:
finally: finally:
db.session.remove() db.session.remove()
def _init_generate_records(self, application_generate_entity: EasyUIBasedAppGenerateEntity) \ def _get_conversation_by_user(self, app_model: App, conversation_id: str,
user: Union[Account, EndUser]) -> Conversation:
conversation_filter = [
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.status == 'normal'
]
if isinstance(user, Account):
conversation_filter.append(Conversation.from_account_id == user.id)
else:
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
raise ConversationCompletedError()
return conversation
def _get_app_model_config(self, app_model: App,
conversation: Optional[Conversation] = None) \
-> AppModelConfig:
if conversation:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
if not app_model_config:
raise AppModelConfigBrokenError()
else:
if app_model.app_model_config_id is None:
raise AppModelConfigBrokenError()
app_model_config = app_model.app_model_config
if not app_model_config:
raise AppModelConfigBrokenError()
return app_model_config
def _init_generate_records(self,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
],
conversation: Optional[Conversation] = None) \
-> tuple[Conversation, Message]: -> tuple[Conversation, Message]:
""" """
Initialize generate records Initialize generate records
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:return: :return:
""" """
model_type_instance = application_generate_entity.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_schema = model_type_instance.get_model_schema(
model=application_generate_entity.model_config.model,
credentials=application_generate_entity.model_config.credentials
)
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_record = (db.session.query(App)
.filter(App.id == app_config.app_id).first())
app_mode = app_record.mode
# get from source # get from source
end_user_id = None end_user_id = None
account_id = None account_id = None
...@@ -335,22 +130,21 @@ class EasyUIBasedAppManager: ...@@ -335,22 +130,21 @@ class EasyUIBasedAppManager:
account_id = application_generate_entity.user_id account_id = application_generate_entity.user_id
override_model_configs = None override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS: if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]:
override_model_configs = app_config.app_model_config_dict override_model_configs = app_config.app_model_config_dict
introduction = '' # get conversation introduction
if app_mode == 'chat': introduction = self._get_conversation_introduction(application_generate_entity)
# get conversation introduction
introduction = self._get_conversation_introduction(application_generate_entity)
if not application_generate_entity.conversation_id: if not conversation:
conversation = Conversation( conversation = Conversation(
app_id=app_record.id, app_id=app_config.app_id,
app_model_config_id=app_config.app_model_config_id, app_model_config_id=app_config.app_model_config_id,
model_provider=application_generate_entity.model_config.provider, model_provider=application_generate_entity.model_config.provider,
model_id=application_generate_entity.model_config.model, model_id=application_generate_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_mode, mode=app_config.app_mode.value,
name='New conversation', name='New conversation',
inputs=application_generate_entity.inputs, inputs=application_generate_entity.inputs,
introduction=introduction, introduction=introduction,
...@@ -364,19 +158,9 @@ class EasyUIBasedAppManager: ...@@ -364,19 +158,9 @@ class EasyUIBasedAppManager:
db.session.add(conversation) db.session.add(conversation)
db.session.commit() db.session.commit()
else:
conversation = (
db.session.query(Conversation)
.filter(
Conversation.id == application_generate_entity.conversation_id,
Conversation.app_id == app_record.id
).first()
)
currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
message = Message( message = Message(
app_id=app_record.id, app_id=app_config.app_id,
model_provider=application_generate_entity.model_config.provider, model_provider=application_generate_entity.model_config.provider,
model_id=application_generate_entity.model_config.model, model_id=application_generate_entity.model_config.model,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
...@@ -393,11 +177,10 @@ class EasyUIBasedAppManager: ...@@ -393,11 +177,10 @@ class EasyUIBasedAppManager:
answer_price_unit=0, answer_price_unit=0,
provider_response_latency=0, provider_response_latency=0,
total_price=0, total_price=0,
currency=currency, currency='USD',
from_source=from_source, from_source=from_source,
from_end_user_id=end_user_id, from_end_user_id=end_user_id,
from_account_id=account_id, from_account_id=account_id
agent_based=app_config.app_mode == AppMode.AGENT_CHAT,
) )
db.session.add(message) db.session.add(message)
...@@ -419,7 +202,7 @@ class EasyUIBasedAppManager: ...@@ -419,7 +202,7 @@ class EasyUIBasedAppManager:
return conversation, message return conversation, message
def _get_conversation_introduction(self, application_generate_entity: EasyUIBasedAppGenerateEntity) -> str: def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:
""" """
Get conversation introduction Get conversation introduction
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
......
...@@ -17,7 +17,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig): ...@@ -17,7 +17,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
class WorkflowAppConfigManager(BaseAppConfigManager): class WorkflowAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: def get_app_config(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig:
features_dict = workflow.features_dict features_dict = workflow.features_dict
app_config = WorkflowAppConfig( app_config = WorkflowAppConfig(
......
...@@ -3,7 +3,7 @@ from typing import Any, Optional ...@@ -3,7 +3,7 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig, AppConfig
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
...@@ -49,9 +49,9 @@ class InvokeFrom(Enum): ...@@ -49,9 +49,9 @@ class InvokeFrom(Enum):
return 'dev' return 'dev'
class EasyUIBasedModelConfigEntity(BaseModel): class ModelConfigWithCredentialsEntity(BaseModel):
""" """
Model Config Entity. Model Config With Credentials Entity.
""" """
provider: str provider: str
model: str model: str
...@@ -63,21 +63,19 @@ class EasyUIBasedModelConfigEntity(BaseModel): ...@@ -63,21 +63,19 @@ class EasyUIBasedModelConfigEntity(BaseModel):
stop: list[str] = [] stop: list[str] = []
class EasyUIBasedAppGenerateEntity(BaseModel): class AppGenerateEntity(BaseModel):
""" """
EasyUI Based Application Generate Entity. App Generate Entity.
""" """
task_id: str task_id: str
# app config # app config
app_config: EasyUIBasedAppConfig app_config: AppConfig
model_config: EasyUIBasedModelConfigEntity
conversation_id: Optional[str] = None
inputs: dict[str, str] inputs: dict[str, str]
query: Optional[str] = None
files: list[FileObj] = [] files: list[FileObj] = []
user_id: str user_id: str
# extras # extras
stream: bool stream: bool
invoke_from: InvokeFrom invoke_from: InvokeFrom
...@@ -86,26 +84,52 @@ class EasyUIBasedAppGenerateEntity(BaseModel): ...@@ -86,26 +84,52 @@ class EasyUIBasedAppGenerateEntity(BaseModel):
extras: dict[str, Any] = {} extras: dict[str, Any] = {}
class WorkflowUIBasedAppGenerateEntity(BaseModel): class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
""" """
Workflow UI Based Application Generate Entity. Chat Application Generate Entity.
""" """
task_id: str
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: EasyUIBasedAppConfig
model_config: ModelConfigWithCredentialsEntity
inputs: dict[str, str] query: Optional[str] = None
files: list[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom
# extra parameters
extras: dict[str, Any] = {}
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None
class AdvancedChatAppGenerateEntity(WorkflowUIBasedAppGenerateEntity):
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Completion Application Generate Entity.
"""
pass
class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Agent Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
query: str
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
"""
Advanced Chat Application Generate Entity.
"""
# app config
app_config: WorkflowUIBasedAppConfig
conversation_id: Optional[str] = None
query: Optional[str] = None
class WorkflowUIBasedAppGenerateEntity(AppGenerateEntity):
"""
Workflow UI Based Application Generate Entity.
"""
# app config
app_config: WorkflowUIBasedAppConfig
import logging import logging
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, EasyUIBasedAppGenerateEntity
from core.helper import moderation from core.helper import moderation
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
......
...@@ -7,7 +7,8 @@ from typing import Optional, Union, cast ...@@ -7,7 +7,8 @@ from typing import Optional, Union, cast
from pydantic import BaseModel from pydantic import BaseModel
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom, CompletionAppGenerateEntity, \
AgentChatAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AnnotationReplyEvent, AnnotationReplyEvent,
QueueAgentMessageEvent, QueueAgentMessageEvent,
...@@ -39,7 +40,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser ...@@ -39,7 +40,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought, MessageFile from models.model import Conversation, Message, MessageAgentThought, MessageFile, AppMode
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -58,7 +59,11 @@ class GenerateTaskPipeline: ...@@ -58,7 +59,11 @@ class GenerateTaskPipeline:
GenerateTaskPipeline is a class that generate stream output and state management for Application. GenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
def __init__(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def __init__(self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
...@@ -425,6 +430,7 @@ class GenerateTaskPipeline: ...@@ -425,6 +430,7 @@ class GenerateTaskPipeline:
self._message.answer_price_unit = usage.completion_price_unit self._message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price self._message.total_price = usage.total_price
self._message.currency = usage.currency
db.session.commit() db.session.commit()
...@@ -432,7 +438,11 @@ class GenerateTaskPipeline: ...@@ -432,7 +438,11 @@ class GenerateTaskPipeline:
self._message, self._message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation, conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None, is_first_message=self._application_generate_entity.app_config.app_mode in [
AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.ADVANCED_CHAT
] and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras extras=self._application_generate_entity.extras
) )
......
import logging import logging
import random import random
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
from extensions.ext_hosting_provider import hosting_configuration from extensions.ext_hosting_provider import hosting_configuration
...@@ -10,7 +10,7 @@ from models.provider import ProviderType ...@@ -10,7 +10,7 @@ from models.provider import ProviderType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_moderation(model_config: EasyUIBasedModelConfigEntity, text: str) -> bool: def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config moderation_config = hosting_configuration.moderation_config
if (moderation_config and moderation_config.enabled is True if (moderation_config and moderation_config.enabled is True
and 'openai' in hosting_configuration.provider_map and 'openai' in hosting_configuration.provider_map
......
from typing import Optional from typing import Optional
from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
...@@ -28,7 +28,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -28,7 +28,7 @@ class AdvancedPromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
prompt_messages = [] prompt_messages = []
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
...@@ -62,7 +62,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -62,7 +62,7 @@ class AdvancedPromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
""" """
Get completion model prompt messages. Get completion model prompt messages.
""" """
...@@ -110,7 +110,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -110,7 +110,7 @@ class AdvancedPromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
""" """
Get chat model prompt messages. Get chat model prompt messages.
""" """
...@@ -199,7 +199,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -199,7 +199,7 @@ class AdvancedPromptTransform(PromptTransform):
role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity,
prompt_template: PromptTemplateParser, prompt_template: PromptTemplateParser,
prompt_inputs: dict, prompt_inputs: dict,
model_config: EasyUIBasedModelConfigEntity) -> dict: model_config: ModelConfigWithCredentialsEntity) -> dict:
if '#histories#' in prompt_template.variable_keys: if '#histories#' in prompt_template.variable_keys:
if memory: if memory:
inputs = {'#histories#': '', **prompt_inputs} inputs = {'#histories#': '', **prompt_inputs}
......
from typing import Optional, cast from typing import Optional, cast
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
...@@ -10,14 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -10,14 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
class PromptTransform: class PromptTransform:
def _append_chat_histories(self, memory: TokenBufferMemory, def _append_chat_histories(self, memory: TokenBufferMemory,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config) rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, rest_tokens) histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories) prompt_messages.extend(histories)
return prompt_messages return prompt_messages
def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: EasyUIBasedModelConfigEntity) -> int: def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> int:
rest_tokens = 2000 rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
......
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
from typing import Optional from typing import Optional
from core.app.app_config.entities import PromptTemplateEntity from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
...@@ -52,7 +52,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -52,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) -> \ model_config: ModelConfigWithCredentialsEntity) -> \
tuple[list[PromptMessage], Optional[list[str]]]: tuple[list[PromptMessage], Optional[list[str]]]:
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
...@@ -81,7 +81,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -81,7 +81,7 @@ class SimplePromptTransform(PromptTransform):
return prompt_messages, stops return prompt_messages, stops
def get_prompt_str_and_rules(self, app_mode: AppMode, def get_prompt_str_and_rules(self, app_mode: AppMode,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
pre_prompt: str, pre_prompt: str,
inputs: dict, inputs: dict,
query: Optional[str] = None, query: Optional[str] = None,
...@@ -162,7 +162,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -162,7 +162,7 @@ class SimplePromptTransform(PromptTransform):
context: Optional[str], context: Optional[str],
files: list[FileObj], files: list[FileObj],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) \ model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]: -> tuple[list[PromptMessage], Optional[list[str]]]:
prompt_messages = [] prompt_messages = []
...@@ -200,7 +200,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -200,7 +200,7 @@ class SimplePromptTransform(PromptTransform):
context: Optional[str], context: Optional[str],
files: list[FileObj], files: list[FileObj],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) \ model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]: -> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt # get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules( prompt, prompt_rules = self.get_prompt_str_and_rules(
......
...@@ -5,14 +5,14 @@ from langchain.callbacks.manager import CallbackManagerForChainRun ...@@ -5,14 +5,14 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.retrieval.agent.fake_llm import FakeLLM from core.rag.retrieval.agent.fake_llm import FakeLLM
class LLMChain(LCLLMChain): class LLMChain(LCLLMChain):
model_config: EasyUIBasedModelConfigEntity model_config: ModelConfigWithCredentialsEntity
"""The language model instance to use.""" """The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="") llm: BaseLanguageModel = FakeLLM(response="")
parameters: dict[str, Any] = {} parameters: dict[str, Any] = {}
......
...@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage ...@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import root_validator from pydantic import root_validator
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessageTool
...@@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
""" """
An Multi Dataset Retrieve Agent driven by Router. An Multi Dataset Retrieve Agent driven by Router.
""" """
model_config: EasyUIBasedModelConfigEntity model_config: ModelConfigWithCredentialsEntity
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
...@@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
......
...@@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy ...@@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.rag.retrieval.agent.llm_chain import LLMChain from core.rag.retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
...@@ -206,7 +206,7 @@ Thought: {agent_scratchpad} ...@@ -206,7 +206,7 @@ Thought: {agent_scratchpad}
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
......
...@@ -7,7 +7,7 @@ from langchain.callbacks.manager import Callbacks ...@@ -7,7 +7,7 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.message_entities import prompt_messages_to_lc_messages from core.entities.message_entities import prompt_messages_to_lc_messages
from core.helper import moderation from core.helper import moderation
...@@ -22,9 +22,9 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr ...@@ -22,9 +22,9 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr
class AgentConfiguration(BaseModel): class AgentConfiguration(BaseModel):
strategy: PlanningStrategy strategy: PlanningStrategy
model_config: EasyUIBasedModelConfigEntity model_config: ModelConfigWithCredentialsEntity
tools: list[BaseTool] tools: list[BaseTool]
summary_model_config: Optional[EasyUIBasedModelConfigEntity] = None summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None
memory: Optional[TokenBufferMemory] = None memory: Optional[TokenBufferMemory] = None
callbacks: Callbacks = None callbacks: Callbacks = None
max_iterations: int = 6 max_iterations: int = 6
......
...@@ -3,7 +3,7 @@ from typing import Optional, cast ...@@ -3,7 +3,7 @@ from typing import Optional, cast
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity, InvokeFrom from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
...@@ -18,7 +18,7 @@ from models.dataset import Dataset ...@@ -18,7 +18,7 @@ from models.dataset import Dataset
class DatasetRetrieval: class DatasetRetrieval:
def retrieve(self, tenant_id: str, def retrieve(self, tenant_id: str,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity, config: DatasetEntity,
query: str, query: str,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
......
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
...@@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType ...@@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType
@message_was_created.connect @message_was_created.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
message = sender message = sender
application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity')
model_config = application_generate_entity.model_config model_config = application_generate_entity.model_config
provider_model_bundle = model_config.provider_model_bundle provider_model_bundle = model_config.provider_model_bundle
......
from datetime import datetime from datetime import datetime
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.provider import Provider from models.provider import Provider
...@@ -9,7 +9,7 @@ from models.provider import Provider ...@@ -9,7 +9,7 @@ from models.provider import Provider
@message_was_created.connect @message_was_created.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
message = sender message = sender
application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity')
db.session.query(Provider).filter( db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.tenant_id == application_generate_entity.app_config.tenant_id,
......
import json
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Union from typing import Any, Union
from sqlalchemy import and_ from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.app_manager import EasyUIBasedAppManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.message_file_parser import MessageFileParser from models.model import Account, App, AppMode, EndUser
from extensions.ext_database import db
from models.model import Account, App, AppMode, AppModelConfig, Conversation, EndUser, Message
from services.app_model_config_service import AppModelConfigService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
from services.errors.message import MessageNotExistsError
class CompletionService: class CompletionService:
@classmethod @classmethod
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
invoke_from: InvokeFrom, streaming: bool = True, invoke_from: InvokeFrom, streaming: bool = True) -> Union[dict, Generator]:
is_model_config_override: bool = False) -> Union[dict, Generator]: """
# is streaming mode App Completion
inputs = args['inputs'] :param app_model: app model
query = args['query'] :param user: user
files = args['files'] if 'files' in args and args['files'] else [] :param args: args
auto_generate_name = args['auto_generate_name'] \ :param invoke_from: invoke from
if 'auto_generate_name' in args else True :param streaming: streaming
:return:
if app_model.mode != AppMode.COMPLETION.value: """
if not query: if app_model.mode == AppMode.COMPLETION.value:
raise ValueError('query is required') return CompletionAppGenerator().generate(
app_model=app_model,
if query: user=user,
if not isinstance(query, str): args=args,
raise ValueError('query must be a string') invoke_from=invoke_from,
stream=streaming
query = query.replace('\x00', '') )
elif app_model.mode == AppMode.CHAT.value:
conversation_id = args['conversation_id'] if 'conversation_id' in args else None return ChatAppGenerator().generate(
app_model=app_model,
conversation = None user=user,
app_model_config_dict = None args=args,
if conversation_id: invoke_from=invoke_from,
conversation_filter = [ stream=streaming
Conversation.id == args['conversation_id'], )
Conversation.app_id == app_model.id, elif app_model.mode == AppMode.AGENT_CHAT.value:
Conversation.status == 'normal' return AgentChatAppGenerator().generate(
] app_model=app_model,
user=user,
if isinstance(user, Account): args=args,
conversation_filter.append(Conversation.from_account_id == user.id) invoke_from=invoke_from,
else: stream=streaming
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
raise ConversationCompletedError()
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
if not app_model_config:
raise AppModelConfigBrokenError()
else:
if app_model.app_model_config_id is None:
raise AppModelConfigBrokenError()
app_model_config = app_model.app_model_config
if not app_model_config:
raise AppModelConfigBrokenError()
if is_model_config_override:
if not isinstance(user, Account):
raise Exception("Only account can override model config")
# validate config
app_model_config_dict = AppModelConfigService.validate_configuration(
tenant_id=app_model.tenant_id,
config=args['model_config'],
app_mode=AppMode.value_of(app_model.mode)
)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(app_model_config_dict or app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_upload_entity,
user
) )
else: else:
file_objs = [] raise ValueError('Invalid app mode')
application_manager = EasyUIBasedAppManager()
return application_manager.generate(
app_model=app_model,
app_model_config=app_model_config,
app_model_config_dict=app_model_config_dict,
user=user,
invoke_from=invoke_from,
inputs=inputs,
query=query,
files=file_objs,
conversation=conversation,
stream=streaming,
extras={
"auto_generate_conversation_name": auto_generate_name
}
)
@classmethod @classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
-> Union[dict, Generator]: -> Union[dict, Generator]:
if not user: """
raise ValueError('user cannot be None') Generate more like this
:param app_model: app model
message = db.session.query(Message).filter( :param user: user
Message.id == message_id, :param message_id: message id
Message.app_id == app_model.id, :param invoke_from: invoke from
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), :param streaming: streaming
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), :return:
Message.from_account_id == (user.id if isinstance(user, Account) else None), """
).first() return CompletionAppGenerator().generate_more_like_this(
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
model_dict = app_model_config.model_dict
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(current_app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.transform_message_files(
message.files, file_upload_entity
)
else:
file_objs = []
application_manager = EasyUIBasedAppManager()
return application_manager.generate(
app_model=app_model, app_model=app_model,
app_model_config=current_app_model_config, message_id=message_id,
app_model_config_dict=app_model_config.to_dict(),
user=user, user=user,
invoke_from=invoke_from, invoke_from=invoke_from,
inputs=message.inputs, stream=streaming
query=message.query,
files=file_objs,
conversation=None,
stream=streaming,
extras={
"auto_generate_conversation_name": False
}
) )
...@@ -8,9 +8,11 @@ from core.app.app_config.entities import ( ...@@ -8,9 +8,11 @@ from core.app.app_config.entities import (
FileUploadEntity, FileUploadEntity,
ModelConfigEntity, ModelConfigEntity,
PromptTemplateEntity, PromptTemplateEntity,
VariableEntity, VariableEntity, EasyUIBasedAppConfig,
) )
from core.app.app_manager import EasyUIBasedAppManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.helper import encrypter from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
...@@ -87,8 +89,7 @@ class WorkflowConverter: ...@@ -87,8 +89,7 @@ class WorkflowConverter:
new_app_mode = self._get_new_app_mode(app_model) new_app_mode = self._get_new_app_mode(app_model)
# convert app model config # convert app model config
application_manager = EasyUIBasedAppManager() app_config = self._convert_to_app_config(
app_config = application_manager.convert_to_app_config(
app_model=app_model, app_model=app_model,
app_model_config=app_model_config app_model_config=app_model_config
) )
...@@ -190,6 +191,30 @@ class WorkflowConverter: ...@@ -190,6 +191,30 @@ class WorkflowConverter:
return workflow return workflow
def _convert_to_app_config(self, app_model: App,
app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
app_mode = AppMode.value_of(app_model.mode)
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config
)
elif app_mode == AppMode.CHAT:
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config
)
elif app_mode == AppMode.COMPLETION:
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config
)
else:
raise ValueError("Invalid app mode")
return app_config
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
""" """
Convert to Start Node Convert to Start Node
...@@ -566,6 +591,6 @@ class WorkflowConverter: ...@@ -566,6 +591,6 @@ class WorkflowConverter:
:return: :return:
""" """
return db.session.query(APIBasedExtension).filter( return db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id APIBasedExtension.id == api_based_extension_id
).first() ).first()
from unittest.mock import MagicMock from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform
...@@ -139,7 +139,7 @@ def test_get_common_chat_app_prompt_template_with_p(): ...@@ -139,7 +139,7 @@ def test_get_common_chat_app_prompt_template_with_p():
def test__get_chat_model_prompt_messages(): def test__get_chat_model_prompt_messages():
model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = 'openai' model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4' model_config_mock.model = 'gpt-4'
...@@ -191,7 +191,7 @@ def test__get_chat_model_prompt_messages(): ...@@ -191,7 +191,7 @@ def test__get_chat_model_prompt_messages():
def test__get_completion_model_prompt_messages(): def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = 'openai' model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-3.5-turbo-instruct' model_config_mock.model = 'gpt-3.5-turbo-instruct'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment