Unverified Commit a55ba6e6 authored by Jyong's avatar Jyong Committed by GitHub

Fix/ignore economy dataset (#1043)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent f9bec1ed
...@@ -92,11 +92,14 @@ class DatasetListApi(Resource): ...@@ -92,11 +92,14 @@ class DatasetListApi(Resource):
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields) data = marshal(datasets, dataset_detail_fields)
for item in data: for item in data:
if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names: if item_model in model_names:
item['embedding_available'] = True item['embedding_available'] = True
else: else:
item['embedding_available'] = False item['embedding_available'] = False
else:
item['embedding_available'] = True
response = { response = {
'data': data, 'data': data,
'has_more': len(datasets) == limit, 'has_more': len(datasets) == limit,
...@@ -122,14 +125,6 @@ class DatasetListApi(Resource): ...@@ -122,14 +125,6 @@ class DatasetListApi(Resource):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
...@@ -167,6 +162,11 @@ class DatasetApi(Resource): ...@@ -167,6 +162,11 @@ class DatasetApi(Resource):
@account_initialization_required @account_initialization_required
def patch(self, dataset_id): def patch(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, parser.add_argument('name', nullable=False,
...@@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
...@@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource):
try: try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'], args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id']) args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
...@@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource):
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'], args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'], args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id']) args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
......
...@@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource): ...@@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource):
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
...@@ -339,7 +325,7 @@ class DatasetInitApi(Resource): ...@@ -339,7 +325,7 @@ class DatasetInitApi(Resource):
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json') location='json')
args = parser.parse_args() args = parser.parse_args()
if args['indexing_technique'] == 'high_quality':
try: try:
ModelFactory.get_embedding_model( ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id tenant_id=current_user.current_tenant_id
...@@ -348,6 +334,8 @@ class DatasetInitApi(Resource): ...@@ -348,6 +334,8 @@ class DatasetInitApi(Resource):
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
...@@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource): ...@@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource):
def delete(self, dataset_id, document_id): def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
try: try:
...@@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource): ...@@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource):
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
......
...@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
...@@ -158,7 +159,7 @@ class DatasetDocumentSegmentApi(Resource): ...@@ -158,7 +159,7 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
if dataset.indexing_technique == 'high_quality':
# check embedding model setting # check embedding model setting
try: try:
ModelFactory.get_embedding_model( ModelFactory.get_embedding_model(
...@@ -244,6 +245,7 @@ class DatasetDocumentSegmentAddApi(Resource): ...@@ -244,6 +245,7 @@ class DatasetDocumentSegmentAddApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
# check embedding model setting # check embedding model setting
if dataset.indexing_technique == 'high_quality':
try: try:
ModelFactory.get_embedding_model( ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
...@@ -284,11 +286,14 @@ class DatasetDocumentSegmentUpdateApi(Resource): ...@@ -284,11 +286,14 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document # check document
document_id = str(document_id) document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound('Document not found.') raise NotFound('Document not found.')
if dataset.indexing_technique == 'high_quality':
# check embedding model setting # check embedding model setting
try: try:
ModelFactory.get_embedding_model( ModelFactory.get_embedding_model(
...@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource): ...@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document # check document
document_id = str(document_id) document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
...@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource): ...@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound('Document not found.') raise NotFound('Document not found.')
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# get file from request # get file from request
file = request.files['file'] file = request.files['file']
# check file # check file
......
...@@ -67,7 +67,8 @@ class DatesetDocumentStore: ...@@ -67,7 +67,8 @@ class DatesetDocumentStore:
if max_position is None: if max_position is None:
max_position = 0 max_position = 0
embedding_model = None
if self._dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id, tenant_id=self._dataset.tenant_id,
model_provider_name=self._dataset.embedding_model_provider, model_provider_name=self._dataset.embedding_model_provider,
...@@ -88,7 +89,7 @@ class DatesetDocumentStore: ...@@ -88,7 +89,7 @@ class DatesetDocumentStore:
) )
# calc embedding use tokens # calc embedding use tokens
tokens = embedding_model.get_num_tokens(doc.page_content) tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
if not segment_document: if not segment_document:
max_position += 1 max_position += 1
......
import json
from flask import current_app from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.providers.openai_provider import OpenAIProvider
from models.dataset import Dataset from models.dataset import Dataset
from models.provider import Provider, ProviderType
class IndexBuilder: class IndexBuilder:
...@@ -36,3 +44,12 @@ class IndexBuilder: ...@@ -36,3 +44,12 @@ class IndexBuilder:
) )
else: else:
raise ValueError('Unknown indexing technique') raise ValueError('Unknown indexing technique')
@classmethod
def get_default_high_quality_index(cls, dataset: Dataset):
embeddings = OpenAIEmbeddings(openai_api_key=' ')
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
...@@ -217,22 +217,26 @@ class IndexingRunner: ...@@ -217,22 +217,26 @@ class IndexingRunner:
db.session.commit() db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict: doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
embedding_model = None
if dataset_id: if dataset_id:
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=dataset_id id=dataset_id
).first() ).first()
if not dataset: if not dataset:
raise ValueError('Dataset not found.') raise ValueError('Dataset not found.')
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_name=dataset.embedding_model
) )
else: else:
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id tenant_id=tenant_id
) )
...@@ -263,7 +267,7 @@ class IndexingRunner: ...@@ -263,7 +267,7 @@ class IndexingRunner:
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model:
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
...@@ -286,32 +290,35 @@ class IndexingRunner: ...@@ -286,32 +290,35 @@ class IndexingRunner:
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)), "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency() if embedding_model else 'USD',
"preview": preview_texts "preview": preview_texts
} }
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict: doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
embedding_model = None
if dataset_id: if dataset_id:
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=dataset_id id=dataset_id
).first() ).first()
if not dataset: if not dataset:
raise ValueError('Dataset not found.') raise ValueError('Dataset not found.')
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_name=dataset.embedding_model
) )
else: else:
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id tenant_id=tenant_id
) )
# load data from notion # load data from notion
tokens = 0 tokens = 0
preview_texts = [] preview_texts = []
...@@ -356,7 +363,7 @@ class IndexingRunner: ...@@ -356,7 +363,7 @@ class IndexingRunner:
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model:
tokens += embedding_model.get_num_tokens(document.page_content) tokens += embedding_model.get_num_tokens(document.page_content)
if doc_form and doc_form == 'qa_model': if doc_form and doc_form == 'qa_model':
...@@ -379,8 +386,8 @@ class IndexingRunner: ...@@ -379,8 +386,8 @@ class IndexingRunner:
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)), "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency() if embedding_model else 'USD',
"preview": preview_texts "preview": preview_texts
} }
...@@ -657,7 +664,8 @@ class IndexingRunner: ...@@ -657,7 +664,8 @@ class IndexingRunner:
""" """
vector_index = IndexBuilder.get_index(dataset, 'high_quality') vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy') keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model = None
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, model_provider_name=dataset.embedding_model_provider,
...@@ -672,7 +680,7 @@ class IndexingRunner: ...@@ -672,7 +680,7 @@ class IndexingRunner:
# check document is paused # check document is paused
self._check_document_paused_status(dataset_document.id) self._check_document_paused_status(dataset_document.id)
chunk_documents = documents[i:i + chunk_size] chunk_documents = documents[i:i + chunk_size]
if dataset.indexing_technique == 'high_quality' or embedding_model:
tokens += sum( tokens += sum(
embedding_model.get_num_tokens(document.page_content) embedding_model.get_num_tokens(document.page_content)
for document in chunk_documents for document in chunk_documents
......
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
from events.event_handlers.document_index_event import document_index_created from events.event_handlers.document_index_event import document_index_created
from tasks.clean_dataset_task import clean_dataset_task
import datetime import datetime
import logging import logging
import time import time
......
"""update_dataset_model_field_null_available
Revision ID: 4bcffcd64aa4
Revises: 853f9b9cd3b6
Create Date: 2023-08-28 20:58:50.077056
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4bcffcd64aa4'
down_revision = '853f9b9cd3b6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.alter_column('embedding_model',
existing_type=sa.VARCHAR(length=255),
nullable=True,
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
batch_op.alter_column('embedding_model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True,
existing_server_default=sa.text("'openai'::character varying"))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.alter_column('embedding_model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False,
existing_server_default=sa.text("'openai'::character varying"))
batch_op.alter_column('embedding_model',
existing_type=sa.VARCHAR(length=255),
nullable=False,
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
# ### end Alembic commands ###
...@@ -36,10 +36,8 @@ class Dataset(db.Model): ...@@ -36,10 +36,8 @@ class Dataset(db.Model):
updated_by = db.Column(UUID, nullable=True) updated_by = db.Column(UUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String( embedding_model = db.Column(db.String(255), nullable=True)
255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")) embedding_model_provider = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(
255), nullable=False, server_default=db.text("'openai'::character varying"))
@property @property
def dataset_keyword_table(self): def dataset_keyword_table(self):
......
...@@ -10,6 +10,7 @@ from flask import current_app ...@@ -10,6 +10,7 @@ from flask import current_app
from sqlalchemy import func from sqlalchemy import func
from core.index.index import IndexBuilder from core.index.index import IndexBuilder
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from flask_login import current_user from flask_login import current_user
...@@ -91,6 +92,8 @@ class DatasetService: ...@@ -91,6 +92,8 @@ class DatasetService:
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError( raise DatasetNameDuplicateError(
f'Dataset with name {name} already exists.') f'Dataset with name {name} already exists.')
embedding_model = None
if indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id tenant_id=current_user.current_tenant_id
) )
...@@ -99,8 +102,8 @@ class DatasetService: ...@@ -99,8 +102,8 @@ class DatasetService:
dataset.created_by = account.id dataset.created_by = account.id
dataset.updated_by = account.id dataset.updated_by = account.id
dataset.tenant_id = tenant_id dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.model_provider.provider_name dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
dataset.embedding_model = embedding_model.name dataset.embedding_model = embedding_model.name if embedding_model else None
db.session.add(dataset) db.session.add(dataset)
db.session.commit() db.session.commit()
return dataset return dataset
...@@ -115,6 +118,23 @@ class DatasetService: ...@@ -115,6 +118,23 @@ class DatasetService:
else: else:
return dataset return dataset
@staticmethod
def check_dataset_model_setting(dataset):
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(f"The dataset in unavailable, due to: "
f"{ex.description}")
@staticmethod @staticmethod
def update_dataset(dataset_id, data, user): def update_dataset(dataset_id, data, user):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
...@@ -124,6 +144,19 @@ class DatasetService: ...@@ -124,6 +144,19 @@ class DatasetService:
if data['indexing_technique'] == 'economy': if data['indexing_technique'] == 'economy':
deal_dataset_vector_index_task.delay(dataset_id, 'remove') deal_dataset_vector_index_task.delay(dataset_id, 'remove')
elif data['indexing_technique'] == 'high_quality': elif data['indexing_technique'] == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
deal_dataset_vector_index_task.delay(dataset_id, 'add') deal_dataset_vector_index_task.delay(dataset_id, 'add')
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'} filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
...@@ -397,6 +430,7 @@ class DocumentService: ...@@ -397,6 +430,7 @@ class DocumentService:
# check document limit # check document limit
if current_app.config['EDITION'] == 'CLOUD': if current_app.config['EDITION'] == 'CLOUD':
if 'original_document_id' not in document_data or not document_data['original_document_id']:
count = 0 count = 0
if document_data["data_source"]["type"] == "upload_file": if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
...@@ -413,7 +447,6 @@ class DocumentService: ...@@ -413,7 +447,6 @@ class DocumentService:
# if dataset is empty, update dataset data_source_type # if dataset is empty, update dataset data_source_type
if not dataset.data_source_type: if not dataset.data_source_type:
dataset.data_source_type = document_data["data_source"]["type"] dataset.data_source_type = document_data["data_source"]["type"]
db.session.commit()
if not dataset.indexing_technique: if not dataset.indexing_technique:
if 'indexing_technique' not in document_data \ if 'indexing_technique' not in document_data \
...@@ -421,6 +454,13 @@ class DocumentService: ...@@ -421,6 +454,13 @@ class DocumentService:
raise ValueError("Indexing technique is required") raise ValueError("Indexing technique is required")
dataset.indexing_technique = document_data["indexing_technique"] dataset.indexing_technique = document_data["indexing_technique"]
if document_data["indexing_technique"] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
documents = [] documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
...@@ -567,6 +607,7 @@ class DocumentService: ...@@ -567,6 +607,7 @@ class DocumentService:
def update_document_with_dataset_id(dataset: Dataset, document_data: dict, def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'): created_from: str = 'web'):
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
if document.display_status != 'available': if document.display_status != 'available':
raise ValueError("Document is not available") raise ValueError("Document is not available")
...@@ -674,6 +715,8 @@ class DocumentService: ...@@ -674,6 +715,8 @@ class DocumentService:
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count: if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.") raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id tenant_id=tenant_id
) )
...@@ -684,8 +727,8 @@ class DocumentService: ...@@ -684,8 +727,8 @@ class DocumentService:
data_source_type=document_data["data_source"]["type"], data_source_type=document_data["data_source"]["type"],
indexing_technique=document_data["indexing_technique"], indexing_technique=document_data["indexing_technique"],
created_by=account.id, created_by=account.id,
embedding_model=embedding_model.name, embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
) )
db.session.add(dataset) db.session.add(dataset)
...@@ -903,13 +946,13 @@ class SegmentService: ...@@ -903,13 +946,13 @@ class SegmentService:
content = args['content'] content = args['content']
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content) segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_name=dataset.embedding_model
) )
# calc embedding use tokens # calc embedding use tokens
tokens = embedding_model.get_num_tokens(content) tokens = embedding_model.get_num_tokens(content)
max_position = db.session.query(func.max(DocumentSegment.position)).filter( max_position = db.session.query(func.max(DocumentSegment.position)).filter(
...@@ -973,7 +1016,8 @@ class SegmentService: ...@@ -973,7 +1016,8 @@ class SegmentService:
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords) kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
else: else:
segment_hash = helper.generate_text_hash(content) segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, model_provider_name=dataset.embedding_model_provider,
......
...@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s ...@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
raise ValueError('Document is not available.') raise ValueError('Document is not available.')
document_segments = [] document_segments = []
for segment in content: embedding_model = None
content = segment['content'] if dataset.indexing_technique == 'high_quality':
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_name=dataset.embedding_model
) )
for segment in content:
content = segment['content']
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens # calc embedding use tokens
tokens = embedding_model.get_num_tokens(content) tokens = embedding_model.get_num_tokens(content) if embedding_model else 0
max_position = db.session.query(func.max(DocumentSegment.position)).filter( max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == dataset_document.id DocumentSegment.document_id == dataset_document.id
).scalar() ).scalar()
......
...@@ -3,8 +3,10 @@ import time ...@@ -3,8 +3,10 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from flask import current_app
from core.index.index import IndexBuilder from core.index.index import IndexBuilder
from core.index.vector_index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
AppDatasetJoin, Document AppDatasetJoin, Document
...@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy') kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
if vector_index: if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try: try:
vector_index.delete() vector_index.delete()
except Exception: except Exception:
......
...@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
raise Exception('Dataset not found') raise Exception('Dataset not found')
if action == "remove": if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
index.delete() index.delete()
elif action == "add": elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter( dataset_documents = db.session.query(DatasetDocument).filter(
...@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents: if dataset_documents:
# save vector index # save vector index
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
documents = [] documents = []
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
# delete from vector index # delete from vector index
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment