Commit b41a4766 authored by jyong's avatar jyong

merge file and notion indexing-estimate

parent e0c1d94f
...@@ -13,7 +13,7 @@ from core.indexing_runner import IndexingRunner ...@@ -13,7 +13,7 @@ from core.indexing_runner import IndexingRunner
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
from models.model import UploadFile from models.model import UploadFile
from services.dataset_service import DatasetService from services.dataset_service import DatasetService, DocumentService
dataset_detail_fields = { dataset_detail_fields = {
'id': fields.String, 'id': fields.String,
...@@ -217,17 +217,30 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -217,17 +217,30 @@ class DatasetIndexingEstimateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
segment_rule = request.get_json() parser = reqparse.RequestParser()
file_detail = db.session.query(UploadFile).filter( parser.add_argument('info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
if args['info_list']['data_source_type'] == 'upload_file':
file_details = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id == segment_rule["file_id"] UploadFile.id in args['info_list']['file_info_list']['file_ids']
).first() ).all()
if file_detail is None: if file_details is None:
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule']) response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'])
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule'])
else:
raise ValueError('Data source type not support')
return response, 200 return response, 200
...@@ -277,5 +290,5 @@ class DatasetRelatedAppListApi(Resource): ...@@ -277,5 +290,5 @@ class DatasetRelatedAppListApi(Resource):
api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate') api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps') api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
...@@ -335,7 +335,8 @@ class DocumentIndexingEstimateApi(DocumentResource): ...@@ -335,7 +335,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound('File not found.') raise NotFound('File not found.')
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(file, data_process_rule_dict)
response = indexing_runner.file_indexing_estimate(list(file), data_process_rule_dict)
return response return response
......
...@@ -168,10 +168,14 @@ class IndexingRunner: ...@@ -168,10 +168,14 @@ class IndexingRunner:
nodes=nodes nodes=nodes
) )
def indexing_estimate(self, file_detail: UploadFile, tmp_processing_rule: dict) -> dict: def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
tokens = 0
preview_texts = []
total_segments = 0
for file_detail in file_details:
# load data from file # load data from file
text_docs = self._load_data_from_file(file_detail) text_docs = self._load_data_from_file(file_detail)
...@@ -189,9 +193,7 @@ class IndexingRunner: ...@@ -189,9 +193,7 @@ class IndexingRunner:
node_parser=node_parser, node_parser=node_parser,
processing_rule=processing_rule processing_rule=processing_rule
) )
total_segments += len(nodes)
tokens = 0
preview_texts = []
for node in nodes: for node in nodes:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(node.get_text()) preview_texts.append(node.get_text())
...@@ -199,7 +201,7 @@ class IndexingRunner: ...@@ -199,7 +201,7 @@ class IndexingRunner:
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
return { return {
"total_segments": len(nodes), "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name), "currency": TokenCalculator.get_currency(self.embedding_model_name),
......
...@@ -707,11 +707,11 @@ class DocumentService: ...@@ -707,11 +707,11 @@ class DocumentService:
raise ValueError("Process rule segmentation max_tokens is invalid") raise ValueError("Process rule segmentation max_tokens is invalid")
@classmethod @classmethod
def notion_estimate_args_validate(cls, args: dict): def estimate_args_validate(cls, args: dict):
if 'notion_info_list' not in args or not args['notion_info_list']: if 'info_list' not in args or not args['info_list']:
raise ValueError("Notion info is required") raise ValueError("Data source info is required")
if not isinstance(args['notion_info_list'], list): if not isinstance(args['info_list'], dict):
raise ValueError("Notion info is invalid") raise ValueError("Notion info is invalid")
if 'process_rule' not in args or not args['process_rule']: if 'process_rule' not in args or not args['process_rule']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment