Commit 9d694360 authored by John Wang's avatar John Wang

fix: fix dataset del bugs

parent 71981eac
......@@ -33,9 +33,7 @@ class VectorIndex:
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings,
# attributes=['doc_id', 'dataset_id', 'document_id', 'source'],
attributes=['doc_id'],
embeddings=embeddings
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
......
from typing import Optional, Any, List, cast
from typing import Optional, cast
import weaviate
from langchain.embeddings.base import Embeddings
......@@ -25,11 +25,10 @@ class WeaviateConfig(BaseModel):
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list[str]):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
self._dataset = dataset
self._client = self._init_client(config)
self._embeddings = embeddings
self._attributes = attributes
self._vector_store = None
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
......@@ -95,12 +94,16 @@ class WeaviateVectorIndex(BaseVectorIndex):
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id', 'source']
if self._is_origin():
attributes = ['doc_id', 'ref_doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self._dataset),
text_key='text',
embedding=self._embeddings,
attributes=self._attributes,
attributes=attributes,
by_text=False
)
......@@ -113,6 +116,15 @@ class WeaviateVectorIndex(BaseVectorIndex):
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"path": ["doc_id" if self._is_origin() else "document_id"],
"valueText": document_id
})
def _is_origin(self):
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
......@@ -260,7 +260,7 @@ class Document(db.Model):
@property
def dataset(self):
return Dataset.query.get(self.dataset_id)
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
@property
def segment_count(self):
......@@ -400,8 +400,10 @@ class DatasetKeywordTable(db.Model):
super().__init__(object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, dct):
if "__set__" in dct:
return set(dct["__set__"])
if isinstance(dct, dict):
for keyword, node_idxs in dct.items():
if isinstance(node_idxs, list):
dct[keyword] = set(node_idxs)
return dct
return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment