Commit e2ef272f authored by Jyong's avatar Jyong

support notion import documents

parent 201d9943
...@@ -220,7 +220,7 @@ class DatasetDocumentListApi(Resource): ...@@ -220,7 +220,7 @@ class DatasetDocumentListApi(Resource):
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
try: try:
document = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
raise ProviderNotInitializeError() raise ProviderNotInitializeError()
except QuotaExceededError: except QuotaExceededError:
...@@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource): ...@@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
return document return documents
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
...@@ -257,7 +257,7 @@ class DatasetInitApi(Resource): ...@@ -257,7 +257,7 @@ class DatasetInitApi(Resource):
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
try: try:
dataset, document = DocumentService.save_document_without_dataset_id( dataset, documents = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
document_data=args, document_data=args,
account=current_user account=current_user
...@@ -271,7 +271,7 @@ class DatasetInitApi(Resource): ...@@ -271,7 +271,7 @@ class DatasetInitApi(Resource):
response = { response = {
'dataset': dataset, 'dataset': dataset,
'document': document 'documents': documents
} }
return response return response
......
...@@ -69,12 +69,16 @@ class DocumentListApi(DatasetApiResource): ...@@ -69,12 +69,16 @@ class DocumentListApi(DatasetApiResource):
document_data = { document_data = {
'data_source': { 'data_source': {
'type': 'upload_file', 'type': 'upload_file',
'info': upload_file.id 'info': [
{
'upload_file_id': upload_file.id
}
]
} }
} }
try: try:
document = DocumentService.save_document_with_dataset_id( documents = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=document_data, document_data=document_data,
account=dataset.created_by_account, account=dataset.created_by_account,
...@@ -83,7 +87,7 @@ class DocumentListApi(DatasetApiResource): ...@@ -83,7 +87,7 @@ class DocumentListApi(DatasetApiResource):
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
raise ProviderNotInitializeError() raise ProviderNotInitializeError()
document = documents[0]
if doc_type and doc_metadata: if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
......
...@@ -38,8 +38,9 @@ class IndexingRunner: ...@@ -38,8 +38,9 @@ class IndexingRunner:
self.storage = storage self.storage = storage
self.embedding_model_name = embedding_model_name self.embedding_model_name = embedding_model_name
def run(self, document: Document): def run(self, documents: List[Document]):
"""Run the indexing process.""" """Run the indexing process."""
for document in documents:
# get dataset # get dataset
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=document.dataset_id id=document.dataset_id
...@@ -362,7 +363,7 @@ class IndexingRunner: ...@@ -362,7 +363,7 @@ class IndexingRunner:
embedding_model_name=self.embedding_model_name, embedding_model_name=self.embedding_model_name,
document_id=document.id document_id=document.id
) )
# add document segments
doc_store.add_documents(nodes) doc_store.add_documents(nodes)
# update document status to indexing # update document status to indexing
......
...@@ -14,6 +14,7 @@ from extensions.ext_database import db ...@@ -14,6 +14,7 @@ from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceBinding
from services.errors.account import NoPermissionError from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError from services.errors.document import DocumentIndexingError
...@@ -374,11 +375,13 @@ class DocumentService: ...@@ -374,11 +375,13 @@ class DocumentService:
) )
db.session.add(dataset_process_rule) db.session.add(dataset_process_rule)
db.session.commit() db.session.commit()
position = DocumentService.get_documents_position(dataset.id)
file_name = '' document_ids = []
data_source_info = {} documents = []
if document_data["data_source"]["type"] == "upload_file": if document_data["data_source"]["type"] == "upload_file":
file_id = document_data["data_source"]["info"] upload_file_list = document_data["data_source"]["info"]
for upload_file in upload_file_list:
file_id = upload_file["upload_file_id"]
file = db.session.query(UploadFile).filter( file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == dataset.tenant_id, UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id UploadFile.id == file_id
...@@ -392,29 +395,65 @@ class DocumentService: ...@@ -392,29 +395,65 @@ class DocumentService:
data_source_info = { data_source_info = {
"upload_file_id": file_id, "upload_file_id": file_id,
} }
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, file_name)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
for page in notion_info['pages']:
data_source_info = {
"notion_page_id": page['page_id'],
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, page['page_name'])
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
# save document db.session.commit()
position = DocumentService.get_documents_position(dataset.id)
# trigger async task
document_indexing_task.delay(dataset.id, document_ids)
return documents
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
created_from: str, position: int, account: Account, name: str):
document = Document( document = Document(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
position=position, position=position,
data_source_type=document_data["data_source"]["type"], data_source_type=data_source_type,
data_source_info=json.dumps(data_source_info), data_source_info=json.dumps(data_source_info),
dataset_process_rule_id=dataset_process_rule.id, dataset_process_rule_id=process_rule_id,
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)), batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
name=file_name, name=name,
created_from=created_from, created_from=created_from,
created_by=account.id, created_by=account.id,
# created_api_request_id = db.Column(UUID, nullable=True)
) )
db.session.add(document)
db.session.commit()
# trigger async task
document_indexing_task.delay(document.dataset_id, document.id)
return document return document
@staticmethod @staticmethod
...@@ -431,15 +470,15 @@ class DocumentService: ...@@ -431,15 +470,15 @@ class DocumentService:
db.session.add(dataset) db.session.add(dataset)
db.session.flush() db.session.flush()
document = DocumentService.save_document_with_dataset_id(dataset, document_data, account) documents = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
cut_length = 18 cut_length = 18
cut_name = document.name[:cut_length] cut_name = documents[0].name[:cut_length]
dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name dataset.name = cut_name + '...' if len(documents[0].name) > cut_length else cut_name
dataset.description = 'useful for when you want to answer queries about the ' + document.name dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
db.session.commit() db.session.commit()
return dataset, document return dataset, documents
@classmethod @classmethod
def document_create_args_validate(cls, args: dict): def document_create_args_validate(cls, args: dict):
......
...@@ -13,14 +13,16 @@ from models.dataset import Document ...@@ -13,14 +13,16 @@ from models.dataset import Document
@shared_task @shared_task
def document_indexing_task(dataset_id: str, document_id: str): def document_indexing_task(dataset_id: str, document_ids: list):
""" """
Async process document Async process document
:param dataset_id: :param dataset_id:
:param document_id: :param document_ids:
Usage: document_indexing_task.delay(dataset_id, document_id) Usage: document_indexing_task.delay(dataset_id, document_id)
""" """
documents = []
for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
start_at = time.perf_counter() start_at = time.perf_counter()
...@@ -34,11 +36,13 @@ def document_indexing_task(dataset_id: str, document_id: str): ...@@ -34,11 +36,13 @@ def document_indexing_task(dataset_id: str, document_id: str):
document.indexing_status = 'parsing' document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow() document.processing_started_at = datetime.datetime.utcnow()
documents.append(document)
db.session.add(document)
db.session.commit() db.session.commit()
try: try:
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run(document) indexing_runner.run(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment