Commit b41a4766 authored by jyong's avatar jyong

merge file and notion indexing-estimate

parent e0c1d94f
......@@ -13,7 +13,7 @@ from core.indexing_runner import IndexingRunner
from libs.helper import TimestampField
from extensions.ext_database import db
from models.model import UploadFile
from services.dataset_service import DatasetService
from services.dataset_service import DatasetService, DocumentService
dataset_detail_fields = {
'id': fields.String,
......@@ -217,17 +217,30 @@ class DatasetIndexingEstimateApi(Resource):
@login_required
@account_initialization_required
def post(self):
segment_rule = request.get_json()
file_detail = db.session.query(UploadFile).filter(
parser = reqparse.RequestParser()
parser.add_argument('info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
if args['info_list']['data_source_type'] == 'upload_file':
file_details = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id == segment_rule["file_id"]
).first()
UploadFile.id in args['info_list']['file_info_list']['file_ids']
).all()
if file_detail is None:
if file_details is None:
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule'])
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'])
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule'])
else:
raise ValueError('Data source type not support')
return response, 200
......@@ -277,5 +290,5 @@ class DatasetRelatedAppListApi(Resource):
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
......@@ -335,7 +335,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound('File not found.')
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(file, data_process_rule_dict)
response = indexing_runner.file_indexing_estimate(list(file), data_process_rule_dict)
return response
......
......@@ -168,10 +168,14 @@ class IndexingRunner:
nodes=nodes
)
def indexing_estimate(self, file_detail: UploadFile, tmp_processing_rule: dict) -> dict:
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
"""
Estimate the indexing for the document.
"""
tokens = 0
preview_texts = []
total_segments = 0
for file_detail in file_details:
# load data from file
text_docs = self._load_data_from_file(file_detail)
......@@ -189,9 +193,7 @@ class IndexingRunner:
node_parser=node_parser,
processing_rule=processing_rule
)
tokens = 0
preview_texts = []
total_segments += len(nodes)
for node in nodes:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
......@@ -199,7 +201,7 @@ class IndexingRunner:
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
return {
"total_segments": len(nodes),
"total_segments": total_segments,
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
......
......@@ -707,11 +707,11 @@ class DocumentService:
raise ValueError("Process rule segmentation max_tokens is invalid")
@classmethod
def notion_estimate_args_validate(cls, args: dict):
if 'notion_info_list' not in args or not args['notion_info_list']:
raise ValueError("Notion info is required")
def estimate_args_validate(cls, args: dict):
if 'info_list' not in args or not args['info_list']:
raise ValueError("Data source info is required")
if not isinstance(args['notion_info_list'], list):
if not isinstance(args['info_list'], dict):
raise ValueError("Notion info is invalid")
if 'process_rule' not in args or not args['process_rule']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment