Commit e2ef272f authored by Jyong's avatar Jyong

support notion import documents

parent 201d9943
......@@ -220,7 +220,7 @@ class DatasetDocumentListApi(Resource):
DocumentService.document_create_args_validate(args)
try:
document = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
documents = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
......@@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return document
return documents
class DatasetInitApi(Resource):
......@@ -257,7 +257,7 @@ class DatasetInitApi(Resource):
DocumentService.document_create_args_validate(args)
try:
dataset, document = DocumentService.save_document_without_dataset_id(
dataset, documents = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id,
document_data=args,
account=current_user
......@@ -271,7 +271,7 @@ class DatasetInitApi(Resource):
response = {
'dataset': dataset,
'document': document
'documents': documents
}
return response
......
......@@ -69,12 +69,16 @@ class DocumentListApi(DatasetApiResource):
document_data = {
'data_source': {
'type': 'upload_file',
'info': upload_file.id
'info': [
{
'upload_file_id': upload_file.id
}
]
}
}
try:
document = DocumentService.save_document_with_dataset_id(
documents = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=document_data,
account=dataset.created_by_account,
......@@ -83,7 +87,7 @@ class DocumentListApi(DatasetApiResource):
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
document = documents[0]
if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
......
......@@ -38,8 +38,9 @@ class IndexingRunner:
self.storage = storage
self.embedding_model_name = embedding_model_name
def run(self, document: Document):
def run(self, documents: List[Document]):
"""Run the indexing process."""
for document in documents:
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
......@@ -362,7 +363,7 @@ class IndexingRunner:
embedding_model_name=self.embedding_model_name,
document_id=document.id
)
# add document segments
doc_store.add_documents(nodes)
# update document status to indexing
......
......@@ -14,6 +14,7 @@ from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin
from models.model import UploadFile
from models.source import DataSourceBinding
from services.errors.account import NoPermissionError
from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
......@@ -374,11 +375,13 @@ class DocumentService:
)
db.session.add(dataset_process_rule)
db.session.commit()
file_name = ''
data_source_info = {}
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
documents = []
if document_data["data_source"]["type"] == "upload_file":
file_id = document_data["data_source"]["info"]
upload_file_list = document_data["data_source"]["info"]
for upload_file in upload_file_list:
file_id = upload_file["upload_file_id"]
file = db.session.query(UploadFile).filter(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id == file_id
......@@ -392,29 +395,65 @@ class DocumentService:
data_source_info = {
"upload_file_id": file_id,
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, file_name)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]['info']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == workspace_id
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
for page in notion_info['pages']:
data_source_info = {
"notion_page_id": page['page_id'],
}
document = DocumentService.save_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
data_source_info, created_from, position,
account, page['page_name'])
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
# save document
position = DocumentService.get_documents_position(dataset.id)
db.session.commit()
# trigger async task
document_indexing_task.delay(dataset.id, document_ids)
return documents
@staticmethod
def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict,
created_from: str, position: int, account: Account, name: str):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type=document_data["data_source"]["type"],
data_source_type=data_source_type,
data_source_info=json.dumps(data_source_info),
dataset_process_rule_id=dataset_process_rule.id,
dataset_process_rule_id=process_rule_id,
batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)),
name=file_name,
name=name,
created_from=created_from,
created_by=account.id,
# created_api_request_id = db.Column(UUID, nullable=True)
)
db.session.add(document)
db.session.commit()
# trigger async task
document_indexing_task.delay(document.dataset_id, document.id)
return document
@staticmethod
......@@ -431,15 +470,15 @@ class DocumentService:
db.session.add(dataset)
db.session.flush()
document = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
documents = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
cut_length = 18
cut_name = document.name[:cut_length]
dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name
dataset.description = 'useful for when you want to answer queries about the ' + document.name
cut_name = documents[0].name[:cut_length]
dataset.name = cut_name + '...' if len(documents[0].name) > cut_length else cut_name
dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name
db.session.commit()
return dataset, document
return dataset, documents
@classmethod
def document_create_args_validate(cls, args: dict):
......
......@@ -13,14 +13,16 @@ from models.dataset import Document
@shared_task
def document_indexing_task(dataset_id: str, document_id: str):
def document_indexing_task(dataset_id: str, document_ids: list):
"""
Async process document
:param dataset_id:
:param document_id:
:param document_ids:
Usage: document_indexing_task.delay(dataset_id, document_id)
"""
documents = []
for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
start_at = time.perf_counter()
......@@ -34,11 +36,13 @@ def document_indexing_task(dataset_id: str, document_id: str):
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(document)
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment