Unverified Commit 724e0537 authored by Jyong's avatar Jyong Committed by GitHub

Fix/qdrant data issue (#1203)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent e409895c
...@@ -3,12 +3,13 @@ import json ...@@ -3,12 +3,13 @@ import json
import math import math
import random import random
import string import string
import threading
import time import time
import uuid import uuid
import click import click
from tqdm import tqdm from tqdm import tqdm
from flask import current_app from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
...@@ -456,92 +457,92 @@ def update_qdrant_indexes(): ...@@ -456,92 +457,92 @@ def update_qdrant_indexes():
@click.command('normalization-collections', help='restore all collections in one') @click.command('normalization-collections', help='restore all collections in one')
def normalization_collections(): def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green')) click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0 normalization_count = []
page = 1 page = 1
while True: while True:
try: try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
except NotFound: except NotFound:
break break
datasets_result = datasets.items
page += 1 page += 1
for dataset in datasets: for i in range(0, len(datasets_result), 5):
if not dataset.collection_binding_id: threads = []
try: sub_datasets = datasets_result[i:i + 5]
click.echo('restore dataset index: {}'.format(dataset.id)) for dataset in sub_datasets:
try: document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
embedding_model = ModelFactory.get_embedding_model( 'flask_app': current_app._get_current_object(),
tenant_id=dataset.tenant_id, 'dataset': dataset,
model_provider_name=dataset.embedding_model_provider, 'normalization_count': normalization_count
model_name=dataset.embedding_model })
) threads.append(document_format_thread)
except Exception: document_format_thread.start()
provider = Provider( for thread in threads:
id='provider_id', thread.join()
tenant_id=dataset.tenant_id,
provider_name='openai', click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True, def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
) with flask_app.app_context():
model_provider = OpenAIProvider(provider=provider) try:
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", click.echo('restore dataset index: {}'.format(dataset.id))
model_provider=model_provider) try:
embeddings = CacheEmbedding(embedding_model) embedding_model = ModelFactory.get_embedding_model(
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ tenant_id=dataset.tenant_id,
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, model_provider_name=dataset.embedding_model_provider,
DatasetCollectionBinding.model_name == embedding_model.name). \ model_name=dataset.embedding_model
order_by(DatasetCollectionBinding.created_at). \ )
first() except Exception:
provider = Provider(
if not dataset_collection_binding: id='provider_id',
dataset_collection_binding = DatasetCollectionBinding( tenant_id=dataset.tenant_id,
provider_name=embedding_model.model_provider.provider_name, provider_name='openai',
model_name=embedding_model.name, provider_type=ProviderType.CUSTOM.value,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
) is_valid=True,
db.session.add(dataset_collection_binding) )
db.session.commit() model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex( index = QdrantVectorIndex(
dataset=dataset, dataset=dataset,
config=QdrantConfig( config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'), endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'), api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path root_path=current_app.root_path
), ),
embeddings=embeddings embeddings=embeddings
) )
if index: if index:
index.restore_dataset_in_one(dataset, dataset_collection_binding) # index.delete_by_group_id(dataset.id)
else: index.restore_dataset_in_one(dataset, dataset_collection_binding)
click.echo('passed.') else:
click.echo('passed.')
original_index = QdrantVectorIndex( normalization_count.append(1)
dataset=dataset, except Exception as e:
config=QdrantConfig( click.echo(
endpoint=current_app.config.get('QDRANT_URL'), click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
api_key=current_app.config.get('QDRANT_API_KEY'), fg='red'))
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
......
...@@ -113,8 +113,10 @@ class BaseVectorIndex(BaseIndex): ...@@ -113,8 +113,10 @@ class BaseVectorIndex(BaseIndex):
def delete_by_group_id(self, group_id: str) -> None: def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
if self.dataset.collection_binding_id:
vector_store.delete() vector_store.delete_by_group_id(group_id)
else:
vector_store.delete()
def delete(self) -> None: def delete(self) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
...@@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex): ...@@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex):
if documents: if documents:
try: try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name) self.add_texts(documents)
except Exception as e: except Exception as e:
raise e raise e
......
...@@ -1390,70 +1390,12 @@ class Qdrant(VectorStore): ...@@ -1390,70 +1390,12 @@ class Qdrant(VectorStore):
path=path, path=path,
**kwargs, **kwargs,
) )
try: all_collection_name = []
# Skip any validation in case of forced collection recreate. collections_response = client.get_collections()
if force_recreate: collection_list = collections_response.collections
raise ValueError for collection in collection_list:
all_collection_name.append(collection.name)
# Get the vector configuration of the existing collection and vector, if it if collection_name not in all_collection_name:
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
except (UnexpectedResponse, RpcError, ValueError):
vectors_config = rest.VectorParams( vectors_config = rest.VectorParams(
size=vector_size, size=vector_size,
distance=rest.Distance[distance_func], distance=rest.Distance[distance_func],
...@@ -1481,6 +1423,67 @@ class Qdrant(VectorStore): ...@@ -1481,6 +1423,67 @@ class Qdrant(VectorStore):
timeout=timeout, # type: ignore[arg-type] timeout=timeout, # type: ignore[arg-type]
) )
is_new_collection = True is_new_collection = True
if force_recreate:
raise ValueError
# Get the vector configuration of the existing collection and vector, if it
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
qdrant = cls( qdrant = cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,
......
...@@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex):
], ],
)) ))
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def _is_origin(self): def _is_origin(self):
if self.dataset.index_struct_dict: if self.dataset.index_struct_dict:
......
...@@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task ...@@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect @dataset_was_deleted.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
dataset = sender dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct) clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
dataset.index_struct, dataset.collection_binding_id)
...@@ -13,13 +13,15 @@ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, Datase ...@@ -13,13 +13,15 @@ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, Datase
@shared_task(queue='dataset') @shared_task(queue='dataset')
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str): def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct: str, collection_binding_id: str):
""" """
Clean dataset when dataset deleted. Clean dataset when dataset deleted.
:param dataset_id: dataset id :param dataset_id: dataset id
:param tenant_id: tenant id :param tenant_id: tenant id
:param indexing_technique: indexing technique :param indexing_technique: indexing technique
:param index_struct: index struct dict :param index_struct: index struct dict
:param collection_binding_id: collection binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
""" """
...@@ -31,9 +33,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -31,9 +33,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
id=dataset_id, id=dataset_id,
tenant_id=tenant_id, tenant_id=tenant_id,
indexing_technique=indexing_technique, indexing_technique=indexing_technique,
index_struct=index_struct index_struct=index_struct,
collection_binding_id=collection_binding_id
) )
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
...@@ -43,7 +45,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -43,7 +45,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset) vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try: try:
vector_index.delete() vector_index.delete_by_group_id(dataset.id)
except Exception: except Exception:
logging.exception("Delete doc index failed when dataset deleted.") logging.exception("Delete doc index failed when dataset deleted.")
......
...@@ -31,8 +31,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -31,8 +31,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
raise Exception('Dataset not found') raise Exception('Dataset not found')
if action == "remove": if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False) index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
index.delete() index.delete_by_group_id(dataset.id)
elif action == "add": elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter( dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment