Commit 01baae87 authored by John Wang's avatar John Wang

feat: recreate dataset when origin dataset format

parent f33056f4
......@@ -4,8 +4,14 @@ from typing import List, Any
from langchain.schema import Document, BaseRetriever
from models.dataset import Dataset
class BaseIndex(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
......
......@@ -17,7 +17,7 @@ class KeywordTableConfig(BaseModel):
class KeywordTableIndex(BaseIndex):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
self._dataset = dataset
super().__init__(dataset)
self._config = config
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
......@@ -29,11 +29,11 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self._dataset.id,
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
......@@ -70,7 +70,7 @@ class KeywordTableIndex(BaseIndex):
def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
......@@ -98,7 +98,7 @@ class KeywordTableIndex(BaseIndex):
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
......@@ -115,7 +115,7 @@ class KeywordTableIndex(BaseIndex):
return documents
def delete(self) -> None:
dataset_keyword_table = self._dataset.dataset_keyword_table
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
......@@ -124,26 +124,26 @@ class KeywordTableIndex(BaseIndex):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self._dataset.id,
"index_id": self.dataset.id,
"summary": None,
"table": keyword_table
}
}
self._dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self._dataset.dataset_keyword_table
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict:
return dataset_keyword_table.keyword_table_dict['__data__']['table']
else:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self._dataset.id,
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
......
import json
import logging
from abc import abstractmethod
from typing import List, Any, Tuple, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from core.index.base import BaseIndex
from models.dataset import Dataset
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
......@@ -69,6 +80,9 @@ class BaseVectorIndex(BaseIndex):
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document], **kwargs):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
......@@ -85,6 +99,9 @@ class BaseVectorIndex(BaseIndex):
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
......@@ -96,3 +113,45 @@ class BaseVectorIndex(BaseIndex):
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
return False
def recreate_dataset(self, dataset: Dataset):
logging.debug(f"Recreating dataset {dataset.id}")
self.delete()
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
self.create(documents)
dataset.index_struct = json.dumps(self.to_index_struct())
db.session.commit()
self.dataset = dataset
......@@ -36,17 +36,15 @@ class QdrantConfig(BaseModel):
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
self._dataset = dataset
super().__init__(dataset, embeddings)
self._client_config = config
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if self._dataset.index_struct_dict:
return self._dataset.index_struct_dict['vector_store']['collection_name']
if self.dataset.index_struct_dict:
return self.dataset.index_struct_dict['vector_store']['collection_name']
dataset_id = dataset.id
return "Index_" + dataset_id.replace("-", "_")
......@@ -54,7 +52,7 @@ class QdrantVectorIndex(BaseVectorIndex):
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self._dataset)}
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
......@@ -62,7 +60,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self._dataset),
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='text',
**self._client_config.to_qdrant_params()
......@@ -81,7 +79,7 @@ class QdrantVectorIndex(BaseVectorIndex):
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self._dataset),
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='text'
)
......@@ -90,6 +88,9 @@ class QdrantVectorIndex(BaseVectorIndex):
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
......@@ -98,15 +99,15 @@ class QdrantVectorIndex(BaseVectorIndex):
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="doc_id" if self._is_origin() else "metadata.document_id",
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
def _is_origin(self):
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['collection_name']
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
if not class_prefix.strip('Vector_'):
# original class_prefix
return True
......
......@@ -26,10 +26,8 @@ class WeaviateConfig(BaseModel):
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
self._dataset = dataset
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._embeddings = embeddings
self._vector_store = None
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
......@@ -59,8 +57,8 @@ class WeaviateVectorIndex(BaseVectorIndex):
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
......@@ -73,7 +71,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self._dataset)}
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
......@@ -82,7 +80,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self._dataset),
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
......@@ -96,11 +94,11 @@ class WeaviateVectorIndex(BaseVectorIndex):
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id', 'ref_doc_id']
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self._dataset),
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
......@@ -111,18 +109,21 @@ class WeaviateVectorIndex(BaseVectorIndex):
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["doc_id" if self._is_origin() else "document_id"],
"path": ["document_id"],
"valueText": document_id
})
def _is_origin(self):
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
......
......@@ -488,28 +488,8 @@ class IndexingRunner:
"""
Build the index for the document.
"""
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
keyword_table_index = KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
......@@ -526,14 +506,11 @@ class IndexingRunner:
)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(chunk_documents)
if vector_index:
vector_index.add_texts(chunk_documents)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(chunk_documents)
keyword_table_index.add_texts(chunk_documents)
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
......
......@@ -54,7 +54,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
# delete from keyword index
if index_node_ids:
vector_index.delete_by_ids(index_node_ids)
kw_index.delete_by_ids(index_node_ids)
for segment in segments:
db.session.delete(segment)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment