Commit 3a98c636 authored by Jyong's avatar Jyong

add notion import indexing estimate interface

parent dbd2babb
......@@ -9,7 +9,7 @@ from pathlib import Path
from cachetools import TTLCache
from flask import request, current_app
from flask_login import login_required, current_user
from flask_restful import Resource, marshal_with, fields
from flask_restful import Resource, marshal_with, fields, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api
......@@ -20,6 +20,7 @@ from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from core.indexing_runner import IndexingRunner
from extensions.ext_storage import storage
from libs.helper import TimestampField
from extensions.ext_database import db
......@@ -184,10 +185,15 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def post(self):
segment_rule = request.get_json()
notion_import_info = request.get_json()
parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
# validate args
DocumentService.notion_estimate_args_validate(args)
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(file_detail, segment_rule['process_rule'])
response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
return response, 200
......
......@@ -5,6 +5,8 @@ import tempfile
import time
from pathlib import Path
from typing import Optional, List
from flask_login import current_user
from langchain.text_splitter import RecursiveCharacterTextSplitter
from llama_index import SimpleDirectoryReader
......@@ -14,6 +16,7 @@ from llama_index.node_parser import SimpleNodeParser, NodeParser
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
from llama_index.readers.file.markdown_parser import MarkdownParser
from core.data_source.notion import NotionPageReader
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.index.keyword_table_index import KeywordTableIndex
from core.index.readers.html_parser import HTMLParser
......@@ -26,6 +29,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
from models.model import UploadFile
from models.source import DataSourceBinding
class IndexingRunner:
......@@ -201,43 +205,59 @@ class IndexingRunner:
"preview": preview_texts
}
def notion_indexing_estimate(self, notion_info, tmp_processing_rule: dict) -> dict:
def notion_indexing_estimate(self, notion_info_list: dict, tmp_processing_rule: dict) -> dict:
"""
Estimate the indexing for the document.
"""
# load data from file
text_docs = self._load_data_from_file(file_detail)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
text_docs=text_docs,
node_parser=node_parser,
processing_rule=processing_rule
)
# load data from notion
tokens = 0
preview_texts = []
for node in nodes:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
total_segments = 0
for notion_info in notion_info_list:
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == notion_info['workspace_id']
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
for page in notion_info['pages']:
page_ids = [page['page_id']]
documents = reader.load_data(page_ids=page_ids)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
text_docs=documents,
node_parser=node_parser,
processing_rule=processing_rule
)
total_segments += len(nodes)
for node in nodes:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
return {
"total_segments": len(nodes),
"total_segments": len(total_segments),
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"preview": preview_texts
}
def _load_data(self, document: Document) -> List[Document]:
# load file
if document.data_source_type != "upload_file":
......
......@@ -519,3 +519,78 @@ class DocumentService:
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
@classmethod
def notion_estimate_args_validate(cls, args: dict):
if 'notion_info_list' not in args or not args['notion_info_list']:
raise ValueError("Notion info is required")
if not isinstance(args['notion_info_list'], dict):
raise ValueError("Notion info is invalid")
if 'process_rule' not in args or not args['process_rule']:
raise ValueError("Process rule is required")
if not isinstance(args['process_rule'], dict):
raise ValueError("Process rule is invalid")
if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
raise ValueError("Process rule mode is required")
if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if args['process_rule']['mode'] == 'automatic':
args['process_rule']['rules'] = {}
else:
if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
raise ValueError("Process rule rules is required")
if not isinstance(args['process_rule']['rules'], dict):
raise ValueError("Process rule rules is invalid")
if 'pre_processing_rules' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['pre_processing_rules'] is None:
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
raise ValueError("Process rule pre_processing_rules is invalid")
unique_pre_processing_rule_dicts = {}
for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
raise ValueError("Process rule pre_processing_rules id is required")
if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
raise ValueError("Process rule pre_processing_rules id is invalid")
if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
raise ValueError("Process rule pre_processing_rules enabled is required")
if not isinstance(pre_processing_rule['enabled'], bool):
raise ValueError("Process rule pre_processing_rules enabled is invalid")
unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
if 'segmentation' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['segmentation'] is None:
raise ValueError("Process rule segmentation is required")
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
raise ValueError("Process rule segmentation is invalid")
if 'separator' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['separator']:
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
raise ValueError("Process rule segmentation separator is invalid")
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['max_tokens']:
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment