Unverified Commit 724e0537 authored by Jyong's avatar Jyong Committed by GitHub

Fix/qdrant data issue (#1203)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent e409895c
......@@ -3,12 +3,13 @@ import json
import math
import random
import string
import threading
import time
import uuid
import click
from tqdm import tqdm
from flask import current_app
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound
......@@ -456,19 +457,35 @@ def update_qdrant_indexes():
@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0
normalization_count = []
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break
datasets_result = datasets.items
page += 1
for dataset in datasets:
if not dataset.collection_binding_id:
for i in range(0, len(datasets_result), 5):
threads = []
sub_datasets = datasets_result[i:i + 5]
for dataset in sub_datasets:
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'normalization_count': normalization_count
})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context():
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
......@@ -517,31 +534,15 @@ def normalization_collections():
embeddings=embeddings
)
if index:
# index.delete_by_group_id(dataset.id)
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
original_index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
normalization_count.append(1)
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
......
......@@ -113,7 +113,9 @@ class BaseVectorIndex(BaseIndex):
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if self.dataset.collection_binding_id:
vector_store.delete_by_group_id(group_id)
else:
vector_store.delete()
def delete(self) -> None:
......@@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex):
if documents:
try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
self.add_texts(documents)
except Exception as e:
raise e
......
......@@ -1390,8 +1390,39 @@ class Qdrant(VectorStore):
path=path,
**kwargs,
)
try:
# Skip any validation in case of forced collection recreate.
all_collection_name = []
collections_response = client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
)
# If vector name was provided, we're going to use the named vectors feature
# with just a single vector.
if vector_name is not None:
vectors_config = { # type: ignore[assignment]
vector_name: vectors_config,
}
client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
shard_number=shard_number,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
is_new_collection = True
if force_recreate:
raise ValueError
......@@ -1453,34 +1484,6 @@ class Qdrant(VectorStore):
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
except (UnexpectedResponse, RpcError, ValueError):
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
)
# If vector name was provided, we're going to use the named vectors feature
# with just a single vector.
if vector_name is not None:
vectors_config = { # type: ignore[assignment]
vector_name: vectors_config,
}
client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
shard_number=shard_number,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
is_new_collection = True
qdrant = cls(
client=client,
collection_name=collection_name,
......
......@@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
......
......@@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect
def handle(sender, **kwargs):
dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct)
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
dataset.index_struct, dataset.collection_binding_id)
......@@ -13,13 +13,15 @@ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, Datase
@shared_task(queue='dataset')
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str):
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct: str, collection_binding_id: str):
"""
Clean dataset when dataset deleted.
:param dataset_id: dataset id
:param tenant_id: tenant id
:param indexing_technique: indexing technique
:param index_struct: index struct dict
:param collection_binding_id: collection binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
"""
......@@ -31,9 +33,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
id=dataset_id,
tenant_id=tenant_id,
indexing_technique=indexing_technique,
index_struct=index_struct
index_struct=index_struct,
collection_binding_id=collection_binding_id
)
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
......@@ -43,7 +45,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try:
vector_index.delete()
vector_index.delete_by_group_id(dataset.id)
except Exception:
logging.exception("Delete doc index failed when dataset deleted.")
......
......@@ -31,8 +31,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
raise Exception('Dataset not found')
if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
index.delete()
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
index.delete_by_group_id(dataset.id)
elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset_id,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment