Commit e2ef272f authored by Jyong's avatar Jyong

support notion import documents

parent 201d9943
...@@ -220,7 +220,7 @@ class DatasetDocumentListApi(Resource): ...@@ -220,7 +220,7 @@ class DatasetDocumentListApi(Resource):
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
try: try:
document = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
raise ProviderNotInitializeError() raise ProviderNotInitializeError()
except QuotaExceededError: except QuotaExceededError:
...@@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource): ...@@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
return document return documents
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
...@@ -257,7 +257,7 @@ class DatasetInitApi(Resource): ...@@ -257,7 +257,7 @@ class DatasetInitApi(Resource):
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
try: try:
dataset, document = DocumentService.save_document_without_dataset_id( dataset, documents = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
document_data=args, document_data=args,
account=current_user account=current_user
...@@ -271,7 +271,7 @@ class DatasetInitApi(Resource): ...@@ -271,7 +271,7 @@ class DatasetInitApi(Resource):
response = { response = {
'dataset': dataset, 'dataset': dataset,
'document': document 'documents': documents
} }
return response return response
......
...@@ -69,12 +69,16 @@ class DocumentListApi(DatasetApiResource): ...@@ -69,12 +69,16 @@ class DocumentListApi(DatasetApiResource):
document_data = { document_data = {
'data_source': { 'data_source': {
'type': 'upload_file', 'type': 'upload_file',
'info': upload_file.id 'info': [
{
'upload_file_id': upload_file.id
}
]
} }
} }
try: try:
document = DocumentService.save_document_with_dataset_id( documents = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=document_data, document_data=document_data,
account=dataset.created_by_account, account=dataset.created_by_account,
...@@ -83,7 +87,7 @@ class DocumentListApi(DatasetApiResource): ...@@ -83,7 +87,7 @@ class DocumentListApi(DatasetApiResource):
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
raise ProviderNotInitializeError() raise ProviderNotInitializeError()
document = documents[0]
if doc_type and doc_metadata: if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
......
...@@ -38,42 +38,43 @@ class IndexingRunner: ...@@ -38,42 +38,43 @@ class IndexingRunner:
self.storage = storage self.storage = storage
self.embedding_model_name = embedding_model_name self.embedding_model_name = embedding_model_name
def run(self, document: Document): def run(self, documents: List[Document]):
"""Run the indexing process.""" """Run the indexing process."""
# get dataset for document in documents:
dataset = Dataset.query.filter_by( # get dataset
id=document.dataset_id dataset = Dataset.query.filter_by(
).first() id=document.dataset_id
).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# load file # load file
text_docs = self._load_data(document) text_docs = self._load_data(document)
# get the process rule # get the process rule
processing_rule = db.session.query(DatasetProcessRule). \ processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
first() first()
# get node parser for splitting # get node parser for splitting
node_parser = self._get_node_parser(processing_rule) node_parser = self._get_node_parser(processing_rule)
# split to nodes # split to nodes
nodes = self._step_split( nodes = self._step_split(
text_docs=text_docs, text_docs=text_docs,
node_parser=node_parser, node_parser=node_parser,
dataset=dataset, dataset=dataset,
document=document, document=document,
processing_rule=processing_rule processing_rule=processing_rule
) )
# build index # build index
self._build_index( self._build_index(
dataset=dataset, dataset=dataset,
document=document, document=document,
nodes=nodes nodes=nodes
) )
def run_in_splitting_status(self, document: Document): def run_in_splitting_status(self, document: Document):
"""Run the indexing process when the index_status is splitting.""" """Run the indexing process when the index_status is splitting."""
...@@ -362,7 +363,7 @@ class IndexingRunner: ...@@ -362,7 +363,7 @@ class IndexingRunner:
embedding_model_name=self.embedding_model_name, embedding_model_name=self.embedding_model_name,
document_id=document.id document_id=document.id
) )
# add document segments
doc_store.add_documents(nodes) doc_store.add_documents(nodes)
# update document status to indexing # update document status to indexing
......
...@@ -14,6 +14,7 @@ from extensions.ext_database import db ...@@ -14,6 +14,7 @@ from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceBinding
from services.errors.account import NoPermissionError from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError from services.errors.document import DocumentIndexingError
...@@ -374,47 +375,85 @@ class DocumentService: ...@@ -374,47 +375,85 @@ class DocumentService:
) )
db.session.add(dataset_process_rule) db.session.add(dataset_process_rule)
db.session.commit() db.session.commit()
position = DocumentService.get_documents_position(dataset.id)
file_name = '' document_ids = []
data_source_info = {} documents = []
if document_data["data_source"]["type"] == "upload_file": if document_data["data_source"]["type"] == "upload_file":
file_id = document_data["data_source"]["info"] upload_file_list = document_data["data_source"]["info"]
file = db.session.query(UploadFile).filter( for upload_file in upload_file_list:
UploadFile.tenant_id == dataset.tenant_id, file_id = upload_file["upload_file_id"]
UploadFile.id == file_id file = db.session.query(UploadFile).filter(
).first() UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id
# raise error if file not found ).first()
if not file:
raise FileNotExistsError() # raise error if file not found
if not file:
file_name = file.name raise FileNotExistsError()
data_source_info = {
"upload_file_id": file_id, file_name = file.name
} data_source_info = {
"upload_file_id": file_id,
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, file_name)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
for page in notion_info['pages']:
data_source_info = {
"notion_page_id": page['page_id'],
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, page['page_name'])
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
# save document db.session.commit()
position = DocumentService.get_documents_position(dataset.id)
# trigger async task
document_indexing_task.delay(dataset.id, document_ids)
return documents
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
created_from: str, position: int, account: Account, name: str):
document = Document( document = Document(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
position=position, position=position,
data_source_type=document_data["data_source"]["type"], data_source_type=data_source_type,
data_source_info=json.dumps(data_source_info), data_source_info=json.dumps(data_source_info),
dataset_process_rule_id=dataset_process_rule.id, dataset_process_rule_id=process_rule_id,
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)), batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
name=file_name, name=name,
created_from=created_from, created_from=created_from,
created_by=account.id, created_by=account.id,
# created_api_request_id = db.Column(UUID, nullable=True)
) )
db.session.add(document)
db.session.commit()
# trigger async task
document_indexing_task.delay(document.dataset_id, document.id)
return document return document
@staticmethod @staticmethod
...@@ -431,15 +470,15 @@ class DocumentService: ...@@ -431,15 +470,15 @@ class DocumentService:
db.session.add(dataset) db.session.add(dataset)
db.session.flush() db.session.flush()
document = DocumentService.save_document_with_dataset_id(dataset, document_data, account) documents = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
cut_length = 18 cut_length = 18
cut_name = document.name[:cut_length] cut_name = documents[0].name[:cut_length]
dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name dataset.name = cut_name + '...' if len(documents[0].name) > cut_length else cut_name
dataset.description = 'useful for when you want to answer queries about the ' + document.name dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
db.session.commit() db.session.commit()
return dataset, document return dataset, documents
@classmethod @classmethod
def document_create_args_validate(cls, args: dict): def document_create_args_validate(cls, args: dict):
......
...@@ -13,32 +13,36 @@ from models.dataset import Document ...@@ -13,32 +13,36 @@ from models.dataset import Document
@shared_task @shared_task
def document_indexing_task(dataset_id: str, document_id: str): def document_indexing_task(dataset_id: str, document_ids: list):
""" """
Async process document Async process document
:param dataset_id: :param dataset_id:
:param document_id: :param document_ids:
Usage: document_indexing_task.delay(dataset_id, document_id) Usage: document_indexing_task.delay(dataset_id, document_id)
""" """
logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) documents = []
start_at = time.perf_counter() for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
document = db.session.query(Document).filter( start_at = time.perf_counter()
Document.id == document_id,
Document.dataset_id == dataset_id document = db.session.query(Document).filter(
).first() Document.id == document_id,
Document.dataset_id == dataset_id
if not document: ).first()
raise NotFound('Document not found')
if not document:
document.indexing_status = 'parsing' raise NotFound('Document not found')
document.processing_started_at = datetime.datetime.utcnow()
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
documents.append(document)
db.session.add(document)
db.session.commit() db.session.commit()
try: try:
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run(document) indexing_runner.run(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment