Commit 0578c1b6 authored by John Wang's avatar John Wang

feat: replace using new index builder

parent fb5118f0
......@@ -4,6 +4,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import root_validator
......@@ -52,7 +53,8 @@ class LLMRouterChain(Chain):
def _call(
self,
inputs: Dict[str, Any]
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
output = cast(
Dict[str, Any],
......
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
......@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return self.canned_response
return text
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output}
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.tools import BaseTool
......@@ -30,12 +31,20 @@ class ToolChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key]
output = self.tool.run(input, self.verbose)
return {self.output_key: output}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose)
......
......@@ -7,11 +7,11 @@ from langchain.schema import Document, BaseRetriever
class BaseIndex(ABC):
@abstractmethod
def create(self, texts: list[Document]) -> BaseIndex:
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document]):
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@abstractmethod
......
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, high_quality: str):
if high_quality == "high_quality":
if dataset.indexing_technique != 'high_quality':
return None
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif high_quality == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
\ No newline at end of file
......@@ -20,7 +20,7 @@ class KeywordTableIndex(BaseIndex):
self._dataset = dataset
self._config = config
def create(self, texts: list[Document]) -> BaseIndex:
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
......@@ -37,7 +37,7 @@ class KeywordTableIndex(BaseIndex):
return self
def add_texts(self, texts: list[Document]):
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
......
......@@ -67,11 +67,13 @@ class BaseVectorIndex(BaseIndex):
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document]):
def add_texts(self, texts: list[Document], **kwargs):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if kwargs.get('duplicate_check', False):
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
......
......@@ -53,7 +53,7 @@ class QdrantVectorIndex(BaseVectorIndex):
"vector_store": {"collection_name": self.get_index_name(self._dataset.get_id())}
}
def create(self, texts: list[Document]) -> BaseIndex:
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
......
......@@ -51,14 +51,14 @@ class VectorIndex:
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document]):
def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts)
self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts)
self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name):
if self._vector_index is not None:
......
......@@ -62,7 +62,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
"vector_store": {"class_prefix": self.get_index_name(self._dataset.get_id())}
}
def create(self, texts: list[Document]) -> BaseIndex:
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
......
import datetime
import json
import logging
import re
import time
import uuid
......@@ -15,8 +16,10 @@ from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator
......@@ -39,6 +42,7 @@ class IndexingRunner:
def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process."""
for dataset_document in dataset_documents:
try:
# get dataset
dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id
......@@ -73,9 +77,23 @@ class IndexingRunner:
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
try:
# get dataset
dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id
......@@ -119,9 +137,23 @@ class IndexingRunner:
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is indexing."""
try:
# get dataset
dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id
......@@ -159,6 +191,19 @@ class IndexingRunner:
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
"""
......@@ -481,11 +526,14 @@ class IndexingRunner:
)
# save vector index
if dataset.indexing_technique == "high_quality":
vector_index.add_texts(chunk_documents)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(chunk_documents)
# save keyword index
keyword_table_index.add_texts(chunk_documents)
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(chunk_documents)
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
......
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any
......@@ -69,7 +70,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
......@@ -87,7 +92,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
......
import os
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any
......@@ -71,7 +72,11 @@ class StreamableChatOpenAI(ChatOpenAI):
return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
......@@ -88,7 +93,11 @@ class StreamableChatOpenAI(ChatOpenAI):
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
......
from llama_index import QueryKeywordExtractPrompt
CONVERSATION_TITLE_PROMPT = (
"Human:{query}\n-----\n"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
......@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[\"question1\",\"question2\",\"question3\"]\n"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
"A question is provided below. Given the question, extract up to {max_keywords} "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords."
"I am not sure which language the following question is in. "
"If the user asked the question in Chinese, please return the keywords in Chinese. "
"If the user asked the question in English, please return the keywords in English.\n"
"---------------------\n"
"{question}\n"
"---------------------\n"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement.
......
......@@ -4,7 +4,6 @@ import uuid
from core.constant import llm_constant
from models.account import Account
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class AppModelConfigService:
......
......@@ -4,96 +4,81 @@ import time
import click
from celery import shared_task
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment, Document
from models.dataset import DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task
def add_document_to_index_task(document_id: str):
def add_document_to_index_task(dataset_document_id: str):
"""
Async Add document to index
:param document_id:
Usage: add_document_to_index.delay(document_id)
"""
logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green'))
logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green'))
start_at = time.perf_counter()
document = db.session.query(Document).filter(Document.id == document_id).first()
if not document:
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
if not dataset_document:
raise NotFound('Document not found')
if document.indexing_status != 'completed':
if dataset_document.indexing_status != 'completed':
return
indexing_cache_key = 'document_{}_indexing'.format(document.id)
indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id)
try:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True
) \
.order_by(DocumentSegment.position.asc()).all()
nodes = []
previous_node = None
documents = []
for segment in segments:
relationships = {
DocumentRelationship.SOURCE: document.id
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
if previous_node:
relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id
previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id
node = Node(
doc_id=segment.index_node_id,
doc_hash=segment.index_node_hash,
text=segment.content,
extra_info=None,
node_info=None,
relationships=relationships
)
previous_node = node
documents.append(document)
nodes.append(node)
dataset = document.dataset
dataset = dataset_document.dataset
if not dataset:
raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
# save vector index
if dataset.indexing_technique == "high_quality":
vector_index.add_nodes(
nodes=nodes,
duplicate_check=True
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts(documents)
# save keyword index
keyword_table_index.add_nodes(nodes)
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(documents)
end_at = time.perf_counter()
logging.info(
click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green'))
except Exception as e:
logging.exception("add document to index failed")
document.enabled = False
document.disabled_at = datetime.datetime.utcnow()
document.status = 'error'
document.error = str(e)
dataset_document.enabled = False
dataset_document.disabled_at = datetime.datetime.utcnow()
dataset_document.status = 'error'
dataset_document.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
......@@ -4,12 +4,10 @@ import time
import click
from celery import shared_task
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
......@@ -36,25 +34,14 @@ def add_segment_to_index_task(segment_id: str):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
try:
relationships = {
DocumentRelationship.SOURCE: segment.document_id,
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
previous_segment = segment.previous_segment
if previous_segment:
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
next_segment = segment.next_segment
if next_segment:
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
node = Node(
doc_id=segment.index_node_id,
doc_hash=segment.index_node_hash,
text=segment.content,
extra_info=None,
node_info=None,
relationships=relationships
)
dataset = segment.dataset
......@@ -62,18 +49,15 @@ def add_segment_to_index_task(segment_id: str):
if not dataset:
raise Exception('Segment has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
# save vector index
if dataset.indexing_technique == "high_quality":
vector_index.add_nodes(
nodes=[node],
duplicate_check=True
)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document], duplicate_check=True)
# save keyword index
keyword_table_index.add_nodes([node])
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts([document])
end_at = time.perf_counter()
logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
......
......@@ -4,8 +4,7 @@ import time
import click
from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
AppDatasetJoin
......@@ -33,19 +32,19 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct=index_struct
)
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
index_doc_ids = [document.id for document in documents]
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if dataset.indexing_technique == "high_quality":
if vector_index:
for index_doc_id in index_doc_ids:
try:
vector_index.del_doc(index_doc_id)
vector_index.delete_by_document_id(index_doc_id)
except Exception:
logging.exception("Delete doc index failed when dataset deleted.")
continue
......@@ -53,7 +52,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
# delete from keyword index
if index_node_ids:
try:
keyword_table_index.del_nodes(index_node_ids)
kw_index.delete_by_ids(index_node_ids)
except Exception:
logging.exception("Delete nodes index failed when dataset deleted.")
......
......@@ -4,8 +4,7 @@ import time
import click
from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset
......@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
if not dataset:
raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
vector_index.del_nodes(index_node_ids)
if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index
if index_node_ids:
keyword_table_index.del_nodes(index_node_ids)
kw_index.delete_by_ids(index_node_ids)
for segment in segments:
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logging.info(
......
......@@ -5,8 +5,7 @@ from typing import List
import click
from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, Document
......@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
if not dataset:
raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
for document_id in document_ids:
document = db.session.query(Document).filter(
Document.id == document_id
).first()
db.session.delete(document)
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
vector_index.del_nodes(index_node_ids)
if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index
if index_node_ids:
keyword_table_index.del_nodes(index_node_ids)
kw_index.delete_by_ids(index_node_ids)
for segment in segments:
db.session.delete(segment)
......
......@@ -3,10 +3,12 @@ import time
import click
from celery import shared_task
from llama_index.data_structs.node_v2 import DocumentRelationship, Node
from core.index.vector_index import VectorIndex
from langchain.schema import Document
from core.index.index import IndexBuilder
from extensions.ext_database import db
from models.dataset import DocumentSegment, Document, Dataset
from models.dataset import DocumentSegment, Dataset
from models.dataset import Document as DatasetDocument
@shared_task
......@@ -26,45 +28,38 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
).first()
if not dataset:
raise Exception('Dataset not found')
documents = Document.query.filter_by(dataset_id=dataset_id).all()
if documents:
vector_index = VectorIndex(dataset=dataset)
for document in documents:
dataset_documents = DatasetDocument.query.filter_by(dataset_id=dataset_id).all()
if dataset_documents:
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
for dataset_document in dataset_documents:
# delete from vector index
if action == "remove":
vector_index.del_doc(document.id)
index.delete_by_document_id(dataset_document.id)
elif action == "add":
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True
) .order_by(DocumentSegment.position.asc()).all()
nodes = []
previous_node = None
documents = []
for segment in segments:
relationships = {
DocumentRelationship.SOURCE: document.id
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
if previous_node:
relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id
previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id
node = Node(
doc_id=segment.index_node_id,
doc_hash=segment.index_node_hash,
text=segment.content,
extra_info=None,
node_info=None,
relationships=relationships
)
previous_node = node
nodes.append(node)
documents.append(document)
# save vector index
vector_index.add_nodes(
nodes=nodes,
index.add_texts(
documents,
duplicate_check=True
)
......
......@@ -7,10 +7,8 @@ from celery import shared_task
from werkzeug.exceptions import NotFound
from core.data_loader.loader.notion import NotionLoader
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_database import db
from models.dataset import Document, Dataset, DocumentSegment
from models.source import DataSourceBinding
......@@ -77,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
if not dataset:
raise Exception('Dataset not found')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
vector_index.del_nodes(index_node_ids)
if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index
if index_node_ids:
keyword_table_index.del_nodes(index_node_ids)
kw_index.delete_by_ids(index_node_ids)
for segment in segments:
db.session.delete(segment)
......@@ -98,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
except Exception:
logging.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException:
logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow'))
except ProviderTokenNotInitError as e:
document.indexing_status = 'error'
document.error = str(e.description)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume update document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
except Exception:
pass
......@@ -7,7 +7,6 @@ from celery import shared_task
from werkzeug.exceptions import NotFound
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_database import db
from models.dataset import Document
......@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
Usage: document_indexing_task.delay(dataset_id, document_id)
"""
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
start_at = time.perf_counter()
document = db.session.query(Document).filter(
Document.id == document_id,
......@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException:
logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow'))
except ProviderTokenNotInitError as e:
document.indexing_status = 'error'
document.error = str(e.description)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
except Exception:
pass
......@@ -6,10 +6,8 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_database import db
from models.dataset import Document, Dataset, DocumentSegment
......@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
if not dataset:
raise Exception('Dataset not found')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
vector_index.del_nodes(index_node_ids)
if vector_index:
vector_index.delete_by_ids(index_node_ids)
# delete from keyword index
if index_node_ids:
keyword_table_index.del_nodes(index_node_ids)
vector_index.delete_by_ids(index_node_ids)
for segment in segments:
db.session.delete(segment)
......@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
except Exception:
logging.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException:
logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow'))
except ProviderTokenNotInitError as e:
document.indexing_status = 'error'
document.error = str(e.description)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume update document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
except Exception:
pass
import datetime
import logging
import time
......@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner.run_in_indexing_status(document)
end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException:
logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow'))
except Exception as e:
logging.exception("consume document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
except Exception:
pass
......@@ -5,8 +5,7 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment, Document
......@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
if not dataset:
raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
vector_index.del_doc(document.id)
vector_index.delete_by_document_id(document.id)
# delete from keyword index
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
keyword_table_index.del_nodes(index_node_ids)
kw_index.delete_by_ids(index_node_ids)
end_at = time.perf_counter()
logging.info(
......
......@@ -5,8 +5,7 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
......@@ -38,15 +37,15 @@ def remove_segment_from_index_task(segment_id: str):
if not dataset:
raise Exception('Segment has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if dataset.indexing_technique == "high_quality":
vector_index.del_nodes([segment.index_node_id])
if vector_index:
vector_index.delete_by_ids([segment.index_node_id])
# delete from keyword index
keyword_table_index.del_nodes([segment.index_node_id])
kw_index.delete_by_ids([segment.index_node_id])
end_at = time.perf_counter()
logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment