Commit 3a98c636 authored by Jyong's avatar Jyong

add notion import indexing estimate interface

parent dbd2babb
...@@ -9,7 +9,7 @@ from pathlib import Path ...@@ -9,7 +9,7 @@ from pathlib import Path
from cachetools import TTLCache from cachetools import TTLCache
from flask import request, current_app from flask import request, current_app
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_restful import Resource, marshal_with, fields from flask_restful import Resource, marshal_with, fields, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
...@@ -20,6 +20,7 @@ from controllers.console.wraps import account_initialization_required ...@@ -20,6 +20,7 @@ from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader from core.data_source.notion import NotionPageReader
from core.index.readers.html_parser import HTMLParser from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser from core.index.readers.pdf_parser import PDFParser
from core.indexing_runner import IndexingRunner
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
...@@ -184,10 +185,15 @@ class DataSourceNotionApi(Resource): ...@@ -184,10 +185,15 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
segment_rule = request.get_json() notion_import_info = request.get_json()
parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
# validate args
DocumentService.notion_estimate_args_validate(args)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(file_detail, segment_rule['process_rule']) response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
return response, 200 return response, 200
......
...@@ -5,6 +5,8 @@ import tempfile ...@@ -5,6 +5,8 @@ import tempfile
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional, List from typing import Optional, List
from flask_login import current_user
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from llama_index import SimpleDirectoryReader from llama_index import SimpleDirectoryReader
...@@ -14,6 +16,7 @@ from llama_index.node_parser import SimpleNodeParser, NodeParser ...@@ -14,6 +16,7 @@ from llama_index.node_parser import SimpleNodeParser, NodeParser
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
from llama_index.readers.file.markdown_parser import MarkdownParser from llama_index.readers.file.markdown_parser import MarkdownParser
from core.data_source.notion import NotionPageReader
from core.docstore.dataset_docstore import DatesetDocumentStore from core.docstore.dataset_docstore import DatesetDocumentStore
from core.index.keyword_table_index import KeywordTableIndex from core.index.keyword_table_index import KeywordTableIndex
from core.index.readers.html_parser import HTMLParser from core.index.readers.html_parser import HTMLParser
...@@ -26,6 +29,7 @@ from extensions.ext_redis import redis_client ...@@ -26,6 +29,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceBinding
class IndexingRunner: class IndexingRunner:
...@@ -201,43 +205,59 @@ class IndexingRunner: ...@@ -201,43 +205,59 @@ class IndexingRunner:
"preview": preview_texts "preview": preview_texts
} }
def notion_indexing_estimate(self, notion_info, tmp_processing_rule: dict) -> dict: def notion_indexing_estimate(self, notion_info_list: dict, tmp_processing_rule: dict) -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
# load data from file # load data from notion
text_docs = self._load_data_from_file(file_detail)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
text_docs=text_docs,
node_parser=node_parser,
processing_rule=processing_rule
)
tokens = 0 tokens = 0
preview_texts = [] preview_texts = []
for node in nodes: total_segments = 0
if len(preview_texts) < 5: for notion_info in notion_info_list:
preview_texts.append(node.get_text()) data_source_binding = DataSourceBinding.query.filter(
db.and_(
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == notion_info['workspace_id']
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
for page in notion_info['pages']:
page_ids = [page['page_id']]
documents = reader.load_data(page_ids=page_ids)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
text_docs=documents,
node_parser=node_parser,
processing_rule=processing_rule
)
total_segments += len(nodes)
for node in nodes:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
return { return {
"total_segments": len(nodes), "total_segments": len(total_segments),
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name), "currency": TokenCalculator.get_currency(self.embedding_model_name),
"preview": preview_texts "preview": preview_texts
} }
def _load_data(self, document: Document) -> List[Document]: def _load_data(self, document: Document) -> List[Document]:
# load file # load file
if document.data_source_type != "upload_file": if document.data_source_type != "upload_file":
......
...@@ -519,3 +519,78 @@ class DocumentService: ...@@ -519,3 +519,78 @@ class DocumentService:
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid") raise ValueError("Process rule segmentation max_tokens is invalid")
@classmethod
def notion_estimate_args_validate(cls, args: dict):
if 'notion_info_list' not in args or not args['notion_info_list']:
raise ValueError("Notion info is required")
if not isinstance(args['notion_info_list'], dict):
raise ValueError("Notion info is invalid")
if 'process_rule' not in args or not args['process_rule']:
raise ValueError("Process rule is required")
if not isinstance(args['process_rule'], dict):
raise ValueError("Process rule is invalid")
if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:
raise ValueError("Process rule mode is required")
if args['process_rule']['mode'] not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if args['process_rule']['mode'] == 'automatic':
args['process_rule']['rules'] = {}
else:
if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:
raise ValueError("Process rule rules is required")
if not isinstance(args['process_rule']['rules'], dict):
raise ValueError("Process rule rules is invalid")
if 'pre_processing_rules' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['pre_processing_rules'] is None:
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
raise ValueError("Process rule pre_processing_rules is invalid")
unique_pre_processing_rule_dicts = {}
for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:
if 'id' not in pre_processing_rule or not pre_processing_rule['id']:
raise ValueError("Process rule pre_processing_rules id is required")
if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:
raise ValueError("Process rule pre_processing_rules id is invalid")
if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:
raise ValueError("Process rule pre_processing_rules enabled is required")
if not isinstance(pre_processing_rule['enabled'], bool):
raise ValueError("Process rule pre_processing_rules enabled is invalid")
unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
if 'segmentation' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['segmentation'] is None:
raise ValueError("Process rule segmentation is required")
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
raise ValueError("Process rule segmentation is invalid")
if 'separator' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['separator']:
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
raise ValueError("Process rule segmentation separator is invalid")
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['max_tokens']:
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment