Unverified Commit 0828873b authored by takatost's avatar takatost Committed by GitHub

fix: missing default user for APP service api (#2606)

parent 816b707a
from extensions.ext_database import db
from models.model import EndUser
def create_or_update_end_user_for_user_id(app_model, user_id):
"""
Create or update session terminal based on user ID.
"""
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
import json import json
from flask import current_app from flask import current_app
from flask_restful import fields, marshal_with from flask_restful import fields, marshal_with, Resource
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppModelConfig from models.model import App, AppModelConfig
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
class AppParameterApi(AppApiResource): class AppParameterApi(Resource):
"""Resource for app variables.""" """Resource for app variables."""
variable_fields = { variable_fields = {
...@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource): ...@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
'system_parameters': fields.Nested(system_parameters_fields) 'system_parameters': fields.Nested(system_parameters_fields)
} }
@validate_app_token
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App, end_user): def get(self, app_model: App):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
...@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource): ...@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
} }
} }
class AppMetaApi(AppApiResource): class AppMetaApi(Resource):
def get(self, app_model: App, end_user): @validate_app_token
def get(self, app_model: App):
"""Get app meta""" """Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config app_model_config: AppModelConfig = app_model.app_model_config
......
import logging import logging
from flask import request from flask import request
from flask_restful import reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
...@@ -17,10 +17,10 @@ from controllers.service_api.app.error import ( ...@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
UnsupportedAudioTypeError, UnsupportedAudioTypeError,
) )
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig from models.model import App, AppModelConfig, EndUser
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
...@@ -30,8 +30,9 @@ from services.errors.audio import ( ...@@ -30,8 +30,9 @@ from services.errors.audio import (
) )
class AudioApi(AppApiResource): class AudioApi(Resource):
def post(self, app_model: App, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']: if not app_model_config.speech_to_text_dict['enabled']:
...@@ -73,11 +74,11 @@ class AudioApi(AppApiResource): ...@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
raise InternalServerError() raise InternalServerError()
class TextApi(AppApiResource): class TextApi(Resource):
def post(self, app_model: App, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json') parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json') parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
...@@ -85,7 +86,7 @@ class TextApi(AppApiResource): ...@@ -85,7 +86,7 @@ class TextApi(AppApiResource):
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
text=args['text'], text=args['text'],
end_user=args['user'], end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'), voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming'] streaming=args['streaming']
) )
......
...@@ -4,12 +4,11 @@ from collections.abc import Generator ...@@ -4,12 +4,11 @@ from collections.abc import Generator
from typing import Union from typing import Union
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
...@@ -19,17 +18,19 @@ from controllers.service_api.app.error import ( ...@@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError, ProviderNotInitializeError,
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import App, EndUser
from services.completion_service import CompletionService from services.completion_service import CompletionService
class CompletionApi(AppApiResource): class CompletionApi(Resource):
def post(self, app_model, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
...@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource): ...@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
args['auto_generate_name'] = False args['auto_generate_name'] = False
try: try:
...@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource): ...@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
raise InternalServerError() raise InternalServerError()
class CompletionStopApi(AppApiResource): class CompletionStopApi(Resource):
def post(self, app_model, end_user, task_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
class ChatApi(AppApiResource): class ChatApi(Resource):
def post(self, app_model, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -114,7 +102,6 @@ class ChatApi(AppApiResource): ...@@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
...@@ -122,9 +109,6 @@ class ChatApi(AppApiResource): ...@@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
response = CompletionService.completion( response = CompletionService.completion(
app_model=app_model, app_model=app_model,
...@@ -157,22 +141,12 @@ class ChatApi(AppApiResource): ...@@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
raise InternalServerError() raise InternalServerError()
class ChatStopApi(AppApiResource): class ChatStopApi(Resource):
def post(self, app_model, end_user, task_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
from flask import request from flask_restful import Resource, marshal_with, reqparse
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import App, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
class ConversationApi(AppApiResource): class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args') parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
except services.errors.conversation.LastConversationNotExistsError: except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource): class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
user = request.get_json().get('user')
if end_user is None and user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
try: try:
ConversationService.delete(app_model, conversation_id, end_user) ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
...@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource): ...@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
return {"result": "success"}, 204 return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource): class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource): ...@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json') parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
return ConversationService.rename( return ConversationService.rename(
app_model, app_model,
......
from flask import request from flask import request
from flask_restful import marshal_with from flask_restful import Resource, marshal_with
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
FileTooLargeError, FileTooLargeError,
NoFileUploadedError, NoFileUploadedError,
TooManyFilesError, TooManyFilesError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields from fields.file_fields import file_fields
from models.model import App, EndUser
from services.file_service import FileService from services.file_service import FileService
class FileApi(AppApiResource): class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields) @marshal_with(file_fields)
def post(self, app_model, end_user): def post(self, app_model: App, end_user: EndUser):
file = request.files['file'] file = request.files['file']
user_args = request.form.get('user')
if end_user is None and user_args is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
# check file # check file
if 'file' not in request.files: if 'file' not in request.files:
......
from flask_restful import fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from models.model import EndUser, Message from models.model import App, EndUser
from services.message_service import MessageService from services.message_service import MessageService
class MessageListApi(AppApiResource): class MessageListApi(Resource):
feedback_fields = { feedback_fields = {
'rating': fields.String 'rating': fields.String
} }
...@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource): ...@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
'data': fields.List(fields.Nested(message_fields)) 'data': fields.List(fields.Nested(message_fields))
} }
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource): ...@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
return MessageService.pagination_by_first_id(app_model, end_user, return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit']) args['conversation_id'], args['first_id'], args['limit'])
...@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource): ...@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
raise NotFound("First Message Not Exists.") raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(AppApiResource): class MessageFeedbackApi(Resource):
def post(self, app_model, end_user, message_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument('user', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating']) MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
...@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource): ...@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'} return {'result': 'success'}
class MessageSuggestedApi(AppApiResource): class MessageSuggestedApi(Resource):
def get(self, app_model, end_user, message_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
try: try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, app_model=app_model,
user=user, user=end_user,
message_id=message_id, message_id=message_id,
check_enabled=False check_enabled=False
) )
......
from collections.abc import Callable
from datetime import datetime from datetime import datetime
from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
from flask_restful import Resource from flask_restful import Resource
from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import _get_user from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin from models.account import Account, Tenant, TenantAccountJoin
from models.model import ApiToken, App from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService from services.feature_service import FeatureService
def validate_app_token(view=None): class WhereisUserArg(Enum):
def decorator(view): """
@wraps(view) Enum for whereis_user_arg.
def decorated(*args, **kwargs): """
QUERY = 'query'
JSON = 'json'
FORM = 'form'
class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token('app') api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).filter(App.id == api_token.app_id).first() app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
...@@ -29,15 +47,34 @@ def validate_app_token(view=None): ...@@ -29,15 +47,34 @@ def validate_app_token(view=None):
if not app_model.enable_api: if not app_model.enable_api:
raise NotFound() raise NotFound()
return view(app_model, None, *args, **kwargs) kwargs['app_model'] = app_model
return decorated
if view: if not fetch_user_arg:
return decorator(view) # use default-user
user_id = None
else:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get('user')
else:
# use default-user
user_id = None
# if view is None, it means that the decorator is used without parentheses if not user_id and fetch_user_arg.required:
# use the decorator as a function for method_decorators raise ValueError("Arg user must be provided.")
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator return decorator
else:
return decorator(view)
def cloud_edition_billing_resource_check(resource: str, def cloud_edition_billing_resource_check(resource: str,
...@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None): ...@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None):
return api_token return api_token
class AppApiResource(Resource): def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
method_decorators = [validate_app_token] """
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = 'DEFAULT-USER'
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
class DatasetApiResource(Resource): class DatasetApiResource(Resource):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment