Unverified Commit 724e0537 authored by Jyong's avatar Jyong Committed by GitHub

Fix/qdrant data issue (#1203)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent e409895c
...@@ -3,12 +3,13 @@ import json ...@@ -3,12 +3,13 @@ import json
import math import math
import random import random
import string import string
import threading
import time import time
import uuid import uuid
import click import click
from tqdm import tqdm from tqdm import tqdm
from flask import current_app from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
...@@ -456,19 +457,35 @@ def update_qdrant_indexes(): ...@@ -456,19 +457,35 @@ def update_qdrant_indexes():
@click.command('normalization-collections', help='restore all collections in one') @click.command('normalization-collections', help='restore all collections in one')
def normalization_collections(): def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green')) click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0 normalization_count = []
page = 1 page = 1
while True: while True:
try: try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
except NotFound: except NotFound:
break break
datasets_result = datasets.items
page += 1 page += 1
for dataset in datasets: for i in range(0, len(datasets_result), 5):
if not dataset.collection_binding_id: threads = []
sub_datasets = datasets_result[i:i + 5]
for dataset in sub_datasets:
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'normalization_count': normalization_count
})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context():
try: try:
click.echo('restore dataset index: {}'.format(dataset.id)) click.echo('restore dataset index: {}'.format(dataset.id))
try: try:
...@@ -517,31 +534,15 @@ def normalization_collections(): ...@@ -517,31 +534,15 @@ def normalization_collections():
embeddings=embeddings embeddings=embeddings
) )
if index: if index:
# index.delete_by_group_id(dataset.id)
index.restore_dataset_in_one(dataset, dataset_collection_binding) index.restore_dataset_in_one(dataset, dataset_collection_binding)
else: else:
click.echo('passed.') click.echo('passed.')
normalization_count.append(1)
original_index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red')) fg='red'))
continue
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
......
...@@ -113,7 +113,9 @@ class BaseVectorIndex(BaseIndex): ...@@ -113,7 +113,9 @@ class BaseVectorIndex(BaseIndex):
def delete_by_group_id(self, group_id: str) -> None: def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
if self.dataset.collection_binding_id:
vector_store.delete_by_group_id(group_id)
else:
vector_store.delete() vector_store.delete()
def delete(self) -> None: def delete(self) -> None:
...@@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex): ...@@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex):
if documents: if documents:
try: try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name) self.add_texts(documents)
except Exception as e: except Exception as e:
raise e raise e
......
...@@ -1390,8 +1390,39 @@ class Qdrant(VectorStore): ...@@ -1390,8 +1390,39 @@ class Qdrant(VectorStore):
path=path, path=path,
**kwargs, **kwargs,
) )
try: all_collection_name = []
# Skip any validation in case of forced collection recreate. collections_response = client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
)
# If vector name was provided, we're going to use the named vectors feature
# with just a single vector.
if vector_name is not None:
vectors_config = { # type: ignore[assignment]
vector_name: vectors_config,
}
client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
shard_number=shard_number,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
is_new_collection = True
if force_recreate: if force_recreate:
raise ValueError raise ValueError
...@@ -1453,34 +1484,6 @@ class Qdrant(VectorStore): ...@@ -1453,34 +1484,6 @@ class Qdrant(VectorStore):
f"recreate the collection, set `force_recreate` parameter to " f"recreate the collection, set `force_recreate` parameter to "
f"`True`." f"`True`."
) )
except (UnexpectedResponse, RpcError, ValueError):
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
)
# If vector name was provided, we're going to use the named vectors feature
# with just a single vector.
if vector_name is not None:
vectors_config = { # type: ignore[assignment]
vector_name: vectors_config,
}
client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
shard_number=shard_number,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
is_new_collection = True
qdrant = cls( qdrant = cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,
......
...@@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex):
], ],
)) ))
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def _is_origin(self): def _is_origin(self):
if self.dataset.index_struct_dict: if self.dataset.index_struct_dict:
......
...@@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task ...@@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect @dataset_was_deleted.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
dataset = sender dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct) clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
dataset.index_struct, dataset.collection_binding_id)
...@@ -13,13 +13,15 @@ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, Datase ...@@ -13,13 +13,15 @@ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, Datase
@shared_task(queue='dataset') @shared_task(queue='dataset')
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str): def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct: str, collection_binding_id: str):
""" """
Clean dataset when dataset deleted. Clean dataset when dataset deleted.
:param dataset_id: dataset id :param dataset_id: dataset id
:param tenant_id: tenant id :param tenant_id: tenant id
:param indexing_technique: indexing technique :param indexing_technique: indexing technique
:param index_struct: index struct dict :param index_struct: index struct dict
:param collection_binding_id: collection binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
""" """
...@@ -31,9 +33,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -31,9 +33,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
id=dataset_id, id=dataset_id,
tenant_id=tenant_id, tenant_id=tenant_id,
indexing_technique=indexing_technique, indexing_technique=indexing_technique,
index_struct=index_struct index_struct=index_struct,
collection_binding_id=collection_binding_id
) )
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
...@@ -43,7 +45,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -43,7 +45,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset) vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try: try:
vector_index.delete() vector_index.delete_by_group_id(dataset.id)
except Exception: except Exception:
logging.exception("Delete doc index failed when dataset deleted.") logging.exception("Delete doc index failed when dataset deleted.")
......
...@@ -31,8 +31,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -31,8 +31,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
raise Exception('Dataset not found') raise Exception('Dataset not found')
if action == "remove": if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False) index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
index.delete() index.delete_by_group_id(dataset.id)
elif action == "add": elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter( dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment