Commit 8c6f5add authored by Jyong's avatar Jyong

fix notion import bugs

parent 1268f3bb
...@@ -185,6 +185,9 @@ class Config: ...@@ -185,6 +185,9 @@ class Config:
# For temp use only # For temp use only
# set default LLM provider, default is 'openai', support `azure_openai` # set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
class CloudEditionConfig(Config): class CloudEditionConfig(Config):
......
...@@ -9,10 +9,10 @@ api = ExternalApi(bp) ...@@ -9,10 +9,10 @@ api = ExternalApi(bp)
from .app import app, site, explore, completion, model_config, statistic, conversation, message from .app import app, site, explore, completion, model_config, statistic, conversation, message
# Import auth controllers # Import auth controllers
from .auth import login, oauth from .auth import login, oauth, data_source_oauth
# Import datasets controllers # Import datasets controllers
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import other controllers # Import other controllers
from . import setup, version, apikey from . import setup, version, apikey
......
...@@ -9,7 +9,7 @@ from flask_login import current_user, login_required ...@@ -9,7 +9,7 @@ from flask_login import current_user, login_required
from flask_restful import Resource from flask_restful import Resource
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from .. import api from controllers.console import api
from ..setup import setup_required from ..setup import setup_required
from ..wraps import account_initialization_required from ..wraps import account_initialization_required
...@@ -29,9 +29,6 @@ def get_oauth_providers(): ...@@ -29,9 +29,6 @@ def get_oauth_providers():
class OAuthDataSource(Resource): class OAuthDataSource(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str): def get(self, provider: str):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
...@@ -66,5 +63,5 @@ class OAuthDataSourceCallback(Resource): ...@@ -66,5 +63,5 @@ class OAuthDataSourceCallback(Resource):
return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success') return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success')
api.add_resource(OAuthDataSource, '/oauth/data-source/<provider>') api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<provider>') api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
import datetime import datetime
import hashlib
import json import json
import tempfile
import time
import uuid
from pathlib import Path
from cachetools import TTLCache from cachetools import TTLCache
from flask import request, current_app from flask import request, current_app
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_restful import Resource, marshal_with, fields, reqparse from flask_restful import Resource, marshal_with, fields, reqparse, marshal
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
UnsupportedFileTypeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader from core.data_source.notion import NotionPageReader
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from extensions.ext_storage import storage
from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from models.dataset import Document from models.dataset import Document
from models.model import UploadFile
from models.source import DataSourceBinding from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
...@@ -39,9 +29,35 @@ PREVIEW_WORDS_LIMIT = 3000 ...@@ -39,9 +29,35 @@ PREVIEW_WORDS_LIMIT = 3000
class DataSourceApi(Resource): class DataSourceApi(Resource):
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.String,
'total': fields.Integer
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(integrate_list_fields)
def get(self): def get(self):
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = db.session.query(DataSourceBinding).filter( data_source_integrates = db.session.query(DataSourceBinding).filter(
...@@ -76,8 +92,7 @@ class DataSourceApi(Resource): ...@@ -76,8 +92,7 @@ class DataSourceApi(Resource):
'disabled': None, 'disabled': None,
'link': f'{base_url}{data_source_oauth_base_path}/{provider}' 'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
}) })
return {'data': integrate_data}, 200
return {'data': integrate_data}
@setup_required @setup_required
@login_required @login_required
...@@ -110,10 +125,25 @@ class DataSourceApi(Resource): ...@@ -110,10 +125,25 @@ class DataSourceApi(Resource):
class DataSourceNotionListApi(Resource): class DataSourceNotionListApi(Resource):
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
def get(self): def get(self):
dataset_id = request.args.get('dataset_id', default=None, type=str) dataset_id = request.args.get('dataset_id', default=None, type=str)
exist_page_ids = [] exist_page_ids = []
...@@ -143,9 +173,14 @@ class DataSourceNotionListApi(Resource): ...@@ -143,9 +173,14 @@ class DataSourceNotionListApi(Resource):
raise NotFound('Data source binding not found.') raise NotFound('Data source binding not found.')
pre_import_info_list = [] pre_import_info_list = []
for data_source_binding in data_source_bindings: for data_source_binding in data_source_bindings:
pages = NotionOAuth.get_authorized_pages(data_source_binding.access_token) notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'),
client_secret=current_app.config.get(
'NOTION_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_URL') + '/console/api/oauth/data-source/authorize/notion')
pages = notion_oauth.get_authorized_pages(data_source_binding.access_token)
# Filter out already bound pages # Filter out already bound pages
filter_pages = filter(lambda page: page['page_id'] not in exist_page_ids, pages) filter_pages = [page for page in pages if page['page_id'] not in exist_page_ids]
source_info = json.loads(data_source_binding.source_info) source_info = json.loads(data_source_binding.source_info)
pre_import_info = { pre_import_info = {
'workspace_name': source_info['workspace_name'], 'workspace_name': source_info['workspace_name'],
...@@ -165,12 +200,14 @@ class DataSourceNotionApi(Resource): ...@@ -165,12 +200,14 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, workspace_id, page_id): def get(self, workspace_id, page_id):
workspace_id = str(workspace_id)
page_id = str(page_id)
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion', DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False, DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
) )
).first() ).first()
if not data_source_binding: if not data_source_binding:
...@@ -185,9 +222,8 @@ class DataSourceNotionApi(Resource): ...@@ -185,9 +222,8 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
notion_import_info = request.get_json()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
...@@ -197,7 +233,7 @@ class DataSourceNotionApi(Resource): ...@@ -197,7 +233,7 @@ class DataSourceNotionApi(Resource):
return response, 200 return response, 200
api.add_resource(DataSourceApi, '/oauth/data-source/integrates') api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceApi, '/oauth/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
api.add_resource(DataSourceNotionApi, '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/preview') api.add_resource(DataSourceNotionApi, '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/preview',
'/datasets/notion-indexing-estimate')
...@@ -251,7 +251,7 @@ class DatasetDocumentListApi(Resource): ...@@ -251,7 +251,7 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
dataset_and_document_fields = { dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields), 'dataset': fields.Nested(dataset_fields),
'document': fields.Nested(document_fields) 'documents': fields.List(fields.Nested(document_fields))
} }
@setup_required @setup_required
......
...@@ -126,6 +126,7 @@ class NotionPageReader(BaseReader): ...@@ -126,6 +126,7 @@ class NotionPageReader(BaseReader):
cur_result_text_arr.append(children_text) cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr) cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
result_lines_arr.append(cur_result_text) result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None: if data["next_cursor"] is None:
...@@ -204,11 +205,11 @@ class NotionPageReader(BaseReader): ...@@ -204,11 +205,11 @@ class NotionPageReader(BaseReader):
page_ids = self.query_database(database_id) page_ids = self.query_database(database_id)
for page_id in page_ids: for page_id in page_ids:
page_text = self.read_page(page_id) page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id})) docs.append(Document(page_text))
else: else:
for page_id in page_ids: for page_id in page_ids:
page_text = self.read_page(page_id) page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id})) docs.append(Document(page_text))
return docs return docs
...@@ -223,12 +224,12 @@ class NotionPageReader(BaseReader): ...@@ -223,12 +224,12 @@ class NotionPageReader(BaseReader):
page_ids = self.query_database(database_id) page_ids = self.query_database(database_id)
for page_id in page_ids: for page_id in page_ids:
page_text = self.read_page(page_id) page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id})) docs.append(Document(page_text))
else: else:
for page_id in page_ids: for page_id in page_ids:
page_text_list = self.read_page_as_documents(page_id) page_text_list = self.read_page_as_documents(page_id)
for page_text in page_text_list: for page_text in page_text_list:
docs.append(Document(page_text, extra_info={"page_id": page_id})) docs.append(Document(page_text))
return docs return docs
......
...@@ -215,12 +215,13 @@ class IndexingRunner: ...@@ -215,12 +215,13 @@ class IndexingRunner:
preview_texts = [] preview_texts = []
total_segments = 0 total_segments = 0
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion', DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False, DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == notion_info['workspace_id'] DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
) )
).first() ).first()
if not data_source_binding: if not data_source_binding:
...@@ -228,7 +229,7 @@ class IndexingRunner: ...@@ -228,7 +229,7 @@ class IndexingRunner:
reader = NotionPageReader(integration_token=data_source_binding.access_token) reader = NotionPageReader(integration_token=data_source_binding.access_token)
for page in notion_info['pages']: for page in notion_info['pages']:
page_ids = [page['page_id']] page_ids = [page['page_id']]
documents = reader.load_data(page_ids=page_ids) documents = reader.load_data_as_documents(page_ids=page_ids)
processing_rule = DatasetProcessRule( processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], mode=tmp_processing_rule["mode"],
...@@ -279,7 +280,7 @@ class IndexingRunner: ...@@ -279,7 +280,7 @@ class IndexingRunner:
if not data_source_info or 'notion_page_id' not in data_source_info \ if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info: or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found") raise ValueError("no notion page found")
text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id']) text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id'], document.tenant_id)
# update document status to splitting # update document status to splitting
self._update_document_index_status( self._update_document_index_status(
document_id=document.id, document_id=document.id,
...@@ -319,13 +320,13 @@ class IndexingRunner: ...@@ -319,13 +320,13 @@ class IndexingRunner:
return text_docs return text_docs
def _load_data_from_notion(self, workspace_id: str, page_id: str) -> List[Document]: def _load_data_from_notion(self, workspace_id: str, page_id: str, tenant_id: str) -> List[Document]:
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceBinding.query.filter(
db.and_( db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion', DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False, DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
) )
).first() ).first()
if not data_source_binding: if not data_source_binding:
......
...@@ -65,7 +65,7 @@ class NotionOAuth(OAuthDataSource): ...@@ -65,7 +65,7 @@ class NotionOAuth(OAuthDataSource):
data_source_binding = DataSourceBinding( data_source_binding = DataSourceBinding(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
access_token=access_token, access_token=access_token,
source_info=json.dumps(source_info), source_info=source_info,
provider='notion' provider='notion'
) )
db.session.add(data_source_binding) db.session.add(data_source_binding)
......
...@@ -190,7 +190,7 @@ class Document(db.Model): ...@@ -190,7 +190,7 @@ class Document(db.Model):
doc_type = db.Column(db.String(40), nullable=True) doc_type = db.Column(db.String(40), nullable=True)
doc_metadata = db.Column(db.JSON, nullable=True) doc_metadata = db.Column(db.JSON, nullable=True)
DATA_SOURCES = ['upload_file'] DATA_SOURCES = ['upload_file', 'notion_import']
@property @property
def display_status(self): def display_status(self):
......
...@@ -7,7 +7,8 @@ class DataSourceBinding(db.Model): ...@@ -7,7 +7,8 @@ class DataSourceBinding(db.Model):
__tablename__ = 'data_source_bindings' __tablename__ = 'data_source_bindings'
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='source_binding_pkey'), db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
db.Index('source_binding_tenant_id_idx', 'tenant_id') db.Index('source_binding_tenant_id_idx', 'tenant_id'),
db.Index('source_info_idx', "source_info", postgresql_using='gin')
) )
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
......
...@@ -423,7 +423,7 @@ class DocumentService: ...@@ -423,7 +423,7 @@ class DocumentService:
DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion', DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False, DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
) )
).first() ).first()
if not data_source_binding: if not data_source_binding:
...@@ -581,7 +581,7 @@ class DocumentService: ...@@ -581,7 +581,7 @@ class DocumentService:
if 'notion_info_list' not in args or not args['notion_info_list']: if 'notion_info_list' not in args or not args['notion_info_list']:
raise ValueError("Notion info is required") raise ValueError("Notion info is required")
if not isinstance(args['notion_info_list'], dict): if not isinstance(args['notion_info_list'], list):
raise ValueError("Notion info is invalid") raise ValueError("Notion info is invalid")
if 'process_rule' not in args or not args['process_rule']: if 'process_rule' not in args or not args['process_rule']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment