Commit ff1aadf9 authored by Jyong's avatar Jyong

add notion sync

parent 8c6f5add
import datetime import datetime
import json import json
from cachetools import TTLCache from cachetools import TTLCache
from flask import request, current_app from flask import request, current_app
from flask_login import login_required, current_user from flask_login import login_required, current_user
...@@ -28,7 +27,6 @@ PREVIEW_WORDS_LIMIT = 3000 ...@@ -28,7 +27,6 @@ PREVIEW_WORDS_LIMIT = 3000
class DataSourceApi(Resource): class DataSourceApi(Resource):
integrate_page_fields = { integrate_page_fields = {
'page_name': fields.String, 'page_name': fields.String,
'page_id': fields.String, 'page_id': fields.String,
...@@ -233,7 +231,25 @@ class DataSourceNotionApi(Resource): ...@@ -233,7 +231,25 @@ class DataSourceNotionApi(Resource):
return response, 200 return response, 200
class DataSourceNotionSyncApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
for document in documents:
return 200
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>') api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
api.add_resource(DataSourceNotionApi, '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/preview', api.add_resource(DataSourceNotionApi, '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/preview',
'/datasets/notion-indexing-estimate') '/datasets/notion-indexing-estimate')
api.add_resource(DataSourceNotionSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
"""Notion reader.""" """Notion reader."""
import logging import logging
import os import os
from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests # type: ignore import requests # type: ignore
...@@ -12,6 +13,7 @@ INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN" ...@@ -12,6 +13,7 @@ INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search" SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -233,6 +235,17 @@ class NotionPageReader(BaseReader): ...@@ -233,6 +235,17 @@ class NotionPageReader(BaseReader):
return docs return docs
def get_page_last_edited_time(self, page_id: str) -> str:
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", retrieve_page_url, headers=self.headers, json=query_dict
)
data = res.json()
# last_edited_time = datetime.fromisoformat(data["last_edited_time"])
return data["last_edited_time"]
if __name__ == "__main__": if __name__ == "__main__":
reader = NotionPageReader() reader = NotionPageReader()
......
...@@ -280,7 +280,21 @@ class IndexingRunner: ...@@ -280,7 +280,21 @@ class IndexingRunner:
if not data_source_info or 'notion_page_id' not in data_source_info \ if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info: or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found") raise ValueError("no notion page found")
text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id'], document.tenant_id) workspace_id = data_source_info['notion_workspace_id']
page_id = data_source_info['notion_page_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == document.tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
# add page last_edited_time to data_source_info
self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_data_from_notion(page_id, data_source_binding.access_token)
# update document status to splitting # update document status to splitting
self._update_document_index_status( self._update_document_index_status(
document_id=document.id, document_id=document.id,
...@@ -320,22 +334,24 @@ class IndexingRunner: ...@@ -320,22 +334,24 @@ class IndexingRunner:
return text_docs return text_docs
def _load_data_from_notion(self, workspace_id: str, page_id: str, tenant_id: str) -> List[Document]: def _load_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
page_ids = [page_id] page_ids = [page_id]
reader = NotionPageReader(integration_token=data_source_binding.access_token) reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(page_ids=page_ids) text_docs = reader.load_data_as_documents(page_ids=page_ids)
return text_docs return text_docs
def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_page_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
""" """
Get the NodeParser object according to the processing rule. Get the NodeParser object according to the processing rule.
......
...@@ -62,14 +62,27 @@ class NotionOAuth(OAuthDataSource): ...@@ -62,14 +62,27 @@ class NotionOAuth(OAuthDataSource):
'total': len(pages) 'total': len(pages)
} }
# save data source binding # save data source binding
data_source_binding = DataSourceBinding( data_source_binding = DataSourceBinding.query.filter(
tenant_id=current_user.current_tenant_id, db.and_(
access_token=access_token, DataSourceBinding.tenant_id == current_user.current_tenant_id,
source_info=source_info, DataSourceBinding.provider == 'notion',
provider='notion' DataSourceBinding.access_token == access_token
) )
db.session.add(data_source_binding) ).first()
db.session.commit() if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
db.session.add(data_source_binding)
db.session.commit()
else:
new_data_source_binding = DataSourceBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
provider='notion'
)
db.session.add(new_data_source_binding)
db.session.commit()
def get_authorized_pages(self, access_token: str): def get_authorized_pages(self, access_token: str):
pages = [] pages = []
......
...@@ -277,6 +277,15 @@ class DocumentService: ...@@ -277,6 +277,15 @@ class DocumentService:
return document return document
@staticmethod
def get_document_by_dataset_id(dataset_id: str) -> List[Document]:
documents = db.session.query(Document).filter(
Document.dataset_id == dataset_id,
Document.enabled == True
).all()
return documents
@staticmethod @staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> List[Document]: def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:
documents = db.session.query(Document).filter( documents = db.session.query(Document).filter(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment