Unverified Commit a55ba6e6 authored by Jyong's avatar Jyong Committed by GitHub

Fix/ignore economy dataset (#1043)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent f9bec1ed
...@@ -92,11 +92,14 @@ class DatasetListApi(Resource): ...@@ -92,11 +92,14 @@ class DatasetListApi(Resource):
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields) data = marshal(datasets, dataset_detail_fields)
for item in data: for item in data:
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item['indexing_technique'] == 'high_quality':
if item_model in model_names: item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
item['embedding_available'] = True if item_model in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
else: else:
item['embedding_available'] = False item['embedding_available'] = True
response = { response = {
'data': data, 'data': data,
'has_more': len(datasets) == limit, 'has_more': len(datasets) == limit,
...@@ -122,14 +125,6 @@ class DatasetListApi(Resource): ...@@ -122,14 +125,6 @@ class DatasetListApi(Resource):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
...@@ -167,6 +162,11 @@ class DatasetApi(Resource): ...@@ -167,6 +162,11 @@ class DatasetApi(Resource):
@account_initialization_required @account_initialization_required
def patch(self, dataset_id): def patch(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, parser.add_argument('name', nullable=False,
...@@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
...@@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource):
try: try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'], args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id']) args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
...@@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource):
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'], args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'], args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id']) args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
......
...@@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource): ...@@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource):
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
...@@ -339,15 +325,17 @@ class DatasetInitApi(Resource): ...@@ -339,15 +325,17 @@ class DatasetInitApi(Resource):
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json') location='json')
args = parser.parse_args() args = parser.parse_args()
if args['indexing_technique'] == 'high_quality':
try: try:
ModelFactory.get_embedding_model( ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id tenant_id=current_user.current_tenant_id
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
...@@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource): ...@@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource):
def delete(self, dataset_id, document_id): def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
try: try:
...@@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource): ...@@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource):
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
......
...@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
...@@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
if dataset.indexing_technique == 'high_quality':
# check embedding model setting # check embedding model setting
try: try:
ModelFactory.get_embedding_model( ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider, model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_name=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter( segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.id == str(segment_id),
...@@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource): ...@@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
# check embedding model setting # check embedding model setting
try: if dataset.indexing_technique == 'high_quality':
ModelFactory.get_embedding_model( try:
tenant_id=current_user.current_tenant_id, ModelFactory.get_embedding_model(
model_provider_name=dataset.embedding_model_provider, tenant_id=current_user.current_tenant_id,
model_name=dataset.embedding_model model_provider_name=dataset.embedding_model_provider,
) model_name=dataset.embedding_model
except LLMBadRequestError: )
raise ProviderNotInitializeError( except LLMBadRequestError:
f"No Embedding Model available. Please configure a valid provider " raise ProviderNotInitializeError(
f"in the Settings -> Model Provider.") f"No Embedding Model available. Please configure a valid provider "
except ProviderTokenNotInitError as ex: f"in the Settings -> Model Provider.")
raise ProviderNotInitializeError(ex.description) except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try: try:
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
...@@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource): ...@@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document # check document
document_id = str(document_id) document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound('Document not found.') raise NotFound('Document not found.')
# check embedding model setting if dataset.indexing_technique == 'high_quality':
try: # check embedding model setting
ModelFactory.get_embedding_model( try:
tenant_id=current_user.current_tenant_id, ModelFactory.get_embedding_model(
model_provider_name=dataset.embedding_model_provider, tenant_id=current_user.current_tenant_id,
model_name=dataset.embedding_model model_provider_name=dataset.embedding_model_provider,
) model_name=dataset.embedding_model
except LLMBadRequestError: )
raise ProviderNotInitializeError( except LLMBadRequestError:
f"No Embedding Model available. Please configure a valid provider " raise ProviderNotInitializeError(
f"in the Settings -> Model Provider.") f"No Embedding Model available. Please configure a valid provider "
except ProviderTokenNotInitError as ex: f"in the Settings -> Model Provider.")
raise ProviderNotInitializeError(ex.description) except ProviderTokenNotInitError as ex:
# check segment raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.id == str(segment_id),
...@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource): ...@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document # check document
document_id = str(document_id) document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
...@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource): ...@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound('Document not found.') raise NotFound('Document not found.')
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# get file from request # get file from request
file = request.files['file'] file = request.files['file']
# check file # check file
......
...@@ -67,12 +67,13 @@ class DatesetDocumentStore: ...@@ -67,12 +67,13 @@ class DatesetDocumentStore:
if max_position is None: if max_position is None:
max_position = 0 max_position = 0
embedding_model = None
embedding_model = ModelFactory.get_embedding_model( if self._dataset.indexing_technique == 'high_quality':
tenant_id=self._dataset.tenant_id, embedding_model = ModelFactory.get_embedding_model(
model_provider_name=self._dataset.embedding_model_provider, tenant_id=self._dataset.tenant_id,
model_name=self._dataset.embedding_model model_provider_name=self._dataset.embedding_model_provider,
) model_name=self._dataset.embedding_model
)
for doc in docs: for doc in docs:
if not isinstance(doc, Document): if not isinstance(doc, Document):
...@@ -88,7 +89,7 @@ class DatesetDocumentStore: ...@@ -88,7 +89,7 @@ class DatesetDocumentStore:
) )
# calc embedding use tokens # calc embedding use tokens
tokens = embedding_model.get_num_tokens(doc.page_content) tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
if not segment_document: if not segment_document:
max_position += 1 max_position += 1
......
import json
from flask import current_app from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.dataset import Dataset from models.dataset import Dataset
from models.provider import Provider, ProviderType
class IndexBuilder: class IndexBuilder:
...@@ -35,4 +43,13 @@ class IndexBuilder: ...@@ -35,4 +43,13 @@ class IndexBuilder:
) )
) )
else: else:
raise ValueError('Unknown indexing technique') raise ValueError('Unknown indexing technique')
\ No newline at end of file
@classmethod
def get_default_high_quality_index(cls, dataset: Dataset):
embeddings = OpenAIEmbeddings(openai_api_key=' ')
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
...@@ -217,25 +217,29 @@ class IndexingRunner: ...@@ -217,25 +217,29 @@ class IndexingRunner:
db.session.commit() db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict: doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
embedding_model = None
if dataset_id: if dataset_id:
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=dataset_id id=dataset_id
).first() ).first()
if not dataset: if not dataset:
raise ValueError('Dataset not found.') raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model( if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
tenant_id=dataset.tenant_id, embedding_model = ModelFactory.get_embedding_model(
model_provider_name=dataset.embedding_model_provider, tenant_id=dataset.tenant_id,
model_name=dataset.embedding_model model_provider_name=dataset.embedding_model_provider,
) model_name=dataset.embedding_model
)
else: else:
embedding_model = ModelFactory.get_embedding_model( if indexing_technique == 'high_quality':
tenant_id=tenant_id embedding_model = ModelFactory.get_embedding_model(
) tenant_id=tenant_id
)
tokens = 0 tokens = 0
preview_texts = [] preview_texts = []
total_segments = 0 total_segments = 0
...@@ -263,8 +267,8 @@ class IndexingRunner: ...@@ -263,8 +267,8 @@ class IndexingRunner:
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model:
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model( text_generation_model = ModelFactory.get_text_generation_model(
...@@ -286,32 +290,35 @@ class IndexingRunner: ...@@ -286,32 +290,35 @@ class IndexingRunner:
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)), "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency() if embedding_model else 'USD',
"preview": preview_texts "preview": preview_texts
} }
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict: doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
embedding_model = None
if dataset_id: if dataset_id:
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=dataset_id id=dataset_id
).first() ).first()
if not dataset: if not dataset:
raise ValueError('Dataset not found.') raise ValueError('Dataset not found.')
embedding_model = ModelFactory.get_embedding_model( if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
tenant_id=dataset.tenant_id, embedding_model = ModelFactory.get_embedding_model(
model_provider_name=dataset.embedding_model_provider, tenant_id=dataset.tenant_id,
model_name=dataset.embedding_model model_provider_name=dataset.embedding_model_provider,
) model_name=dataset.embedding_model
)
else: else:
embedding_model = ModelFactory.get_embedding_model( if indexing_technique == 'high_quality':
tenant_id=tenant_id embedding_model = ModelFactory.get_embedding_model(
) tenant_id=tenant_id
)
# load data from notion # load data from notion
tokens = 0 tokens = 0
preview_texts = [] preview_texts = []
...@@ -356,8 +363,8 @@ class IndexingRunner: ...@@ -356,8 +363,8 @@ class IndexingRunner:
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model:
tokens += embedding_model.get_num_tokens(document.page_content) tokens += embedding_model.get_num_tokens(document.page_content)
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
text_generation_model = ModelFactory.get_text_generation_model( text_generation_model = ModelFactory.get_text_generation_model(
...@@ -379,8 +386,8 @@ class IndexingRunner: ...@@ -379,8 +386,8 @@ class IndexingRunner:
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)), "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency() if embedding_model else 'USD',
"preview": preview_texts "preview": preview_texts
} }
...@@ -657,12 +664,13 @@ class IndexingRunner: ...@@ -657,12 +664,13 @@ class IndexingRunner:
""" """
vector_index = IndexBuilder.get_index(dataset, 'high_quality') vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy') keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = None
embedding_model = ModelFactory.get_embedding_model( if dataset.indexing_technique == 'high_quality':
tenant_id=dataset.tenant_id, embedding_model = ModelFactory.get_embedding_model(
model_provider_name=dataset.embedding_model_provider, tenant_id=dataset.tenant_id,
model_name=dataset.embedding_model model_provider_name=dataset.embedding_model_provider,
) model_name=dataset.embedding_model
)
# chunk nodes by chunk size # chunk nodes by chunk size
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
...@@ -672,11 +680,11 @@ class IndexingRunner: ...@@ -672,11 +680,11 @@ class IndexingRunner:
# check document is paused # check document is paused
self._check_document_paused_status(dataset_document.id) self._check_document_paused_status(dataset_document.id)
chunk_documents = documents[i:i + chunk_size] chunk_documents = documents[i:i + chunk_size]
if dataset.indexing_technique == 'high_quality' or embedding_model:
tokens += sum( tokens += sum(
embedding_model.get_num_tokens(document.page_content) embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents for document in chunk_documents
) )
# save vector index # save vector index
if vector_index: if vector_index:
......
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
from events.event_handlers.document_index_event import document_index_created from events.event_handlers.document_index_event import document_index_created
from tasks.clean_dataset_task import clean_dataset_task
import datetime import datetime
import logging import logging
import time import time
......
"""update_dataset_model_field_null_available
Revision ID: 4bcffcd64aa4
Revises: 853f9b9cd3b6
Create Date: 2023-08-28 20:58:50.077056
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4bcffcd64aa4'
down_revision = '853f9b9cd3b6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.alter_column('embedding_model',
existing_type=sa.VARCHAR(length=255),
nullable=True,
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
batch_op.alter_column('embedding_model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True,
existing_server_default=sa.text("'openai'::character varying"))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.alter_column('embedding_model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False,
existing_server_default=sa.text("'openai'::character varying"))
batch_op.alter_column('embedding_model',
existing_type=sa.VARCHAR(length=255),
nullable=False,
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
# ### end Alembic commands ###
...@@ -36,10 +36,8 @@ class Dataset(db.Model): ...@@ -36,10 +36,8 @@ class Dataset(db.Model):
updated_by = db.Column(UUID, nullable=True) updated_by = db.Column(UUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String( embedding_model = db.Column(db.String(255), nullable=True)
255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")) embedding_model_provider = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(
255), nullable=False, server_default=db.text("'openai'::character varying"))
@property @property
def dataset_keyword_table(self): def dataset_keyword_table(self):
......
This diff is collapsed.
...@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s ...@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
raise ValueError('Document is not available.') raise ValueError('Document is not available.')
document_segments = [] document_segments = []
for segment in content: embedding_model = None
content = segment['content'] if dataset.indexing_technique == 'high_quality':
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_name=dataset.embedding_model
) )
for segment in content:
content = segment['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens # calc embedding use tokens
tokens = embedding_model.get_num_tokens(content) tokens = embedding_model.get_num_tokens(content) if embedding_model else 0
max_position = db.session.query(func.max(DocumentSegment.position)).filter( max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == dataset_document.id DocumentSegment.document_id == dataset_document.id
).scalar() ).scalar()
......
...@@ -3,8 +3,10 @@ import time ...@@ -3,8 +3,10 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from flask import current_app
from core.index.index import IndexBuilder from core.index.index import IndexBuilder
from core.index.vector_index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
AppDatasetJoin, Document AppDatasetJoin, Document
...@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy') kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
if vector_index: if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try: try:
vector_index.delete() vector_index.delete()
except Exception: except Exception:
......
...@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
raise Exception('Dataset not found') raise Exception('Dataset not found')
if action == "remove": if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
index.delete() index.delete()
elif action == "add": elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter( dataset_documents = db.session.query(DatasetDocument).filter(
...@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents: if dataset_documents:
# save vector index # save vector index
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
documents = [] documents = []
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
# delete from vector index # delete from vector index
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment