Unverified Commit 0620fa30 authored by Jyong's avatar Jyong Committed by GitHub

Feat/vdb migrate command (#2562)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent d93288f7
import base64 import base64
import json import json
import secrets import secrets
from typing import cast
import click import click
from flask import current_app from flask import current_app
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding from core.rag.datasource.vdb.vector_factory import Vector
from core.model_manager import ModelManager from core.rag.models.document import Document
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email as email_validate from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import Tenant from models.account import Tenant
from models.dataset import Dataset from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account from models.model import Account
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
...@@ -124,14 +125,15 @@ def reset_encrypt_key_pair(): ...@@ -124,14 +125,15 @@ def reset_encrypt_key_pair():
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.') @click.command('vdb-migrate', help='migrate vector db.')
def create_qdrant_indexes(): def vdb_migrate():
""" """
Migrate other vector database datas to Qdrant. Migrate vector database datas to target vector database .
""" """
click.echo(click.style('Start create qdrant indexes.', fg='green')) click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0 create_count = 0
config = cast(dict, current_app.config)
vector_type = config.get('VECTOR_STORE')
page = 1 page = 1
while True: while True:
try: try:
...@@ -140,50 +142,97 @@ def create_qdrant_indexes(): ...@@ -140,50 +142,97 @@ def create_qdrant_indexes():
except NotFound: except NotFound:
break break
model_manager = ModelManager()
page += 1 page += 1
for dataset in datasets: for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try: try:
embedding_model = model_manager.get_model_instance( click.echo('Create dataset vdb index: {}'.format(dataset.id))
tenant_id=dataset.tenant_id, if dataset.index_struct_dict:
provider=dataset.embedding_model_provider, if dataset.index_struct_dict['type'] == vector_type:
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except Exception:
continue continue
embeddings = CacheEmbedding(embedding_model) if vector_type == "weaviate":
dataset_id = dataset.id
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
index = QdrantVectorIndex( "type": 'weaviate',
dataset=dataset, "vector_store": {"class_prefix": collection_name}
config=QdrantConfig( }
endpoint=current_app.config.get('QDRANT_URL'), dataset.index_struct = json.dumps(index_struct_dict)
api_key=current_app.config.get('QDRANT_API_KEY'), elif vector_type == "qdrant":
root_path=current_app.root_path if dataset.collection_binding_id:
), dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
embeddings=embeddings filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
) one_or_none()
if index: if dataset_collection_binding:
index.create_qdrant_dataset(dataset) collection_name = dataset_collection_binding.collection_name
index_struct = { else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'qdrant', "type": 'qdrant',
"vector_store": { "vector_store": {"class_prefix": collection_name}
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} }
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "milvus":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
} }
dataset.index_struct = json.dumps(index_struct) dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
vector = Vector(dataset)
click.echo(f"vdb_migrate {dataset.id}")
try:
vector.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
vector.create(documents)
except Exception as e:
raise e
click.echo(f"Dataset {dataset.id} create successfully.")
db.session.add(dataset)
db.session.commit() db.session.commit()
create_count += 1 create_count += 1
else:
click.echo('passed.')
except Exception as e: except Exception as e:
db.session.rollback()
click.echo( click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red')) fg='red'))
...@@ -196,4 +245,4 @@ def register_commands(app): ...@@ -196,4 +245,4 @@ def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(create_qdrant_indexes) app.cli.add_command(vdb_migrate)
...@@ -664,6 +664,7 @@ class IndexingRunner: ...@@ -664,6 +664,7 @@ class IndexingRunner:
) )
# load index # load index
index_processor.load(dataset, chunk_documents) index_processor.load(dataset, chunk_documents)
db.session.add(dataset)
document_ids = [document.metadata['doc_id'] for document in chunk_documents] document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
......
...@@ -127,9 +127,15 @@ class MilvusVector(BaseVector): ...@@ -127,9 +127,15 @@ class MilvusVector(BaseVector):
self._client.delete(collection_name=self._collection_name, pks=doc_ids) self._client.delete(collection_name=self._collection_name, pks=doc_ids)
def delete(self) -> None: def delete(self) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
from pymilvus import utility from pymilvus import utility
utility.drop_collection(self._collection_name, None) utility.drop_collection(self._collection_name, None, using=alias)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
......
import json
from typing import Any, cast from typing import Any, cast
from flask import current_app from flask import current_app
...@@ -39,6 +40,11 @@ class Vector: ...@@ -39,6 +40,11 @@ class Vector:
else: else:
dataset_id = self._dataset.id dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return WeaviateVector( return WeaviateVector(
collection_name=collection_name, collection_name=collection_name,
config=WeaviateConfig( config=WeaviateConfig(
...@@ -66,6 +72,13 @@ class Vector: ...@@ -66,6 +72,13 @@ class Vector:
dataset_id = self._dataset.id dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
if not self._dataset.index_struct_dict:
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return QdrantVector( return QdrantVector(
collection_name=collection_name, collection_name=collection_name,
group_id=self._dataset.id, group_id=self._dataset.id,
...@@ -84,6 +97,11 @@ class Vector: ...@@ -84,6 +97,11 @@ class Vector:
else: else:
dataset_id = self._dataset.id dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return MilvusVector( return MilvusVector(
collection_name=collection_name, collection_name=collection_name,
config=MilvusConfig( config=MilvusConfig(
......
...@@ -127,6 +127,9 @@ class WeaviateVector(BaseVector): ...@@ -127,6 +127,9 @@ class WeaviateVector(BaseVector):
) )
def delete(self): def delete(self):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
self._client.schema.delete_class(self._collection_name) self._client.schema.delete_class(self._collection_name)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment