Commit f1f5d45d authored by Jyong's avatar Jyong

support mutil files and notion pages

parent e2ef272f
...@@ -90,12 +90,61 @@ class NotionPageReader(BaseReader): ...@@ -90,12 +90,61 @@ class NotionPageReader(BaseReader):
result_lines = "\n".join(result_lines_arr) result_lines = "\n".join(result_lines_arr)
return result_lines return result_lines
def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]:
"""Read a block."""
done = False
result_lines_arr = []
cur_block_id = block_id
while not done:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict
)
data = res.json()
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
cur_result_text_arr = []
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:
# skip if doesn't have text object
if "text" in rich_text:
text = rich_text["text"]["content"]
prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
result_block_id = result["id"]
has_children = result["has_children"]
if has_children:
children_text = self._read_block(
result_block_id, num_tabs=num_tabs + 1
)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
return result_lines_arr
def read_page(self, page_id: str) -> str: def read_page(self, page_id: str) -> str:
"""Read a page.""" """Read a page."""
return self._read_block(page_id) return self._read_block(page_id)
def read_page_as_documents(self, page_id: str) -> List[str]:
"""Read a page as documents."""
return self._read_block(page_id)
def query_database( def query_database(
self, database_id: str, query_dict: Dict[str, Any] = {} self, database_id: str, query_dict: Dict[str, Any] = {}
) -> List[str]: ) -> List[str]:
"""Get all the pages from a Notion database.""" """Get all the pages from a Notion database."""
res = requests.post( res = requests.post(
...@@ -136,7 +185,7 @@ class NotionPageReader(BaseReader): ...@@ -136,7 +185,7 @@ class NotionPageReader(BaseReader):
return page_ids return page_ids
def load_data( def load_data(
self, page_ids: List[str] = [], database_id: Optional[str] = None self, page_ids: List[str] = [], database_id: Optional[str] = None
) -> List[Document]: ) -> List[Document]:
"""Load data from the input directory. """Load data from the input directory.
......
...@@ -252,7 +252,7 @@ class IndexingRunner: ...@@ -252,7 +252,7 @@ class IndexingRunner:
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
return { return {
"total_segments": len(total_segments), "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name), "currency": TokenCalculator.get_currency(self.embedding_model_name),
...@@ -261,25 +261,30 @@ class IndexingRunner: ...@@ -261,25 +261,30 @@ class IndexingRunner:
def _load_data(self, document: Document) -> List[Document]: def _load_data(self, document: Document) -> List[Document]:
# load file # load file
if document.data_source_type != "upload_file": if document.data_source_type not in ["upload_file", "notion_import"]:
return [] return []
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if not data_source_info or 'upload_file_id' not in data_source_info: text_docs = []
raise ValueError("no upload file found") if document.data_source_type == 'upload_file':
if not data_source_info or 'upload_file_id' not in data_source_info:
file_detail = db.session.query(UploadFile). \ raise ValueError("no upload file found")
filter(UploadFile.id == data_source_info['upload_file_id']). \
one_or_none() file_detail = db.session.query(UploadFile). \
filter(UploadFile.id == data_source_info['upload_file_id']). \
text_docs = self._load_data_from_file(file_detail) one_or_none()
text_docs = self._load_data_from_file(file_detail)
elif document.data_source_type == 'notion_import':
if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id'])
# update document status to splitting # update document status to splitting
self._update_document_index_status( self._update_document_index_status(
document_id=document.id, document_id=document.id,
after_indexing_status="splitting", after_indexing_status="splitting",
extra_update_params={ extra_update_params={
Document.file_id: file_detail.id,
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]), Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
Document.parsing_completed_at: datetime.datetime.utcnow() Document.parsing_completed_at: datetime.datetime.utcnow()
} }
...@@ -314,6 +319,22 @@ class IndexingRunner: ...@@ -314,6 +319,22 @@ class IndexingRunner:
return text_docs return text_docs
def _load_data_from_notion(self, workspace_id: str, page_id: str) -> List[Document]:
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
page_ids = [page_id]
reader = NotionPageReader(integration_token=data_source_binding.access_token)
text_docs = reader.load_data(page_ids=page_ids)
return text_docs
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
""" """
Get the NodeParser object according to the processing rule. Get the NodeParser object according to the processing rule.
......
...@@ -420,6 +420,7 @@ class DocumentService: ...@@ -420,6 +420,7 @@ class DocumentService:
raise ValueError('Data source binding not found.') raise ValueError('Data source binding not found.')
for page in notion_info['pages']: for page in notion_info['pages']:
data_source_info = { data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page['page_id'], "notion_page_id": page['page_id'],
} }
document = DocumentService.save_document(dataset, dataset_process_rule.id, document = DocumentService.save_document(dataset, dataset_process_rule.id,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment