Commit 8c6f5add authored by Jyong's avatar Jyong

fix notion import bugs

parent 1268f3bb
......@@ -185,6 +185,9 @@ class Config:
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
class CloudEditionConfig(Config):
......
......@@ -9,10 +9,10 @@ api = ExternalApi(bp)
from .app import app, site, explore, completion, model_config, statistic, conversation, message
# Import auth controllers
from .auth import login, oauth
from .auth import login, oauth, data_source_oauth
# Import datasets controllers
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import other controllers
from . import setup, version, apikey
......
......@@ -9,7 +9,7 @@ from flask_login import current_user, login_required
from flask_restful import Resource
from werkzeug.exceptions import Forbidden
from libs.oauth_data_source import NotionOAuth
from .. import api
from controllers.console import api
from ..setup import setup_required
from ..wraps import account_initialization_required
......@@ -29,9 +29,6 @@ def get_oauth_providers():
class OAuthDataSource(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
# The role of the current user in the table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
......@@ -66,5 +63,5 @@ class OAuthDataSourceCallback(Resource):
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
api.add_resource(OAuthDataSource, '/oauth/data-source/<provider>')
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<provider>')
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
import datetime
import hashlib
import json
import tempfile
import time
import uuid
from pathlib import Path
from cachetools import TTLCache
from flask import request, current_app
from flask_login import login_required, current_user
from flask_restful import Resource, marshal_with, fields, reqparse
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
UnsupportedFileTypeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from core.indexing_runner import IndexingRunner
from extensions.ext_storage import storage
from libs.helper import TimestampField
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.oauth_data_source import NotionOAuth
from models.dataset import Document
from models.model import UploadFile
from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService
......@@ -39,9 +29,35 @@ PREVIEW_WORDS_LIMIT = 3000
class DataSourceApi(Resource):
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.String,
'total': fields.Integer
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
# get workspace data source integrates
data_source_integrates = db.session.query(DataSourceBinding).filter(
......@@ -76,8 +92,7 @@ class DataSourceApi(Resource):
'disabled': None,
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
})
return {'data': integrate_data}
return {'data': integrate_data}, 200
@setup_required
@login_required
......@@ -110,10 +125,25 @@ class DataSourceApi(Resource):
class DataSourceNotionListApi(Resource):
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
def get(self):
dataset_id = request.args.get('dataset_id', default=None, type=str)
exist_page_ids = []
......@@ -143,9 +173,14 @@ class DataSourceNotionListApi(Resource):
raise NotFound('Data source binding not found.')
pre_import_info_list = []
for data_source_binding in data_source_bindings:
pages = NotionOAuth.get_authorized_pages(data_source_binding.access_token)
notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'),
client_secret=current_app.config.get(
'NOTION_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_URL') + '/console/api/oauth/data-source/authorize/notion')
pages = notion_oauth.get_authorized_pages(data_source_binding.access_token)
# Filter out already bound pages
filter_pages = filter(lambda page: page['page_id'] not in exist_page_ids, pages)
filter_pages = [page for page in pages if page['page_id'] not in exist_page_ids]
source_info = json.loads(data_source_binding.source_info)
pre_import_info = {
'workspace_name': source_info['workspace_name'],
......@@ -165,12 +200,14 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def get(self, workspace_id, page_id):
workspace_id = str(workspace_id)
page_id = str(page_id)
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
......@@ -185,9 +222,8 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def post(self):
notion_import_info = request.get_json()
parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
# validate args
......@@ -197,7 +233,7 @@ class DataSourceNotionApi(Resource):
return response, 200
api.add_resource(DataSourceApi, '/oauth/data-source/integrates')
api.add_resource(DataSourceApi, '/oauth/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
api.add_resource(DataSourceNotionApi, '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/preview')
api.add_resource(DataSourceNotionApi, '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/preview',
'/datasets/notion-indexing-estimate')
......@@ -251,7 +251,7 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource):
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'document': fields.Nested(document_fields)
'documents': fields.List(fields.Nested(document_fields))
}
@setup_required
......
......@@ -126,6 +126,7 @@ class NotionPageReader(BaseReader):
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None:
......@@ -204,11 +205,11 @@ class NotionPageReader(BaseReader):
page_ids = self.query_database(database_id)
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id}))
docs.append(Document(page_text))
else:
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id}))
docs.append(Document(page_text))
return docs
......@@ -223,12 +224,12 @@ class NotionPageReader(BaseReader):
page_ids = self.query_database(database_id)
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id}))
docs.append(Document(page_text))
else:
for page_id in page_ids:
page_text_list = self.read_page_as_documents(page_id)
for page_text in page_text_list:
docs.append(Document(page_text, extra_info={"page_id": page_id}))
docs.append(Document(page_text))
return docs
......
......@@ -215,12 +215,13 @@ class IndexingRunner:
preview_texts = []
total_segments = 0
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == notion_info['workspace_id']
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
......@@ -228,7 +229,7 @@ class IndexingRunner:
reader = NotionPageReader(integration_token=data_source_binding.access_token)
for page in notion_info['pages']:
page_ids = [page['page_id']]
documents = reader.load_data(page_ids=page_ids)
documents = reader.load_data_as_documents(page_ids=page_ids)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
......@@ -279,7 +280,7 @@ class IndexingRunner:
if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id'])
text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id'], document.tenant_id)
# update document status to splitting
self._update_document_index_status(
document_id=document.id,
......@@ -319,13 +320,13 @@ class IndexingRunner:
return text_docs
def _load_data_from_notion(self, workspace_id: str, page_id: str) -> List[Document]:
def _load_data_from_notion(self, workspace_id: str, page_id: str, tenant_id: str) -> List[Document]:
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
......
......@@ -65,7 +65,7 @@ class NotionOAuth(OAuthDataSource):
data_source_binding = DataSourceBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=json.dumps(source_info),
source_info=source_info,
provider='notion'
)
db.session.add(data_source_binding)
......
......@@ -190,7 +190,7 @@ class Document(db.Model):
doc_type = db.Column(db.String(40), nullable=True)
doc_metadata = db.Column(db.JSON, nullable=True)
DATA_SOURCES = ['upload_file']
DATA_SOURCES = ['upload_file', 'notion_import']
@property
def display_status(self):
......
......@@ -7,7 +7,8 @@ class DataSourceBinding(db.Model):
__tablename__ = 'data_source_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
db.Index('source_binding_tenant_id_idx', 'tenant_id')
db.Index('source_binding_tenant_id_idx', 'tenant_id'),
db.Index('source_info_idx', "source_info", postgresql_using='gin')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
......
......@@ -423,7 +423,7 @@ class DocumentService:
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
......@@ -581,7 +581,7 @@ class DocumentService:
if 'notion_info_list' not in args or not args['notion_info_list']:
raise ValueError("Notion info is required")
if not isinstance(args['notion_info_list'], dict):
if not isinstance(args['notion_info_list'], list):
raise ValueError("Notion info is invalid")
if 'process_rule' not in args or not args['process_rule']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment