Commit b41a4766 authored by jyong's avatar jyong

merge file and notion indexing-estimate

parent e0c1d94f
...@@ -13,7 +13,7 @@ from core.indexing_runner import IndexingRunner ...@@ -13,7 +13,7 @@ from core.indexing_runner import IndexingRunner
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
from models.model import UploadFile from models.model import UploadFile
from services.dataset_service import DatasetService from services.dataset_service import DatasetService, DocumentService
dataset_detail_fields = { dataset_detail_fields = {
'id': fields.String, 'id': fields.String,
...@@ -217,17 +217,30 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -217,17 +217,30 @@ class DatasetIndexingEstimateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
segment_rule = request.get_json() parser = reqparse.RequestParser()
file_detail = db.session.query(UploadFile).filter( parser.add_argument('info_list', type=list, required=True, nullable=True, location='json')
UploadFile.tenant_id == current_user.current_tenant_id, parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
UploadFile.id == segment_rule["file_id"] args = parser.parse_args()
).first() # validate args
DocumentService.estimate_args_validate(args)
if file_detail is None: if args['info_list']['data_source_type'] == 'upload_file':
raise NotFound("File not found.") file_details = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
indexing_runner = IndexingRunner() UploadFile.id in args['info_list']['file_info_list']['file_ids']
response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule']) ).all()
if file_details is None:
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'])
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule'])
else:
raise ValueError('Data source type not support')
return response, 200 return response, 200
...@@ -277,5 +290,5 @@ class DatasetRelatedAppListApi(Resource): ...@@ -277,5 +290,5 @@ class DatasetRelatedAppListApi(Resource):
api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate') api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps') api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
...@@ -335,7 +335,8 @@ class DocumentIndexingEstimateApi(DocumentResource): ...@@ -335,7 +335,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound('File not found.') raise NotFound('File not found.')
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(file, data_process_rule_dict)
response = indexing_runner.file_indexing_estimate(list(file), data_process_rule_dict)
return response return response
......
...@@ -168,38 +168,40 @@ class IndexingRunner: ...@@ -168,38 +168,40 @@ class IndexingRunner:
nodes=nodes nodes=nodes
) )
def indexing_estimate(self, file_detail: UploadFile, tmp_processing_rule: dict) -> dict: def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
# load data from file tokens = 0
text_docs = self._load_data_from_file(file_detail) preview_texts = []
total_segments = 0
processing_rule = DatasetProcessRule( for file_detail in file_details:
mode=tmp_processing_rule["mode"], # load data from file
rules=json.dumps(tmp_processing_rule["rules"]) text_docs = self._load_data_from_file(file_detail)
)
# get node parser for splitting processing_rule = DatasetProcessRule(
node_parser = self._get_node_parser(processing_rule) mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# split to nodes # get node parser for splitting
nodes = self._split_to_nodes( node_parser = self._get_node_parser(processing_rule)
text_docs=text_docs,
node_parser=node_parser,
processing_rule=processing_rule
)
tokens = 0 # split to nodes
preview_texts = [] nodes = self._split_to_nodes(
for node in nodes: text_docs=text_docs,
if len(preview_texts) < 5: node_parser=node_parser,
preview_texts.append(node.get_text()) processing_rule=processing_rule
)
total_segments += len(nodes)
for node in nodes:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
return { return {
"total_segments": len(nodes), "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name), "currency": TokenCalculator.get_currency(self.embedding_model_name),
......
...@@ -707,11 +707,11 @@ class DocumentService: ...@@ -707,11 +707,11 @@ class DocumentService:
raise ValueError("Process rule segmentation max_tokens is invalid") raise ValueError("Process rule segmentation max_tokens is invalid")
@classmethod @classmethod
def notion_estimate_args_validate(cls, args: dict): def estimate_args_validate(cls, args: dict):
if 'notion_info_list' not in args or not args['notion_info_list']: if 'info_list' not in args or not args['info_list']:
raise ValueError("Notion info is required") raise ValueError("Data source info is required")
if not isinstance(args['notion_info_list'], list): if not isinstance(args['info_list'], dict):
raise ValueError("Notion info is invalid") raise ValueError("Notion info is invalid")
if 'process_rule' not in args or not args['process_rule']: if 'process_rule' not in args or not args['process_rule']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment