Commit 01baae87 authored by John Wang's avatar John Wang

feat: recreate dataset when origin dataset format

parent f33056f4
...@@ -4,8 +4,14 @@ from typing import List, Any ...@@ -4,8 +4,14 @@ from typing import List, Any
from langchain.schema import Document, BaseRetriever from langchain.schema import Document, BaseRetriever
from models.dataset import Dataset
class BaseIndex(ABC): class BaseIndex(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@abstractmethod @abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError raise NotImplementedError
......
...@@ -17,7 +17,7 @@ class KeywordTableConfig(BaseModel): ...@@ -17,7 +17,7 @@ class KeywordTableConfig(BaseModel):
class KeywordTableIndex(BaseIndex): class KeywordTableIndex(BaseIndex):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
self._dataset = dataset super().__init__(dataset)
self._config = config self._config = config
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
...@@ -29,11 +29,11 @@ class KeywordTableIndex(BaseIndex): ...@@ -29,11 +29,11 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable( dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id, dataset_id=self.dataset.id,
keyword_table=json.dumps({ keyword_table=json.dumps({
'__type__': 'keyword_table', '__type__': 'keyword_table',
'__data__': { '__data__': {
"index_id": self._dataset.id, "index_id": self.dataset.id,
"summary": None, "summary": None,
"table": {} "table": {}
} }
...@@ -70,7 +70,7 @@ class KeywordTableIndex(BaseIndex): ...@@ -70,7 +70,7 @@ class KeywordTableIndex(BaseIndex):
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
# get segment ids by document_id # get segment ids by document_id
segments = db.session.query(DocumentSegment).filter( segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id DocumentSegment.document_id == document_id
).all() ).all()
...@@ -98,7 +98,7 @@ class KeywordTableIndex(BaseIndex): ...@@ -98,7 +98,7 @@ class KeywordTableIndex(BaseIndex):
documents = [] documents = []
for chunk_index in sorted_chunk_indices: for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter( segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index DocumentSegment.index_node_id == chunk_index
).first() ).first()
...@@ -115,7 +115,7 @@ class KeywordTableIndex(BaseIndex): ...@@ -115,7 +115,7 @@ class KeywordTableIndex(BaseIndex):
return documents return documents
def delete(self) -> None: def delete(self) -> None:
dataset_keyword_table = self._dataset.dataset_keyword_table dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table: if dataset_keyword_table:
db.session.delete(dataset_keyword_table) db.session.delete(dataset_keyword_table)
db.session.commit() db.session.commit()
...@@ -124,26 +124,26 @@ class KeywordTableIndex(BaseIndex): ...@@ -124,26 +124,26 @@ class KeywordTableIndex(BaseIndex):
keyword_table_dict = { keyword_table_dict = {
'__type__': 'keyword_table', '__type__': 'keyword_table',
'__data__': { '__data__': {
"index_id": self._dataset.id, "index_id": self.dataset.id,
"summary": None, "summary": None,
"table": keyword_table "table": keyword_table
} }
} }
self._dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit() db.session.commit()
def _get_dataset_keyword_table(self) -> Optional[dict]: def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self._dataset.dataset_keyword_table dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table: if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict: if dataset_keyword_table.keyword_table_dict:
return dataset_keyword_table.keyword_table_dict['__data__']['table'] return dataset_keyword_table.keyword_table_dict['__data__']['table']
else: else:
dataset_keyword_table = DatasetKeywordTable( dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id, dataset_id=self.dataset.id,
keyword_table=json.dumps({ keyword_table=json.dumps({
'__type__': 'keyword_table', '__type__': 'keyword_table',
'__data__': { '__data__': {
"index_id": self._dataset.id, "index_id": self.dataset.id,
"summary": None, "summary": None,
"table": {} "table": {}
} }
......
import json
import logging
from abc import abstractmethod from abc import abstractmethod
from typing import List, Any, Tuple, cast from typing import List, Any, Tuple, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore from langchain.vectorstores import VectorStore
from core.index.base import BaseIndex from core.index.base import BaseIndex
from models.dataset import Dataset from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex): class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str: def get_type(self) -> str:
raise NotImplementedError raise NotImplementedError
...@@ -69,6 +80,9 @@ class BaseVectorIndex(BaseIndex): ...@@ -69,6 +80,9 @@ class BaseVectorIndex(BaseIndex):
return vector_store.as_retriever(**kwargs) return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document], **kwargs): def add_texts(self, texts: list[Document], **kwargs):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
...@@ -85,6 +99,9 @@ class BaseVectorIndex(BaseIndex): ...@@ -85,6 +99,9 @@ class BaseVectorIndex(BaseIndex):
return vector_store.text_exists(id) return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
...@@ -96,3 +113,45 @@ class BaseVectorIndex(BaseIndex): ...@@ -96,3 +113,45 @@ class BaseVectorIndex(BaseIndex):
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete() vector_store.delete()
def _is_origin(self):
return False
def recreate_dataset(self, dataset: Dataset):
logging.debug(f"Recreating dataset {dataset.id}")
self.delete()
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
self.create(documents)
dataset.index_struct = json.dumps(self.to_index_struct())
db.session.commit()
self.dataset = dataset
...@@ -36,17 +36,15 @@ class QdrantConfig(BaseModel): ...@@ -36,17 +36,15 @@ class QdrantConfig(BaseModel):
class QdrantVectorIndex(BaseVectorIndex): class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings): def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
self._dataset = dataset super().__init__(dataset, embeddings)
self._client_config = config self._client_config = config
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str: def get_type(self) -> str:
return 'qdrant' return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str: def get_index_name(self, dataset: Dataset) -> str:
if self._dataset.index_struct_dict: if self.dataset.index_struct_dict:
return self._dataset.index_struct_dict['vector_store']['collection_name'] return self.dataset.index_struct_dict['vector_store']['collection_name']
dataset_id = dataset.id dataset_id = dataset.id
return "Index_" + dataset_id.replace("-", "_") return "Index_" + dataset_id.replace("-", "_")
...@@ -54,7 +52,7 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -54,7 +52,7 @@ class QdrantVectorIndex(BaseVectorIndex):
def to_index_struct(self) -> dict: def to_index_struct(self) -> dict:
return { return {
"type": self.get_type(), "type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self._dataset)} "vector_store": {"collection_name": self.get_index_name(self.dataset)}
} }
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
...@@ -62,7 +60,7 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -62,7 +60,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self._vector_store = QdrantVectorStore.from_documents( self._vector_store = QdrantVectorStore.from_documents(
texts, texts,
self._embeddings, self._embeddings,
collection_name=self.get_index_name(self._dataset), collection_name=self.get_index_name(self.dataset),
ids=uuids, ids=uuids,
content_payload_key='text', content_payload_key='text',
**self._client_config.to_qdrant_params() **self._client_config.to_qdrant_params()
...@@ -81,7 +79,7 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -81,7 +79,7 @@ class QdrantVectorIndex(BaseVectorIndex):
return QdrantVectorStore( return QdrantVectorStore(
client=client, client=client,
collection_name=self.get_index_name(self._dataset), collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings, embeddings=self._embeddings,
content_payload_key='text' content_payload_key='text'
) )
...@@ -90,6 +88,9 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -90,6 +88,9 @@ class QdrantVectorIndex(BaseVectorIndex):
return QdrantVectorStore return QdrantVectorStore
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
...@@ -98,15 +99,15 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -98,15 +99,15 @@ class QdrantVectorIndex(BaseVectorIndex):
vector_store.del_texts(models.Filter( vector_store.del_texts(models.Filter(
must=[ must=[
models.FieldCondition( models.FieldCondition(
key="doc_id" if self._is_origin() else "metadata.document_id", key="metadata.document_id",
match=models.MatchValue(value=document_id), match=models.MatchValue(value=document_id),
), ),
], ],
)) ))
def _is_origin(self): def _is_origin(self):
if self._dataset.index_struct_dict: if self.dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['collection_name'] class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
if not class_prefix.strip('Vector_'): if not class_prefix.strip('Vector_'):
# original class_prefix # original class_prefix
return True return True
......
...@@ -26,10 +26,8 @@ class WeaviateConfig(BaseModel): ...@@ -26,10 +26,8 @@ class WeaviateConfig(BaseModel):
class WeaviateVectorIndex(BaseVectorIndex): class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
self._dataset = dataset super().__init__(dataset, embeddings)
self._client = self._init_client(config) self._client = self._init_client(config)
self._embeddings = embeddings
self._vector_store = None
def _init_client(self, config: WeaviateConfig) -> weaviate.Client: def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
...@@ -59,8 +57,8 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -59,8 +57,8 @@ class WeaviateVectorIndex(BaseVectorIndex):
return 'weaviate' return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str: def get_index_name(self, dataset: Dataset) -> str:
if self._dataset.index_struct_dict: if self.dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'): if not class_prefix.endswith('_Node'):
# original class_prefix # original class_prefix
class_prefix += '_Node' class_prefix += '_Node'
...@@ -73,7 +71,7 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -73,7 +71,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
def to_index_struct(self) -> dict: def to_index_struct(self) -> dict:
return { return {
"type": self.get_type(), "type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self._dataset)} "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
} }
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
...@@ -82,7 +80,7 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -82,7 +80,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
texts, texts,
self._embeddings, self._embeddings,
client=self._client, client=self._client,
index_name=self.get_index_name(self._dataset), index_name=self.get_index_name(self.dataset),
uuids=uuids, uuids=uuids,
by_text=False by_text=False
) )
...@@ -96,11 +94,11 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -96,11 +94,11 @@ class WeaviateVectorIndex(BaseVectorIndex):
attributes = ['doc_id', 'dataset_id', 'document_id'] attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin(): if self._is_origin():
attributes = ['doc_id', 'ref_doc_id'] attributes = ['doc_id']
return WeaviateVectorStore( return WeaviateVectorStore(
client=self._client, client=self._client,
index_name=self.get_index_name(self._dataset), index_name=self.get_index_name(self.dataset),
text_key='text', text_key='text',
embedding=self._embeddings, embedding=self._embeddings,
attributes=attributes, attributes=attributes,
...@@ -111,18 +109,21 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -111,18 +109,21 @@ class WeaviateVectorIndex(BaseVectorIndex):
return WeaviateVectorStore return WeaviateVectorStore
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({ vector_store.del_texts({
"operator": "Equal", "operator": "Equal",
"path": ["doc_id" if self._is_origin() else "document_id"], "path": ["document_id"],
"valueText": document_id "valueText": document_id
}) })
def _is_origin(self): def _is_origin(self):
if self._dataset.index_struct_dict: if self.dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'): if not class_prefix.endswith('_Node'):
# original class_prefix # original class_prefix
return True return True
......
...@@ -488,28 +488,8 @@ class IndexingRunner: ...@@ -488,28 +488,8 @@ class IndexingRunner:
""" """
Build the index for the document. Build the index for the document.
""" """
model_credentials = LLMBuilder.get_model_credentials( vector_index = IndexBuilder.get_index(dataset, 'high_quality')
tenant_id=dataset.tenant_id, keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
keyword_table_index = KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
# chunk nodes by chunk size # chunk nodes by chunk size
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
...@@ -526,14 +506,11 @@ class IndexingRunner: ...@@ -526,14 +506,11 @@ class IndexingRunner:
) )
# save vector index # save vector index
index = IndexBuilder.get_index(dataset, 'high_quality') if vector_index:
if index: vector_index.add_texts(chunk_documents)
index.add_texts(chunk_documents)
# save keyword index # save keyword index
index = IndexBuilder.get_index(dataset, 'economy') keyword_table_index.add_texts(chunk_documents)
if index:
index.add_texts(chunk_documents)
document_ids = [document.metadata['doc_id'] for document in chunk_documents] document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
......
...@@ -54,7 +54,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): ...@@ -54,7 +54,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
vector_index.delete_by_ids(index_node_ids) kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment