Unverified Commit 269a465f authored by Jyong's avatar Jyong Committed by GitHub

Feat/improve vector database logic (#1193)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 60e0bbd7
...@@ -4,6 +4,7 @@ import math ...@@ -4,6 +4,7 @@ import math
import random import random
import string import string
import time import time
import uuid
import click import click
from tqdm import tqdm from tqdm import tqdm
...@@ -23,7 +24,7 @@ from libs.helper import email as email_validate ...@@ -23,7 +24,7 @@ from libs.helper import email as email_validate
from extensions.ext_database import db from extensions.ext_database import db
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App from models.model import Account, AppModelConfig, App
import secrets import secrets
import base64 import base64
...@@ -239,7 +240,13 @@ def clean_unused_dataset_indexes(): ...@@ -239,7 +240,13 @@ def clean_unused_dataset_indexes():
kw_index = IndexBuilder.get_index(dataset, 'economy') kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
if vector_index: if vector_index:
vector_index.delete() if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete() kw_index.delete()
# update document # update document
update_params = { update_params = {
...@@ -346,7 +353,8 @@ def create_qdrant_indexes(): ...@@ -346,7 +353,8 @@ def create_qdrant_indexes():
is_valid=True, is_valid=True,
) )
model_provider = OpenAIProvider(provider=provider) model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
...@@ -364,7 +372,8 @@ def create_qdrant_indexes(): ...@@ -364,7 +372,8 @@ def create_qdrant_indexes():
index.create_qdrant_dataset(dataset) index.create_qdrant_dataset(dataset)
index_struct = { index_struct = {
"type": 'qdrant', "type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} "vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
} }
dataset.index_struct = json.dumps(index_struct) dataset.index_struct = json.dumps(index_struct)
db.session.commit() db.session.commit()
...@@ -373,7 +382,8 @@ def create_qdrant_indexes(): ...@@ -373,7 +382,8 @@ def create_qdrant_indexes():
click.echo('passed.') click.echo('passed.')
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
...@@ -414,7 +424,8 @@ def update_qdrant_indexes(): ...@@ -414,7 +424,8 @@ def update_qdrant_indexes():
is_valid=True, is_valid=True,
) )
model_provider = OpenAIProvider(provider=provider) model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
...@@ -435,11 +446,104 @@ def update_qdrant_indexes(): ...@@ -435,11 +446,104 @@ def update_qdrant_indexes():
click.echo('passed.') click.echo('passed.')
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue continue
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green')) click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if not dataset.collection_binding_id:
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
original_index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def update_app_model_configs(batch_size): def update_app_model_configs(batch_size):
...@@ -473,7 +577,7 @@ def update_app_model_configs(batch_size): ...@@ -473,7 +577,7 @@ def update_app_model_configs(batch_size):
.join(App, App.app_model_config_id == AppModelConfig.id) \ .join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \ .filter(App.mode == 'completion') \
.count() .count()
if total_records == 0: if total_records == 0:
click.secho("No data to migrate.", fg='green') click.secho("No data to migrate.", fg='green')
return return
...@@ -485,14 +589,14 @@ def update_app_model_configs(batch_size): ...@@ -485,14 +589,14 @@ def update_app_model_configs(batch_size):
offset = i * batch_size offset = i * batch_size
limit = min(batch_size, total_records - offset) limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green') click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \ data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \ .join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \ .filter(App.mode == 'completion') \
.order_by(App.created_at) \ .order_by(App.created_at) \
.offset(offset).limit(limit).all() .offset(offset).limit(limit).all()
if not data_batch: if not data_batch:
click.secho("No more data to migrate.", fg='green') click.secho("No more data to migrate.", fg='green')
break break
...@@ -512,7 +616,7 @@ def update_app_model_configs(batch_size): ...@@ -512,7 +616,7 @@ def update_app_model_configs(batch_size):
app_data = db.session.query(App) \ app_data = db.session.query(App) \
.filter(App.id == data.app_id) \ .filter(App.id == data.app_id) \
.one() .one()
account_data = db.session.query(Account) \ account_data = db.session.query(Account) \
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \ .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
.filter(TenantAccountJoin.role == 'owner') \ .filter(TenantAccountJoin.role == 'owner') \
...@@ -534,13 +638,15 @@ def update_app_model_configs(batch_size): ...@@ -534,13 +638,15 @@ def update_app_model_configs(batch_size):
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red') click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue continue
click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green') click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch)) pbar.update(len(data_batch))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
...@@ -551,4 +657,5 @@ def register_commands(app): ...@@ -551,4 +657,5 @@ def register_commands(app):
app.cli.add_command(clean_unused_dataset_indexes) app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes) app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes) app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs) app.cli.add_command(update_app_model_configs)
\ No newline at end of file app.cli.add_command(normalization_collections)
...@@ -16,6 +16,10 @@ class BaseIndex(ABC): ...@@ -16,6 +16,10 @@ class BaseIndex(ABC):
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod @abstractmethod
def add_texts(self, texts: list[Document], **kwargs): def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError raise NotImplementedError
...@@ -28,6 +32,10 @@ class BaseIndex(ABC): ...@@ -28,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError
@abstractmethod @abstractmethod
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
raise NotImplementedError raise NotImplementedError
......
...@@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex): ...@@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex):
return self return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs): def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
...@@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex): ...@@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
db.session.delete(dataset_keyword_table) db.session.delete(dataset_keyword_table)
db.session.commit() db.session.commit()
def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table): def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = { keyword_table_dict = {
'__type__': 'keyword_table', '__type__': 'keyword_table',
......
...@@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException ...@@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
from core.index.base import BaseIndex from core.index.base import BaseIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
...@@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex): ...@@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex):
for node_id in ids: for node_id in ids:
vector_store.del_text(node_id) vector_store.del_text(node_id)
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def delete(self) -> None: def delete(self) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
...@@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex): ...@@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex):
raise e raise e
logging.info(f"Dataset {dataset.id} recreate successfully.") logging.info(f"Dataset {dataset.id} recreate successfully.")
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")
self.delete()
dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()
logging.info(f"Dataset {dataset.id} recreate successfully.")
...@@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex): ...@@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
return self return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=collection_name,
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore: def _get_vector_store(self) -> VectorStore:
"""Only for created index.""" """Only for created index."""
if self._vector_store: if self._vector_store:
......
...@@ -28,6 +28,7 @@ from langchain.docstore.document import Document ...@@ -28,6 +28,7 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType
if TYPE_CHECKING: if TYPE_CHECKING:
from qdrant_client import grpc # noqa from qdrant_client import grpc # noqa
...@@ -84,6 +85,7 @@ class Qdrant(VectorStore): ...@@ -84,6 +85,7 @@ class Qdrant(VectorStore):
CONTENT_KEY = "page_content" CONTENT_KEY = "page_content"
METADATA_KEY = "metadata" METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR_NAME = None VECTOR_NAME = None
def __init__( def __init__(
...@@ -93,9 +95,12 @@ class Qdrant(VectorStore): ...@@ -93,9 +95,12 @@ class Qdrant(VectorStore):
embeddings: Optional[Embeddings] = None, embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
distance_strategy: str = "COSINE", distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME, vector_name: Optional[str] = VECTOR_NAME,
embedding_function: Optional[Callable] = None, # deprecated embedding_function: Optional[Callable] = None, # deprecated
is_new_collection: bool = False
): ):
"""Initialize with necessary components.""" """Initialize with necessary components."""
try: try:
...@@ -129,7 +134,10 @@ class Qdrant(VectorStore): ...@@ -129,7 +134,10 @@ class Qdrant(VectorStore):
self.collection_name = collection_name self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.group_payload_key = group_payload_key or self.GROUP_KEY
self.vector_name = vector_name or self.VECTOR_NAME self.vector_name = vector_name or self.VECTOR_NAME
self.group_id = group_id
self.is_new_collection= is_new_collection
if embedding_function is not None: if embedding_function is not None:
warnings.warn( warnings.warn(
...@@ -170,6 +178,8 @@ class Qdrant(VectorStore): ...@@ -170,6 +178,8 @@ class Qdrant(VectorStore):
batch_size: batch_size:
How many vectors upload per-request. How many vectors upload per-request.
Default: 64 Default: 64
group_id:
collection group
Returns: Returns:
List of ids from adding the texts into the vectorstore. List of ids from adding the texts into the vectorstore.
...@@ -182,7 +192,11 @@ class Qdrant(VectorStore): ...@@ -182,7 +192,11 @@ class Qdrant(VectorStore):
collection_name=self.collection_name, points=points, **kwargs collection_name=self.collection_name, points=points, **kwargs
) )
added_ids.extend(batch_ids) added_ids.extend(batch_ids)
# if is new collection, create payload index on group_id
if self.is_new_collection:
self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
return added_ids return added_ids
@sync_call_fallback @sync_call_fallback
...@@ -970,6 +984,8 @@ class Qdrant(VectorStore): ...@@ -970,6 +984,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine", distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME, vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64, batch_size: int = 64,
shard_number: Optional[int] = None, shard_number: Optional[int] = None,
...@@ -1034,6 +1050,11 @@ class Qdrant(VectorStore): ...@@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
metadata_payload_key: metadata_payload_key:
A payload key used to store the metadata of the document. A payload key used to store the metadata of the document.
Default: "metadata" Default: "metadata"
group_payload_key:
A payload key used to store the content of the document.
Default: "group_id"
group_id:
collection group id
vector_name: vector_name:
Name of the vector to be used internally in Qdrant. Name of the vector to be used internally in Qdrant.
Default: None Default: None
...@@ -1107,6 +1128,8 @@ class Qdrant(VectorStore): ...@@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
distance_func, distance_func,
content_payload_key, content_payload_key,
metadata_payload_key, metadata_payload_key,
group_payload_key,
group_id,
vector_name, vector_name,
shard_number, shard_number,
replication_factor, replication_factor,
...@@ -1321,6 +1344,8 @@ class Qdrant(VectorStore): ...@@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine", distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME, vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None, shard_number: Optional[int] = None,
replication_factor: Optional[int] = None, replication_factor: Optional[int] = None,
...@@ -1350,6 +1375,7 @@ class Qdrant(VectorStore): ...@@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
vector_size = len(partial_embeddings[0]) vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper() distance_func = distance_func.upper()
is_new_collection = False
client = qdrant_client.QdrantClient( client = qdrant_client.QdrantClient(
location=location, location=location,
url=url, url=url,
...@@ -1454,6 +1480,7 @@ class Qdrant(VectorStore): ...@@ -1454,6 +1480,7 @@ class Qdrant(VectorStore):
init_from=init_from, init_from=init_from,
timeout=timeout, # type: ignore[arg-type] timeout=timeout, # type: ignore[arg-type]
) )
is_new_collection = True
qdrant = cls( qdrant = cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,
...@@ -1462,6 +1489,9 @@ class Qdrant(VectorStore): ...@@ -1462,6 +1489,9 @@ class Qdrant(VectorStore):
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func, distance_strategy=distance_func,
vector_name=vector_name, vector_name=vector_name,
group_id=group_id,
group_payload_key=group_payload_key,
is_new_collection=is_new_collection
) )
return qdrant return qdrant
...@@ -1516,6 +1546,8 @@ class Qdrant(VectorStore): ...@@ -1516,6 +1546,8 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]], metadatas: Optional[List[dict]],
content_payload_key: str, content_payload_key: str,
metadata_payload_key: str, metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> List[dict]: ) -> List[dict]:
payloads = [] payloads = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
...@@ -1529,6 +1561,7 @@ class Qdrant(VectorStore): ...@@ -1529,6 +1561,7 @@ class Qdrant(VectorStore):
{ {
content_payload_key: text, content_payload_key: text,
metadata_payload_key: metadata, metadata_payload_key: metadata,
group_payload_key: group_id
} }
) )
...@@ -1578,7 +1611,7 @@ class Qdrant(VectorStore): ...@@ -1578,7 +1611,7 @@ class Qdrant(VectorStore):
else: else:
out.append( out.append(
rest.FieldCondition( rest.FieldCondition(
key=f"{self.metadata_payload_key}.{key}", key=key,
match=rest.MatchValue(value=value), match=rest.MatchValue(value=value),
) )
) )
...@@ -1654,6 +1687,7 @@ class Qdrant(VectorStore): ...@@ -1654,6 +1687,7 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None, ids: Optional[Sequence[str]] = None,
batch_size: int = 64, batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]: ) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
...@@ -1684,6 +1718,8 @@ class Qdrant(VectorStore): ...@@ -1684,6 +1718,8 @@ class Qdrant(VectorStore):
batch_metadatas, batch_metadatas,
self.content_payload_key, self.content_payload_key,
self.metadata_payload_key, self.metadata_payload_key,
self.group_id,
self.group_payload_key
), ),
) )
] ]
......
...@@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings ...@@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore from langchain.vectorstores import VectorStore
from pydantic import BaseModel from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff
from core.index.base import BaseIndex from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel): class QdrantConfig(BaseModel):
endpoint: str endpoint: str
api_key: Optional[str] api_key: Optional[str]
root_path: Optional[str] root_path: Optional[str]
def to_qdrant_params(self): def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'): if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '') path = self.endpoint.replace('path:', '')
...@@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return 'qdrant' return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str: def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict: if dataset.collection_binding_id:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
if not class_prefix.endswith('_Node'): filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
# original class_prefix one_or_none()
class_prefix += '_Node' if dataset_collection_binding:
return dataset_collection_binding.collection_name
return class_prefix else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix
dataset_id = dataset.id dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict: def to_index_struct(self) -> dict:
return { return {
...@@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
collection_name=self.get_index_name(self.dataset), collection_name=self.get_index_name(self.dataset),
ids=uuids, ids=uuids,
content_payload_key='page_content', content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params() **self._client_config.to_qdrant_params()
) )
...@@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
if self._vector_store: if self._vector_store:
return self._vector_store return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id'] attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
client = qdrant_client.QdrantClient( client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params() **self._client_config.to_qdrant_params()
) )
...@@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
client=client, client=client,
collection_name=self.get_index_name(self.dataset), collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings, embeddings=self._embeddings,
content_payload_key='page_content' content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
) )
def _get_vector_store_class(self) -> type: def _get_vector_store_class(self) -> type:
return QdrantVectorStore return QdrantVectorStore
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
...@@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex):
)) ))
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
...@@ -132,6 +154,22 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -132,6 +154,22 @@ class QdrantVectorIndex(BaseVectorIndex):
], ],
)) ))
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))
def _is_origin(self): def _is_origin(self):
if self.dataset.index_struct_dict: if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
......
...@@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
return self return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore: def _get_vector_store(self) -> VectorStore:
"""Only for created index.""" """Only for created index."""
if self._vector_store: if self._vector_store:
......
...@@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool): ...@@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool):
return_resource: str return_resource: str
retriever_from: str retriever_from: str
@classmethod @classmethod
def from_dataset(cls, dataset: Dataset, **kwargs): def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description description = dataset.description
...@@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool): ...@@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool):
query, query,
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
search_kwargs={ search_kwargs={
'k': self.k 'k': self.k,
'filter': {
'group_id': [dataset.id]
}
} }
) )
else: else:
......
...@@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant): ...@@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
self.client.delete_collection(collection_name=self.collection_name) self.client.delete_collection(collection_name=self.collection_name)
def delete_group(self):
self._reload_if_needed()
self.client.delete_collection(collection_name=self.collection_name)
@classmethod @classmethod
def _document_from_scored_point( def _document_from_scored_point(
cls, cls,
......
"""add_dataset_collection_binding
Revision ID: 6e2cfb077b04
Revises: 77e83833755c
Create Date: 2023-09-13 22:16:48.027810
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6e2cfb077b04'
down_revision = '77e83833755c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_collection_bindings',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('provider_name', sa.String(length=40), nullable=False),
sa.Column('model_name', sa.String(length=40), nullable=False),
sa.Column('collection_name', sa.String(length=64), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey')
)
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('collection_binding_id')
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_index('provider_model_name_idx')
op.drop_table('dataset_collection_bindings')
# ### end Alembic commands ###
...@@ -38,6 +38,8 @@ class Dataset(db.Model): ...@@ -38,6 +38,8 @@ class Dataset(db.Model):
server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String(255), nullable=True) embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True)
@property @property
def dataset_keyword_table(self): def dataset_keyword_table(self):
...@@ -445,3 +447,19 @@ class Embedding(db.Model): ...@@ -445,3 +447,19 @@ class Embedding(db.Model):
def get_embedding(self) -> list[float]: def get_embedding(self) -> list[float]:
return pickle.loads(self.embedding) return pickle.loads(self.embedding)
class DatasetCollectionBinding(db.Model):
__tablename__ = 'dataset_collection_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'),
db.Index('provider_model_name_idx', 'provider_name', 'model_name')
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
...@@ -20,7 +20,8 @@ from events.document_event import document_was_deleted ...@@ -20,7 +20,8 @@ from events.document_event import document_was_deleted
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.account import Account from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \
DatasetCollectionBinding
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceBinding from models.source import DataSourceBinding
from services.errors.account import NoPermissionError from services.errors.account import NoPermissionError
...@@ -147,6 +148,7 @@ class DatasetService: ...@@ -147,6 +148,7 @@ class DatasetService:
action = 'remove' action = 'remove'
filtered_data['embedding_model'] = None filtered_data['embedding_model'] = None
filtered_data['embedding_model_provider'] = None filtered_data['embedding_model_provider'] = None
filtered_data['collection_binding_id'] = None
elif data['indexing_technique'] == 'high_quality': elif data['indexing_technique'] == 'high_quality':
action = 'add' action = 'add'
# get embedding model setting # get embedding model setting
...@@ -156,6 +158,11 @@ class DatasetService: ...@@ -156,6 +158,11 @@ class DatasetService:
) )
filtered_data['embedding_model'] = embedding_model.name filtered_data['embedding_model'] = embedding_model.name
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError: except LLMBadRequestError:
raise ValueError( raise ValueError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
...@@ -464,7 +471,11 @@ class DocumentService: ...@@ -464,7 +471,11 @@ class DocumentService:
) )
dataset.embedding_model = embedding_model.name dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset.collection_binding_id = dataset_collection_binding.id
documents = [] documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
...@@ -720,10 +731,16 @@ class DocumentService: ...@@ -720,10 +731,16 @@ class DocumentService:
if total_count > tenant_document_count: if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.") raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None embedding_model = None
dataset_collection_binding_id = None
if document_data['indexing_technique'] == 'high_quality': if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id tenant_id=tenant_id
) )
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset_collection_binding_id = dataset_collection_binding.id
# save dataset # save dataset
dataset = Dataset( dataset = Dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
...@@ -732,7 +749,8 @@ class DocumentService: ...@@ -732,7 +749,8 @@ class DocumentService:
indexing_technique=document_data["indexing_technique"], indexing_technique=document_data["indexing_technique"],
created_by=account.id, created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None, embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
collection_binding_id=dataset_collection_binding_id
) )
db.session.add(dataset) db.session.add(dataset)
...@@ -1069,3 +1087,23 @@ class SegmentService: ...@@ -1069,3 +1087,23 @@ class SegmentService:
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment) db.session.delete(segment)
db.session.commit() db.session.commit()
class DatasetCollectionBindingService:
@classmethod
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=provider_name,
model_name=model_name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.flush()
return dataset_collection_binding
...@@ -47,7 +47,10 @@ class HitTestingService: ...@@ -47,7 +47,10 @@ class HitTestingService:
query, query,
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
search_kwargs={ search_kwargs={
'k': 10 'k': 10,
'filter': {
'group_id': [dataset.id]
}
} }
) )
end = time.perf_counter() end = time.perf_counter()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment