Commit 1fc521c8 authored by takatost's avatar takatost

refactor app generate

parent b7ca6d78
...@@ -59,8 +59,7 @@ class CompletionMessageApi(Resource): ...@@ -59,8 +59,7 @@ class CompletionMessageApi(Resource):
user=account, user=account,
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming, streaming=streaming
is_model_config_override=True
) )
return compact_response(response) return compact_response(response)
...@@ -126,8 +125,7 @@ class ChatMessageApi(Resource): ...@@ -126,8 +125,7 @@ class ChatMessageApi(Resource):
user=account, user=account,
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming, streaming=streaming
is_model_config_override=True
) )
return compact_response(response) return compact_response(response)
......
...@@ -10,9 +10,8 @@ from core.app.app_queue_manager import AppQueueManager ...@@ -10,9 +10,8 @@ from core.app.app_queue_manager import AppQueueManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
EasyUIBasedAppGenerateEntity, ModelConfigWithCredentialsEntity,
EasyUIBasedModelConfigEntity, InvokeFrom, AgentChatAppGenerateEntity,
InvokeFrom,
) )
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
...@@ -49,9 +48,9 @@ logger = logging.getLogger(__name__) ...@@ -49,9 +48,9 @@ logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner): class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str, def __init__(self, tenant_id: str,
application_generate_entity: EasyUIBasedAppGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
app_config: AgentChatAppConfig, app_config: AgentChatAppConfig,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity, config: AgentEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message: Message, message: Message,
...@@ -122,8 +121,8 @@ class BaseAgentRunner(AppRunner): ...@@ -122,8 +121,8 @@ class BaseAgentRunner(AppRunner):
else: else:
self.stream_tool_call = False self.stream_tool_call = False
def _repack_app_generate_entity(self, app_generate_entity: EasyUIBasedAppGenerateEntity) \ def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> EasyUIBasedAppGenerateEntity: -> AgentChatAppGenerateEntity:
""" """
Repack app generate entity Repack app generate entity
""" """
......
from typing import cast from typing import cast
from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
...@@ -9,11 +9,11 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -9,11 +9,11 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
class EasyUIBasedModelConfigEntityConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, def convert(cls, app_config: EasyUIBasedAppConfig,
skip_check: bool = False) \ skip_check: bool = False) \
-> EasyUIBasedModelConfigEntity: -> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
...@@ -91,7 +91,7 @@ class EasyUIBasedModelConfigEntityConverter: ...@@ -91,7 +91,7 @@ class EasyUIBasedModelConfigEntityConverter:
if not skip_check and not model_schema: if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")
return EasyUIBasedModelConfigEntity( return ModelConfigWithCredentialsEntity(
provider=model_config.provider, provider=model_config.provider,
model=model_config.model, model=model_config.model,
model_schema=model_schema, model_schema=model_schema,
......
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.app_config.entities import WorkflowUIBasedAppConfig
...@@ -10,7 +12,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ...@@ -10,7 +12,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
) )
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.model import App, AppMode from models.model import App, AppMode, Conversation
from models.workflow import Workflow from models.workflow import Workflow
...@@ -23,7 +25,9 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): ...@@ -23,7 +25,9 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
class AdvancedChatAppConfigManager(BaseAppConfigManager): class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: def get_app_config(cls, app_model: App,
workflow: Workflow,
conversation: Optional[Conversation] = None) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict features_dict = workflow.features_dict
app_config = AdvancedChatAppConfig( app_config = AdvancedChatAppConfig(
......
...@@ -19,7 +19,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ...@@ -19,7 +19,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
) )
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from models.model import App, AppMode, AppModelConfig from models.model import App, AppMode, AppModelConfig, Conversation
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
...@@ -33,19 +33,30 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): ...@@ -33,19 +33,30 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
class AgentChatAppConfigManager(BaseAppConfigManager): class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, def get_app_config(cls, app_model: App,
config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: AppModelConfig, app_model_config: AppModelConfig,
config_dict: Optional[dict] = None) -> AgentChatAppConfig: conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
""" """
Convert app model config to agent chat app config Convert app model config to agent chat app config
:param app_model: app model :param app_model: app model
:param config_from: app model config from
:param app_model_config: app model config :param app_model_config: app model config
:param config_dict: app model config dict :param conversation: conversation
:param override_config_dict: app model config dict
:return: :return:
""" """
config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) if override_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
elif conversation:
config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
app_config = AgentChatAppConfig( app_config = AgentChatAppConfig(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
......
import logging
import threading
import uuid
from typing import Union, Any, Generator
from flask import current_app, Flask
from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, AgentChatAppGenerateEntity
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator]:
"""
Generate App response.
:param app_model: App
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
query = args['query']
if not isinstance(query, str):
raise ValueError('query must be a string')
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {
"auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True
}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
# validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
)
# parse files
files = args['files'] if 'files' in args and args['files'] else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_upload_entity,
user
)
else:
file_objs = []
# convert to app config
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict
)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_config=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = AppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AgentChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
...@@ -7,7 +7,8 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner ...@@ -7,7 +7,8 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, \
AgentChatAppGenerateEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
...@@ -26,7 +27,7 @@ class AgentChatAppRunner(AppRunner): ...@@ -26,7 +27,7 @@ class AgentChatAppRunner(AppRunner):
""" """
Agent Application Runner Agent Application Runner
""" """
def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def run(self, application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
...@@ -288,7 +289,7 @@ class AgentChatAppRunner(AppRunner): ...@@ -288,7 +289,7 @@ class AgentChatAppRunner(AppRunner):
'pool': db_variables.variables 'pool': db_variables.variables
}) })
def _get_usage_of_all_agent_thoughts(self, model_config: EasyUIBasedModelConfigEntity, def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
message: Message) -> LLMUsage: message: Message) -> LLMUsage:
""" """
Get usage of all agent thoughts Get usage of all agent thoughts
......
from core.app.app_config.entities import VariableEntity, AppConfig
class BaseAppGenerator:
def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"{variable} is required in input form")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value:
if not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
...@@ -5,9 +5,8 @@ from typing import Optional, Union, cast ...@@ -5,9 +5,8 @@ from typing import Optional, Union, cast
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
EasyUIBasedAppGenerateEntity, ModelConfigWithCredentialsEntity,
EasyUIBasedModelConfigEntity, InvokeFrom, AppGenerateEntity, EasyUIBasedAppGenerateEntity,
InvokeFrom,
) )
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
...@@ -27,7 +26,7 @@ from models.model import App, AppMode, Message, MessageAnnotation ...@@ -27,7 +26,7 @@ from models.model import App, AppMode, Message, MessageAnnotation
class AppRunner: class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App, def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileObj], files: list[FileObj],
...@@ -83,7 +82,7 @@ class AppRunner: ...@@ -83,7 +82,7 @@ class AppRunner:
return rest_tokens return rest_tokens
def recale_llm_max_tokens(self, model_config: EasyUIBasedModelConfigEntity, def recale_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage]): prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
...@@ -119,7 +118,7 @@ class AppRunner: ...@@ -119,7 +118,7 @@ class AppRunner:
model_config.parameters[parameter_rule.name] = max_tokens model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(self, app_record: App, def organize_prompt_messages(self, app_record: App,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileObj], files: list[FileObj],
...@@ -292,7 +291,7 @@ class AppRunner: ...@@ -292,7 +291,7 @@ class AppRunner:
def moderation_for_inputs(self, app_id: str, def moderation_for_inputs(self, app_id: str,
tenant_id: str, tenant_id: str,
app_generate_entity: EasyUIBasedAppGenerateEntity, app_generate_entity: AppGenerateEntity,
inputs: dict, inputs: dict,
query: str) -> tuple[bool, dict, str]: query: str) -> tuple[bool, dict, str]:
""" """
......
...@@ -15,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ...@@ -15,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
SuggestedQuestionsAfterAnswerConfigManager, SuggestedQuestionsAfterAnswerConfigManager,
) )
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig from models.model import App, AppMode, AppModelConfig, Conversation
class ChatAppConfig(EasyUIBasedAppConfig): class ChatAppConfig(EasyUIBasedAppConfig):
...@@ -27,19 +27,30 @@ class ChatAppConfig(EasyUIBasedAppConfig): ...@@ -27,19 +27,30 @@ class ChatAppConfig(EasyUIBasedAppConfig):
class ChatAppConfigManager(BaseAppConfigManager): class ChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, def get_app_config(cls, app_model: App,
config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: AppModelConfig, app_model_config: AppModelConfig,
config_dict: Optional[dict] = None) -> ChatAppConfig: conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> ChatAppConfig:
""" """
Convert app model config to chat app config Convert app model config to chat app config
:param app_model: app model :param app_model: app model
:param config_from: app model config from
:param app_model_config: app model config :param app_model_config: app model config
:param config_dict: app model config dict :param conversation: conversation
:param override_config_dict: app model config dict
:return: :return:
""" """
config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) if override_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
elif conversation:
config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
app_config = ChatAppConfig( app_config = ChatAppConfig(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
......
import logging
import threading
import uuid
from typing import Union, Any, Generator
from flask import current_app, Flask
from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_queue_manager import ConversationTaskStoppedException, PublishFrom, AppQueueManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, ChatAppGenerateEntity
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
logger = logging.getLogger(__name__)
class ChatAppGenerator(MessageBasedAppGenerator):
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator]:
"""
Generate App response.
:param app_model: App
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
query = args['query']
if not isinstance(query, str):
raise ValueError('query must be a string')
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {
"auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True
}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
)
# parse files
files = args['files'] if 'files' in args and args['files'] else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_upload_entity,
user
)
else:
file_objs = []
# convert to app config
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict
)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_config=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = AppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = ChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
...@@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom ...@@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.apps.chat.app_config_manager import ChatAppConfig
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
EasyUIBasedAppGenerateEntity, ChatAppGenerateEntity,
) )
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
...@@ -23,7 +23,7 @@ class ChatAppRunner(AppRunner): ...@@ -23,7 +23,7 @@ class ChatAppRunner(AppRunner):
Chat Application Runner Chat Application Runner
""" """
def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def run(self, application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
......
...@@ -10,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod ...@@ -10,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig from models.model import App, AppMode, AppModelConfig, Conversation
class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfig(EasyUIBasedAppConfig):
...@@ -22,19 +22,26 @@ class CompletionAppConfig(EasyUIBasedAppConfig): ...@@ -22,19 +22,26 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
class CompletionAppConfigManager(BaseAppConfigManager): class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, def get_app_config(cls, app_model: App,
config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: AppModelConfig, app_model_config: AppModelConfig,
config_dict: Optional[dict] = None) -> CompletionAppConfig: override_config_dict: Optional[dict] = None) -> CompletionAppConfig:
""" """
Convert app model config to completion app config Convert app model config to completion app config
:param app_model: app model :param app_model: app model
:param config_from: app model config from
:param app_model_config: app model config :param app_model_config: app model config
:param config_dict: app model config dict :param override_config_dict: app model config dict
:return: :return:
""" """
config_dict = cls.convert_to_config_dict(config_from, app_model_config, config_dict) if override_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict
app_config = CompletionAppConfig( app_config = CompletionAppConfig(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
......
This diff is collapsed.
...@@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager ...@@ -5,7 +5,7 @@ from core.app.app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.apps.completion.app_config_manager import CompletionAppConfig
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
EasyUIBasedAppGenerateEntity, CompletionAppGenerateEntity,
) )
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
...@@ -22,7 +22,7 @@ class CompletionAppRunner(AppRunner): ...@@ -22,7 +22,7 @@ class CompletionAppRunner(AppRunner):
Completion Application Runner Completion Application Runner
""" """
def run(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def run(self, application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message: Message) -> None: message: Message) -> None:
""" """
......
...@@ -17,7 +17,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig): ...@@ -17,7 +17,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
class WorkflowAppConfigManager(BaseAppConfigManager): class WorkflowAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def config_convert(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig: def get_app_config(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig:
features_dict = workflow.features_dict features_dict = workflow.features_dict
app_config = WorkflowAppConfig( app_config = WorkflowAppConfig(
......
...@@ -3,7 +3,7 @@ from typing import Any, Optional ...@@ -3,7 +3,7 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig, AppConfig
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
...@@ -49,9 +49,9 @@ class InvokeFrom(Enum): ...@@ -49,9 +49,9 @@ class InvokeFrom(Enum):
return 'dev' return 'dev'
class EasyUIBasedModelConfigEntity(BaseModel): class ModelConfigWithCredentialsEntity(BaseModel):
""" """
Model Config Entity. Model Config With Credentials Entity.
""" """
provider: str provider: str
model: str model: str
...@@ -63,21 +63,19 @@ class EasyUIBasedModelConfigEntity(BaseModel): ...@@ -63,21 +63,19 @@ class EasyUIBasedModelConfigEntity(BaseModel):
stop: list[str] = [] stop: list[str] = []
class EasyUIBasedAppGenerateEntity(BaseModel): class AppGenerateEntity(BaseModel):
""" """
EasyUI Based Application Generate Entity. App Generate Entity.
""" """
task_id: str task_id: str
# app config # app config
app_config: EasyUIBasedAppConfig app_config: AppConfig
model_config: EasyUIBasedModelConfigEntity
conversation_id: Optional[str] = None
inputs: dict[str, str] inputs: dict[str, str]
query: Optional[str] = None
files: list[FileObj] = [] files: list[FileObj] = []
user_id: str user_id: str
# extras # extras
stream: bool stream: bool
invoke_from: InvokeFrom invoke_from: InvokeFrom
...@@ -86,26 +84,52 @@ class EasyUIBasedAppGenerateEntity(BaseModel): ...@@ -86,26 +84,52 @@ class EasyUIBasedAppGenerateEntity(BaseModel):
extras: dict[str, Any] = {} extras: dict[str, Any] = {}
class WorkflowUIBasedAppGenerateEntity(BaseModel): class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
""" """
Workflow UI Based Application Generate Entity. Chat Application Generate Entity.
""" """
task_id: str
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: EasyUIBasedAppConfig
model_config: ModelConfigWithCredentialsEntity
inputs: dict[str, str] query: Optional[str] = None
files: list[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom
# extra parameters
extras: dict[str, Any] = {}
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None
class AdvancedChatAppGenerateEntity(WorkflowUIBasedAppGenerateEntity):
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Completion Application Generate Entity.
"""
pass
class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Agent Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
query: str
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
"""
Advanced Chat Application Generate Entity.
"""
# app config
app_config: WorkflowUIBasedAppConfig
conversation_id: Optional[str] = None
query: Optional[str] = None
class WorkflowUIBasedAppGenerateEntity(AppGenerateEntity):
"""
Workflow UI Based Application Generate Entity.
"""
# app config
app_config: WorkflowUIBasedAppConfig
import logging import logging
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, EasyUIBasedAppGenerateEntity
from core.helper import moderation from core.helper import moderation
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
......
...@@ -7,7 +7,8 @@ from typing import Optional, Union, cast ...@@ -7,7 +7,8 @@ from typing import Optional, Union, cast
from pydantic import BaseModel from pydantic import BaseModel
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom, CompletionAppGenerateEntity, \
AgentChatAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AnnotationReplyEvent, AnnotationReplyEvent,
QueueAgentMessageEvent, QueueAgentMessageEvent,
...@@ -39,7 +40,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser ...@@ -39,7 +40,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought, MessageFile from models.model import Conversation, Message, MessageAgentThought, MessageFile, AppMode
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -58,7 +59,11 @@ class GenerateTaskPipeline: ...@@ -58,7 +59,11 @@ class GenerateTaskPipeline:
GenerateTaskPipeline is a class that generate stream output and state management for Application. GenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
def __init__(self, application_generate_entity: EasyUIBasedAppGenerateEntity, def __init__(self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message) -> None: message: Message) -> None:
...@@ -425,6 +430,7 @@ class GenerateTaskPipeline: ...@@ -425,6 +430,7 @@ class GenerateTaskPipeline:
self._message.answer_price_unit = usage.completion_price_unit self._message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price self._message.total_price = usage.total_price
self._message.currency = usage.currency
db.session.commit() db.session.commit()
...@@ -432,7 +438,11 @@ class GenerateTaskPipeline: ...@@ -432,7 +438,11 @@ class GenerateTaskPipeline:
self._message, self._message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation, conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None, is_first_message=self._application_generate_entity.app_config.app_mode in [
AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.ADVANCED_CHAT
] and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras extras=self._application_generate_entity.extras
) )
......
import logging import logging
import random import random
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
from extensions.ext_hosting_provider import hosting_configuration from extensions.ext_hosting_provider import hosting_configuration
...@@ -10,7 +10,7 @@ from models.provider import ProviderType ...@@ -10,7 +10,7 @@ from models.provider import ProviderType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_moderation(model_config: EasyUIBasedModelConfigEntity, text: str) -> bool: def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config moderation_config = hosting_configuration.moderation_config
if (moderation_config and moderation_config.enabled is True if (moderation_config and moderation_config.enabled is True
and 'openai' in hosting_configuration.provider_map and 'openai' in hosting_configuration.provider_map
......
from typing import Optional from typing import Optional
from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
...@@ -28,7 +28,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -28,7 +28,7 @@ class AdvancedPromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
prompt_messages = [] prompt_messages = []
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
...@@ -62,7 +62,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -62,7 +62,7 @@ class AdvancedPromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
""" """
Get completion model prompt messages. Get completion model prompt messages.
""" """
...@@ -110,7 +110,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -110,7 +110,7 @@ class AdvancedPromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
""" """
Get chat model prompt messages. Get chat model prompt messages.
""" """
...@@ -199,7 +199,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -199,7 +199,7 @@ class AdvancedPromptTransform(PromptTransform):
role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity,
prompt_template: PromptTemplateParser, prompt_template: PromptTemplateParser,
prompt_inputs: dict, prompt_inputs: dict,
model_config: EasyUIBasedModelConfigEntity) -> dict: model_config: ModelConfigWithCredentialsEntity) -> dict:
if '#histories#' in prompt_template.variable_keys: if '#histories#' in prompt_template.variable_keys:
if memory: if memory:
inputs = {'#histories#': '', **prompt_inputs} inputs = {'#histories#': '', **prompt_inputs}
......
from typing import Optional, cast from typing import Optional, cast
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
...@@ -10,14 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -10,14 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
class PromptTransform: class PromptTransform:
def _append_chat_histories(self, memory: TokenBufferMemory, def _append_chat_histories(self, memory: TokenBufferMemory,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_config: EasyUIBasedModelConfigEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config) rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, rest_tokens) histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories) prompt_messages.extend(histories)
return prompt_messages return prompt_messages
def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: EasyUIBasedModelConfigEntity) -> int: def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> int:
rest_tokens = 2000 rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
......
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
from typing import Optional from typing import Optional
from core.app.app_config.entities import PromptTemplateEntity from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
...@@ -52,7 +52,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -52,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) -> \ model_config: ModelConfigWithCredentialsEntity) -> \
tuple[list[PromptMessage], Optional[list[str]]]: tuple[list[PromptMessage], Optional[list[str]]]:
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
...@@ -81,7 +81,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -81,7 +81,7 @@ class SimplePromptTransform(PromptTransform):
return prompt_messages, stops return prompt_messages, stops
def get_prompt_str_and_rules(self, app_mode: AppMode, def get_prompt_str_and_rules(self, app_mode: AppMode,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
pre_prompt: str, pre_prompt: str,
inputs: dict, inputs: dict,
query: Optional[str] = None, query: Optional[str] = None,
...@@ -162,7 +162,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -162,7 +162,7 @@ class SimplePromptTransform(PromptTransform):
context: Optional[str], context: Optional[str],
files: list[FileObj], files: list[FileObj],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) \ model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]: -> tuple[list[PromptMessage], Optional[list[str]]]:
prompt_messages = [] prompt_messages = []
...@@ -200,7 +200,7 @@ class SimplePromptTransform(PromptTransform): ...@@ -200,7 +200,7 @@ class SimplePromptTransform(PromptTransform):
context: Optional[str], context: Optional[str],
files: list[FileObj], files: list[FileObj],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: EasyUIBasedModelConfigEntity) \ model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]: -> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt # get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules( prompt, prompt_rules = self.get_prompt_str_and_rules(
......
...@@ -5,14 +5,14 @@ from langchain.callbacks.manager import CallbackManagerForChainRun ...@@ -5,14 +5,14 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.retrieval.agent.fake_llm import FakeLLM from core.rag.retrieval.agent.fake_llm import FakeLLM
class LLMChain(LCLLMChain): class LLMChain(LCLLMChain):
model_config: EasyUIBasedModelConfigEntity model_config: ModelConfigWithCredentialsEntity
"""The language model instance to use.""" """The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="") llm: BaseLanguageModel = FakeLLM(response="")
parameters: dict[str, Any] = {} parameters: dict[str, Any] = {}
......
...@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage ...@@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import root_validator from pydantic import root_validator
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessageTool
...@@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -21,7 +21,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
""" """
An Multi Dataset Retrieve Agent driven by Router. An Multi Dataset Retrieve Agent driven by Router.
""" """
model_config: EasyUIBasedModelConfigEntity model_config: ModelConfigWithCredentialsEntity
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
...@@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -156,7 +156,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
......
...@@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy ...@@ -12,7 +12,7 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.rag.retrieval.agent.llm_chain import LLMChain from core.rag.retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
...@@ -206,7 +206,7 @@ Thought: {agent_scratchpad} ...@@ -206,7 +206,7 @@ Thought: {agent_scratchpad}
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
......
...@@ -7,7 +7,7 @@ from langchain.callbacks.manager import Callbacks ...@@ -7,7 +7,7 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.message_entities import prompt_messages_to_lc_messages from core.entities.message_entities import prompt_messages_to_lc_messages
from core.helper import moderation from core.helper import moderation
...@@ -22,9 +22,9 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr ...@@ -22,9 +22,9 @@ from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetr
class AgentConfiguration(BaseModel): class AgentConfiguration(BaseModel):
strategy: PlanningStrategy strategy: PlanningStrategy
model_config: EasyUIBasedModelConfigEntity model_config: ModelConfigWithCredentialsEntity
tools: list[BaseTool] tools: list[BaseTool]
summary_model_config: Optional[EasyUIBasedModelConfigEntity] = None summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None
memory: Optional[TokenBufferMemory] = None memory: Optional[TokenBufferMemory] = None
callbacks: Callbacks = None callbacks: Callbacks = None
max_iterations: int = 6 max_iterations: int = 6
......
...@@ -3,7 +3,7 @@ from typing import Optional, cast ...@@ -3,7 +3,7 @@ from typing import Optional, cast
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity, InvokeFrom from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
...@@ -18,7 +18,7 @@ from models.dataset import Dataset ...@@ -18,7 +18,7 @@ from models.dataset import Dataset
class DatasetRetrieval: class DatasetRetrieval:
def retrieve(self, tenant_id: str, def retrieve(self, tenant_id: str,
model_config: EasyUIBasedModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity, config: DatasetEntity,
query: str, query: str,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
......
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
...@@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType ...@@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType
@message_was_created.connect @message_was_created.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
message = sender message = sender
application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity')
model_config = application_generate_entity.model_config model_config = application_generate_entity.model_config
provider_model_bundle = model_config.provider_model_bundle provider_model_bundle = model_config.provider_model_bundle
......
from datetime import datetime from datetime import datetime
from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.provider import Provider from models.provider import Provider
...@@ -9,7 +9,7 @@ from models.provider import Provider ...@@ -9,7 +9,7 @@ from models.provider import Provider
@message_was_created.connect @message_was_created.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
message = sender message = sender
application_generate_entity: EasyUIBasedAppGenerateEntity = kwargs.get('application_generate_entity') application_generate_entity: ChatAppGenerateEntity = kwargs.get('application_generate_entity')
db.session.query(Provider).filter( db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.tenant_id == application_generate_entity.app_config.tenant_id,
......
import json
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Union from typing import Any, Union
from sqlalchemy import and_ from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.app_manager import EasyUIBasedAppManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.message_file_parser import MessageFileParser from models.model import Account, App, AppMode, EndUser
from extensions.ext_database import db
from models.model import Account, App, AppMode, AppModelConfig, Conversation, EndUser, Message
from services.app_model_config_service import AppModelConfigService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
from services.errors.message import MessageNotExistsError
class CompletionService: class CompletionService:
@classmethod @classmethod
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any, def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
invoke_from: InvokeFrom, streaming: bool = True, invoke_from: InvokeFrom, streaming: bool = True) -> Union[dict, Generator]:
is_model_config_override: bool = False) -> Union[dict, Generator]: """
# is streaming mode App Completion
inputs = args['inputs'] :param app_model: app model
query = args['query'] :param user: user
files = args['files'] if 'files' in args and args['files'] else [] :param args: args
auto_generate_name = args['auto_generate_name'] \ :param invoke_from: invoke from
if 'auto_generate_name' in args else True :param streaming: streaming
:return:
if app_model.mode != AppMode.COMPLETION.value: """
if not query: if app_model.mode == AppMode.COMPLETION.value:
raise ValueError('query is required') return CompletionAppGenerator().generate(
app_model=app_model,
if query: user=user,
if not isinstance(query, str): args=args,
raise ValueError('query must be a string') invoke_from=invoke_from,
stream=streaming
query = query.replace('\x00', '') )
elif app_model.mode == AppMode.CHAT.value:
conversation_id = args['conversation_id'] if 'conversation_id' in args else None return ChatAppGenerator().generate(
app_model=app_model,
conversation = None user=user,
app_model_config_dict = None args=args,
if conversation_id: invoke_from=invoke_from,
conversation_filter = [ stream=streaming
Conversation.id == args['conversation_id'], )
Conversation.app_id == app_model.id, elif app_model.mode == AppMode.AGENT_CHAT.value:
Conversation.status == 'normal' return AgentChatAppGenerator().generate(
] app_model=app_model,
user=user,
if isinstance(user, Account): args=args,
conversation_filter.append(Conversation.from_account_id == user.id) invoke_from=invoke_from,
else: stream=streaming
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
raise ConversationCompletedError()
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
if not app_model_config:
raise AppModelConfigBrokenError()
else:
if app_model.app_model_config_id is None:
raise AppModelConfigBrokenError()
app_model_config = app_model.app_model_config
if not app_model_config:
raise AppModelConfigBrokenError()
if is_model_config_override:
if not isinstance(user, Account):
raise Exception("Only account can override model config")
# validate config
app_model_config_dict = AppModelConfigService.validate_configuration(
tenant_id=app_model.tenant_id,
config=args['model_config'],
app_mode=AppMode.value_of(app_model.mode)
)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(app_model_config_dict or app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_upload_entity,
user
) )
else: else:
file_objs = [] raise ValueError('Invalid app mode')
application_manager = EasyUIBasedAppManager()
return application_manager.generate(
app_model=app_model,
app_model_config=app_model_config,
app_model_config_dict=app_model_config_dict,
user=user,
invoke_from=invoke_from,
inputs=inputs,
query=query,
files=file_objs,
conversation=conversation,
stream=streaming,
extras={
"auto_generate_conversation_name": auto_generate_name
}
)
@classmethod @classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
-> Union[dict, Generator]: -> Union[dict, Generator]:
if not user: """
raise ValueError('user cannot be None') Generate more like this
:param app_model: app model
message = db.session.query(Message).filter( :param user: user
Message.id == message_id, :param message_id: message id
Message.app_id == app_model.id, :param invoke_from: invoke from
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), :param streaming: streaming
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), :return:
Message.from_account_id == (user.id if isinstance(user, Account) else None), """
).first() return CompletionAppGenerator().generate_more_like_this(
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
model_dict = app_model_config.model_dict
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_upload_entity = FileUploadConfigManager.convert(current_app_model_config.to_dict())
if file_upload_entity:
file_objs = message_file_parser.transform_message_files(
message.files, file_upload_entity
)
else:
file_objs = []
application_manager = EasyUIBasedAppManager()
return application_manager.generate(
app_model=app_model, app_model=app_model,
app_model_config=current_app_model_config, message_id=message_id,
app_model_config_dict=app_model_config.to_dict(),
user=user, user=user,
invoke_from=invoke_from, invoke_from=invoke_from,
inputs=message.inputs, stream=streaming
query=message.query,
files=file_objs,
conversation=None,
stream=streaming,
extras={
"auto_generate_conversation_name": False
}
) )
...@@ -8,9 +8,11 @@ from core.app.app_config.entities import ( ...@@ -8,9 +8,11 @@ from core.app.app_config.entities import (
FileUploadEntity, FileUploadEntity,
ModelConfigEntity, ModelConfigEntity,
PromptTemplateEntity, PromptTemplateEntity,
VariableEntity, VariableEntity, EasyUIBasedAppConfig,
) )
from core.app.app_manager import EasyUIBasedAppManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.helper import encrypter from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
...@@ -87,8 +89,7 @@ class WorkflowConverter: ...@@ -87,8 +89,7 @@ class WorkflowConverter:
new_app_mode = self._get_new_app_mode(app_model) new_app_mode = self._get_new_app_mode(app_model)
# convert app model config # convert app model config
application_manager = EasyUIBasedAppManager() app_config = self._convert_to_app_config(
app_config = application_manager.convert_to_app_config(
app_model=app_model, app_model=app_model,
app_model_config=app_model_config app_model_config=app_model_config
) )
...@@ -190,6 +191,30 @@ class WorkflowConverter: ...@@ -190,6 +191,30 @@ class WorkflowConverter:
return workflow return workflow
def _convert_to_app_config(self, app_model: App,
app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
app_mode = AppMode.value_of(app_model.mode)
if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config
)
elif app_mode == AppMode.CHAT:
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config
)
elif app_mode == AppMode.COMPLETION:
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config
)
else:
raise ValueError("Invalid app mode")
return app_config
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
""" """
Convert to Start Node Convert to Start Node
...@@ -566,6 +591,6 @@ class WorkflowConverter: ...@@ -566,6 +591,6 @@ class WorkflowConverter:
:return: :return:
""" """
return db.session.query(APIBasedExtension).filter( return db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id APIBasedExtension.id == api_based_extension_id
).first() ).first()
from unittest.mock import MagicMock from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import EasyUIBasedModelConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.simple_prompt_transform import SimplePromptTransform
...@@ -139,7 +139,7 @@ def test_get_common_chat_app_prompt_template_with_p(): ...@@ -139,7 +139,7 @@ def test_get_common_chat_app_prompt_template_with_p():
def test__get_chat_model_prompt_messages(): def test__get_chat_model_prompt_messages():
model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = 'openai' model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4' model_config_mock.model = 'gpt-4'
...@@ -191,7 +191,7 @@ def test__get_chat_model_prompt_messages(): ...@@ -191,7 +191,7 @@ def test__get_chat_model_prompt_messages():
def test__get_completion_model_prompt_messages(): def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=EasyUIBasedModelConfigEntity) model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = 'openai' model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-3.5-turbo-instruct' model_config_mock.model = 'gpt-3.5-turbo-instruct'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment