Commit 42c4ab73 authored by Jyong's avatar Jyong

add data source binding init and check dataset index status

parent f1f5d45d
# -*- coding:utf-8 -*-
import random
from datetime import datetime
from typing import List
from flask import request
from flask_login import login_required, current_user
......@@ -83,6 +84,22 @@ class DocumentResource(Resource):
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
documents = DocumentService.get_batch_documents(dataset_id, batch)
if not documents:
raise NotFound('Documents not found.')
return documents
class GetProcessRuleApi(Resource):
@setup_required
......@@ -340,23 +357,25 @@ class DocumentIndexingStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
def get(self, dataset_id, batch):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
completed_segments = DocumentSegment.query \
.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id)) \
.count()
total_segments = DocumentSegment.query \
.filter_by(document_id=str(document_id)) \
.count()
batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch)
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query \
.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id)) \
.count()
total_segments = DocumentSegment.query \
.filter_by(document_id=str(document.id)) \
.count()
document.completed_segments = completed_segments
document.total_segments = total_segments
document.completed_segments = completed_segments
document.total_segments = total_segments
documents_status.append(marshal(document, self.document_status_fields))
return marshal(document, self.document_status_fields)
return documents_status
class DocumentDetailApi(DocumentResource):
......@@ -676,7 +695,7 @@ api.add_resource(DatasetInitApi,
api.add_resource(DocumentIndexingEstimateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate')
api.add_resource(DocumentIndexingStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status')
'/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status')
api.add_resource(DocumentDetailApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DocumentProcessingApi,
......
......@@ -141,7 +141,7 @@ class NotionPageReader(BaseReader):
def read_page_as_documents(self, page_id: str) -> List[str]:
"""Read a page as documents."""
return self._read_block(page_id)
return self._read_parent_blocks(page_id)
def query_database(
self, database_id: str, query_dict: Dict[str, Any] = {}
......@@ -212,6 +212,26 @@ class NotionPageReader(BaseReader):
return docs
def load_data_as_documents(
self, page_ids: List[str] = [], database_id: Optional[str] = None
) -> List[Document]:
if not page_ids and not database_id:
raise ValueError("Must specify either `page_ids` or `database_id`.")
docs = []
if database_id is not None:
# get all the pages in the database
page_ids = self.query_database(database_id)
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id}))
else:
for page_id in page_ids:
page_text_list = self.read_page_as_documents(page_id)
for page_text in page_text_list:
docs.append(Document(page_text, extra_info={"page_id": page_id}))
return docs
if __name__ == "__main__":
reader = NotionPageReader()
......
......@@ -332,7 +332,7 @@ class IndexingRunner:
raise ValueError('Data source binding not found.')
page_ids = [page_id]
reader = NotionPageReader(integration_token=data_source_binding.access_token)
text_docs = reader.load_data(page_ids=page_ids)
text_docs = reader.load_data_as_documents(page_ids=page_ids)
return text_docs
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
......
......@@ -6,7 +6,7 @@ from sqlalchemy.dialects.postgresql import JSONB
class DataSourceBinding(db.Model):
__tablename__ = 'data_source_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='app_pkey'),
db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
db.Index('app_tenant_id_idx', 'tenant_id')
)
......
......@@ -3,7 +3,7 @@ import logging
import datetime
import time
import random
from typing import Optional
from typing import Optional, List
from extensions.ext_redis import redis_client
from flask_login import current_user
......@@ -278,6 +278,15 @@ class DocumentService:
return document
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:
documents = db.session.query(Document).filter(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id
).all()
return documents
@staticmethod
def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile). \
filter(UploadFile.id == file_id). \
......@@ -376,6 +385,7 @@ class DocumentService:
db.session.add(dataset_process_rule)
db.session.commit()
position = DocumentService.get_documents_position(dataset.id)
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
document_ids = []
documents = []
if document_data["data_source"]["type"] == "upload_file":
......@@ -398,7 +408,7 @@ class DocumentService:
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, file_name)
account, file_name, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
......@@ -426,7 +436,7 @@ class DocumentService:
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, page['page_name'])
account, page['page_name'], batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
......@@ -442,7 +452,7 @@ class DocumentService:
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
created_from: str, position: int, account: Account, name: str):
created_from: str, position: int, account: Account, name: str, batch: str):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
......@@ -450,7 +460,7 @@ class DocumentService:
data_source_type=data_source_type,
data_source_info=json.dumps(data_source_info),
dataset_process_rule_id=process_rule_id,
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment