Commit 581e9068 authored by jyong's avatar jyong

clean document when import notion not selected

parent b41a4766
...@@ -64,6 +64,9 @@ class OAuthDataSourceCallback(Resource): ...@@ -64,6 +64,9 @@ class OAuthDataSourceCallback(Resource):
class OAuthDataSourceSync(Resource): class OAuthDataSourceSync(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, binding_id): def get(self, provider, binding_id):
provider = str(provider) provider = str(provider)
binding_id = str(binding_id) binding_id = str(binding_id)
......
...@@ -127,7 +127,8 @@ class DataSourceNotionListApi(Resource): ...@@ -127,7 +127,8 @@ class DataSourceNotionListApi(Resource):
integrate_page_fields = { integrate_page_fields = {
'page_name': fields.String, 'page_name': fields.String,
'page_id': fields.String, 'page_id': fields.String,
'page_icon': fields.String 'page_icon': fields.String,
'is_bound': fields.Boolean
} }
integrate_workspace_fields = { integrate_workspace_fields = {
'workspace_name': fields.String, 'workspace_name': fields.String,
...@@ -160,8 +161,9 @@ class DataSourceNotionListApi(Resource): ...@@ -160,8 +161,9 @@ class DataSourceNotionListApi(Resource):
enabled=True enabled=True
).all() ).all()
if documents: if documents:
page_ids = list(map(lambda item: item.data_source_info, documents)) for document in documents:
exist_page_ids.append(page_ids) data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info['notion_page_id'])
# get all authorized pages # get all authorized pages
data_source_bindings = DataSourceBinding.query.filter_by( data_source_bindings = DataSourceBinding.query.filter_by(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
...@@ -179,13 +181,17 @@ class DataSourceNotionListApi(Resource): ...@@ -179,13 +181,17 @@ class DataSourceNotionListApi(Resource):
'CONSOLE_URL') + '/console/api/oauth/data-source/authorize/notion') 'CONSOLE_URL') + '/console/api/oauth/data-source/authorize/notion')
pages = notion_oauth.get_authorized_pages(data_source_binding.access_token) pages = notion_oauth.get_authorized_pages(data_source_binding.access_token)
# Filter out already bound pages # Filter out already bound pages
filter_pages = [page for page in pages if page['page_id'] not in exist_page_ids] for page in pages:
if page['page_id'] in exist_page_ids:
page['is_bound'] = True
else:
page['is_bound'] = False
source_info = json.loads(data_source_binding.source_info) source_info = json.loads(data_source_binding.source_info)
pre_import_info = { pre_import_info = {
'workspace_name': source_info['workspace_name'], 'workspace_name': source_info['workspace_name'],
'workspace_icon': source_info['workspace_icon'], 'workspace_icon': source_info['workspace_icon'],
'workspace_id': source_info['workspace_id'], 'workspace_id': source_info['workspace_id'],
'pages': filter_pages, 'pages': pages,
} }
pre_import_info_list.append(pre_import_info) pre_import_info_list.append(pre_import_info)
return { return {
...@@ -226,7 +232,7 @@ class DataSourceNotionApi(Resource): ...@@ -226,7 +232,7 @@ class DataSourceNotionApi(Resource):
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.notion_estimate_args_validate(args) DocumentService.estimate_args_validate(args)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule']) response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
return response, 200 return response, 200
......
...@@ -19,6 +19,7 @@ from services.errors.account import NoPermissionError ...@@ -19,6 +19,7 @@ from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError from services.errors.file import FileNotExistsError
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task from tasks.document_indexing_update_task import document_indexing_update_task
...@@ -363,9 +364,9 @@ class DocumentService: ...@@ -363,9 +364,9 @@ class DocumentService:
@staticmethod @staticmethod
def get_documents_position(dataset_id): def get_documents_position(dataset_id):
documents = Document.query.filter_by(dataset_id=dataset_id).all() document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
if documents: if document:
return len(documents) + 1 return document.position + 1
else: else:
return 1 return 1
...@@ -437,6 +438,19 @@ class DocumentService: ...@@ -437,6 +438,19 @@ class DocumentService:
position += 1 position += 1
elif document_data["data_source"]["type"] == "notion_import": elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info'] notion_info_list = document_data["data_source"]['info']
exist_page_ids = []
exist_document = dict()
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type='notion',
enabled=True
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info['notion_page_id'])
exist_document[data_source_info['notion_page_id']] = document.id
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id'] workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceBinding.query.filter(
...@@ -450,20 +464,25 @@ class DocumentService: ...@@ -450,20 +464,25 @@ class DocumentService:
if not data_source_binding: if not data_source_binding:
raise ValueError('Data source binding not found.') raise ValueError('Data source binding not found.')
for page in notion_info['pages']: for page in notion_info['pages']:
data_source_info = { if page['page_id'] not in exist_page_ids:
"notion_workspace_id": workspace_id, data_source_info = {
"notion_page_id": page['page_id'], "notion_workspace_id": workspace_id,
} "notion_page_id": page['page_id']
document = DocumentService.save_document(dataset, dataset_process_rule.id, }
document_data["data_source"]["type"], document = DocumentService.save_document(dataset, dataset_process_rule.id,
data_source_info, created_from, position, document_data["data_source"]["type"],
account, page['page_name'], batch) data_source_info, created_from, position,
db.session.add(document) account, page['page_name'], batch)
db.session.flush() db.session.add(document)
document_ids.append(document.id) db.session.flush()
documents.append(document) document_ids.append(document.id)
position += 1 documents.append(document)
position += 1
else:
exist_document.pop(page['page_id'])
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(exist_document.values(), dataset.id)
db.session.commit() db.session.commit()
# trigger async task # trigger async task
......
import logging
import time
from typing import List
import click
from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, Document
@shared_task
def clean_notion_document_task(document_ids: List[str], dataset_id: str):
"""
Clean document when document deleted.
:param document_ids: document ids
:param dataset_id: dataset id
Usage: clean_notion_document_task.delay(document_ids, dataset_id)
"""
logging.info(click.style('Start clean document when import form notion document deleted: {}'.format(dataset_id), fg='green'))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
for document_id in document_ids:
document = db.session.query(Document).filter(
Document.id == document_id
).first()
db.session.delete(document)
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
vector_index.del_nodes(index_node_ids)
# delete from keyword index
if index_node_ids:
keyword_table_index.del_nodes(index_node_ids)
for segment in segments:
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style('Clean document when import form notion document deleted end :: {} latency: {}'.format(
dataset_id, end_at - start_at),
fg='green'))
except Exception:
logging.exception("Cleaned document when import form notion document deleted failed")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment