Commit dbd2babb authored by Jyong's avatar Jyong

notion index estimate

parent d834dece
...@@ -17,6 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles ...@@ -17,6 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
UnsupportedFileTypeError UnsupportedFileTypeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader
from core.index.readers.html_parser import HTMLParser from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser from core.index.readers.pdf_parser import PDFParser
from extensions.ext_storage import storage from extensions.ext_storage import storage
...@@ -107,7 +108,7 @@ class DataSourceApi(Resource): ...@@ -107,7 +108,7 @@ class DataSourceApi(Resource):
return {'result': 'success'}, 200 return {'result': 'success'}, 200
class DataSourceNotionApi(Resource): class DataSourceNotionListApi(Resource):
@setup_required @setup_required
@login_required @login_required
...@@ -157,7 +158,40 @@ class DataSourceNotionApi(Resource): ...@@ -157,7 +158,40 @@ class DataSourceNotionApi(Resource):
}, 200 }, 200
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, workspace_id, page_id):
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id
)
).first()
if not data_source_binding:
raise NotFound('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
page_content = reader.read_page(page_id)
return {
'content': page_content
}, 200
@setup_required
@login_required
@account_initialization_required
def post(self):
segment_rule = request.get_json()
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(file_detail, segment_rule['process_rule'])
return response, 200
api.add_resource(DataSourceApi, '/oauth/data-source/integrates') api.add_resource(DataSourceApi, '/oauth/data-source/integrates')
api.add_resource(DataSourceApi, '/oauth/data-source/integrates/<uuid:binding_id>/<string:action>') api.add_resource(DataSourceApi, '/oauth/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceNotionApi, '/notion/pre-import/pages') api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
api.add_resource(DataSourceNotionApi, '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/preview')
"""Notion reader."""
import logging
import os
from typing import Any, Dict, List, Optional
import requests # type: ignore
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search"
logger = logging.getLogger(__name__)
# TODO: Notion DB reader coming soon!
class NotionPageReader(BaseReader):
"""Notion Page reader.
Reads a set of Notion pages.
Args:
integration_token (str): Notion integration token.
"""
def __init__(self, integration_token: Optional[str] = None) -> None:
"""Initialize with parameters."""
if integration_token is None:
integration_token = os.getenv(INTEGRATION_TOKEN_NAME)
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment "
"variable `NOTION_INTEGRATION_TOKEN`."
)
self.token = integration_token
self.headers = {
"Authorization": "Bearer " + self.token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
}
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block."""
done = False
result_lines_arr = []
cur_block_id = block_id
while not done:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict
)
data = res.json()
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
cur_result_text_arr = []
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:
# skip if doesn't have text object
if "text" in rich_text:
text = rich_text["text"]["content"]
prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
result_block_id = result["id"]
has_children = result["has_children"]
if has_children:
children_text = self._read_block(
result_block_id, num_tabs=num_tabs + 1
)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
result_lines = "\n".join(result_lines_arr)
return result_lines
def read_page(self, page_id: str) -> str:
"""Read a page."""
return self._read_block(page_id)
def query_database(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> List[str]:
"""Get all the pages from a Notion database."""
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
headers=self.headers,
json=query_dict,
)
data = res.json()
page_ids = []
for result in data["results"]:
page_id = result["id"]
page_ids.append(page_id)
return page_ids
def search(self, query: str) -> List[str]:
"""Search Notion page given a text query."""
done = False
next_cursor: Optional[str] = None
page_ids = []
while not done:
query_dict = {
"query": query,
}
if next_cursor is not None:
query_dict["start_cursor"] = next_cursor
res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict)
data = res.json()
for result in data["results"]:
page_id = result["id"]
page_ids.append(page_id)
if data["next_cursor"] is None:
done = True
break
else:
next_cursor = data["next_cursor"]
return page_ids
def load_data(
self, page_ids: List[str] = [], database_id: Optional[str] = None
) -> List[Document]:
"""Load data from the input directory.
Args:
page_ids (List[str]): List of page ids to load.
Returns:
List[Document]: List of documents.
"""
if not page_ids and not database_id:
raise ValueError("Must specify either `page_ids` or `database_id`.")
docs = []
if database_id is not None:
# get all the pages in the database
page_ids = self.query_database(database_id)
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id}))
else:
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text, extra_info={"page_id": page_id}))
return docs
if __name__ == "__main__":
reader = NotionPageReader()
logger.info(reader.search("What I"))
...@@ -201,6 +201,43 @@ class IndexingRunner: ...@@ -201,6 +201,43 @@ class IndexingRunner:
"preview": preview_texts "preview": preview_texts
} }
def notion_indexing_estimate(self, notion_info, tmp_processing_rule: dict) -> dict:
"""
Estimate the indexing for the document.
"""
# load data from file
text_docs = self._load_data_from_file(file_detail)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
text_docs=text_docs,
node_parser=node_parser,
processing_rule=processing_rule
)
tokens = 0
preview_texts = []
for node in nodes:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
return {
"total_segments": len(nodes),
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"preview": preview_texts
}
def _load_data(self, document: Document) -> List[Document]: def _load_data(self, document: Document) -> List[Document]:
# load file # load file
if document.data_source_type != "upload_file": if document.data_source_type != "upload_file":
......
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db from extensions.ext_database import db
from sqlalchemy.dialects.postgresql import JSONB
class DataSourceBinding(db.Model): class DataSourceBinding(db.Model):
__tablename__ = 'data_source_bindings' __tablename__ = 'data_source_bindings'
...@@ -14,7 +14,7 @@ class DataSourceBinding(db.Model): ...@@ -14,7 +14,7 @@ class DataSourceBinding(db.Model):
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(UUID, nullable=False)
access_token = db.Column(db.String(255), nullable=False) access_token = db.Column(db.String(255), nullable=False)
provider = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False)
source_info = db.Column(db.Text, nullable=False) source_info = db.Column(JSONB, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment