Unverified Commit 46154c67 authored by Jyong's avatar Jyong Committed by GitHub

Feat/dataset service api (#1245)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
Co-authored-by: 's avatarStyleZhang <jasonapring2015@outlook.com>
parent 54ff03c3
...@@ -81,6 +81,7 @@ class BaseApiKeyListResource(Resource): ...@@ -81,6 +81,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.token = key api_token.token = key
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)
......
...@@ -19,41 +19,13 @@ from core.model_providers.model_factory import ModelFactory ...@@ -19,41 +19,13 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted from events.app_event import app_was_created, app_was_deleted
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppModelConfig, Site from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}
app_detail_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'created_at': TimestampField
}
def _get_app(app_id, tenant_id): def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
...@@ -63,35 +35,6 @@ def _get_app(app_id, tenant_id): ...@@ -63,35 +35,6 @@ def _get_app(app_id, tenant_id):
class AppListApi(Resource): class AppListApi(Resource):
prompt_config_fields = {
'prompt_template': fields.String,
}
model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
'created_at': TimestampField
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
}
@setup_required @setup_required
@login_required @login_required
...@@ -238,18 +181,6 @@ class AppListApi(Resource): ...@@ -238,18 +181,6 @@ class AppListApi(Resource):
class AppTemplateApi(Resource): class AppTemplateApi(Resource):
template_fields = {
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'mode': fields.String,
'model_config': fields.Nested(model_config_fields),
}
template_list_fields = {
'data': fields.List(fields.Nested(template_fields)),
}
@setup_required @setup_required
@login_required @login_required
...@@ -268,38 +199,6 @@ class AppTemplateApi(Resource): ...@@ -268,38 +199,6 @@ class AppTemplateApi(Resource):
class AppApi(Resource): class AppApi(Resource):
site_fields = {
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
}
app_detail_fields_with_site = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'site': fields.Nested(site_fields),
'api_base_url': fields.String,
'created_at': TimestampField
}
@setup_required @setup_required
@login_required @login_required
......
...@@ -13,107 +13,14 @@ from controllers.console import api ...@@ -13,107 +13,14 @@ from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app import _get_app
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
conversation_message_detail_fields, conversation_with_summary_pagination_fields
from libs.helper import TimestampField, datetime_string, uuid_value from libs.helper import TimestampField, datetime_string, uuid_value
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Message, MessageAnnotation, Conversation from models.model import Message, MessageAnnotation, Conversation
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
feedback_stat_fields = {
'like': fields.Integer,
'dislike': fields.Integer
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'model': fields.Raw,
'user_input_form': fields.Raw,
'pre_prompt': fields.String,
'agent_mode': fields.Raw,
}
class CompletionConversationApi(Resource): class CompletionConversationApi(Resource):
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]['text'] if value else ''
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
simple_message_detail_fields = {
'inputs': fields.Raw,
'query': fields.String,
'message': MessageTextField,
'answer': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(),
'from_account_id': fields.String,
'read_at': TimestampField,
'created_at': TimestampField,
'annotation': fields.Nested(annotation_fields, allow_null=True),
'model_config': fields.Nested(simple_model_config_fields),
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
@setup_required @setup_required
@login_required @login_required
...@@ -191,21 +98,11 @@ class CompletionConversationApi(Resource): ...@@ -191,21 +98,11 @@ class CompletionConversationApi(Resource):
class CompletionConversationDetailApi(Resource): class CompletionConversationDetailApi(Resource):
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'model_config': fields.Nested(model_config_fields),
'message': fields.Nested(message_detail_fields, attribute='first_message'),
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(conversation_detail_fields) @marshal_with(conversation_message_detail_fields)
def get(self, app_id, conversation_id): def get(self, app_id, conversation_id):
app_id = str(app_id) app_id = str(app_id)
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
...@@ -234,44 +131,11 @@ class CompletionConversationDetailApi(Resource): ...@@ -234,44 +131,11 @@ class CompletionConversationDetailApi(Resource):
class ChatConversationApi(Resource): class ChatConversationApi(Resource):
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(simple_model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(conversation_pagination_fields) @marshal_with(conversation_with_summary_pagination_fields)
def get(self, app_id): def get(self, app_id):
app_id = str(app_id) app_id = str(app_id)
...@@ -356,19 +220,6 @@ class ChatConversationApi(Resource): ...@@ -356,19 +220,6 @@ class ChatConversationApi(Resource):
class ChatConversationDetailApi(Resource): class ChatConversationDetailApi(Resource):
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
@setup_required @setup_required
@login_required @login_required
......
...@@ -17,6 +17,7 @@ from controllers.console.wraps import account_initialization_required ...@@ -17,6 +17,7 @@ from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.login.login import login_required from core.login.login import login_required
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db from extensions.ext_database import db
...@@ -27,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError ...@@ -27,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.message_service import MessageService from services.message_service import MessageService
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
class ChatMessageListApi(Resource): class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = { message_infinite_scroll_pagination_fields = {
......
...@@ -8,26 +8,11 @@ from controllers.console import api ...@@ -8,26 +8,11 @@ from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app import _get_app
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from fields.app_fields import app_site_fields
from libs.helper import supported_language from libs.helper import supported_language
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Site from models.model import Site
app_site_fields = {
'app_id': fields.String,
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}
def parse_app_site_args(): def parse_app_site_args():
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
......
...@@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required ...@@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required
from core.data_loader.loader.notion import NotionLoader from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
from libs.helper import TimestampField from libs.helper import TimestampField
from models.dataset import Document from models.dataset import Document
from models.source import DataSourceBinding from models.source import DataSourceBinding
...@@ -24,37 +25,6 @@ cache = TTLCache(maxsize=None, ttl=30) ...@@ -24,37 +25,6 @@ cache = TTLCache(maxsize=None, ttl=30)
class DataSourceApi(Resource): class DataSourceApi(Resource):
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)),
'total': fields.Integer
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}
@setup_required @setup_required
@login_required @login_required
...@@ -131,28 +101,6 @@ class DataSourceApi(Resource): ...@@ -131,28 +101,6 @@ class DataSourceApi(Resource):
class DataSourceNotionListApi(Resource): class DataSourceNotionListApi(Resource):
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'is_bound': fields.Boolean,
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
}
@setup_required @setup_required
@login_required @login_required
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from flask import request import flask_restful
from flask import request, current_app
from flask_login import current_user from flask_login import current_user
from controllers.console.apikey import api_key_list, api_key_fields
from core.login.login import login_required from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal, marshal_with from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
...@@ -12,45 +15,16 @@ from controllers.console.setup import setup_required ...@@ -12,45 +15,16 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from libs.helper import TimestampField from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Document from models.dataset import DocumentSegment, Document
from models.model import UploadFile from models.model import UploadFile, ApiToken
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService from services.provider_service import ProviderService
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'provider': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'app_count': fields.Integer,
'document_count': fields.Integer,
'word_count': fields.Integer,
'created_by': fields.String,
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,
"created_by": fields.String,
"created_at": TimestampField
}
def _validate_name(name): def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
...@@ -82,7 +56,8 @@ class DatasetListApi(Resource): ...@@ -82,7 +56,8 @@ class DatasetListApi(Resource):
# check embedding setting # check embedding setting
provider_service = ProviderService() provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value) valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
# if len(valid_model_list) == 0: # if len(valid_model_list) == 0:
# raise ProviderNotInitializeError( # raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider " # f"No Embedding Model available. Please configure a valid provider "
...@@ -157,7 +132,8 @@ class DatasetApi(Resource): ...@@ -157,7 +132,8 @@ class DatasetApi(Resource):
# check embedding setting # check embedding setting
provider_service = ProviderService() provider_service = ProviderService()
# get valid model list # get valid model list
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value) valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
model_names = [] model_names = []
for valid_model in valid_model_list: for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
...@@ -271,7 +247,8 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -271,7 +247,8 @@ class DatasetIndexingEstimateApi(Resource):
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json') parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
...@@ -320,18 +297,6 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -320,18 +297,6 @@ class DatasetIndexingEstimateApi(Resource):
class DatasetRelatedAppListApi(Resource): class DatasetRelatedAppListApi(Resource):
app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
}
related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
}
@setup_required @setup_required
@login_required @login_required
...@@ -363,24 +328,6 @@ class DatasetRelatedAppListApi(Resource): ...@@ -363,24 +328,6 @@ class DatasetRelatedAppListApi(Resource):
class DatasetIndexingStatusApi(Resource): class DatasetIndexingStatusApi(Resource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
@setup_required @setup_required
@login_required @login_required
...@@ -400,16 +347,97 @@ class DatasetIndexingStatusApi(Resource): ...@@ -400,16 +347,97 @@ class DatasetIndexingStatusApi(Resource):
DocumentSegment.status != 're_segment').count() DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
documents_status.append(marshal(document, self.document_status_fields)) documents_status.append(marshal(document, document_status_fields))
data = { data = {
'data': documents_status 'data': documents_status
} }
return data return data
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = 'dataset-'
resource_type = 'dataset'
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
keys = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
all()
return {"items": keys}
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
current_key_count = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
count()
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code='max_keys_exceeded'
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, api_key_id):
api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
key = db.session.query(ApiToken). \
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id). \
first()
if key is None:
flask_restful.abort(404, message='API key not found')
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
return {'result': 'success'}, 204
class DatasetApiBaseUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
return {
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
else request.host_url.rstrip('/')) + '/v1'
}
api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps') api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status') api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
...@@ -23,6 +23,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE ...@@ -23,6 +23,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMBadRequestError LLMBadRequestError
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.document_fields import document_with_segments_fields, document_fields, \
dataset_and_document_fields, document_status_fields
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DatasetProcessRule, Dataset from models.dataset import DatasetProcessRule, Dataset
...@@ -32,64 +34,6 @@ from services.dataset_service import DocumentService, DatasetService ...@@ -32,64 +34,6 @@ from services.dataset_service import DocumentService, DatasetService
from tasks.add_document_to_index_task import add_document_to_index_task from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task
dataset_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
document_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
}
document_with_segments_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'completed_segments': fields.Integer,
'total_segments': fields.Integer
}
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
...@@ -303,11 +247,6 @@ class DatasetDocumentListApi(Resource): ...@@ -303,11 +247,6 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'documents': fields.List(fields.Nested(document_fields)),
'batch': fields.String
}
@setup_required @setup_required
@login_required @login_required
...@@ -504,24 +443,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): ...@@ -504,24 +443,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
class DocumentBatchIndexingStatusApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
@setup_required @setup_required
@login_required @login_required
...@@ -541,7 +462,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource): ...@@ -541,7 +462,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:
document.indexing_status = 'paused' document.indexing_status = 'paused'
documents_status.append(marshal(document, self.document_status_fields)) documents_status.append(marshal(document, document_status_fields))
data = { data = {
'data': documents_status 'data': documents_status
} }
...@@ -549,20 +470,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource): ...@@ -549,20 +470,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
class DocumentIndexingStatusApi(DocumentResource): class DocumentIndexingStatusApi(DocumentResource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
@setup_required @setup_required
@login_required @login_required
...@@ -586,7 +493,7 @@ class DocumentIndexingStatusApi(DocumentResource): ...@@ -586,7 +493,7 @@ class DocumentIndexingStatusApi(DocumentResource):
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:
document.indexing_status = 'paused' document.indexing_status = 'paused'
return marshal(document, self.document_status_fields) return marshal(document, document_status_fields)
class DocumentDetailApi(DocumentResource): class DocumentDetailApi(DocumentResource):
......
...@@ -3,7 +3,7 @@ import uuid ...@@ -3,7 +3,7 @@ import uuid
from datetime import datetime from datetime import datetime
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
import services import services
...@@ -17,6 +17,7 @@ from core.model_providers.model_factory import ModelFactory ...@@ -17,6 +17,7 @@ from core.model_providers.model_factory import ModelFactory
from core.login.login import login_required from core.login.login import login_required
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
from libs.helper import TimestampField from libs.helper import TimestampField
...@@ -26,36 +27,6 @@ from tasks.disable_segment_from_index_task import disable_segment_from_index_tas ...@@ -26,36 +27,6 @@ from tasks.disable_segment_from_index_task import disable_segment_from_index_tas
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
import pandas as pd import pandas as pd
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField
}
segment_list_response = {
'data': fields.List(fields.Nested(segment_fields)),
'has_more': fields.Boolean,
'limit': fields.Integer
}
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@setup_required @setup_required
......
import datetime
import hashlib
import tempfile
import chardet
import time
import uuid
from pathlib import Path
from cachetools import TTLCache from cachetools import TTLCache
from flask import request, current_app from flask import request, current_app
from flask_login import current_user
import services
from core.login.login import login_required from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields from flask_restful import Resource, marshal_with, fields
from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
UnsupportedFileTypeError UnsupportedFileTypeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_loader.file_extractor import FileExtractor from fields.file_fields import upload_config_fields, file_fields
from extensions.ext_storage import storage
from libs.helper import TimestampField from services.file_service import FileService
from extensions.ext_database import db
from models.model import UploadFile
cache = TTLCache(maxsize=None, ttl=30) cache = TTLCache(maxsize=None, ttl=30)
...@@ -31,10 +22,6 @@ PREVIEW_WORDS_LIMIT = 3000 ...@@ -31,10 +22,6 @@ PREVIEW_WORDS_LIMIT = 3000
class FileApi(Resource): class FileApi(Resource):
upload_config_fields = {
'file_size_limit': fields.Integer,
'batch_count_limit': fields.Integer
}
@setup_required @setup_required
@login_required @login_required
...@@ -48,16 +35,6 @@ class FileApi(Resource): ...@@ -48,16 +35,6 @@ class FileApi(Resource):
'batch_count_limit': batch_count_limit 'batch_count_limit': batch_count_limit
}, 200 }, 200
file_fields = {
'id': fields.String,
'name': fields.String,
'size': fields.Integer,
'extension': fields.String,
'mime_type': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
...@@ -73,45 +50,13 @@ class FileApi(Resource): ...@@ -73,45 +50,13 @@ class FileApi(Resource):
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
try:
file_content = file.read() upload_file = FileService.upload_file(file)
file_size = len(file_content) except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024 except services.errors.file.UnsupportedFileTypeError:
if file_size > file_size_limit:
message = "({file_size} > {file_size_limit})"
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
# save file to storage
storage.save(file_key, file_content)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file.filename,
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by=current_user.id,
created_at=datetime.datetime.utcnow(),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
)
db.session.add(upload_file)
db.session.commit()
return upload_file, 201 return upload_file, 201
...@@ -121,26 +66,7 @@ class FilePreviewApi(Resource): ...@@ -121,26 +66,7 @@ class FilePreviewApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
text = FileService.get_file_preview(file_id)
key = file_id + request.path
cached_response = cache.get(key)
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
return cached_response['response']
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found")
# extract text from file
extension = upload_file.extension
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
text = FileExtractor.load(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return {'content': text} return {'content': text}
......
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
from flask_login import current_user from flask_login import current_user
from core.login.login import login_required from core.login.login import login_required
from flask_restful import Resource, reqparse, marshal, fields from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
import services import services
...@@ -14,48 +14,10 @@ from controllers.console.setup import setup_required ...@@ -14,48 +14,10 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError LLMBadRequestError
from libs.helper import TimestampField from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
document_fields = {
'id': fields.String,
'data_source_type': fields.String,
'name': fields.String,
'doc_type': fields.String,
}
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'document': fields.Nested(document_fields),
}
hit_testing_record_fields = {
'segment': fields.Nested(segment_fields),
'score': fields.Float,
'tsne_position': fields.Raw
}
class HitTestingApi(Resource): class HitTestingApi(Resource):
......
...@@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound ...@@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
...@@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource): ...@@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource):
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != 'chat': if app_model.mode != 'chat':
......
...@@ -11,32 +11,11 @@ from controllers.console import api ...@@ -11,32 +11,11 @@ from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.helper import TimestampField from libs.helper import TimestampField
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
app_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String
}
installed_app_fields = {
'id': fields.String,
'app': fields.Nested(app_fields),
'app_owner_tenant_id': fields.String,
'is_pinned': fields.Boolean,
'last_used_at': TimestampField,
'editable': fields.Boolean,
'uninstallable': fields.Boolean,
}
installed_app_list_fields = {
'installed_apps': fields.List(fields.Nested(installed_app_fields))
}
class InstalledAppsListApi(Resource): class InstalledAppsListApi(Resource):
@login_required @login_required
......
...@@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste ...@@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
...@@ -26,45 +27,6 @@ from services.message_service import MessageService ...@@ -26,45 +27,6 @@ from services.message_service import MessageService
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
feedback_fields = {
'rating': fields.String
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
......
...@@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound ...@@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \
conversation_with_model_config_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField,
'model_config': fields.Raw,
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class UniversalChatConversationListApi(UniversalChatResource): class UniversalChatConversationListApi(UniversalChatResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
def get(self, universal_app): def get(self, universal_app):
app_model = universal_app app_model = universal_app
...@@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource): ...@@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource):
class UniversalChatConversationRenameApi(UniversalChatResource): class UniversalChatConversationRenameApi(UniversalChatResource):
@marshal_with(conversation_fields) @marshal_with(conversation_with_model_config_fields)
def post(self, universal_app, c_id): def post(self, universal_app, c_id):
app_model = universal_app app_model = universal_app
conversation_id = str(c_id) conversation_id = str(c_id)
......
...@@ -9,4 +9,4 @@ api = ExternalApi(bp) ...@@ -9,4 +9,4 @@ api = ExternalApi(bp)
from .app import completion, app, conversation, message, audio from .app import completion, app, conversation, message, audio
from .dataset import document from .dataset import document, segment, dataset
...@@ -8,25 +8,11 @@ from controllers.service_api import api ...@@ -8,25 +8,11 @@ from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import AppApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
import services import services
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationApi(AppApiResource): class ConversationApi(AppApiResource):
...@@ -50,7 +36,7 @@ class ConversationApi(AppApiResource): ...@@ -50,7 +36,7 @@ class ConversationApi(AppApiResource):
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource): class ConversationDetailApi(AppApiResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id): def delete(self, app_model, end_user, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource): ...@@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource):
class ConversationRenameApi(AppApiResource): class ConversationRenameApi(AppApiResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
......
from flask import request
from flask_restful import reqparse, marshal
import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from core.model_providers.models.entity.model_params import ModelType
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from models.account import Account, TenantAccountJoin
from models.dataset import Dataset
from services.dataset_service import DatasetService
from services.provider_service import ProviderService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError('Name must be between 1 to 40 characters.')
return name
class DatasetApi(DatasetApiResource):
"""Resource for get datasets."""
def get(self, tenant_id):
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
provider = request.args.get('provider', default="vendor")
datasets, total = DatasetService.get_datasets(page, limit, provider,
tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
else:
item['embedding_available'] = True
response = {
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
"""Resource for datasets."""
def post(self, tenant_id):
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('indexing_technique', type=str, location='json',
choices=('high_quality', 'economy'),
help='Invalid indexing technique.')
args = parser.parse_args()
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args['name'],
indexing_technique=args['indexing_technique'],
account=current_user
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 200
api.add_resource(DatasetApi, '/datasets')
import datetime import datetime
import json
import uuid import uuid
from flask import current_app from flask import current_app, request
from flask_restful import reqparse from flask_restful import reqparse, marshal
from sqlalchemy import desc
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services.dataset_service import services.dataset_service
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
DatasetNotInitedError NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from core.model_providers.error import ProviderTokenNotInitError from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from fields.document_fields import document_fields, document_status_fields
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile from models.model import UploadFile
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
from services.file_service import FileService
class DocumentListApi(DatasetApiResource): class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents.""" """Resource for documents."""
def post(self, dataset): def post(self, tenant_id, dataset_id):
"""Create document.""" """Create document by text."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json') parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('text', type=str, required=True, nullable=False, location='json') parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('doc_type', type=str, location='json') parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
parser.add_argument('doc_metadata', type=dict, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if not dataset.indexing_technique and not args['indexing_technique']:
raise ValueError('indexing_technique is required.')
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
data_source = {
'type': 'upload_file',
'info_list': {
'data_source_type': 'upload_file',
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
DocumentService.document_create_args_validate(args)
if not dataset.indexing_technique: try:
raise DatasetNotInitedError("Dataset indexing technique must be set.") documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
doc_type = args.get('doc_type') document_data=args,
doc_metadata = args.get('doc_metadata') account=current_user,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: created_from='api'
raise ValueError('Invalid doc_type.')
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt'
# save file to storage
storage.save(file_key, args.get('text'))
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=dataset.tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=args.get('name') + '.txt',
size=len(args.get('text')),
extension='txt',
mime_type='text/plain',
created_by=dataset.created_by,
created_at=datetime.datetime.utcnow(),
used=True,
used_by=dataset.created_by,
used_at=datetime.datetime.utcnow()
) )
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
def post(self, tenant_id, dataset_id, document_id):
"""Update document by text."""
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('text', type=str, required=False, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
db.session.add(upload_file) if not dataset:
db.session.commit() raise ValueError('Dataset is not exist.')
document_data = { if args['text']:
'data_source': { upload_file = FileService.upload_text(args.get('text'), args.get('name'))
data_source = {
'type': 'upload_file', 'type': 'upload_file',
'info': [ 'info_list': {
{ 'data_source_type': 'upload_file',
'upload_file_id': upload_file.id 'file_info_list': {
'file_ids': [upload_file.id]
} }
]
} }
} }
args['data_source'] = data_source
# validate args
args['original_document_id'] = str(document_id)
DocumentService.document_create_args_validate(args)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=document_data, document_data=args,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}
if 'data' in request.form:
args = json.loads(request.form['data'])
if 'doc_form' not in args:
args['doc_form'] = 'text_model'
if 'doc_language' not in args:
args['doc_language'] = 'English'
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if not dataset.indexing_technique and not args['indexing_technique']:
raise ValueError('indexing_technique is required.')
# save file info
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
data_source = {
'type': 'upload_file',
'info_list': {
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=dataset.created_by_account, account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule, dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api' created_from='api'
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
document = documents[0] document = documents[0]
if doc_type and doc_metadata: documents_and_batch_fields = {
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] 'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
document.doc_metadata = {}
for key, value_type in metadata_schema.items(): class DocumentUpdateByFileApi(DatasetApiResource):
value = doc_metadata.get(key) """Resource for update documents."""
if value is not None and isinstance(value, value_type):
document.doc_metadata[key] = value
document.doc_type = doc_type def post(self, tenant_id, dataset_id, document_id):
document.updated_at = datetime.datetime.utcnow() """Update document by upload file."""
db.session.commit() args = {}
if 'data' in request.form:
args = json.loads(request.form['data'])
if 'doc_form' not in args:
args['doc_form'] = 'text_model'
if 'doc_language' not in args:
args['doc_language'] = 'English'
return {'id': document.id} # get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if 'file' in request.files:
# save file info
file = request.files['file']
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
data_source = {
'type': 'upload_file',
'info_list': {
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
args['original_document_id'] = str(document_id)
DocumentService.document_create_args_validate(args)
class DocumentApi(DatasetApiResource): try:
def delete(self, dataset, document_id): documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentDeleteApi(DatasetApiResource):
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document.""" """Delete document."""
document_id = str(document_id) document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
...@@ -126,8 +303,85 @@ class DocumentApi(DatasetApiResource): ...@@ -126,8 +303,85 @@ class DocumentApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError: except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError('Cannot delete document during indexing.') raise DocumentIndexingError('Cannot delete document during indexing.')
return {'result': 'success'}, 204 return {'result': 'success'}, 200
class DocumentListApi(DatasetApiResource):
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
search = request.args.get('keyword', default=None, type=str)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
query = Document.query.filter_by(
dataset_id=str(dataset_id), tenant_id=tenant_id)
if search:
search = f'%{search}%'
query = query.filter(Document.name.like(search))
query = query.order_by(desc(Document.created_at))
paginated_documents = query.paginate(
page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
response = {
'data': marshal(documents, document_fields),
'has_more': len(documents) == limit,
'limit': limit,
'total': paginated_documents.total,
'page': page
}
return response
class DocumentIndexingStatusApi(DatasetApiResource):
def get(self, tenant_id, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# get documents
documents = DocumentService.get_batch_documents(dataset_id, batch)
if not documents:
raise NotFound('Documents not found.')
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
documents_status.append(marshal(document, document_status_fields))
data = {
'data': documents_status
}
return data
api.add_resource(DocumentListApi, '/documents') api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
api.add_resource(DocumentApi, '/documents/<uuid:document_id>') api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = 'file_too_large'
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415
class HighQualityDatasetOnlyError(BaseHTTPException):
error_code = 'high_quality_dataset_only'
description = "Current operation only supports 'high-quality' datasets."
code = 400
class DatasetNotInitializedError(BaseHTTPException):
error_code = 'dataset_not_initialized'
description = "The dataset is still being initialized or indexing. Please wait a moment."
code = 400
class ArchivedDocumentImmutableError(BaseHTTPException): class ArchivedDocumentImmutableError(BaseHTTPException):
error_code = 'archived_document_immutable' error_code = 'archived_document_immutable'
description = "Cannot operate when document was archived." description = "The archived document is not editable."
code = 403 code = 403
class DatasetNameDuplicateError(BaseHTTPException):
error_code = 'dataset_name_duplicate'
description = "The dataset name already exists. Please modify your dataset name."
code = 409
class InvalidActionError(BaseHTTPException):
error_code = 'invalid_action'
description = "Invalid action."
code = 400
class DocumentAlreadyFinishedError(BaseHTTPException):
error_code = 'document_already_finished'
description = "The document has been processed. Please refresh the page or go to the document details."
code = 400
class DocumentIndexingError(BaseHTTPException): class DocumentIndexingError(BaseHTTPException):
error_code = 'document_indexing' error_code = 'document_indexing'
description = "Cannot operate document during indexing." description = "The document is being processed and cannot be edited."
code = 403 code = 400
class DatasetNotInitedError(BaseHTTPException): class InvalidMetadataError(BaseHTTPException):
error_code = 'dataset_not_inited' error_code = 'invalid_metadata'
description = "The dataset is still being initialized or indexing. Please wait a moment." description = "The metadata content is incorrect. Please check and verify."
code = 403 code = 400
from flask_login import current_user
from flask_restful import reqparse, marshal
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset
from services.dataset_service import DocumentService, SegmentService
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
def post(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
...@@ -2,11 +2,14 @@ ...@@ -2,11 +2,14 @@
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from flask import request from flask import request, current_app
from flask_login import user_logged_in
from flask_restful import Resource from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from core.login.login import _get_user
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, Account
from models.dataset import Dataset from models.dataset import Dataset
from models.model import ApiToken, App from models.model import ApiToken, App
...@@ -43,12 +46,24 @@ def validate_dataset_token(view=None): ...@@ -43,12 +46,24 @@ def validate_dataset_token(view=None):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
api_token = validate_and_get_api_token('dataset') api_token = validate_and_get_api_token('dataset')
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first() .filter(Tenant.id == api_token.tenant_id) \
if not dataset: .filter(TenantAccountJoin.tenant_id == Tenant.id) \
raise NotFound() .filter(TenantAccountJoin.role == 'owner') \
.one_or_none()
return view(dataset, *args, **kwargs) if tenant_account_join:
tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account)
user_logged_in.send(current_app._get_current_object(), user=_get_user())
else:
raise Unauthorized("Tenant owner account is not exist.")
else:
raise Unauthorized("Tenant is not exist.")
return view(api_token.tenant_id, *args, **kwargs)
return decorated return decorated
if view: if view:
......
...@@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound ...@@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound
from controllers.web import api from controllers.web import api
from controllers.web.error import NotChatAppError from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationListApi(WebApiResource): class ConversationListApi(WebApiResource):
...@@ -73,7 +59,7 @@ class ConversationApi(WebApiResource): ...@@ -73,7 +59,7 @@ class ConversationApi(WebApiResource):
class ConversationRenameApi(WebApiResource): class ConversationRenameApi(WebApiResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
......
...@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) ...@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search" SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
......
...@@ -246,11 +246,28 @@ class KeywordTableIndex(BaseIndex): ...@@ -246,11 +246,28 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
if pre_segment_data['keywords']:
segment.keywords = pre_segment_data['keywords']
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
pre_segment_data['keywords'])
else:
keywords = keyword_table_handler.extract_keywords(segment.content,
self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: List[str]): def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
class KeywordTableRetriever(BaseRetriever, BaseModel): class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict) search_kwargs: dict = Field(default_factory=dict)
......
from flask_restful import fields
from libs.helper import TimestampField
app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
}
related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}
app_detail_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'created_at': TimestampField
}
prompt_config_fields = {
'prompt_template': fields.String,
}
model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
'created_at': TimestampField
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
}
template_fields = {
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'mode': fields.String,
'model_config': fields.Nested(model_config_fields),
}
template_list_fields = {
'data': fields.List(fields.Nested(template_fields)),
}
site_fields = {
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
}
app_detail_fields_with_site = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'site': fields.Nested(site_fields),
'api_base_url': fields.String,
'created_at': TimestampField
}
app_site_fields = {
'app_id': fields.String,
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}
\ No newline at end of file
from flask_restful import fields
from libs.helper import TimestampField
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]['text'] if value else ''
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
feedback_stat_fields = {
'like': fields.Integer,
'dislike': fields.Integer
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'model': fields.Raw,
'user_input_form': fields.Raw,
'pre_prompt': fields.String,
'agent_mode': fields.Raw,
}
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
simple_message_detail_fields = {
'inputs': fields.Raw,
'query': fields.String,
'message': MessageTextField,
'answer': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(),
'from_account_id': fields.String,
'read_at': TimestampField,
'created_at': TimestampField,
'annotation': fields.Nested(annotation_fields, allow_null=True),
'model_config': fields.Nested(simple_model_config_fields),
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
conversation_message_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'model_config': fields.Nested(model_config_fields),
'message': fields.Nested(message_detail_fields, attribute='first_message'),
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
conversation_with_summary_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(simple_model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
conversation_with_summary_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items')
}
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
simple_conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(simple_conversation_fields))
}
conversation_with_model_config_fields = {
**simple_conversation_fields,
'model_config': fields.Raw,
}
conversation_with_model_config_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
}
\ No newline at end of file
from flask_restful import fields
from libs.helper import TimestampField
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'is_bound': fields.Boolean,
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
}
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)),
'total': fields.Integer
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}
\ No newline at end of file
from flask_restful import fields
from libs.helper import TimestampField
dataset_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'provider': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'app_count': fields.Integer,
'document_count': fields.Integer,
'word_count': fields.Integer,
'created_by': fields.String,
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,
"created_by": fields.String,
"created_at": TimestampField
}
from flask_restful import fields
from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField
document_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
}
document_with_segments_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'completed_segments': fields.Integer,
'total_segments': fields.Integer
}
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'documents': fields.List(fields.Nested(document_fields)),
'batch': fields.String
}
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
\ No newline at end of file
from flask_restful import fields
from libs.helper import TimestampField
upload_config_fields = {
'file_size_limit': fields.Integer,
'batch_count_limit': fields.Integer
}
file_fields = {
'id': fields.String,
'name': fields.String,
'size': fields.Integer,
'extension': fields.String,
'mime_type': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
\ No newline at end of file
from flask_restful import fields
from libs.helper import TimestampField
document_fields = {
'id': fields.String,
'data_source_type': fields.String,
'name': fields.String,
'doc_type': fields.String,
}
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'document': fields.Nested(document_fields),
}
hit_testing_record_fields = {
'segment': fields.Nested(segment_fields),
'score': fields.Float,
'tsne_position': fields.Raw
}
\ No newline at end of file
from flask_restful import fields
from libs.helper import TimestampField
app_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String
}
installed_app_fields = {
'id': fields.String,
'app': fields.Nested(app_fields),
'app_owner_tenant_id': fields.String,
'is_pinned': fields.Boolean,
'last_used_at': TimestampField,
'editable': fields.Boolean,
'uninstallable': fields.Boolean,
}
installed_app_list_fields = {
'installed_apps': fields.List(fields.Nested(installed_app_fields))
}
\ No newline at end of file
from flask_restful import fields
from libs.helper import TimestampField
feedback_fields = {
'rating': fields.String
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
\ No newline at end of file
from flask_restful import fields
from libs.helper import TimestampField
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField
}
segment_list_response = {
'data': fields.List(fields.Nested(segment_fields)),
'has_more': fields.Boolean,
'limit': fields.Integer
}
"""add_tenant_id_in_api_token
Revision ID: 2e9819ca5b28
Revises: 6e2cfb077b04
Create Date: 2023-09-22 15:41:01.243183
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '2e9819ca5b28'
down_revision = 'ab23c11305d4'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
batch_op.drop_column('dataset_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
batch_op.drop_index('api_token_tenant_idx')
batch_op.drop_column('tenant_id')
# ### end Alembic commands ###
...@@ -629,12 +629,13 @@ class ApiToken(db.Model): ...@@ -629,12 +629,13 @@ class ApiToken(db.Model):
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='api_token_pkey'), db.PrimaryKeyConstraint('id', name='api_token_pkey'),
db.Index('api_token_app_id_type_idx', 'app_id', 'type'), db.Index('api_token_app_id_type_idx', 'app_id', 'type'),
db.Index('api_token_token_idx', 'token', 'type') db.Index('api_token_token_idx', 'token', 'type'),
db.Index('api_token_tenant_idx', 'tenant_id', 'type')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=True) app_id = db.Column(UUID, nullable=True)
dataset_id = db.Column(UUID, nullable=True) tenant_id = db.Column(UUID, nullable=True)
type = db.Column(db.String(16), nullable=False) type = db.Column(db.String(16), nullable=False)
token = db.Column(db.String(255), nullable=False) token = db.Column(db.String(255), nullable=False)
last_used_at = db.Column(db.DateTime, nullable=True) last_used_at = db.Column(db.DateTime, nullable=True)
......
...@@ -96,7 +96,7 @@ class DatasetService: ...@@ -96,7 +96,7 @@ class DatasetService:
embedding_model = None embedding_model = None
if indexing_technique == 'high_quality': if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id tenant_id=tenant_id
) )
dataset = Dataset(name=name, indexing_technique=indexing_technique) dataset = Dataset(name=name, indexing_technique=indexing_technique)
# dataset = Dataset(name=name, provider=provider, config=config) # dataset = Dataset(name=name, provider=provider, config=config)
...@@ -477,6 +477,7 @@ class DocumentService: ...@@ -477,6 +477,7 @@ class DocumentService:
) )
dataset.collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding.id
documents = [] documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
if 'original_document_id' in document_data and document_data["original_document_id"]: if 'original_document_id' in document_data and document_data["original_document_id"]:
...@@ -626,6 +627,9 @@ class DocumentService: ...@@ -626,6 +627,9 @@ class DocumentService:
document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
if document.display_status != 'available': if document.display_status != 'available':
raise ValueError("Document is not available") raise ValueError("Document is not available")
# update document name
if 'name' in document_data and document_data['name']:
document.name = document_data['name']
# save process rule # save process rule
if 'process_rule' in document_data and document_data['process_rule']: if 'process_rule' in document_data and document_data['process_rule']:
process_rule = document_data["process_rule"] process_rule = document_data["process_rule"]
...@@ -1014,6 +1018,66 @@ class SegmentService: ...@@ -1014,6 +1018,66 @@ class SegmentService:
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
return segment return segment
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == document.id
).scalar()
pre_segment_data_list = []
segment_data_list = []
for segment_item in segments:
content = segment_item['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality' and embedding_model:
# calc embedding use tokens
tokens = embedding_model.get_num_tokens(content)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
status='completed',
indexing_at=datetime.datetime.utcnow(),
completed_at=datetime.datetime.utcnow(),
created_by=current_user.id
)
if document.doc_form == 'qa_model':
segment_document.answer = segment_item['answer']
db.session.add(segment_document)
segment_data_list.append(segment_document)
pre_segment_data = {
'segment': segment_document,
'keywords': segment_item['keywords']
}
pre_segment_data_list.append(pre_segment_data)
try:
# save vector index
VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
except Exception as e:
logging.exception("create segment index failed")
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = datetime.datetime.utcnow()
segment_document.status = 'error'
segment_document.error = str(e)
db.session.commit()
return segment_data_list
@classmethod @classmethod
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id) indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
__all__ = [ __all__ = [
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
'app', 'completion', 'audio' 'app', 'completion', 'audio', 'file'
] ]
from . import * from . import *
...@@ -3,3 +3,11 @@ from services.errors.base import BaseServiceError ...@@ -3,3 +3,11 @@ from services.errors.base import BaseServiceError
class FileNotExistsError(BaseServiceError): class FileNotExistsError(BaseServiceError):
pass pass
class FileTooLargeError(BaseServiceError):
description = "{message}"
class UnsupportedFileTypeError(BaseServiceError):
pass
import datetime
import hashlib
import time
import uuid
from cachetools import TTLCache
from flask import request, current_app
from flask_login import current_user
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from core.data_loader.file_extractor import FileExtractor
from extensions.ext_storage import storage
from extensions.ext_database import db
from models.model import UploadFile
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
PREVIEW_WORDS_LIMIT = 3000
cache = TTLCache(maxsize=None, ttl=30)
class FileService:
@staticmethod
def upload_file(file: FileStorage) -> UploadFile:
# read file content
file_content = file.read()
# get file size
file_size = len(file_content)
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
if file_size > file_size_limit:
message = f'File size exceeded. {file_size} > {file_size_limit}'
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
# save file to storage
storage.save(file_key, file_content)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file.filename,
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by=current_user.id,
created_at=datetime.datetime.utcnow(),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
)
db.session.add(upload_file)
db.session.commit()
return upload_file
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
# save file to storage
storage.save(file_key, text.encode('utf-8'))
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=text_name + '.txt',
size=len(text),
extension='txt',
mime_type='text/plain',
created_by=current_user.id,
created_at=datetime.datetime.utcnow(),
used=True,
used_by=current_user.id,
used_at=datetime.datetime.utcnow()
)
db.session.add(upload_file)
db.session.commit()
return upload_file
@staticmethod
def get_file_preview(file_id: str) -> str:
# get file storage key
key = file_id + request.path
cached_response = cache.get(key)
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
return cached_response['response']
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found")
# extract text from file
extension = upload_file.extension
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
text = FileExtractor.load(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return text
...@@ -35,6 +35,32 @@ class VectorService: ...@@ -35,6 +35,32 @@ class VectorService:
else: else:
index.add_texts([document]) index.add_texts([document])
@classmethod
def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
documents = []
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents, duplicate_check=True)
# save keyword index
keyword_index = IndexBuilder.get_index(dataset, 'economy')
if keyword_index:
keyword_index.multi_create_segment_keywords(pre_segment_data_list)
@classmethod @classmethod
def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset): def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
# update segment index task # update segment index task
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment