Commit 71981eac authored by John Wang's avatar John Wang

fix: kw table bugs

parent ced9fc52
......@@ -22,21 +22,21 @@ class FileExtractor:
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
if input_file.suffix == '.xlxs':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
input_file = Path(file_path)
if input_file.suffix == '.xlxs':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return loader.load_as_text() if return_text else loader.load()
return '\n'.join([document.page_content for document in loader.load()]) if return_text else loader.load()
......@@ -40,7 +40,3 @@ class ExcelLoader(BaseLoader):
metadata = {"source": self._file_path}
return [Document(page_content='\n\n'.join(data), metadata=metadata)]
def load_as_text(self) -> str:
documents = self.load()
return ''.join([document.page_content for document in documents])
......@@ -25,9 +25,9 @@ class HTMLLoader(BaseLoader):
def load(self) -> List[Document]:
metadata = {"source": self._file_path}
return [Document(page_content=self.load_as_text(), metadata=metadata)]
return [Document(page_content=self._load_as_text(), metadata=metadata)]
def load_as_text(self) -> str:
def _load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
......
......@@ -64,6 +64,3 @@ class PdfLoader(BaseLoader):
metadata = {"source": self._file_path}
return [Document(page_content=text, metadata=metadata)]
def load_as_text(self) -> str:
documents = self.load()
return '\n'.join([document.page_content for document in documents])
......@@ -46,8 +46,8 @@ class CacheEmbedding(Embeddings):
i += 1
embedding_queue_texts.extend(embedding_results)
return embedding_queue_texts
text_embeddings.extend(embedding_results)
return text_embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
......
......@@ -30,11 +30,20 @@ class KeywordTableIndex(BaseIndex):
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(keyword_table)
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self._dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
......@@ -46,8 +55,7 @@ class KeywordTableIndex(BaseIndex):
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
......@@ -57,8 +65,7 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
......@@ -72,8 +79,7 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
......@@ -108,10 +114,38 @@ class KeywordTableIndex(BaseIndex):
return documents
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self._dataset.id,
"summary": None,
"table": keyword_table
}
}
self._dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
def _get_dataset_keyword_table(self) -> Optional[dict]:
keyword_table_dict = self._dataset.dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict
dataset_keyword_table = self._dataset.dataset_keyword_table
if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict:
return dataset_keyword_table.keyword_table_dict['__data__']['table']
else:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self._dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
......@@ -146,9 +180,9 @@ class KeywordTableIndex(BaseIndex):
# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[str, int] = defaultdict(int)
keywords = [k for k in keywords if k in set(keyword_table.keys())]
for k in keywords:
for node_id in keyword_table[k]:
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
sorted_chunk_indices = sorted(
......@@ -190,3 +224,9 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)
\ No newline at end of file
......@@ -5,6 +5,7 @@ from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from core.index.base import BaseIndex
from models.dataset import Dataset
class BaseVectorIndex(BaseIndex):
......@@ -12,7 +13,7 @@ class BaseVectorIndex(BaseIndex):
raise NotImplementedError
@abstractmethod
def get_index_name(self, dataset_id: str) -> str:
def get_index_name(self, dataset: Dataset) -> str:
raise NotImplementedError
@abstractmethod
......
......@@ -44,13 +44,17 @@ class QdrantVectorIndex(BaseVectorIndex):
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset_id: str) -> str:
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def get_index_name(self, dataset: Dataset) -> str:
if self._dataset.index_struct_dict:
return self._dataset.index_struct_dict['vector_store']['collection_name']
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_")
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self._dataset.id)}
"vector_store": {"collection_name": self.get_index_name(self._dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
......@@ -58,7 +62,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self._dataset.id),
collection_name=self.get_index_name(self._dataset),
ids=uuids,
**self._client_config.to_qdrant_params()
)
......@@ -76,7 +80,7 @@ class QdrantVectorIndex(BaseVectorIndex):
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self._dataset.id),
collection_name=self.get_index_name(self._dataset),
embeddings=self._embeddings
)
......
......@@ -59,13 +59,22 @@ class WeaviateVectorIndex(BaseVectorIndex):
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset_id: str) -> str:
def get_index_name(self, dataset: Dataset) -> str:
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self._dataset.id)}
"vector_store": {"class_prefix": self.get_index_name(self._dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
......@@ -74,7 +83,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self._dataset.id),
index_name=self.get_index_name(self._dataset),
uuids=uuids,
by_text=False
)
......@@ -88,7 +97,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self._dataset.id),
index_name=self.get_index_name(self._dataset),
text_key='text',
embedding=self._embeddings,
attributes=self._attributes,
......
......@@ -329,7 +329,7 @@ class IndexingRunner:
document_id=dataset_document.id,
after_indexing_status="splitting",
extra_update_params={
DatasetDocument.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
}
)
......
......@@ -395,7 +395,16 @@ class DatasetKeywordTable(db.Model):
@property
def keyword_table_dict(self):
return json.loads(self.keyword_table) if self.keyword_table else None
class SetDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, dct):
if "__set__" in dct:
return set(dct["__set__"])
return dct
return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
class Embedding(db.Model):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment