Commit 5d2ff412 authored by takatost's avatar takatost

Merge branch 'feat/workflow-backend' into deploy/dev

parents c8757ef0 93ed3946
...@@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair ...@@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair
from models.account import Tenant from models.account import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import Account from models.model import Account, App, AppMode, Conversation
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
...@@ -263,8 +263,62 @@ def vdb_migrate(): ...@@ -263,8 +263,62 @@ def vdb_migrate():
fg='green')) fg='green'))
@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.')
def convert_to_agent_apps():
"""
Convert Agent Assistant to Agent App.
"""
click.echo(click.style('Start convert to agent apps.', fg='green'))
proceeded_app_ids = []
while True:
# fetch first 1000 apps
sql_query = """SELECT a.id AS id FROM apps a
INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
WHERE a.mode = 'chat' AND am.agent_mode is not null
and (am.agent_mode like '%"strategy": "function_call"%' or am.agent_mode like '%"strategy": "react"%')
and am.agent_mode like '{"enabled": true%' ORDER BY a.created_at DESC LIMIT 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query))
apps = []
for i in rs:
app_id = str(i.id)
if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id)
app = db.session.query(App).filter(App.id == app_id).first()
apps.append(app)
if len(apps) == 0:
break
for app in apps:
click.echo('Converting app: {}'.format(app.id))
try:
app.mode = AppMode.AGENT_CHAT.value
db.session.commit()
# update conversation mode to agent
db.session.query(Conversation).filter(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT.value}
)
db.session.commit()
click.echo(click.style('Converted app: {}'.format(app.id), fg='green'))
except Exception as e:
click.echo(
click.style('Convert app error: {} {}'.format(e.__class__.__name__,
str(e)), fg='red'))
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(vdb_migrate) app.cli.add_command(vdb_migrate)
app.cli.add_command(convert_to_agent_apps)
This diff is collapsed.
import json from models.model import AppMode
model_templates = { default_app_templates = {
# completion default mode # workflow default mode
'completion_default': { AppMode.WORKFLOW: {
'app': { 'app': {
'mode': 'completion', 'mode': AppMode.WORKFLOW.value,
'enable_site': True, 'enable_site': True,
'enable_api': True, 'enable_api': True
'is_demo': False, }
'api_rpm': 0, },
'api_rph': 0,
'status': 'normal' # chat default mode
AppMode.CHAT: {
'app': {
'mode': AppMode.CHAT.value,
'enable_site': True,
'enable_api': True
}, },
'model_config': { 'model_config': {
'provider': '', 'model': {
'model_id': '',
'configs': {},
'model': json.dumps({
"provider": "openai", "provider": "openai",
"name": "gpt-3.5-turbo-instruct", "name": "gpt-4",
"mode": "completion", "mode": "chat",
"completion_params": {} "completion_params": {}
}), }
'user_input_form': json.dumps([
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
]),
'pre_prompt': '{{query}}'
} }
}, },
# chat default mode # advanced-chat default mode
'chat_default': { AppMode.ADVANCED_CHAT: {
'app': { 'app': {
'mode': 'chat', 'mode': AppMode.ADVANCED_CHAT.value,
'enable_site': True, 'enable_site': True,
'enable_api': True, 'enable_api': True
'is_demo': False, }
'api_rpm': 0, },
'api_rph': 0,
'status': 'normal' # agent-chat default mode
AppMode.AGENT_CHAT: {
'app': {
'mode': AppMode.AGENT_CHAT.value,
'enable_site': True,
'enable_api': True
}, },
'model_config': { 'model_config': {
'provider': '', 'model': {
'model_id': '',
'configs': {},
'model': json.dumps({
"provider": "openai", "provider": "openai",
"name": "gpt-3.5-turbo", "name": "gpt-4",
"mode": "chat", "mode": "chat",
"completion_params": {} "completion_params": {}
}) }
} }
}, }
} }
...@@ -5,10 +5,10 @@ bp = Blueprint('console', __name__, url_prefix='/console/api') ...@@ -5,10 +5,10 @@ bp = Blueprint('console', __name__, url_prefix='/console/api')
api = ExternalApi(bp) api = ExternalApi(bp)
# Import other controllers # Import other controllers
from . import admin, apikey, extension, feature, setup, version from . import admin, apikey, extension, feature, setup, version, ping
# Import app controllers # Import app controllers
from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message,
model_config, site, statistic) model_config, site, statistic, workflow, workflow_app_log)
# Import auth controllers # Import auth controllers
from .auth import activate, data_source_oauth, login, oauth from .auth import activate, data_source_oauth, login, oauth
# Import billing controllers # Import billing controllers
......
from controllers.console.app.error import AppUnavailableError
from extensions.ext_database import db
from flask_login import current_user
from models.model import App
from werkzeug.exceptions import NotFound
def _get_app(app_id, mode=None):
app = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app:
raise NotFound("App not found")
if mode and app.mode != mode:
raise NotFound("The {} app not found".format(mode))
return app
This diff is collapsed.
...@@ -6,7 +6,6 @@ from werkzeug.exceptions import InternalServerError ...@@ -6,7 +6,6 @@ from werkzeug.exceptions import InternalServerError
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
AudioTooLargeError, AudioTooLargeError,
...@@ -18,11 +17,13 @@ from controllers.console.app.error import ( ...@@ -18,11 +17,13 @@ from controllers.console.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
UnsupportedAudioTypeError, UnsupportedAudioTypeError,
) )
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required from libs.login import login_required
from models.model import AppMode
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
...@@ -36,15 +37,13 @@ class ChatMessageAudioApi(Resource): ...@@ -36,15 +37,13 @@ class ChatMessageAudioApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_id): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
app_id = str(app_id) def post(self, app_model):
app_model = _get_app(app_id, 'chat')
file = request.files['file'] file = request.files['file']
try: try:
response = AudioService.transcript_asr( response = AudioService.transcript_asr(
tenant_id=app_model.tenant_id, app_model=app_model,
file=file, file=file,
end_user=None, end_user=None,
) )
...@@ -80,15 +79,12 @@ class ChatMessageTextApi(Resource): ...@@ -80,15 +79,12 @@ class ChatMessageTextApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_id): @get_app_model
app_id = str(app_id) def post(self, app_model):
app_model = _get_app(app_id, None)
try: try:
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id, app_model=app_model,
text=request.form['text'], text=request.form['text'],
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False streaming=False
) )
...@@ -120,9 +116,11 @@ class ChatMessageTextApi(Resource): ...@@ -120,9 +116,11 @@ class ChatMessageTextApi(Resource):
class TextModesApi(Resource): class TextModesApi(Resource):
def get(self, app_id: str): @setup_required
app_model = _get_app(str(app_id)) @login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('language', type=str, required=True, location='args') parser.add_argument('language', type=str, required=True, location='args')
......
...@@ -10,7 +10,6 @@ from werkzeug.exceptions import InternalServerError, NotFound ...@@ -10,7 +10,6 @@ from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
...@@ -19,14 +18,16 @@ from controllers.console.app.error import ( ...@@ -19,14 +18,16 @@ from controllers.console.app.error import (
ProviderNotInitializeError, ProviderNotInitializeError,
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
from models.model import AppMode
from services.completion_service import CompletionService from services.completion_service import CompletionService
...@@ -36,12 +37,8 @@ class CompletionMessageApi(Resource): ...@@ -36,12 +37,8 @@ class CompletionMessageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_id): @get_app_model(mode=AppMode.COMPLETION)
app_id = str(app_id) def post(self, app_model):
# get app info
app_model = _get_app(app_id, 'completion')
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, location='json', default='')
...@@ -62,8 +59,7 @@ class CompletionMessageApi(Resource): ...@@ -62,8 +59,7 @@ class CompletionMessageApi(Resource):
user=account, user=account,
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming, streaming=streaming
is_model_config_override=True
) )
return compact_response(response) return compact_response(response)
...@@ -93,15 +89,11 @@ class CompletionMessageStopApi(Resource): ...@@ -93,15 +89,11 @@ class CompletionMessageStopApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_id, task_id): @get_app_model(mode=AppMode.COMPLETION)
app_id = str(app_id) def post(self, app_model, task_id):
# get app info
_get_app(app_id, 'completion')
account = flask_login.current_user account = flask_login.current_user
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
...@@ -110,12 +102,8 @@ class ChatMessageApi(Resource): ...@@ -110,12 +102,8 @@ class ChatMessageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_id): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
app_id = str(app_id) def post(self, app_model):
# get app info
app_model = _get_app(app_id, 'chat')
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json')
...@@ -137,8 +125,7 @@ class ChatMessageApi(Resource): ...@@ -137,8 +125,7 @@ class ChatMessageApi(Resource):
user=account, user=account,
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming, streaming=streaming
is_model_config_override=True
) )
return compact_response(response) return compact_response(response)
...@@ -179,15 +166,11 @@ class ChatMessageStopApi(Resource): ...@@ -179,15 +166,11 @@ class ChatMessageStopApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_id, task_id): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
app_id = str(app_id) def post(self, app_model, task_id):
# get app info
_get_app(app_id, 'chat')
account = flask_login.current_user account = flask_login.current_user
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
...@@ -9,7 +9,7 @@ from sqlalchemy.orm import joinedload ...@@ -9,7 +9,7 @@ from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
...@@ -21,7 +21,7 @@ from fields.conversation_fields import ( ...@@ -21,7 +21,7 @@ from fields.conversation_fields import (
) )
from libs.helper import datetime_string from libs.helper import datetime_string
from libs.login import login_required from libs.login import login_required
from models.model import Conversation, Message, MessageAnnotation from models.model import AppMode, Conversation, Message, MessageAnnotation
class CompletionConversationApi(Resource): class CompletionConversationApi(Resource):
...@@ -29,10 +29,9 @@ class CompletionConversationApi(Resource): ...@@ -29,10 +29,9 @@ class CompletionConversationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields) @marshal_with(conversation_pagination_fields)
def get(self, app_id): def get(self, app_model):
app_id = str(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args') parser.add_argument('keyword', type=str, location='args')
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
...@@ -43,10 +42,7 @@ class CompletionConversationApi(Resource): ...@@ -43,10 +42,7 @@ class CompletionConversationApi(Resource):
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
args = parser.parse_args() args = parser.parse_args()
# get app info query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion')
app = _get_app(app_id, 'completion')
query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'completion')
if args['keyword']: if args['keyword']:
query = query.join( query = query.join(
...@@ -106,24 +102,22 @@ class CompletionConversationDetailApi(Resource): ...@@ -106,24 +102,22 @@ class CompletionConversationDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields) @marshal_with(conversation_message_detail_fields)
def get(self, app_id, conversation_id): def get(self, app_model, conversation_id):
app_id = str(app_id)
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'completion') return _get_conversation(app_model, conversation_id)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, app_id, conversation_id): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
app_id = str(app_id) def delete(self, app_model, conversation_id):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
app = _get_app(app_id, 'chat')
conversation = db.session.query(Conversation) \ conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
...@@ -139,10 +133,9 @@ class ChatConversationApi(Resource): ...@@ -139,10 +133,9 @@ class ChatConversationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@marshal_with(conversation_with_summary_pagination_fields) @marshal_with(conversation_with_summary_pagination_fields)
def get(self, app_id): def get(self, app_model):
app_id = str(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args') parser.add_argument('keyword', type=str, location='args')
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
...@@ -154,10 +147,7 @@ class ChatConversationApi(Resource): ...@@ -154,10 +147,7 @@ class ChatConversationApi(Resource):
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args() args = parser.parse_args()
# get app info query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'chat')
app = _get_app(app_id, 'chat')
query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'chat')
if args['keyword']: if args['keyword']:
query = query.join( query = query.join(
...@@ -228,25 +218,22 @@ class ChatConversationDetailApi(Resource): ...@@ -228,25 +218,22 @@ class ChatConversationDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@marshal_with(conversation_detail_fields) @marshal_with(conversation_detail_fields)
def get(self, app_id, conversation_id): def get(self, app_model, conversation_id):
app_id = str(app_id)
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'chat') return _get_conversation(app_model, conversation_id)
@setup_required @setup_required
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@account_initialization_required @account_initialization_required
def delete(self, app_id, conversation_id): def delete(self, app_model, conversation_id):
app_id = str(app_id)
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
# get app info
app = _get_app(app_id, 'chat')
conversation = db.session.query(Conversation) \ conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
...@@ -263,12 +250,9 @@ api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations') ...@@ -263,12 +250,9 @@ api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>') api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>')
def _get_conversation(app_id, conversation_id, mode): def _get_conversation(app_model, conversation_id):
# get app info
app = _get_app(app_id, mode)
conversation = db.session.query(Conversation) \ conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
......
...@@ -85,3 +85,9 @@ class TooManyFilesError(BaseHTTPException): ...@@ -85,3 +85,9 @@ class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files' error_code = 'too_many_files'
description = "Only one file is allowed." description = "Only one file is allowed."
code = 400 code = 400
class DraftWorkflowNotExist(BaseHTTPException):
error_code = 'draft_workflow_not_exist'
description = "Draft workflow need to be initialized."
code = 400
...@@ -11,7 +11,7 @@ from controllers.console.app.error import ( ...@@ -11,7 +11,7 @@ from controllers.console.app.error import (
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required from libs.login import login_required
......
...@@ -10,17 +10,15 @@ from flask_restful.inputs import int_range ...@@ -10,17 +10,15 @@ from flask_restful.inputs import int_range
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.app.error import ( from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError, ProviderNotInitializeError,
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
...@@ -28,10 +26,8 @@ from fields.conversation_fields import annotation_fields, message_detail_fields ...@@ -28,10 +26,8 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required from libs.login import login_required
from models.model import Conversation, Message, MessageAnnotation, MessageFeedback from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.message_service import MessageService from services.message_service import MessageService
...@@ -46,14 +42,10 @@ class ChatMessageListApi(Resource): ...@@ -46,14 +42,10 @@ class ChatMessageListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@account_initialization_required @account_initialization_required
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_id): def get(self, app_model):
app_id = str(app_id)
# get app info
app = _get_app(app_id, 'chat')
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args')
...@@ -62,7 +54,7 @@ class ChatMessageListApi(Resource): ...@@ -62,7 +54,7 @@ class ChatMessageListApi(Resource):
conversation = db.session.query(Conversation).filter( conversation = db.session.query(Conversation).filter(
Conversation.id == args['conversation_id'], Conversation.id == args['conversation_id'],
Conversation.app_id == app.id Conversation.app_id == app_model.id
).first() ).first()
if not conversation: if not conversation:
...@@ -110,12 +102,8 @@ class MessageFeedbackApi(Resource): ...@@ -110,12 +102,8 @@ class MessageFeedbackApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_id): @get_app_model
app_id = str(app_id) def post(self, app_model):
# get app info
app = _get_app(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('message_id', required=True, type=uuid_value, location='json') parser.add_argument('message_id', required=True, type=uuid_value, location='json')
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
...@@ -125,7 +113,7 @@ class MessageFeedbackApi(Resource): ...@@ -125,7 +113,7 @@ class MessageFeedbackApi(Resource):
message = db.session.query(Message).filter( message = db.session.query(Message).filter(
Message.id == message_id, Message.id == message_id,
Message.app_id == app.id Message.app_id == app_model.id
).first() ).first()
if not message: if not message:
...@@ -141,7 +129,7 @@ class MessageFeedbackApi(Resource): ...@@ -141,7 +129,7 @@ class MessageFeedbackApi(Resource):
raise ValueError('rating cannot be None when feedback not exists') raise ValueError('rating cannot be None when feedback not exists')
else: else:
feedback = MessageFeedback( feedback = MessageFeedback(
app_id=app.id, app_id=app_model.id,
conversation_id=message.conversation_id, conversation_id=message.conversation_id,
message_id=message.id, message_id=message.id,
rating=args['rating'], rating=args['rating'],
...@@ -160,21 +148,20 @@ class MessageAnnotationApi(Resource): ...@@ -160,21 +148,20 @@ class MessageAnnotationApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check('annotation') @cloud_edition_billing_resource_check('annotation')
@get_app_model
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('message_id', required=False, type=uuid_value, location='json') parser.add_argument('message_id', required=False, type=uuid_value, location='json')
parser.add_argument('question', required=True, type=str, location='json') parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json') parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument('annotation_reply', required=False, type=dict, location='json') parser.add_argument('annotation_reply', required=False, type=dict, location='json')
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id) annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
return annotation return annotation
...@@ -183,65 +170,15 @@ class MessageAnnotationCountApi(Resource): ...@@ -183,65 +170,15 @@ class MessageAnnotationCountApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): @get_app_model
app_id = str(app_id) def get(self, app_model):
# get app info
app = _get_app(app_id)
count = db.session.query(MessageAnnotation).filter( count = db.session.query(MessageAnnotation).filter(
MessageAnnotation.app_id == app.id MessageAnnotation.app_id == app_model.id
).count() ).count()
return {'count': count} return {'count': count}
class MessageMoreLikeThisApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id, message_id):
app_id = str(app_id)
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'],
location='args')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
# get app info
app_model = _get_app(app_id, 'completion')
try:
response = CompletionService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
def compact_response(response: Union[dict, Generator]) -> Response: def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict): if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json') return Response(response=json.dumps(response), status=200, mimetype='application/json')
...@@ -257,13 +194,10 @@ class MessageSuggestedQuestionApi(Resource): ...@@ -257,13 +194,10 @@ class MessageSuggestedQuestionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id, message_id): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
app_id = str(app_id) def get(self, app_model, message_id):
message_id = str(message_id) message_id = str(message_id)
# get app info
app_model = _get_app(app_id, 'chat')
try: try:
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, app_model=app_model,
...@@ -294,14 +228,11 @@ class MessageApi(Resource): ...@@ -294,14 +228,11 @@ class MessageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
@marshal_with(message_detail_fields) @marshal_with(message_detail_fields)
def get(self, app_id, message_id): def get(self, app_model, message_id):
app_id = str(app_id)
message_id = str(message_id) message_id = str(message_id)
# get app info
app_model = _get_app(app_id)
message = db.session.query(Message).filter( message = db.session.query(Message).filter(
Message.id == message_id, Message.id == message_id,
Message.app_id == app_model.id Message.app_id == app_model.id
...@@ -313,7 +244,6 @@ class MessageApi(Resource): ...@@ -313,7 +244,6 @@ class MessageApi(Resource):
return message return message
api.add_resource(MessageMoreLikeThisApi, '/apps/<uuid:app_id>/completion-messages/<uuid:message_id>/more-like-this')
api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions') api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions')
api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages') api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages')
api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks') api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks')
......
...@@ -4,13 +4,13 @@ from flask_login import current_user ...@@ -4,13 +4,13 @@ from flask_login import current_user
from flask_restful import Resource from flask_restful import Resource
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import login_required
from models.model import AppModelConfig from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
...@@ -19,33 +19,29 @@ class ModelConfigResource(Resource): ...@@ -19,33 +19,29 @@ class ModelConfigResource(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_id): @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config""" """Modify app model config"""
app_id = str(app_id)
app = _get_app(app_id)
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json, config=request.json,
app_mode=app.mode app_mode=AppMode.value_of(app_model.mode)
) )
new_app_model_config = AppModelConfig( new_app_model_config = AppModelConfig(
app_id=app.id, app_id=app_model.id,
) )
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config) db.session.add(new_app_model_config)
db.session.flush() db.session.flush()
app.app_model_config_id = new_app_model_config.id app_model.app_model_config_id = new_app_model_config.id
db.session.commit() db.session.commit()
app_model_config_was_updated.send( app_model_config_was_updated.send(
app, app_model,
app_model_config=new_app_model_config app_model_config=new_app_model_config
) )
......
...@@ -4,7 +4,7 @@ from werkzeug.exceptions import Forbidden, NotFound ...@@ -4,7 +4,7 @@ from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
...@@ -34,13 +34,11 @@ class AppSite(Resource): ...@@ -34,13 +34,11 @@ class AppSite(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_id): def post(self, app_model):
args = parse_app_site_args() args = parse_app_site_args()
app_id = str(app_id)
app_model = _get_app(app_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
...@@ -82,11 +80,9 @@ class AppSiteAccessTokenReset(Resource): ...@@ -82,11 +80,9 @@ class AppSiteAccessTokenReset(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_id): def post(self, app_model):
app_id = str(app_id)
app_model = _get_app(app_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
......
...@@ -7,12 +7,13 @@ from flask_login import current_user ...@@ -7,12 +7,13 @@ from flask_login import current_user
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import datetime_string from libs.helper import datetime_string
from libs.login import login_required from libs.login import login_required
from models.model import AppMode
class DailyConversationStatistic(Resource): class DailyConversationStatistic(Resource):
...@@ -20,10 +21,9 @@ class DailyConversationStatistic(Resource): ...@@ -20,10 +21,9 @@ class DailyConversationStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): @get_app_model
def get(self, app_model):
account = current_user account = current_user
app_id = str(app_id)
app_model = _get_app(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
...@@ -81,10 +81,9 @@ class DailyTerminalsStatistic(Resource): ...@@ -81,10 +81,9 @@ class DailyTerminalsStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): @get_app_model
def get(self, app_model):
account = current_user account = current_user
app_id = str(app_id)
app_model = _get_app(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
...@@ -141,10 +140,9 @@ class DailyTokenCostStatistic(Resource): ...@@ -141,10 +140,9 @@ class DailyTokenCostStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): @get_app_model
def get(self, app_model):
account = current_user account = current_user
app_id = str(app_id)
app_model = _get_app(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
...@@ -205,10 +203,9 @@ class AverageSessionInteractionStatistic(Resource): ...@@ -205,10 +203,9 @@ class AverageSessionInteractionStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def get(self, app_model):
account = current_user account = current_user
app_id = str(app_id)
app_model = _get_app(app_id, 'chat')
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
...@@ -271,10 +268,9 @@ class UserSatisfactionRateStatistic(Resource): ...@@ -271,10 +268,9 @@ class UserSatisfactionRateStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): @get_app_model
def get(self, app_model):
account = current_user account = current_user
app_id = str(app_id)
app_model = _get_app(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
...@@ -334,10 +330,9 @@ class AverageResponseTimeStatistic(Resource): ...@@ -334,10 +330,9 @@ class AverageResponseTimeStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): @get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model):
account = current_user account = current_user
app_id = str(app_id)
app_model = _get_app(app_id, 'completion')
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
...@@ -396,10 +391,9 @@ class TokensPerSecondStatistic(Resource): ...@@ -396,10 +391,9 @@ class TokensPerSecondStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): @get_app_model
def get(self, app_model):
account = current_user account = current_user
app_id = str(app_id)
app_model = _get_app(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
......
import json
from flask_restful import Resource, marshal_with, reqparse
from controllers.console import api
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.workflow_fields import workflow_fields
from libs.login import current_user, login_required
from models.model import App, AppMode
from services.workflow_service import WorkflowService
class DraftWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
def get(self, app_model: App):
"""
Get draft workflow
"""
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if not workflow:
raise DraftWorkflowNotExist()
# return workflow, if not found, return None (initiate graph by frontend)
return workflow
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Sync draft workflow
"""
parser = reqparse.RequestParser()
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
workflow_service = WorkflowService()
workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args.get('graph'),
features=args.get('features'),
account=current_user
)
return {
"result": "success"
}
class DraftWorkflowRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Run draft workflow
"""
# TODO
workflow_service = WorkflowService()
workflow_service.run_draft_workflow(app_model=app_model, account=current_user)
# TODO
return {
"result": "success"
}
class WorkflowTaskStopApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App, task_id: str):
"""
Stop workflow task
"""
# TODO
workflow_service = WorkflowService()
workflow_service.stop_workflow_task(app_model=app_model, task_id=task_id, account=current_user)
return {
"result": "success"
}
class DraftWorkflowNodeRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Run draft workflow node
"""
# TODO
workflow_service = WorkflowService()
workflow_service.run_draft_workflow_node(app_model=app_model, node_id=node_id, account=current_user)
# TODO
return {
"result": "success"
}
class PublishedWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
def get(self, app_model: App):
"""
Get published workflow
"""
# fetch published workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model)
# return workflow, if not found, return None
return workflow
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Publish workflow
"""
workflow_service = WorkflowService()
workflow_service.publish_workflow(app_model=app_model, account=current_user)
return {
"result": "success"
}
class DefaultBlockConfigsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
"""
Get default block config
"""
# Get default block configs
workflow_service = WorkflowService()
return workflow_service.get_default_block_configs()
class DefaultBlockConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, block_type: str):
"""
Get default block config
"""
parser = reqparse.RequestParser()
parser.add_argument('q', type=str, location='args')
args = parser.parse_args()
filters = None
if args.get('q'):
try:
filters = json.loads(args.get('q'))
except json.JSONDecodeError:
raise ValueError('Invalid filters')
# Get default block configs
workflow_service = WorkflowService()
return workflow_service.get_default_block_config(
node_type=block_type,
filters=filters
)
class ConvertToWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model: App):
"""
Convert basic mode of chatbot app to workflow mode
Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App
"""
# convert to workflow mode
workflow_service = WorkflowService()
workflow = workflow_service.convert_to_workflow(
app_model=app_model,
account=current_user
)
# return workflow
return workflow
api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop')
api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<uuid:node_id>/run')
api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/published')
api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs/:block_type')
api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow')
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required
from models.model import App, AppMode
from services.workflow_app_service import WorkflowAppService
class WorkflowAppLogApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_app_log_pagination_fields)
def get(self, app_model: App):
"""
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args')
parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args')
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
args = parser.parse_args()
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model,
args=args
)
return workflow_app_log_pagination
api.add_resource(WorkflowAppLogApi, '/apps/<uuid:app_id>/workflow-app-logs')
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.workflow_run_fields import (
workflow_run_detail_fields,
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from libs.helper import uuid_value
from libs.login import login_required
from models.model import App, AppMode
from services.workflow_run_service import WorkflowRunService
class WorkflowRunListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_pagination_fields)
def get(self, app_model: App):
"""
Get workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model,
args=args
)
return result
class WorkflowRunDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_detail_fields)
def get(self, app_model: App, run_id):
"""
Get workflow run detail
"""
run_id = str(run_id)
workflow_run_service = WorkflowRunService()
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id)
return workflow_run
class WorkflowRunNodeExecutionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_list_fields)
def get(self, app_model: App, run_id):
"""
Get workflow run node execution list
"""
run_id = str(run_id)
workflow_run_service = WorkflowRunService()
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
return {
'data': node_executions
}
api.add_resource(WorkflowRunListApi, '/apps/<uuid:app_id>/workflow-runs')
api.add_resource(WorkflowRunDetailApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>')
api.add_resource(WorkflowRunNodeExecutionListApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions')
from collections.abc import Callable
from functools import wraps
from typing import Optional, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.model import App, AppMode
def get_app_model(view: Optional[Callable] = None, *,
mode: Union[AppMode, list[AppMode]] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
if not kwargs.get('app_id'):
raise ValueError('missing app_id in path parameters')
app_id = kwargs.get('app_id')
app_id = str(app_id)
del kwargs['app_id']
app_model = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app_model:
raise AppNotFoundError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode == AppMode.CHANNEL:
raise AppNotFoundError()
if mode is not None:
if isinstance(mode, list):
modes = mode
else:
modes = [mode]
if app_mode not in modes:
mode_values = {m.value for m in modes}
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
kwargs['app_model'] = app_model
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
...@@ -19,7 +19,6 @@ from controllers.console.app.error import ( ...@@ -19,7 +19,6 @@ from controllers.console.app.error import (
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from models.model import AppModelConfig
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
...@@ -32,16 +31,12 @@ from services.errors.audio import ( ...@@ -32,16 +31,12 @@ from services.errors.audio import (
class ChatAudioApi(InstalledAppResource): class ChatAudioApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file'] file = request.files['file']
try: try:
response = AudioService.transcript_asr( response = AudioService.transcript_asr(
tenant_id=app_model.tenant_id, app_model=app_model,
file=file, file=file,
end_user=None end_user=None
) )
...@@ -76,16 +71,11 @@ class ChatAudioApi(InstalledAppResource): ...@@ -76,16 +71,11 @@ class ChatAudioApi(InstalledAppResource):
class ChatTextApi(InstalledAppResource): class ChatTextApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.text_to_speech_dict['enabled']:
raise AppUnavailableError()
try: try:
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id, app_model=app_model,
text=request.form['text'], text=request.form['text'],
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False streaming=False
) )
return {'data': response.data.decode('latin1')} return {'data': response.data.decode('latin1')}
......
...@@ -21,8 +21,8 @@ from controllers.console.app.error import ( ...@@ -21,8 +21,8 @@ from controllers.console.app.error import (
) )
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
...@@ -90,7 +90,7 @@ class CompletionStopApi(InstalledAppResource): ...@@ -90,7 +90,7 @@ class CompletionStopApi(InstalledAppResource):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
...@@ -154,7 +154,7 @@ class ChatStopApi(InstalledAppResource): ...@@ -154,7 +154,7 @@ class ChatStopApi(InstalledAppResource):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
...@@ -34,8 +34,7 @@ class InstalledAppsListApi(Resource): ...@@ -34,8 +34,7 @@ class InstalledAppsListApi(Resource):
'is_pinned': installed_app.is_pinned, 'is_pinned': installed_app.is_pinned,
'last_used_at': installed_app.last_used_at, 'last_used_at': installed_app.last_used_at,
'editable': current_user.role in ["owner", "admin"], 'editable': current_user.role in ["owner", "admin"],
'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id, 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id
'is_agent': installed_app.is_agent
} }
for installed_app in installed_apps for installed_app in installed_apps
] ]
......
...@@ -24,7 +24,7 @@ from controllers.console.explore.error import ( ...@@ -24,7 +24,7 @@ from controllers.console.explore.error import (
NotCompletionAppError, NotCompletionAppError,
) )
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
......
...@@ -4,9 +4,10 @@ from flask import current_app ...@@ -4,9 +4,10 @@ from flask import current_app
from flask_restful import fields, marshal_with from flask_restful import fields, marshal_with
from controllers.console import api from controllers.console import api
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from extensions.ext_database import db from extensions.ext_database import db
from models.model import AppModelConfig, InstalledApp from models.model import AppMode, AppModelConfig, InstalledApp
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
...@@ -45,30 +46,55 @@ class AppParameterApi(InstalledAppResource): ...@@ -45,30 +46,55 @@ class AppParameterApi(InstalledAppResource):
def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = installed_app.app app_model = installed_app.app
app_model_config = app_model.app_model_config
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form
else:
app_model_config = app_model.app_model_config
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get('user_input_form', [])
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': features_dict.get('opening_statement'),
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': features_dict.get('suggested_questions', []),
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
'speech_to_text': app_model_config.speech_to_text_dict, {"enabled": False}),
'text_to_speech': app_model_config.text_to_speech_dict, 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
'retriever_resource': app_model_config.retriever_resource_dict, 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
'annotation_reply': app_model_config.annotation_reply_dict, 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
'more_like_this': app_model_config.more_like_this_dict, 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
'user_input_form': app_model_config.user_input_form_list, 'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, 'user_input_form': user_input_form,
'file_upload': app_model_config.file_upload_dict, 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
{"enabled": False, "type": "", "configs": []}),
'file_upload': features_dict.get('file_upload', {"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"]
}}),
'system_parameters': { 'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
} }
} }
class ExploreAppMetaApi(InstalledAppResource): class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Get app meta""" """Get app meta"""
app_model_config: AppModelConfig = installed_app.app.app_model_config app_model_config: AppModelConfig = installed_app.app.app_model_config
if not app_model_config:
return {
'tool_icons': {}
}
agent_config = app_model_config.agent_mode_dict or {} agent_config = app_model_config.agent_mode_dict or {}
meta = { meta = {
'tool_icons': {} 'tool_icons': {}
...@@ -77,7 +103,7 @@ class ExploreAppMetaApi(InstalledAppResource): ...@@ -77,7 +103,7 @@ class ExploreAppMetaApi(InstalledAppResource):
# get all tools # get all tools
tools = agent_config.get('tools', []) tools = agent_config.get('tools', [])
url_prefix = (current_app.config.get("CONSOLE_API_URL") url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ "/console/api/workspaces/current/tool-provider/builtin/") + "/console/api/workspaces/current/tool-provider/builtin/")
for tool in tools: for tool in tools:
keys = list(tool.keys()) keys = list(tool.keys())
if len(keys) >= 4: if len(keys) >= 4:
...@@ -94,12 +120,14 @@ class ExploreAppMetaApi(InstalledAppResource): ...@@ -94,12 +120,14 @@ class ExploreAppMetaApi(InstalledAppResource):
) )
meta['tool_icons'][tool_name] = json.loads(provider.icon) meta['tool_icons'][tool_name] = json.loads(provider.icon)
except: except:
meta['tool_icons'][tool_name] = { meta['tool_icons'][tool_name] = {
"background": "#252525", "background": "#252525",
"content": "\ud83d\ude01" "content": "\ud83d\ude01"
} }
return meta return meta
api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters', endpoint='installed_app_parameters')
api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters',
endpoint='installed_app_parameters')
api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta') api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, fields, marshal_with from flask_restful import Resource, fields, marshal_with, reqparse
from sqlalchemy import and_
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import api
from controllers.console.app.error import AppNotFoundError from controllers.console.app.error import AppNotFoundError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from models.model import App, RecommendedApp
from models.model import App, InstalledApp, RecommendedApp from services.app_service import AppService
from services.account_service import TenantService
app_fields = { app_fields = {
'id': fields.String, 'id': fields.String,
...@@ -27,11 +24,7 @@ recommended_app_fields = { ...@@ -27,11 +24,7 @@ recommended_app_fields = {
'privacy_policy': fields.String, 'privacy_policy': fields.String,
'category': fields.String, 'category': fields.String,
'position': fields.Integer, 'position': fields.Integer,
'is_listed': fields.Boolean, 'is_listed': fields.Boolean
'install_count': fields.Integer,
'installed': fields.Boolean,
'editable': fields.Boolean,
'is_agent': fields.Boolean
} }
recommended_app_list_fields = { recommended_app_list_fields = {
...@@ -41,11 +34,19 @@ recommended_app_list_fields = { ...@@ -41,11 +34,19 @@ recommended_app_list_fields = {
class RecommendedAppListApi(Resource): class RecommendedAppListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_fields) @marshal_with(recommended_app_list_fields)
def get(self): def get(self):
language_prefix = current_user.interface_language if current_user.interface_language else languages[0] # language args
parser = reqparse.RequestParser()
parser.add_argument('language', type=str, location='args')
args = parser.parse_args()
if args.get('language') and args.get('language') in languages:
language_prefix = args.get('language')
elif current_user and current_user.interface_language:
language_prefix = current_user.interface_language
else:
language_prefix = languages[0]
recommended_apps = db.session.query(RecommendedApp).filter( recommended_apps = db.session.query(RecommendedApp).filter(
RecommendedApp.is_listed == True, RecommendedApp.is_listed == True,
...@@ -53,16 +54,8 @@ class RecommendedAppListApi(Resource): ...@@ -53,16 +54,8 @@ class RecommendedAppListApi(Resource):
).all() ).all()
categories = set() categories = set()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
recommended_apps_result = [] recommended_apps_result = []
for recommended_app in recommended_apps: for recommended_app in recommended_apps:
installed = db.session.query(InstalledApp).filter(
and_(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id == current_user.current_tenant_id
)
).first() is not None
app = recommended_app.app app = recommended_app.app
if not app or not app.is_public: if not app or not app.is_public:
continue continue
...@@ -80,11 +73,7 @@ class RecommendedAppListApi(Resource): ...@@ -80,11 +73,7 @@ class RecommendedAppListApi(Resource):
'privacy_policy': site.privacy_policy, 'privacy_policy': site.privacy_policy,
'category': recommended_app.category, 'category': recommended_app.category,
'position': recommended_app.position, 'position': recommended_app.position,
'is_listed': recommended_app.is_listed, 'is_listed': recommended_app.is_listed
'install_count': recommended_app.install_count,
'installed': installed,
'editable': current_user.role in ['owner', 'admin'],
"is_agent": app.is_agent
} }
recommended_apps_result.append(recommended_app_result) recommended_apps_result.append(recommended_app_result)
...@@ -94,29 +83,6 @@ class RecommendedAppListApi(Resource): ...@@ -94,29 +83,6 @@ class RecommendedAppListApi(Resource):
class RecommendedAppApi(Resource): class RecommendedAppApi(Resource):
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}
app_simple_detail_fields = {
'id': fields.String,
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'mode': fields.String,
'app_model_config': fields.Nested(model_config_fields),
}
@login_required
@account_initialization_required
@marshal_with(app_simple_detail_fields)
def get(self, app_id): def get(self, app_id):
app_id = str(app_id) app_id = str(app_id)
...@@ -130,11 +96,21 @@ class RecommendedAppApi(Resource): ...@@ -130,11 +96,21 @@ class RecommendedAppApi(Resource):
raise AppNotFoundError raise AppNotFoundError
# get app detail # get app detail
app = db.session.query(App).filter(App.id == app_id).first() app_model = db.session.query(App).filter(App.id == app_id).first()
if not app or not app.is_public: if not app_model or not app_model.is_public:
raise AppNotFoundError raise AppNotFoundError
return app app_service = AppService()
export_str = app_service.export_app(app_model)
return {
'id': app_model.id,
'name': app_model.name,
'icon': app_model.icon,
'icon_background': app_model.icon_background,
'mode': app_model.mode,
'export_data': export_str
}
api.add_resource(RecommendedAppListApi, '/explore/apps') api.add_resource(RecommendedAppListApi, '/explore/apps')
......
from flask_restful import Resource
from controllers.console import api
class PingApi(Resource):
def get(self):
"""
For connection health check
"""
return {
"result": "pong"
}
api.add_resource(PingApi, '/ping')
...@@ -16,26 +16,13 @@ from controllers.console.workspace.error import ( ...@@ -16,26 +16,13 @@ from controllers.console.workspace.error import (
) )
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone from libs.helper import TimestampField, timezone
from libs.login import login_required from libs.login import login_required
from models.account import AccountIntegrate, InvitationCode from models.account import AccountIntegrate, InvitationCode
from services.account_service import AccountService from services.account_service import AccountService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
account_fields = {
'id': fields.String,
'name': fields.String,
'avatar': fields.String,
'email': fields.String,
'is_password_set': fields.Boolean,
'interface_language': fields.String,
'interface_theme': fields.String,
'timezone': fields.String,
'last_login_at': TimestampField,
'last_login_ip': fields.String,
'created_at': TimestampField
}
class AccountInitApi(Resource): class AccountInitApi(Resource):
......
from flask import current_app from flask import current_app
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, abort, fields, marshal_with, reqparse from flask_restful import Resource, abort, marshal_with, reqparse
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from fields.member_fields import account_with_role_list_fields
from libs.login import login_required from libs.login import login_required
from models.account import Account from models.account import Account
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError from services.errors.account import AccountAlreadyInTenantError
account_fields = {
'id': fields.String,
'name': fields.String,
'avatar': fields.String,
'email': fields.String,
'last_login_at': TimestampField,
'created_at': TimestampField,
'role': fields.String,
'status': fields.String,
}
account_list_fields = {
'accounts': fields.List(fields.Nested(account_fields))
}
class MemberListApi(Resource): class MemberListApi(Resource):
"""List all members of current tenant.""" """List all members of current tenant."""
...@@ -35,7 +20,7 @@ class MemberListApi(Resource): ...@@ -35,7 +20,7 @@ class MemberListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_list_fields) @marshal_with(account_with_role_list_fields)
def get(self): def get(self):
members = TenantService.get_tenant_members(current_user.current_tenant) members = TenantService.get_tenant_members(current_user.current_tenant)
return {'result': 'success', 'accounts': members}, 200 return {'result': 'success', 'accounts': members}, 200
......
...@@ -4,9 +4,10 @@ from flask import current_app ...@@ -4,9 +4,10 @@ from flask import current_app
from flask_restful import fields, marshal_with, Resource from flask_restful import fields, marshal_with, Resource
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppModelConfig from models.model import App, AppModelConfig, AppMode
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
...@@ -46,31 +47,55 @@ class AppParameterApi(Resource): ...@@ -46,31 +47,55 @@ class AppParameterApi(Resource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App): def get(self, app_model: App):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model_config = app_model.app_model_config if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form
else:
app_model_config = app_model.app_model_config
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get('user_input_form', [])
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': features_dict.get('opening_statement'),
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': features_dict.get('suggested_questions', []),
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
'speech_to_text': app_model_config.speech_to_text_dict, {"enabled": False}),
'text_to_speech': app_model_config.text_to_speech_dict, 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
'retriever_resource': app_model_config.retriever_resource_dict, 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
'annotation_reply': app_model_config.annotation_reply_dict, 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
'more_like_this': app_model_config.more_like_this_dict, 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
'user_input_form': app_model_config.user_input_form_list, 'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, 'user_input_form': user_input_form,
'file_upload': app_model_config.file_upload_dict, 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
{"enabled": False, "type": "", "configs": []}),
'file_upload': features_dict.get('file_upload', {"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"]
}}),
'system_parameters': { 'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
} }
} }
class AppMetaApi(Resource): class AppMetaApi(Resource):
@validate_app_token @validate_app_token
def get(self, app_model: App): def get(self, app_model: App):
"""Get app meta""" """Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config:
return {
'tool_icons': {}
}
agent_config = app_model_config.agent_mode_dict or {} agent_config = app_model_config.agent_mode_dict or {}
meta = { meta = {
'tool_icons': {} 'tool_icons': {}
......
...@@ -20,7 +20,7 @@ from controllers.service_api.app.error import ( ...@@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig, EndUser from models.model import App, EndUser
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
...@@ -33,16 +33,11 @@ from services.errors.audio import ( ...@@ -33,16 +33,11 @@ from services.errors.audio import (
class AudioApi(Resource): class AudioApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser): def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file'] file = request.files['file']
try: try:
response = AudioService.transcript_asr( response = AudioService.transcript_asr(
tenant_id=app_model.tenant_id, app_model=app_model,
file=file, file=file,
end_user=end_user end_user=end_user
) )
...@@ -84,10 +79,9 @@ class TextApi(Resource): ...@@ -84,10 +79,9 @@ class TextApi(Resource):
try: try:
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id, app_model=app_model,
text=args['text'], text=args['text'],
end_user=end_user, end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming'] streaming=args['streaming']
) )
......
...@@ -19,8 +19,8 @@ from controllers.service_api.app.error import ( ...@@ -19,8 +19,8 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
...@@ -85,7 +85,7 @@ class CompletionStopApi(Resource): ...@@ -85,7 +85,7 @@ class CompletionStopApi(Resource):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
...@@ -147,7 +147,7 @@ class ChatStopApi(Resource): ...@@ -147,7 +147,7 @@ class ChatStopApi(Resource):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
...@@ -4,9 +4,10 @@ from flask import current_app ...@@ -4,9 +4,10 @@ from flask import current_app
from flask_restful import fields, marshal_with from flask_restful import fields, marshal_with
from controllers.web import api from controllers.web import api
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppModelConfig from models.model import App, AppModelConfig, AppMode
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
...@@ -44,30 +45,52 @@ class AppParameterApi(WebApiResource): ...@@ -44,30 +45,52 @@ class AppParameterApi(WebApiResource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model_config = app_model.app_model_config if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form
else:
app_model_config = app_model.app_model_config
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get('user_input_form', [])
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': features_dict.get('opening_statement'),
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': features_dict.get('suggested_questions', []),
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
'speech_to_text': app_model_config.speech_to_text_dict, {"enabled": False}),
'text_to_speech': app_model_config.text_to_speech_dict, 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
'retriever_resource': app_model_config.retriever_resource_dict, 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
'annotation_reply': app_model_config.annotation_reply_dict, 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
'more_like_this': app_model_config.more_like_this_dict, 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
'user_input_form': app_model_config.user_input_form_list, 'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, 'user_input_form': user_input_form,
'file_upload': app_model_config.file_upload_dict, 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
{"enabled": False, "type": "", "configs": []}),
'file_upload': features_dict.get('file_upload', {"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"]
}}),
'system_parameters': { 'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
} }
} }
class AppMeta(WebApiResource): class AppMeta(WebApiResource):
def get(self, app_model: App, end_user): def get(self, app_model: App, end_user):
"""Get app meta""" """Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config:
raise AppUnavailableError()
agent_config = app_model_config.agent_mode_dict or {} agent_config = app_model_config.agent_mode_dict or {}
meta = { meta = {
'tool_icons': {} 'tool_icons': {}
......
...@@ -19,7 +19,7 @@ from controllers.web.error import ( ...@@ -19,7 +19,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig from models.model import App
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
...@@ -31,16 +31,11 @@ from services.errors.audio import ( ...@@ -31,16 +31,11 @@ from services.errors.audio import (
class AudioApi(WebApiResource): class AudioApi(WebApiResource):
def post(self, app_model: App, end_user): def post(self, app_model: App, end_user):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file'] file = request.files['file']
try: try:
response = AudioService.transcript_asr( response = AudioService.transcript_asr(
tenant_id=app_model.tenant_id, app_model=app_model,
file=file, file=file,
end_user=end_user end_user=end_user
) )
...@@ -74,17 +69,11 @@ class AudioApi(WebApiResource): ...@@ -74,17 +69,11 @@ class AudioApi(WebApiResource):
class TextApi(WebApiResource): class TextApi(WebApiResource):
def post(self, app_model: App, end_user): def post(self, app_model: App, end_user):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.text_to_speech_dict['enabled']:
raise AppUnavailableError()
try: try:
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id, app_model=app_model,
text=request.form['text'], text=request.form['text'],
end_user=end_user.external_user_id, end_user=end_user.external_user_id,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False streaming=False
) )
......
...@@ -20,8 +20,8 @@ from controllers.web.error import ( ...@@ -20,8 +20,8 @@ from controllers.web.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
...@@ -84,7 +84,7 @@ class CompletionStopApi(WebApiResource): ...@@ -84,7 +84,7 @@ class CompletionStopApi(WebApiResource):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
...@@ -144,7 +144,7 @@ class ChatStopApi(WebApiResource): ...@@ -144,7 +144,7 @@ class ChatStopApi(WebApiResource):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
...@@ -21,7 +21,7 @@ from controllers.web.error import ( ...@@ -21,7 +21,7 @@ from controllers.web.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.entities.application_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
......
...@@ -83,7 +83,3 @@ class AppSiteInfo: ...@@ -83,7 +83,3 @@ class AppSiteInfo:
'remove_webapp_brand': remove_webapp_brand, 'remove_webapp_brand': remove_webapp_brand,
'replace_webapp_logo': replace_webapp_logo, 'replace_webapp_logo': replace_webapp_logo,
} }
if app.enable_site and site.prompt_public:
app_model_config = app.app_model_config
self.model_config = app_model_config
...@@ -5,18 +5,17 @@ from datetime import datetime ...@@ -5,18 +5,17 @@ from datetime import datetime
from mimetypes import guess_extension from mimetypes import guess_extension
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.app_runner.app_runner import AppRunner from core.agent.entities import AgentEntity, AgentToolEntity
from core.application_queue_manager import ApplicationQueueManager from core.app.app_queue_manager import AppQueueManager
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.app.apps.base_app_runner import AppRunner
from core.entities.application_entities import ( from core.app.entities.app_invoke_entities import (
AgentEntity, AgentChatAppGenerateEntity,
AgentToolEntity,
ApplicationGenerateEntity,
AppOrchestrationConfigEntity,
InvokeFrom, InvokeFrom,
ModelConfigEntity, ModelConfigWithCredentialsEntity,
) )
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file.message_file_parser import FileTransferMethod from core.file.message_file_parser import FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
...@@ -48,13 +47,13 @@ from models.tools import ToolConversationVariables ...@@ -48,13 +47,13 @@ from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseAssistantApplicationRunner(AppRunner): class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str, def __init__(self, tenant_id: str,
application_generate_entity: ApplicationGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
app_orchestration_config: AppOrchestrationConfigEntity, app_config: AgentChatAppConfig,
model_config: ModelConfigEntity, model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity, config: AgentEntity,
queue_manager: ApplicationQueueManager, queue_manager: AppQueueManager,
message: Message, message: Message,
user_id: str, user_id: str,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
...@@ -66,7 +65,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -66,7 +65,7 @@ class BaseAssistantApplicationRunner(AppRunner):
""" """
Agent runner Agent runner
:param tenant_id: tenant id :param tenant_id: tenant id
:param app_orchestration_config: app orchestration config :param app_config: app generate entity
:param model_config: model config :param model_config: model config
:param config: dataset config :param config: dataset config
:param queue_manager: queue manager :param queue_manager: queue manager
...@@ -78,7 +77,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -78,7 +77,7 @@ class BaseAssistantApplicationRunner(AppRunner):
""" """
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.app_orchestration_config = app_orchestration_config self.app_config = app_config
self.model_config = model_config self.model_config = model_config
self.config = config self.config = config
self.queue_manager = queue_manager self.queue_manager = queue_manager
...@@ -97,16 +96,16 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -97,16 +96,16 @@ class BaseAssistantApplicationRunner(AppRunner):
# init dataset tools # init dataset tools
hit_callback = DatasetIndexToolCallbackHandler( hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=queue_manager, queue_manager=queue_manager,
app_id=self.application_generate_entity.app_id, app_id=self.app_config.app_id,
message_id=message.id, message_id=message.id,
user_id=user_id, user_id=user_id,
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
) )
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
tenant_id=tenant_id, tenant_id=tenant_id,
dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [], dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None, retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_orchestration_config.show_retrieve_source, return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback hit_callback=hit_callback
) )
...@@ -123,14 +122,15 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -123,14 +122,15 @@ class BaseAssistantApplicationRunner(AppRunner):
else: else:
self.stream_tool_call = False self.stream_tool_call = False
def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
""" """
Repack app orchestration config Repack app generate entity
""" """
if app_orchestration_config.prompt_template.simple_prompt_template is None: if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
app_orchestration_config.prompt_template.simple_prompt_template = '' app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
return app_orchestration_config return app_generate_entity
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
""" """
...@@ -156,7 +156,7 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -156,7 +156,7 @@ class BaseAssistantApplicationRunner(AppRunner):
""" """
tool_entity = ToolManager.get_tool_runtime( tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name, provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tenant_id=self.application_generate_entity.tenant_id, tenant_id=self.app_config.tenant_id,
agent_callback=self.agent_callback agent_callback=self.agent_callback
) )
tool_entity.load_variables(self.variables_pool) tool_entity.load_variables(self.variables_pool)
......
...@@ -3,9 +3,9 @@ import re ...@@ -3,9 +3,9 @@ import re
from collections.abc import Generator from collections.abc import Generator
from typing import Literal, Union from typing import Literal, Union
from core.application_queue_manager import PublishFrom from core.agent.base_agent_runner import BaseAgentRunner
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
from core.features.assistant_base_runner import BaseAssistantApplicationRunner from core.app.app_queue_manager import PublishFrom
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -27,7 +27,7 @@ from core.tools.errors import ( ...@@ -27,7 +27,7 @@ from core.tools.errors import (
from models.model import Conversation, Message from models.model import Conversation, Message
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): class CotAgentRunner(BaseAgentRunner):
def run(self, conversation: Conversation, def run(self, conversation: Conversation,
message: Message, message: Message,
query: str, query: str,
...@@ -36,32 +36,34 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -36,32 +36,34 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
""" """
Run Cot agent application Run Cot agent application
""" """
app_orchestration_config = self.app_orchestration_config app_generate_entity = self.application_generate_entity
self._repack_app_orchestration_config(app_orchestration_config) self._repack_app_generate_entity(app_generate_entity)
agent_scratchpad: list[AgentScratchpadUnit] = [] agent_scratchpad: list[AgentScratchpadUnit] = []
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
# check model mode # check model mode
if self.app_orchestration_config.model_config.mode == "completion": if app_generate_entity.model_config.mode == "completion":
# TODO: stop words # TODO: stop words
if 'Observation' not in app_orchestration_config.model_config.stop: if 'Observation' not in app_generate_entity.model_config.stop:
app_orchestration_config.model_config.stop.append('Observation') app_generate_entity.model_config.stop.append('Observation')
app_config = self.app_config
# override inputs # override inputs
inputs = inputs or {} inputs = inputs or {}
instruction = self.app_orchestration_config.prompt_template.simple_prompt_template instruction = app_config.prompt_template.simple_prompt_template
instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
prompt_messages = self.history_prompt_messages prompt_messages = self.history_prompt_messages
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
prompt_messages_tools: list[PromptMessageTool] = [] prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {} tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: for tool in app_config.agent.tools if app_config.agent else []:
try: try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception: except Exception:
...@@ -121,11 +123,11 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -121,11 +123,11 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# update prompt messages # update prompt messages
prompt_messages = self._organize_cot_prompt_messages( prompt_messages = self._organize_cot_prompt_messages(
mode=app_orchestration_config.model_config.mode, mode=app_generate_entity.model_config.mode,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=prompt_messages_tools, tools=prompt_messages_tools,
agent_scratchpad=agent_scratchpad, agent_scratchpad=agent_scratchpad,
agent_prompt_message=app_orchestration_config.agent.prompt, agent_prompt_message=app_config.agent.prompt,
instruction=instruction, instruction=instruction,
input=query input=query
) )
...@@ -135,9 +137,9 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -135,9 +137,9 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# invoke model # invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=app_generate_entity.model_config.parameters,
tools=[], tools=[],
stop=app_orchestration_config.model_config.stop, stop=app_generate_entity.model_config.stop,
stream=True, stream=True,
user=self.user_id, user=self.user_id,
callbacks=[], callbacks=[],
...@@ -542,7 +544,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -542,7 +544,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
""" """
convert agent scratchpad list to str convert agent scratchpad list to str
""" """
next_iteration = self.app_orchestration_config.agent.prompt.next_iteration next_iteration = self.app_config.agent.prompt.next_iteration
result = '' result = ''
for scratchpad in agent_scratchpad: for scratchpad in agent_scratchpad:
......
from enum import Enum
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel
class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
class AgentScratchpadUnit(BaseModel):
"""
Agent First Prompt Entity.
"""
class Action(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
agent_response: Optional[str] = None
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
class AgentEntity(BaseModel):
"""
Agent Entity.
"""
class Strategy(Enum):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
provider: str
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] = None
max_iteration: int = 5
...@@ -3,8 +3,8 @@ import logging ...@@ -3,8 +3,8 @@ import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Union from typing import Any, Union
from core.application_queue_manager import PublishFrom from core.agent.base_agent_runner import BaseAgentRunner
from core.features.assistant_base_runner import BaseAssistantApplicationRunner from core.app.app_queue_manager import PublishFrom
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -26,7 +26,7 @@ from models.model import Conversation, Message, MessageAgentThought ...@@ -26,7 +26,7 @@ from models.model import Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, conversation: Conversation, def run(self, conversation: Conversation,
message: Message, message: Message,
query: str, query: str,
...@@ -34,9 +34,11 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ...@@ -34,9 +34,11 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
""" """
Run FunctionCall agent application Run FunctionCall agent application
""" """
app_orchestration_config = self.app_orchestration_config app_generate_entity = self.application_generate_entity
prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or '' app_config = self.app_config
prompt_template = app_config.prompt_template.simple_prompt_template or ''
prompt_messages = self.history_prompt_messages prompt_messages = self.history_prompt_messages
prompt_messages = self.organize_prompt_messages( prompt_messages = self.organize_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
...@@ -47,7 +49,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ...@@ -47,7 +49,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
prompt_messages_tools: list[PromptMessageTool] = [] prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {} tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: for tool in app_config.agent.tools if app_config.agent else []:
try: try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception: except Exception:
...@@ -67,7 +69,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ...@@ -67,7 +69,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
tool_instances[dataset_tool.identity.name] = dataset_tool tool_instances[dataset_tool.identity.name] = dataset_tool
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = True function_call_state = True
...@@ -110,9 +112,9 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ...@@ -110,9 +112,9 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
# invoke model # invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=app_generate_entity.model_config.parameters,
tools=prompt_messages_tools, tools=prompt_messages_tools,
stop=app_orchestration_config.model_config.stop, stop=app_generate_entity.model_config.stop,
stream=self.stream_tool_call, stream=self.stream_tool_call,
user=self.user_id, user=self.user_id,
callbacks=[], callbacks=[],
......
from typing import Optional, Union
from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import (
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import AppModelConfig
class BaseAppConfigManager:
@classmethod
def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom,
app_model_config: Union[AppModelConfig, dict],
config_dict: Optional[dict] = None) -> dict:
"""
Convert app model config to config dict
:param config_from: app model config from
:param app_model_config: app model config
:param config_dict: app model config dict
:return:
"""
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
return config_dict
@classmethod
def convert_features(cls, config_dict: dict) -> AppAdditionalFeatures:
"""
Convert app config to app model config
:param config_dict: app config
"""
config_dict = config_dict.copy()
additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
config=config_dict
)
additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict
)
additional_features.opening_statement, additional_features.suggested_questions = \
OpeningStatementConfigManager.convert(
config=config_dict
)
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
config=config_dict
)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
config=config_dict
)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
config=config_dict
)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
config=config_dict
)
return additional_features
from typing import Optional
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
if not sensitive_word_avoidance_dict:
return None
if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
)
else:
return None
@classmethod
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
-> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
config["sensitive_word_avoidance"]["enabled"] = False
if config["sensitive_word_avoidance"]["enabled"]:
if not config["sensitive_word_avoidance"].get("type"):
raise ValueError("sensitive_word_avoidance.type is required")
if not only_structure_validate:
typ = config["sensitive_word_avoidance"]["type"]
config = config["sensitive_word_avoidance"]["config"]
ModerationFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=config
)
return config, ["sensitive_word_avoidance"]
from typing import Optional
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
class AgentConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[AgentEntity]:
"""
Convert model config to model config
:param config: model config args
"""
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode'] \
and config['agent_mode']['enabled']:
agent_dict = config.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == 'cot' or agent_strategy == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config['model']['provider'] == 'openai':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
continue
agent_tool_properties = {
'provider_type': tool['provider_type'],
'provider_id': tool['provider_id'],
'tool_name': tool['tool_name'],
'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
if 'strategy' in config['agent_mode'] and \
config['agent_mode']['strategy'] not in ['react_router', 'router']:
agent_prompt = agent_dict.get('prompt', None) or {}
# check model mode
model_mode = config.get('model', {}).get('mode', 'completion')
if model_mode == 'completion':
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['completion'][
'agent_scratchpad']),
)
else:
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
)
return AgentEntity(
provider=config['model']['provider'],
model=config['model']['name'],
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get('max_iteration', 5)
)
return None
from typing import Optional
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode
from services.dataset_service import DatasetService
class DatasetConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[DatasetEntity]:
"""
Convert model config to model config
:param config: model config args
"""
dataset_ids = []
if 'datasets' in config.get('dataset_configs', {}):
datasets = config.get('dataset_configs', {}).get('datasets', {
'strategy': 'router',
'datasets': []
})
for dataset in datasets.get('datasets', []):
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != 'dataset':
continue
dataset = dataset['dataset']
if 'enabled' not in dataset or not dataset['enabled']:
continue
dataset_id = dataset.get('id', None)
if dataset_id:
dataset_ids.append(dataset_id)
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode'] \
and config['agent_mode']['enabled']:
agent_dict = config.get('agent_mode', {})
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) == 1:
# old standard
key = list(tool.keys())[0]
if key != 'dataset':
continue
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item['id']
dataset_ids.append(dataset_id)
if len(dataset_ids) == 0:
return None
# dataset configs
dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'})
query_variable = config.get('dataset_query_variable')
if dataset_configs['retrieval_model'] == 'single':
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
)
)
)
else:
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get('top_k'),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model')
)
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for dataset feature
:param tenant_id: tenant ID
:param app_mode: app mode
:param config: app model config args
"""
# Extract dataset config for legacy compatibility
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
# dataset_configs
if not config.get("dataset_configs"):
config["dataset_configs"] = {'retrieval_model': 'single'}
if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {
"strategy": "router",
"datasets": []
}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if config["dataset_configs"]['retrieval_model'] == 'multiple':
if not config["dataset_configs"]['reranking_model']:
raise ValueError("reranking_model has not been set")
if not isinstance(config["dataset_configs"]['reranking_model'], dict):
raise ValueError("reranking_model must be of object type")
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
dataset_query_variable = config.get("dataset_query_variable")
if not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist")
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict:
"""
Extract dataset config for legacy compatibility
:param tenant_id: tenant ID
:param app_mode: app mode
:param config: app model config args
"""
# Extract dataset config for legacy compatibility
if not config.get("agent_mode"):
config["agent_mode"] = {
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
# enabled
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
config["agent_mode"]["enabled"] = False
if not isinstance(config["agent_mode"]["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type")
# tools
if not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
# strategy
if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0]
if key == "dataset":
# old style, use tool name as key
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
tool_item["enabled"] = False
if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if 'id' not in tool_item:
raise ValueError("id is required in dataset")
try:
uuid.UUID(tool_item["id"])
except ValueError:
raise ValueError("id in dataset must be of UUID type")
if not cls.is_dataset_exists(tenant_id, tool_item["id"]):
raise ValueError("Dataset ID does not exist, please check your permission.")
has_datasets = True
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
dataset_query_variable = config.get("dataset_query_variable")
if not dataset_query_variable:
raise ValueError("Dataset query variable is required when dataset is exist")
return config
@classmethod
def is_dataset_exists(cls, tenant_id: str, dataset_id: str) -> bool:
# verify if the dataset ID exists
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
return False
if dataset.tenant_id != tenant_id:
return False
return True
from typing import cast
from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.provider_manager import ProviderManager
class ModelConfigConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig,
skip_check: bool = False) \
-> ModelConfigWithCredentialsEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
:param skip_check: skip check
:raises ProviderTokenNotInitError: provider token not init error
:return: app orchestration config entity
"""
model_config = app_config.model
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id,
provider=model_config.provider,
model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
model_name = model_config.model
model_type_instance = provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_config.model
)
if model_credentials is None:
if not skip_check:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else:
model_credentials = {}
if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model,
model_type=ModelType.LLM
)
if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model_config.parameters
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = model_config.mode
if not model_mode:
mode_enum = model_type_instance.get_model_mode(
model=model_config.model,
credentials=model_credentials
)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(
model_config.model,
model_credentials
)
if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return ModelConfigWithCredentialsEntity(
provider=model_config.provider,
model=model_config.model,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
from core.app.app_config.entities import ModelConfigEntity
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.provider_manager import ProviderManager
class ModelConfigManager:
@classmethod
def convert(cls, config: dict) -> ModelConfigEntity:
"""
Convert model config to model config
:param config: model config args
"""
# model config
model_config = config.get('model')
if not model_config:
raise ValueError("model is required")
completion_params = model_config.get('completion_params')
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = model_config.get('mode')
return ModelConfigEntity(
provider=config['model']['provider'],
model=config['model']['name'],
mode=model_mode,
parameters=completion_params,
stop=stop,
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for model config
:param tenant_id: tenant id
:param config: app model config args
"""
if 'model' not in config:
raise ValueError("model is required")
if not isinstance(config["model"], dict):
raise ValueError("model must be of object type")
# model.provider
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name
if 'name' not in config["model"]:
raise ValueError("model.name is required")
provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"],
model_type=ModelType.LLM
)
if not models:
raise ValueError("model.name must be in the specified model list")
model_ids = [m.model for m in models]
if config["model"]["name"] not in model_ids:
raise ValueError("model.name must be in the specified model list")
model_mode = None
for model in models:
if model.model == config["model"]["name"]:
model_mode = model.model_properties.get(ModelPropertyKey.MODE)
break
# model.mode
if model_mode:
config['model']["mode"] = model_mode
else:
config['model']["mode"] = "completion"
# model.completion_params
if 'completion_params' not in config["model"]:
raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = cls.validate_model_completion_params(
config["model"]["completion_params"]
)
return config, ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict) -> dict:
# model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")
# stop
if 'stop' not in cp:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")
if len(cp["stop"]) > 4:
raise ValueError("stop sequences must be less than 4")
return cp
from core.app.app_config.entities import (
AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity,
PromptTemplateEntity,
)
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.prompt.simple_prompt_transform import ModelMode
from models.model import AppMode
class PromptTemplateConfigManager:
@classmethod
def convert(cls, config: dict) -> PromptTemplateEntity:
if not config.get("prompt_type"):
raise ValueError("prompt_type is required")
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = config.get("pre_prompt", "")
return PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else:
advanced_chat_prompt_template = None
chat_prompt_config = config.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'],
}
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
**completion_prompt_template_params
)
return PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
)
@classmethod
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
"""
Validate pre_prompt and set defaults for prompt feature
depending on the config['model']
:param app_mode: app mode
:param config: app model config args
"""
if not config.get("prompt_type"):
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config['prompt_type'] not in prompt_type_vals:
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
# chat_prompt_config
if not config.get("chat_prompt_config"):
config["chat_prompt_config"] = {}
if not isinstance(config["chat_prompt_config"], dict):
raise ValueError("chat_prompt_config must be of object type")
# completion_prompt_config
if not config.get("completion_prompt_config"):
config["completion_prompt_config"] = {}
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required "
"when prompt_type is advanced")
model_mode_vals = [mode.value for mode in ModelMode]
if config['model']["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
if not user_prefix:
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")
else:
# pre_prompt, for simple mode
if not config.get("pre_prompt"):
config["pre_prompt"] = ""
if not isinstance(config["pre_prompt"], str):
raise ValueError("pre_prompt must be of string type")
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict:
"""
Validate post_prompt and set defaults for prompt feature
:param config: app model config args
"""
# post_prompt
if not config.get("post_prompt"):
config["post_prompt"] = ""
if not isinstance(config["post_prompt"], str):
raise ValueError("post_prompt must be of string type")
return config
import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
class BasicVariablesConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
"""
Convert model config to model config
:param config: model config args
"""
external_data_variables = []
variables = []
# old external_data_tools
external_data_tools = config.get('external_data_tools', [])
for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
)
)
# variables and external_data_tools
for variable in config.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
external_data_variables.append(
ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
)
)
elif typ in [
VariableEntity.Type.TEXT_INPUT.value,
VariableEntity.Type.PARAGRAPH.value,
VariableEntity.Type.NUMBER.value,
]:
variables.append(
VariableEntity(
type=VariableEntity.Type.value_of(typ),
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
max_length=variable[typ].get('max_length'),
default=variable[typ].get('default'),
)
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
)
)
return variables, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for user input form
:param tenant_id: workspace id
:param config: app model config args
"""
related_config_keys = []
config, current_related_config_keys = cls.validate_variables_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
config, current_related_config_keys = cls.validate_external_data_tools_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
return config, related_config_keys
@classmethod
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for user input form
:param config: app model config args
"""
if not config.get("user_input_form"):
config["user_input_form"] = []
if not isinstance(config["user_input_form"], list):
raise ValueError("user_input_form must be a list of objects")
variables = []
for item in config["user_input_form"]:
key = list(item.keys())[0]
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key]
if 'label' not in form_item:
raise ValueError("label is required in user_input_form")
if not isinstance(form_item["label"], str):
raise ValueError("label in user_input_form must be of string type")
if 'variable' not in form_item:
raise ValueError("variable is required in user_input_form")
if not isinstance(form_item["variable"], str):
raise ValueError("variable in user_input_form must be of string type")
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, "
"and cannot start with a number")
variables.append(form_item["variable"])
if 'required' not in form_item or not form_item["required"]:
form_item["required"] = False
if not isinstance(form_item["required"], bool):
raise ValueError("required in user_input_form must be of boolean type")
if key == "select":
if 'options' not in form_item or not form_item["options"]:
form_item["options"] = []
if not isinstance(form_item["options"], list):
raise ValueError("options in user_input_form must be a list of strings")
if "default" in form_item and form_item['default'] \
and form_item["default"] not in form_item["options"]:
raise ValueError("default value in user_input_form must be in the options list")
return config, ["user_input_form"]
@classmethod
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for external data fetch feature
:param tenant_id: workspace id
:param config: app model config args
"""
if not config.get("external_data_tools"):
config["external_data_tools"] = []
if not isinstance(config["external_data_tools"], list):
raise ValueError("external_data_tools must be of list type")
for tool in config["external_data_tools"]:
if "enabled" not in tool or not tool["enabled"]:
tool["enabled"] = False
if not tool["enabled"]:
continue
if "type" not in tool or not tool["type"]:
raise ValueError("external_data_tools[].type is required")
typ = tool["type"]
config = tool["config"]
ExternalDataToolFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=config
)
return config, ["external_data_tools"]
\ No newline at end of file
from enum import Enum from enum import Enum
from typing import Any, Literal, Optional, Union from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.entities.provider_configuration import ProviderModelBundle
from core.file.file_obj import FileObj
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import AIModelEntity from models.model import AppMode
class ModelConfigEntity(BaseModel): class ModelConfigEntity(BaseModel):
...@@ -15,10 +13,7 @@ class ModelConfigEntity(BaseModel): ...@@ -15,10 +13,7 @@ class ModelConfigEntity(BaseModel):
""" """
provider: str provider: str
model: str model: str
model_schema: AIModelEntity mode: Optional[str] = None
mode: str
provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {}
parameters: dict[str, Any] = {} parameters: dict[str, Any] = {}
stop: list[str] = [] stop: list[str] = []
...@@ -86,6 +81,39 @@ class PromptTemplateEntity(BaseModel): ...@@ -86,6 +81,39 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
class Type(Enum):
TEXT_INPUT = 'text-input'
SELECT = 'select'
PARAGRAPH = 'paragraph'
NUMBER = 'number'
@classmethod
def value_of(cls, value: str) -> 'VariableEntity.Type':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid variable type value {value}')
variable: str
label: str
description: Optional[str] = None
type: Type
required: bool = False
max_length: Optional[int] = None
options: Optional[list[str]] = None
default: Optional[str] = None
class ExternalDataVariableEntity(BaseModel): class ExternalDataVariableEntity(BaseModel):
""" """
External Data Variable Entity. External Data Variable Entity.
...@@ -124,7 +152,6 @@ class DatasetRetrieveConfigEntity(BaseModel): ...@@ -124,7 +152,6 @@ class DatasetRetrieveConfigEntity(BaseModel):
query_variable: Optional[str] = None # Only when app mode is completion query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy retrieve_strategy: RetrieveStrategy
single_strategy: Optional[str] = None # for temp
top_k: Optional[int] = None top_k: Optional[int] = None
score_threshold: Optional[float] = None score_threshold: Optional[float] = None
reranking_model: Optional[dict] = None reranking_model: Optional[dict] = None
...@@ -162,148 +189,53 @@ class FileUploadEntity(BaseModel): ...@@ -162,148 +189,53 @@ class FileUploadEntity(BaseModel):
image_config: Optional[dict[str, Any]] = None image_config: Optional[dict[str, Any]] = None
class AgentToolEntity(BaseModel): class AppAdditionalFeatures(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
class AgentScratchpadUnit(BaseModel):
"""
Agent First Prompt Entity.
"""
class Action(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
agent_response: Optional[str] = None
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
class AgentEntity(BaseModel):
"""
Agent Entity.
"""
class Strategy(Enum):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
provider: str
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] = None
max_iteration: int = 5
class AppOrchestrationConfigEntity(BaseModel):
"""
App Orchestration Config Entity.
"""
model_config: ModelConfigEntity
prompt_template: PromptTemplateEntity
external_data_variables: list[ExternalDataVariableEntity] = []
agent: Optional[AgentEntity] = None
# features
dataset: Optional[DatasetEntity] = None
file_upload: Optional[FileUploadEntity] = None file_upload: Optional[FileUploadEntity] = None
opening_statement: Optional[str] = None opening_statement: Optional[str] = None
suggested_questions: list[str] = []
suggested_questions_after_answer: bool = False suggested_questions_after_answer: bool = False
show_retrieve_source: bool = False show_retrieve_source: bool = False
more_like_this: bool = False more_like_this: bool = False
speech_to_text: bool = False speech_to_text: bool = False
text_to_speech: dict = {} text_to_speech: Optional[TextToSpeechEntity] = None
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
class InvokeFrom(Enum): class AppConfig(BaseModel):
""" """
Invoke From. Application Config Entity.
""" """
SERVICE_API = 'service-api' tenant_id: str
WEB_APP = 'web-app' app_id: str
EXPLORE = 'explore' app_mode: AppMode
DEBUGGER = 'debugger' additional_features: AppAdditionalFeatures
variables: list[VariableEntity] = []
@classmethod sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
def value_of(cls, value: str) -> 'InvokeFrom':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid invoke from value {value}')
def to_source(self) -> str:
"""
Get source of invoke from.
:return: source class EasyUIBasedAppModelConfigFrom(Enum):
""" """
if self == InvokeFrom.WEB_APP: App Model Config From.
return 'web_app' """
elif self == InvokeFrom.DEBUGGER: ARGS = 'args'
return 'dev' APP_LATEST_CONFIG = 'app-latest-config'
elif self == InvokeFrom.EXPLORE: CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
return 'explore_app'
elif self == InvokeFrom.SERVICE_API:
return 'api'
return 'dev'
class ApplicationGenerateEntity(BaseModel): class EasyUIBasedAppConfig(AppConfig):
""" """
Application Generate Entity. Easy UI Based App Config Entity.
""" """
task_id: str app_model_config_from: EasyUIBasedAppModelConfigFrom
tenant_id: str
app_id: str
app_model_config_id: str app_model_config_id: str
# for save
app_model_config_dict: dict app_model_config_dict: dict
app_model_config_override: bool model: ModelConfigEntity
prompt_template: PromptTemplateEntity
# Converted from app_model_config to Entity object, or directly covered by external input dataset: Optional[DatasetEntity] = None
app_orchestration_config_entity: AppOrchestrationConfigEntity external_data_variables: list[ExternalDataVariableEntity] = []
conversation_id: Optional[str] = None
inputs: dict[str, str] class WorkflowUIBasedAppConfig(AppConfig):
query: Optional[str] = None """
files: list[FileObj] = [] Workflow UI Based App Config Entity.
user_id: str """
# extras workflow_id: str
stream: bool
invoke_from: InvokeFrom
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
from typing import Optional
from core.app.app_config.entities import FileUploadEntity
class FileUploadConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[FileUploadEntity]:
"""
Convert model config to model config
:param config: model config args
"""
file_upload_dict = config.get('file_upload')
if file_upload_dict:
if 'image' in file_upload_dict and file_upload_dict['image']:
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
return FileUploadEntity(
image_config={
'number_limits': file_upload_dict['image']['number_limits'],
'detail': file_upload_dict['image']['detail'],
'transfer_methods': file_upload_dict['image']['transfer_methods']
}
)
return None
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for file upload feature
:param config: app model config args
"""
if not config.get("file_upload"):
config["file_upload"] = {}
if not isinstance(config["file_upload"], dict):
raise ValueError("file_upload must be of dict type")
# check image config
if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False}
if config['file_upload']['image']['enabled']:
number_limits = config['file_upload']['image']['number_limits']
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
detail = config['file_upload']['image']['detail']
if detail not in ['high', 'low']:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config['file_upload']['image']['transfer_methods']
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in ['remote_url', 'local_file']:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"]
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
:param config: model config args
"""
more_like_this = False
more_like_this_dict = config.get('more_like_this')
if more_like_this_dict:
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
more_like_this = True
return more_like_this
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for more like this feature
:param config: app model config args
"""
if not config.get("more_like_this"):
config["more_like_this"] = {
"enabled": False
}
if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type")
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
config["more_like_this"]["enabled"] = False
if not isinstance(config["more_like_this"]["enabled"], bool):
raise ValueError("enabled in more_like_this must be of boolean type")
return config, ["more_like_this"]
class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[str, list]:
"""
Convert model config to model config
:param config: model config args
"""
# opening statement
opening_statement = config.get('opening_statement')
# suggested questions
suggested_questions_list = config.get('suggested_questions')
return opening_statement, suggested_questions_list
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for opening statement feature
:param config: app model config args
"""
if not config.get("opening_statement"):
config["opening_statement"] = ""
if not isinstance(config["opening_statement"], str):
raise ValueError("opening_statement must be of string type")
# suggested_questions
if not config.get("suggested_questions"):
config["suggested_questions"] = []
if not isinstance(config["suggested_questions"], list):
raise ValueError("suggested_questions must be of list type")
for question in config["suggested_questions"]:
if not isinstance(question, str):
raise ValueError("Elements in suggested_questions list must be of string type")
return config, ["opening_statement", "suggested_questions"]
class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get('retriever_resource')
if retriever_resource_dict:
if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
show_retrieve_source = True
return show_retrieve_source
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for retriever resource feature
:param config: app model config args
"""
if not config.get("retriever_resource"):
config["retriever_resource"] = {
"enabled": False
}
if not isinstance(config["retriever_resource"], dict):
raise ValueError("retriever_resource must be of dict type")
if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]:
config["retriever_resource"]["enabled"] = False
if not isinstance(config["retriever_resource"]["enabled"], bool):
raise ValueError("enabled in retriever_resource must be of boolean type")
return config, ["retriever_resource"]
class SpeechToTextConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
:param config: model config args
"""
speech_to_text = False
speech_to_text_dict = config.get('speech_to_text')
if speech_to_text_dict:
if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
speech_to_text = True
return speech_to_text
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for speech to text feature
:param config: app model config args
"""
if not config.get("speech_to_text"):
config["speech_to_text"] = {
"enabled": False
}
if not isinstance(config["speech_to_text"], dict):
raise ValueError("speech_to_text must be of dict type")
if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]:
config["speech_to_text"]["enabled"] = False
if not isinstance(config["speech_to_text"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type")
return config, ["speech_to_text"]
class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
:param config: model config args
"""
suggested_questions_after_answer = False
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
if suggested_questions_after_answer_dict:
if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
suggested_questions_after_answer = True
return suggested_questions_after_answer
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for suggested questions feature
:param config: app model config args
"""
if not config.get("suggested_questions_after_answer"):
config["suggested_questions_after_answer"] = {
"enabled": False
}
if not isinstance(config["suggested_questions_after_answer"], dict):
raise ValueError("suggested_questions_after_answer must be of dict type")
if "enabled" not in config["suggested_questions_after_answer"] or not \
config["suggested_questions_after_answer"]["enabled"]:
config["suggested_questions_after_answer"]["enabled"] = False
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
return config, ["suggested_questions_after_answer"]
from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
:param config: model config args
"""
text_to_speech = False
text_to_speech_dict = config.get('text_to_speech')
if text_to_speech_dict:
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
text_to_speech = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get('language'),
)
return text_to_speech
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for text to speech feature
:param config: app model config args
"""
if not config.get("text_to_speech"):
config["text_to_speech"] = {
"enabled": False,
"voice": "",
"language": ""
}
if not isinstance(config["text_to_speech"], dict):
raise ValueError("text_to_speech must be of dict type")
if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]:
config["text_to_speech"]["enabled"] = False
config["text_to_speech"]["voice"] = ""
config["text_to_speech"]["language"] = ""
if not isinstance(config["text_to_speech"]["enabled"], bool):
raise ValueError("enabled in text_to_speech must be of boolean type")
return config, ["text_to_speech"]
from core.app.app_config.entities import VariableEntity
from models.workflow import Workflow
class WorkflowVariablesConfigManager:
@classmethod
def convert(cls, workflow: Workflow) -> list[VariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
# find start node
user_input_form = workflow.user_input_form()
# variables
for variable in user_input_form:
variables.append(VariableEntity(**variable))
return variables
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import (
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.model import App, AppMode
from models.workflow import Workflow
class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
"""
Advanced Chatbot App Config Entity.
"""
pass
class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=AppMode.value_of(app_model.mode),
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict)
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for advanced chat app model config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: if True, only structure validation will be performed
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment