Commit 42c4ab73 authored by Jyong's avatar Jyong

add data source binding init and check dataset index status

parent f1f5d45d
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import random import random
from datetime import datetime from datetime import datetime
from typing import List
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
...@@ -83,6 +84,22 @@ class DocumentResource(Resource): ...@@ -83,6 +84,22 @@ class DocumentResource(Resource):
return document return document
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
documents = DocumentService.get_batch_documents(dataset_id, batch)
if not documents:
raise NotFound('Documents not found.')
return documents
class GetProcessRuleApi(Resource): class GetProcessRuleApi(Resource):
@setup_required @setup_required
...@@ -340,23 +357,25 @@ class DocumentIndexingStatusApi(DocumentResource): ...@@ -340,23 +357,25 @@ class DocumentIndexingStatusApi(DocumentResource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, batch):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) batch = str(batch)
document = self.get_document(dataset_id, document_id) documents = self.get_batch_documents(dataset_id, batch)
documents_status = []
completed_segments = DocumentSegment.query \ for document in documents:
.filter(DocumentSegment.completed_at.isnot(None), completed_segments = DocumentSegment.query \
DocumentSegment.document_id == str(document_id)) \ .filter(DocumentSegment.completed_at.isnot(None),
.count() DocumentSegment.document_id == str(document.id)) \
total_segments = DocumentSegment.query \ .count()
.filter_by(document_id=str(document_id)) \ total_segments = DocumentSegment.query \
.count() .filter_by(document_id=str(document.id)) \
.count()
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
documents_status.append(marshal(document, self.document_status_fields))
return marshal(document, self.document_status_fields) return documents_status
class DocumentDetailApi(DocumentResource): class DocumentDetailApi(DocumentResource):
...@@ -676,7 +695,7 @@ api.add_resource(DatasetInitApi, ...@@ -676,7 +695,7 @@ api.add_resource(DatasetInitApi,
api.add_resource(DocumentIndexingEstimateApi, api.add_resource(DocumentIndexingEstimateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate')
api.add_resource(DocumentIndexingStatusApi, api.add_resource(DocumentIndexingStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status') '/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status')
api.add_resource(DocumentDetailApi, api.add_resource(DocumentDetailApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DocumentProcessingApi, api.add_resource(DocumentProcessingApi,
......
...@@ -141,7 +141,7 @@ class NotionPageReader(BaseReader): ...@@ -141,7 +141,7 @@ class NotionPageReader(BaseReader):
def read_page_as_documents(self, page_id: str) -> List[str]: def read_page_as_documents(self, page_id: str) -> List[str]:
"""Read a page as documents.""" """Read a page as documents."""
return self._read_block(page_id) return self._read_parent_blocks(page_id)
def query_database( def query_database(
self, database_id: str, query_dict: Dict[str, Any] = {} self, database_id: str, query_dict: Dict[str, Any] = {}
...@@ -212,6 +212,26 @@ class NotionPageReader(BaseReader): ...@@ -212,6 +212,26 @@ class NotionPageReader(BaseReader):
return docs return docs
def load_data_as_documents(
self, page_ids: List[str] = [], database_id: Optional[str] = None
) -> List[Document]:
if not page_ids and not database_id:
raise ValueError("Must specify either `page_ids` or `database_id`.")
docs = []
if database_id is not None:
# get all the pages in the database
page_ids = self.query_database(database_id)
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id}))
else:
for page_id in page_ids:
page_text_list = self.read_page_as_documents(page_id)
for page_text in page_text_list:
docs.append(Document(page_text, extra_info={"page_id": page_id}))
return docs
if __name__ == "__main__": if __name__ == "__main__":
reader = NotionPageReader() reader = NotionPageReader()
......
...@@ -332,7 +332,7 @@ class IndexingRunner: ...@@ -332,7 +332,7 @@ class IndexingRunner:
raise ValueError('Data source binding not found.') raise ValueError('Data source binding not found.')
page_ids = [page_id] page_ids = [page_id]
reader = NotionPageReader(integration_token=data_source_binding.access_token) reader = NotionPageReader(integration_token=data_source_binding.access_token)
text_docs = reader.load_data(page_ids=page_ids) text_docs = reader.load_data_as_documents(page_ids=page_ids)
return text_docs return text_docs
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
......
...@@ -6,7 +6,7 @@ from sqlalchemy.dialects.postgresql import JSONB ...@@ -6,7 +6,7 @@ from sqlalchemy.dialects.postgresql import JSONB
class DataSourceBinding(db.Model): class DataSourceBinding(db.Model):
__tablename__ = 'data_source_bindings' __tablename__ = 'data_source_bindings'
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='app_pkey'), db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
db.Index('app_tenant_id_idx', 'tenant_id') db.Index('app_tenant_id_idx', 'tenant_id')
) )
......
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
import datetime import datetime
import time import time
import random import random
from typing import Optional from typing import Optional, List
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from flask_login import current_user from flask_login import current_user
...@@ -278,6 +278,15 @@ class DocumentService: ...@@ -278,6 +278,15 @@ class DocumentService:
return document return document
@staticmethod @staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:
documents = db.session.query(Document).filter(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id
).all()
return documents
@staticmethod
def get_document_file_detail(file_id: str): def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile). \ file_detail = db.session.query(UploadFile). \
filter(UploadFile.id == file_id). \ filter(UploadFile.id == file_id). \
...@@ -376,6 +385,7 @@ class DocumentService: ...@@ -376,6 +385,7 @@ class DocumentService:
db.session.add(dataset_process_rule) db.session.add(dataset_process_rule)
db.session.commit() db.session.commit()
position = DocumentService.get_documents_position(dataset.id) position = DocumentService.get_documents_position(dataset.id)
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
document_ids = [] document_ids = []
documents = [] documents = []
if document_data["data_source"]["type"] == "upload_file": if document_data["data_source"]["type"] == "upload_file":
...@@ -398,7 +408,7 @@ class DocumentService: ...@@ -398,7 +408,7 @@ class DocumentService:
document = DocumentService.save_document(dataset, dataset_process_rule.id, document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"], document_data["data_source"]["type"],
data_source_info, created_from, position, data_source_info, created_from, position,
account, file_name) account, file_name, batch)
db.session.add(document) db.session.add(document)
db.session.flush() db.session.flush()
document_ids.append(document.id) document_ids.append(document.id)
...@@ -426,7 +436,7 @@ class DocumentService: ...@@ -426,7 +436,7 @@ class DocumentService:
document = DocumentService.save_document(dataset, dataset_process_rule.id, document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"], document_data["data_source"]["type"],
data_source_info, created_from, position, data_source_info, created_from, position,
account, page['page_name']) account, page['page_name'], batch)
db.session.add(document) db.session.add(document)
db.session.flush() db.session.flush()
document_ids.append(document.id) document_ids.append(document.id)
...@@ -442,7 +452,7 @@ class DocumentService: ...@@ -442,7 +452,7 @@ class DocumentService:
@staticmethod @staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict, def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
created_from: str, position: int, account: Account, name: str): created_from: str, position: int, account: Account, name: str, batch: str):
document = Document( document = Document(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
...@@ -450,7 +460,7 @@ class DocumentService: ...@@ -450,7 +460,7 @@ class DocumentService:
data_source_type=data_source_type, data_source_type=data_source_type,
data_source_info=json.dumps(data_source_info), data_source_info=json.dumps(data_source_info),
dataset_process_rule_id=process_rule_id, dataset_process_rule_id=process_rule_id,
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)), batch=batch,
name=name, name=name,
created_from=created_from, created_from=created_from,
created_by=account.id, created_by=account.id,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment