Commit 0578c1b6 authored by John Wang's avatar John Wang

feat: replace using new index builder

parent fb5118f0
...@@ -4,6 +4,7 @@ from __future__ import annotations ...@@ -4,6 +4,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from pydantic import root_validator from pydantic import root_validator
...@@ -51,8 +52,9 @@ class LLMRouterChain(Chain): ...@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
raise ValueError raise ValueError
def _call( def _call(
self, self,
inputs: Dict[str, Any] inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
output = cast( output = cast(
Dict[str, Any], Dict[str, Any],
......
from typing import List, Dict from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
...@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain): ...@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return self.canned_response return self.canned_response
return text return text
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key] text = inputs[self.input_key]
output = self._check_sensitive_word(text) output = self._check_sensitive_word(text)
return {self.output_key: output} return {self.output_key: output}
from typing import List, Dict from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.tools import BaseTool from langchain.tools import BaseTool
...@@ -30,12 +31,20 @@ class ToolChain(Chain): ...@@ -30,12 +31,20 @@ class ToolChain(Chain):
""" """
return [self.output_key] return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key] input = inputs[self.input_key]
output = self.tool.run(input, self.verbose) output = self.tool.run(input, self.verbose)
return {self.output_key: output} return {self.output_key: output}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output.""" """Run the logic of this chain and return the output."""
input = inputs[self.input_key] input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose) output = await self.tool.arun(input, self.verbose)
......
...@@ -7,11 +7,11 @@ from langchain.schema import Document, BaseRetriever ...@@ -7,11 +7,11 @@ from langchain.schema import Document, BaseRetriever
class BaseIndex(ABC): class BaseIndex(ABC):
@abstractmethod @abstractmethod
def create(self, texts: list[Document]) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def add_texts(self, texts: list[Document]): def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, high_quality: str):
if high_quality == "high_quality":
if dataset.indexing_technique != 'high_quality':
return None
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif high_quality == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
\ No newline at end of file
...@@ -20,7 +20,7 @@ class KeywordTableIndex(BaseIndex): ...@@ -20,7 +20,7 @@ class KeywordTableIndex(BaseIndex):
self._dataset = dataset self._dataset = dataset
self._config = config self._config = config
def create(self, texts: list[Document]) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {} keyword_table = {}
for text in texts: for text in texts:
...@@ -37,7 +37,7 @@ class KeywordTableIndex(BaseIndex): ...@@ -37,7 +37,7 @@ class KeywordTableIndex(BaseIndex):
return self return self
def add_texts(self, texts: list[Document]): def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
......
...@@ -67,11 +67,13 @@ class BaseVectorIndex(BaseIndex): ...@@ -67,11 +67,13 @@ class BaseVectorIndex(BaseIndex):
return vector_store.as_retriever(**kwargs) return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document]): def add_texts(self, texts: list[Document], **kwargs):
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
texts = self._filter_duplicate_texts(texts) if kwargs.get('duplicate_check', False):
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts) uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids) vector_store.add_documents(texts, uuids=uuids)
......
...@@ -53,7 +53,7 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -53,7 +53,7 @@ class QdrantVectorIndex(BaseVectorIndex):
"vector_store": {"collection_name": self.get_index_name(self._dataset.get_id())} "vector_store": {"collection_name": self.get_index_name(self._dataset.get_id())}
} }
def create(self, texts: list[Document]) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts) uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents( self._vector_store = QdrantVectorStore.from_documents(
texts, texts,
......
...@@ -51,14 +51,14 @@ class VectorIndex: ...@@ -51,14 +51,14 @@ class VectorIndex:
else: else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document]): def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict: if not self._dataset.index_struct_dict:
self._vector_index.create(texts) self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct()) self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit() db.session.commit()
return return
self._vector_index.add_texts(texts) self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name): def __getattr__(self, name):
if self._vector_index is not None: if self._vector_index is not None:
......
...@@ -62,7 +62,7 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -62,7 +62,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
"vector_store": {"class_prefix": self.get_index_name(self._dataset.get_id())} "vector_store": {"class_prefix": self.get_index_name(self._dataset.get_id())}
} }
def create(self, texts: list[Document]) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts) uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents( self._vector_store = WeaviateVectorStore.from_documents(
texts, texts,
......
import datetime import datetime
import json import json
import logging
import re import re
import time import time
import uuid import uuid
...@@ -15,8 +16,10 @@ from core.data_loader.file_extractor import FileExtractor ...@@ -15,8 +16,10 @@ from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore from core.docstore.dataset_docstore import DatesetDocumentStore
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.llm.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator from core.llm.token_calculator import TokenCalculator
...@@ -39,6 +42,58 @@ class IndexingRunner: ...@@ -39,6 +42,58 @@ class IndexingRunner:
def run(self, dataset_documents: List[DatasetDocument]): def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process.""" """Run the indexing process."""
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
try:
# get dataset
dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# load file
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# get splitter
splitter = self._get_splitter(processing_rule)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
try:
# get dataset # get dataset
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id id=dataset_document.dataset_id
...@@ -47,6 +102,15 @@ class IndexingRunner: ...@@ -47,6 +102,15 @@ class IndexingRunner:
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=dataset_document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file # load file
text_docs = self._load_data(dataset_document) text_docs = self._load_data(dataset_document)
...@@ -73,92 +137,73 @@ class IndexingRunner: ...@@ -73,92 +137,73 @@ class IndexingRunner:
dataset_document=dataset_document, dataset_document=dataset_document,
documents=documents documents=documents
) )
except DocumentIsPausedException:
def run_in_splitting_status(self, dataset_document: DatasetDocument): raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
"""Run the indexing process when the index_status is splitting.""" except ProviderTokenNotInitError as e:
# get dataset dataset_document.indexing_status = 'error'
dataset = Dataset.query.filter_by( dataset_document.error = str(e.description)
id=dataset_document.dataset_id dataset_document.stopped_at = datetime.datetime.utcnow()
).first() db.session.commit()
except Exception as e:
if not dataset: logging.exception("consume document failed")
raise ValueError("no dataset found") dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
# get exist document_segment list and delete dataset_document.stopped_at = datetime.datetime.utcnow()
document_segments = DocumentSegment.query.filter_by( db.session.commit()
dataset_id=dataset.id,
document_id=dataset_document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# get splitter
splitter = self._get_splitter(processing_rule)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
documents=documents
)
def run_in_indexing_status(self, dataset_document: DatasetDocument): def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is indexing.""" """Run the indexing process when the index_status is indexing."""
# get dataset try:
dataset = Dataset.query.filter_by( # get dataset
id=dataset_document.dataset_id dataset = Dataset.query.filter_by(
).first() id=dataset_document.dataset_id
).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get exist document_segment list and delete # get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by( document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=dataset_document.id document_id=dataset_document.id
).all() ).all()
documents = [] documents = []
if document_segments: if document_segments:
for document_segment in document_segments: for document_segment in document_segments:
# transform segment to node # transform segment to node
if document_segment.status != "completed": if document_segment.status != "completed":
document = Document( document = Document(
page_content=document_segment.content, page_content=document_segment.content,
metadata={ metadata={
"doc_id": document_segment.index_node_id, "doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash, "doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id, "document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id, "dataset_id": document_segment.dataset_id,
} }
) )
documents.append(document) documents.append(document)
# build index # build index
self._build_index( self._build_index(
dataset=dataset, dataset=dataset,
dataset_document=dataset_document, dataset_document=dataset_document,
documents=documents documents=documents
) )
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict: def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
""" """
...@@ -481,11 +526,14 @@ class IndexingRunner: ...@@ -481,11 +526,14 @@ class IndexingRunner:
) )
# save vector index # save vector index
if dataset.indexing_technique == "high_quality": index = IndexBuilder.get_index(dataset, 'high_quality')
vector_index.add_texts(chunk_documents) if index:
index.add_texts(chunk_documents)
# save keyword index # save keyword index
keyword_table_index.add_texts(chunk_documents) index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(chunk_documents)
document_ids = [document.metadata['doc_id'] for document in chunk_documents] document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
......
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
...@@ -69,7 +70,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ...@@ -69,7 +70,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return message_tokens return message_tokens
def _generate( def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
self.callback_manager.on_llm_start( self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
...@@ -87,7 +92,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ...@@ -87,7 +92,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return chat_result return chat_result
async def _agenerate( async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_llm_start( await self.callback_manager.on_llm_start(
......
import os import os
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
...@@ -71,7 +72,11 @@ class StreamableChatOpenAI(ChatOpenAI): ...@@ -71,7 +72,11 @@ class StreamableChatOpenAI(ChatOpenAI):
return message_tokens return message_tokens
def _generate( def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
self.callback_manager.on_llm_start( self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
...@@ -88,7 +93,11 @@ class StreamableChatOpenAI(ChatOpenAI): ...@@ -88,7 +93,11 @@ class StreamableChatOpenAI(ChatOpenAI):
return chat_result return chat_result
async def _agenerate( async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_llm_start( await self.callback_manager.on_llm_start(
......
from llama_index import QueryKeywordExtractPrompt
CONVERSATION_TITLE_PROMPT = ( CONVERSATION_TITLE_PROMPT = (
"Human:{query}\n-----\n" "Human:{query}\n-----\n"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n" "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
...@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( ...@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[\"question1\",\"question2\",\"question3\"]\n" "[\"question1\",\"question2\",\"question3\"]\n"
) )
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
"A question is provided below. Given the question, extract up to {max_keywords} "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords."
"I am not sure which language the following question is in. "
"If the user asked the question in Chinese, please return the keywords in Chinese. "
"If the user asked the question in English, please return the keywords in English.\n"
"---------------------\n"
"{question}\n"
"---------------------\n"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
the model prompt that best suits the input. the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement. You will be provided with the prompt, variables, and an opening statement.
......
...@@ -4,7 +4,6 @@ import uuid ...@@ -4,7 +4,6 @@ import uuid
from core.constant import llm_constant from core.constant import llm_constant
from models.account import Account from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class AppModelConfigService: class AppModelConfigService:
......
...@@ -4,96 +4,81 @@ import time ...@@ -4,96 +4,81 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from llama_index.data_structs import Node from langchain.schema import Document
from llama_index.data_structs.node_v2 import DocumentRelationship
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment, Document from models.dataset import DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task @shared_task
def add_document_to_index_task(document_id: str): def add_document_to_index_task(dataset_document_id: str):
""" """
Async Add document to index Async Add document to index
:param document_id: :param document_id:
Usage: add_document_to_index.delay(document_id) Usage: add_document_to_index.delay(document_id)
""" """
logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green')) logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green'))
start_at = time.perf_counter() start_at = time.perf_counter()
document = db.session.query(Document).filter(Document.id == document_id).first() dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
if not document: if not dataset_document:
raise NotFound('Document not found') raise NotFound('Document not found')
if document.indexing_status != 'completed': if dataset_document.indexing_status != 'completed':
return return
indexing_cache_key = 'document_{}_indexing'.format(document.id) indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id)
try: try:
segments = db.session.query(DocumentSegment).filter( segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True DocumentSegment.enabled == True
) \ ) \
.order_by(DocumentSegment.position.asc()).all() .order_by(DocumentSegment.position.asc()).all()
nodes = [] documents = []
previous_node = None
for segment in segments: for segment in segments:
relationships = { document = Document(
DocumentRelationship.SOURCE: document.id page_content=segment.content,
} metadata={
"doc_id": segment.index_node_id,
if previous_node: "doc_hash": segment.index_node_hash,
relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id "document_id": segment.document_id,
"dataset_id": segment.dataset_id,
previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id }
node = Node(
doc_id=segment.index_node_id,
doc_hash=segment.index_node_hash,
text=segment.content,
extra_info=None,
node_info=None,
relationships=relationships
) )
previous_node = node documents.append(document)
nodes.append(node) dataset = dataset_document.dataset
dataset = document.dataset
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
# save vector index # save vector index
if dataset.indexing_technique == "high_quality": index = IndexBuilder.get_index(dataset, 'high_quality')
vector_index.add_nodes( if index:
nodes=nodes, index.add_texts(documents)
duplicate_check=True
)
# save keyword index # save keyword index
keyword_table_index.add_nodes(nodes) index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green'))
except Exception as e: except Exception as e:
logging.exception("add document to index failed") logging.exception("add document to index failed")
document.enabled = False dataset_document.enabled = False
document.disabled_at = datetime.datetime.utcnow() dataset_document.disabled_at = datetime.datetime.utcnow()
document.status = 'error' dataset_document.status = 'error'
document.error = str(e) dataset_document.error = str(e)
db.session.commit() db.session.commit()
finally: finally:
redis_client.delete(indexing_cache_key) redis_client.delete(indexing_cache_key)
...@@ -4,12 +4,10 @@ import time ...@@ -4,12 +4,10 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from llama_index.data_structs import Node from langchain.schema import Document
from llama_index.data_structs.node_v2 import DocumentRelationship
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
...@@ -36,25 +34,14 @@ def add_segment_to_index_task(segment_id: str): ...@@ -36,25 +34,14 @@ def add_segment_to_index_task(segment_id: str):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id) indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
try: try:
relationships = { document = Document(
DocumentRelationship.SOURCE: segment.document_id, page_content=segment.content,
} metadata={
"doc_id": segment.index_node_id,
previous_segment = segment.previous_segment "doc_hash": segment.index_node_hash,
if previous_segment: "document_id": segment.document_id,
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id "dataset_id": segment.dataset_id,
}
next_segment = segment.next_segment
if next_segment:
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
node = Node(
doc_id=segment.index_node_id,
doc_hash=segment.index_node_hash,
text=segment.content,
extra_info=None,
node_info=None,
relationships=relationships
) )
dataset = segment.dataset dataset = segment.dataset
...@@ -62,18 +49,15 @@ def add_segment_to_index_task(segment_id: str): ...@@ -62,18 +49,15 @@ def add_segment_to_index_task(segment_id: str):
if not dataset: if not dataset:
raise Exception('Segment has no dataset') raise Exception('Segment has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
# save vector index # save vector index
if dataset.indexing_technique == "high_quality": index = IndexBuilder.get_index(dataset, 'high_quality')
vector_index.add_nodes( if index:
nodes=[node], index.add_texts([document], duplicate_check=True)
duplicate_check=True
)
# save keyword index # save keyword index
keyword_table_index.add_nodes([node]) index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
......
...@@ -4,8 +4,7 @@ import time ...@@ -4,8 +4,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
AppDatasetJoin AppDatasetJoin
...@@ -33,19 +32,19 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -33,19 +32,19 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct=index_struct index_struct=index_struct
) )
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
index_doc_ids = [document.id for document in documents] index_doc_ids = [document.id for document in documents]
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
if dataset.indexing_technique == "high_quality": if vector_index:
for index_doc_id in index_doc_ids: for index_doc_id in index_doc_ids:
try: try:
vector_index.del_doc(index_doc_id) vector_index.delete_by_document_id(index_doc_id)
except Exception: except Exception:
logging.exception("Delete doc index failed when dataset deleted.") logging.exception("Delete doc index failed when dataset deleted.")
continue continue
...@@ -53,7 +52,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -53,7 +52,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
try: try:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
except Exception: except Exception:
logging.exception("Delete nodes index failed when dataset deleted.") logging.exception("Delete nodes index failed when dataset deleted.")
......
...@@ -4,8 +4,7 @@ import time ...@@ -4,8 +4,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset from models.dataset import DocumentSegment, Dataset
...@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str): ...@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
vector_index.del_nodes(index_node_ids) if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
db.session.commit() db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
......
...@@ -5,8 +5,7 @@ from typing import List ...@@ -5,8 +5,7 @@ from typing import List
import click import click
from celery import shared_task from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, Document from models.dataset import DocumentSegment, Dataset, Document
...@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str): ...@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
for document_id in document_ids: for document_id in document_ids:
document = db.session.query(Document).filter( document = db.session.query(Document).filter(
Document.id == document_id Document.id == document_id
).first() ).first()
db.session.delete(document) db.session.delete(document)
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
vector_index.del_nodes(index_node_ids) if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
......
...@@ -3,10 +3,12 @@ import time ...@@ -3,10 +3,12 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from llama_index.data_structs.node_v2 import DocumentRelationship, Node from langchain.schema import Document
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Document, Dataset from models.dataset import DocumentSegment, Dataset
from models.dataset import Document as DatasetDocument
@shared_task @shared_task
...@@ -26,48 +28,41 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -26,48 +28,41 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
).first() ).first()
if not dataset: if not dataset:
raise Exception('Dataset not found') raise Exception('Dataset not found')
documents = Document.query.filter_by(dataset_id=dataset_id).all() dataset_documents = DatasetDocument.query.filter_by(dataset_id=dataset_id).all()
if documents: if dataset_documents:
vector_index = VectorIndex(dataset=dataset) # save vector index
for document in documents: index = IndexBuilder.get_index(dataset, 'high_quality')
# delete from vector index if index:
if action == "remove": for dataset_document in dataset_documents:
vector_index.del_doc(document.id) # delete from vector index
elif action == "add": if action == "remove":
segments = db.session.query(DocumentSegment).filter( index.delete_by_document_id(dataset_document.id)
DocumentSegment.document_id == document.id, elif action == "add":
DocumentSegment.enabled == True segments = db.session.query(DocumentSegment).filter(
) .order_by(DocumentSegment.position.asc()).all() DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True
nodes = [] ) .order_by(DocumentSegment.position.asc()).all()
previous_node = None
for segment in segments:
relationships = {
DocumentRelationship.SOURCE: document.id
}
if previous_node: documents = []
relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id documents.append(document)
node = Node( # save vector index
doc_id=segment.index_node_id, index.add_texts(
doc_hash=segment.index_node_hash, documents,
text=segment.content, duplicate_check=True
extra_info=None,
node_info=None,
relationships=relationships
) )
previous_node = node
nodes.append(node)
# save vector index
vector_index.add_nodes(
nodes=nodes,
duplicate_check=True
)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style('Deal dataset vector index: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) click.style('Deal dataset vector index: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
......
...@@ -7,10 +7,8 @@ from celery import shared_task ...@@ -7,10 +7,8 @@ from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.data_loader.loader.notion import NotionLoader from core.data_loader.loader.notion import NotionLoader
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from core.indexing_runner import IndexingRunner, DocumentIsPausedException from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document, Dataset, DocumentSegment from models.dataset import Document, Dataset, DocumentSegment
from models.source import DataSourceBinding from models.source import DataSourceBinding
...@@ -77,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ...@@ -77,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
if not dataset: if not dataset:
raise Exception('Dataset not found') raise Exception('Dataset not found')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
vector_index.del_nodes(index_node_ids) if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
...@@ -98,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ...@@ -98,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
except Exception: except Exception:
logging.exception("Cleaned document when document update data source or process rule failed") logging.exception("Cleaned document when document update data source or process rule failed")
try: try:
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run([document]) indexing_runner.run([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException as ex:
logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) logging.info(click.style(str(ex), fg='yellow'))
except ProviderTokenNotInitError as e: except Exception:
document.indexing_status = 'error' pass
document.error = str(e.description)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume update document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
...@@ -7,7 +7,6 @@ from celery import shared_task ...@@ -7,7 +7,6 @@ from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.indexing_runner import IndexingRunner, DocumentIsPausedException from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document from models.dataset import Document
...@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list): ...@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
Usage: document_indexing_task.delay(dataset_id, document_id) Usage: document_indexing_task.delay(dataset_id, document_id)
""" """
documents = [] documents = []
start_at = time.perf_counter()
for document_id in document_ids: for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
start_at = time.perf_counter()
document = db.session.query(Document).filter( document = db.session.query(Document).filter(
Document.id == document_id, Document.id == document_id,
...@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list): ...@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run(documents) indexing_runner.run(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException as ex:
logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) logging.info(click.style(str(ex), fg='yellow'))
except ProviderTokenNotInitError as e: except Exception:
document.indexing_status = 'error' pass
document.error = str(e.description)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
...@@ -6,10 +6,8 @@ import click ...@@ -6,10 +6,8 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from core.indexing_runner import IndexingRunner, DocumentIsPausedException from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document, Dataset, DocumentSegment from models.dataset import Document, Dataset, DocumentSegment
...@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str): ...@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
if not dataset: if not dataset:
raise Exception('Dataset not found') raise Exception('Dataset not found')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
vector_index.del_nodes(index_node_ids) if vector_index:
vector_index.delete_by_ids(index_node_ids)
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) vector_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
...@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str): ...@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
except Exception: except Exception:
logging.exception("Cleaned document when document update data source or process rule failed") logging.exception("Cleaned document when document update data source or process rule failed")
try: try:
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run([document]) indexing_runner.run([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException as ex:
logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) logging.info(click.style(str(ex), fg='yellow'))
except ProviderTokenNotInitError as e: except Exception:
document.indexing_status = 'error' pass
document.error = str(e.description)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume update document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
import datetime
import logging import logging
import time import time
...@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): ...@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner.run_in_indexing_status(document) indexing_runner.run_in_indexing_status(document)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException as ex:
logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) logging.info(click.style(str(ex), fg='yellow'))
except Exception as e: except Exception:
logging.exception("consume document failed") pass
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
...@@ -5,8 +5,7 @@ import click ...@@ -5,8 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment, Document from models.dataset import DocumentSegment, Document
...@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str): ...@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
vector_index.del_doc(document.id) vector_index.delete_by_document_id(document.id)
# delete from keyword index # delete from keyword index
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
......
...@@ -5,8 +5,7 @@ import click ...@@ -5,8 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
...@@ -38,15 +37,15 @@ def remove_segment_from_index_task(segment_id: str): ...@@ -38,15 +37,15 @@ def remove_segment_from_index_task(segment_id: str):
if not dataset: if not dataset:
raise Exception('Segment has no dataset') raise Exception('Segment has no dataset')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
if dataset.indexing_technique == "high_quality": if vector_index:
vector_index.del_nodes([segment.index_node_id]) vector_index.delete_by_ids([segment.index_node_id])
# delete from keyword index # delete from keyword index
keyword_table_index.del_nodes([segment.index_node_id]) kw_index.delete_by_ids([segment.index_node_id])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment