Commit ff1aadf9 authored by Jyong's avatar Jyong

add notion sync

parent 8c6f5add
import datetime
import json
from cachetools import TTLCache
from flask import request, current_app
from flask_login import login_required, current_user
......@@ -28,7 +27,6 @@ PREVIEW_WORDS_LIMIT = 3000
class DataSourceApi(Resource):
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
......@@ -233,7 +231,25 @@ class DataSourceNotionApi(Resource):
return response, 200
class DataSourceNotionSyncApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
for document in documents:
return 200
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
api.add_resource(DataSourceNotionApi, '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/preview',
'/datasets/notion-indexing-estimate')
api.add_resource(DataSourceNotionSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
"""Notion reader."""
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
import requests # type: ignore
......@@ -12,6 +13,7 @@ INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
logger = logging.getLogger(__name__)
......@@ -233,6 +235,17 @@ class NotionPageReader(BaseReader):
return docs
def get_page_last_edited_time(self, page_id: str) -> str:
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", retrieve_page_url, headers=self.headers, json=query_dict
)
data = res.json()
# last_edited_time = datetime.fromisoformat(data["last_edited_time"])
return data["last_edited_time"]
if __name__ == "__main__":
reader = NotionPageReader()
......
......@@ -280,7 +280,21 @@ class IndexingRunner:
if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id'], document.tenant_id)
workspace_id = data_source_info['notion_workspace_id']
page_id = data_source_info['notion_page_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == document.tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
# add page last_edited_time to data_source_info
self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_data_from_notion(page_id, data_source_binding.access_token)
# update document status to splitting
self._update_document_index_status(
document_id=document.id,
......@@ -320,22 +334,24 @@ class IndexingRunner:
return text_docs
def _load_data_from_notion(self, workspace_id: str, page_id: str, tenant_id: str) -> List[Document]:
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
def _load_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
page_ids = [page_id]
reader = NotionPageReader(integration_token=data_source_binding.access_token)
reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(page_ids=page_ids)
return text_docs
def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_page_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
"""
Get the NodeParser object according to the processing rule.
......
......@@ -62,14 +62,27 @@ class NotionOAuth(OAuthDataSource):
'total': len(pages)
}
# save data source binding
data_source_binding = DataSourceBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
provider='notion'
)
db.session.add(data_source_binding)
db.session.commit()
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.access_token == access_token
)
).first()
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
db.session.add(data_source_binding)
db.session.commit()
else:
new_data_source_binding = DataSourceBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
provider='notion'
)
db.session.add(new_data_source_binding)
db.session.commit()
def get_authorized_pages(self, access_token: str):
pages = []
......
......@@ -277,6 +277,15 @@ class DocumentService:
return document
@staticmethod
def get_document_by_dataset_id(dataset_id: str) -> List[Document]:
documents = db.session.query(Document).filter(
Document.dataset_id == dataset_id,
Document.enabled == True
).all()
return documents
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:
documents = db.session.query(Document).filter(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment