Commit 9e9d15ec authored by John Wang's avatar John Wang

feat: replace llama-index to langchain in index build

parent 23ef2262
......@@ -187,11 +187,13 @@ class Config:
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
class CloudEditionConfig(Config):
......
......@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader
from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.oauth_data_source import NotionOAuth
from models.dataset import Document
from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService
......@@ -232,15 +231,16 @@ class DataSourceNotionApi(Resource):
).first()
if not data_source_binding:
raise NotFound('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
if page_type == 'page':
page_content = reader.read_page(page_id)
elif page_type == 'database':
page_content = reader.query_database_data(page_id)
else:
page_content = ""
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type
)
return {
'content': page_content
'content': loader.load_as_text()
}, 200
@setup_required
......
......@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
UnsupportedFileTypeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from core.index.readers.xlsx_parser import XLSXParser
from core.data_loader.file_extractor import FileExtractor
from extensions.ext_storage import storage
from libs.helper import TimestampField
from extensions.ext_database import db
......@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
if extension not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, filepath)
if extension == 'pdf':
parser = PDFParser({'upload_file': upload_file})
text = parser.parse_file(Path(filepath))
elif extension in ['html', 'htm']:
# Use BeautifulSoup to extract text
parser = HTMLParser()
text = parser.parse_file(Path(filepath))
elif extension == 'xlsx':
parser = XLSXParser()
text = parser.parse_file(filepath)
else:
# ['txt', 'markdown', 'md']
with open(filepath, "rb") as fp:
data = fp.read()
encoding = chardet.detect(data)['encoding']
if encoding:
text = data.decode(encoding=encoding).strip() if data else ''
else:
text = data.decode(encoding='utf-8').strip() if data else ''
text = FileExtractor.load(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return {'content': text}
......
......@@ -3,19 +3,12 @@ from typing import Optional
import langchain
from flask import Flask
from jieba.analyse import default_tfidf
from langchain import set_handler
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
from llama_index import IndexStructType, QueryMode
from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.index.keyword_table.stopwords import STOPWORDS
from core.prompt.prompt_template import OneLineFormatter
from core.vector_store.vector_store import VectorStore
from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
class HostedOpenAICredential(BaseModel):
......@@ -32,17 +25,6 @@ hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask):
formatter = OneLineFormatter()
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
default_tfidf.stop_words = STOPWORDS
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
......
import tempfile
from pathlib import Path
from typing import List, Union
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
from core.data_loader.loader.csv import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from extensions.ext_storage import storage
from models.model import UploadFile
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
if input_file.suffix == '.xlxs':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return loader.load_as_text() if return_text else loader.load()
import logging
from typing import Optional, Dict, List
from langchain.document_loaders import CSVLoader as LCCSVLoader
from langchain.document_loaders.helpers import detect_file_encodings
from models.dataset import Document
logger = logging.getLogger(__name__)
class CSVLoader(LCCSVLoader):
def __init__(
self,
file_path: str,
source_column: Optional[str] = None,
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
self.file_path = file_path
self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
"""Load data into document objects."""
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self.file_path}") from e
return docs
def _read_from_file(self, csvfile):
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else self.file_path
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs
import json
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from openpyxl.reader.excel import load_workbook
logger = logging.getLogger(__name__)
class ExcelLoader(BaseLoader):
"""Load xlxs files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
data = []
keys = []
wb = load_workbook(filename=self._file_path, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
data.append(json.dumps(dict(zip(keys, list(map(str, row)))), ensure_ascii=False))
metadata = {"source": self._file_path}
return [Document(page_content='\n\n'.join(data), metadata=metadata)]
def load_as_text(self) -> str:
documents = self.load()
return ''.join([document.page_content for document in documents])
import logging
from typing import List
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class HTMLLoader(BaseLoader):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
metadata = {"source": self._file_path}
return [Document(page_content=self.load_as_text(), metadata=metadata)]
def load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
"""Markdown parser.
import logging
import re
from typing import Optional, List, Tuple, cast
Contains parser for md files.
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class MarkdownLoader(BaseLoader):
"""Load md files.
"""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from llama_index.readers.file.base_parser import BaseParser
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
class MarkdownParser(BaseParser):
"""Markdown parser.
remove_images: Whether to remove images from the text.
Extract text from markdown files.
Returns dictionary with keys as headers and values as the text between headers.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
def __init__(
self,
*args: Any,
file_path: str,
remove_hyperlinks: bool = True,
remove_images: bool = True,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
"""Initialize with file path."""
self._file_path = file_path
self._remove_hyperlinks = remove_hyperlinks
self._remove_images = remove_images
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
tups = self.parse_tups(self._file_path)
documents = []
metadata = {"source": self._file_path}
for header, value in tups:
if header is None:
documents.append(Document(page_content=value, metadata=metadata))
else:
documents.append(Document(page_content=f"\n\n{header}\n{value}", metadata=metadata))
return documents
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
"""Convert a markdown file to a dictionary.
......@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser):
content = re.sub(pattern, r"\1", content)
return content
def _init_parser(self) -> Dict:
"""Initialize the parser with the config."""
return {}
def parse_tups(
self, filepath: Path, errors: str = "ignore"
) -> List[Tuple[Optional[str], str]]:
def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
"""Parse file into tuples."""
with open(filepath, "r", encoding="utf-8") as f:
content = f.read()
content = ""
try:
with open(filepath, "r", encoding=self._encoding) as f:
content = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(filepath, encoding=encoding.encoding) as f:
content = f.read()
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {filepath}") from e
except Exception as e:
raise RuntimeError(f"Error loading {filepath}") from e
if self._remove_hyperlinks:
content = self.remove_hyperlinks(content)
if self._remove_images:
content = self.remove_images(content)
markdown_tups = self.markdown_to_tups(content)
return markdown_tups
def parse_file(
self, filepath: Path, errors: str = "ignore"
) -> Union[str, List[str]]:
"""Parse file into string."""
tups = self.parse_tups(filepath, errors=errors)
results = []
# TODO: don't include headers right now
for header, value in tups:
if header is None:
results.append(value)
else:
results.append(f"\n\n{header}\n{value}")
return results
return self.markdown_to_tups(content)
"""Notion reader."""
import json
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import List, Dict, Any, Optional
import requests # type: ignore
import requests
from flask import current_app
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
from extensions.ext_database import db
from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding
logger = logging.getLogger(__name__)
INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
logger = logging.getLogger(__name__)
# TODO: Notion DB reader coming soon!
class NotionPageReader(BaseReader):
"""Notion Page reader.
Reads a set of Notion pages.
Args:
integration_token (str): Notion integration token.
"""
def __init__(self, integration_token: Optional[str] = None) -> None:
"""Initialize with parameters."""
if integration_token is None:
integration_token = os.getenv(INTEGRATION_TOKEN_NAME)
class NotionLoader(BaseLoader):
def __init__(
self,
notion_access_token: str,
notion_workspace_id: str,
notion_obj_id: str,
notion_page_type: str,
document_model: Optional[DocumentModel] = None
):
self._document_model = document_model
self._notion_workspace_id = notion_workspace_id
self._notion_obj_id = notion_obj_id
self._notion_page_type = notion_page_type
self._notion_access_token = notion_access_token
if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment "
"variable `NOTION_INTEGRATION_TOKEN`."
)
self.token = integration_token
self.headers = {
"Authorization": "Bearer " + self.token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
}
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block."""
done = False
self._notion_access_token = integration_token
@classmethod
def from_document(cls, document_model: DocumentModel):
data_source_info = document_model.data_source_info_dict
if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
notion_workspace_id = data_source_info['notion_workspace_id']
notion_obj_id = data_source_info['notion_page_id']
notion_page_type = data_source_info['type']
notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
return cls(
notion_access_token=notion_access_token,
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_obj_id,
notion_page_type=notion_page_type,
document_model=document_model
)
def load(self) -> List[Document]:
self.update_last_edited_time(
self._document_model
)
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
return text_docs
def load_as_text(self) -> str:
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
text = "\n".join([doc.page_content for doc in text_docs])
return text
def _load_data_as_documents(
self, notion_obj_id: str, notion_page_type: str
) -> List[Document]:
docs = []
if notion_page_type == 'database':
# get all the pages in the database
page_text = self._get_notion_database_data(notion_obj_id)
docs.append(Document(page_content=page_text))
elif notion_page_type == 'page':
page_text_list = self._get_notion_block_data(notion_obj_id)
for page_text in page_text_list:
docs.append(Document(page_content=page_text))
else:
raise ValueError("notion page type not supported")
return docs
def _get_notion_database_data(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> str:
"""Get all the pages from a Notion database."""
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict,
)
data = res.json()
database_content_list = []
if 'results' not in data or data["results"] is None:
return ""
for result in data["results"]:
properties = result['properties']
data = {}
for property_name, property_value in properties.items():
type = property_value['type']
if type == 'multi_select':
value = []
multi_select_list = property_value[type]
for multi_select in multi_select_list:
value.append(multi_select['name'])
elif type == 'rich_text' or type == 'title':
if len(property_value[type]) > 0:
value = property_value[type][0]['plain_text']
else:
value = ''
elif type == 'select' or type == 'status':
if property_value[type]:
value = property_value[type]['name']
else:
value = ''
else:
value = property_value[type]
data[property_name] = value
database_content_list.append(json.dumps(data))
return "\n\n".join(database_content_list)
def _get_notion_block_data(self, page_id: str) -> List[str]:
result_lines_arr = []
cur_block_id = block_id
while not done:
cur_block_id = page_id
while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict
"GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
if 'results' not in data or data["results"] is None:
done = True
break
# current block's heading
heading = ''
for result in data["results"]:
result_type = result["type"]
......@@ -71,6 +170,7 @@ class NotionPageReader(BaseReader):
if result_type == 'table':
result_block_id = result["id"]
text = self._read_table_rows(result_block_id)
text += "\n\n"
result_lines_arr.append(text)
else:
if "rich_text" in result_obj:
......@@ -78,10 +178,10 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object
if "text" in rich_text:
text = rich_text["text"]["content"]
prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
cur_result_text_arr.append(text)
if result_type in HEADING_TYPE:
heading = text
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
......@@ -92,77 +192,39 @@ class NotionPageReader(BaseReader):
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text)
else:
result_lines_arr.append(f'{heading}\n{cur_result_text}')
if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
result_lines = "\n".join(result_lines_arr)
return result_lines
def _read_table_rows(self, block_id: str) -> str:
"""Read table rows."""
done = False
result_lines_arr = []
cur_block_id = block_id
while not done:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict
)
data = res.json()
# get table headers text
table_header_cell_texts = []
tabel_header_cells = data["results"][0]['table_row']['cells']
for tabel_header_cell in tabel_header_cells:
if tabel_header_cell:
for table_header_cell_text in tabel_header_cell:
text = table_header_cell_text["text"]["content"]
table_header_cell_texts.append(text)
# get table columns text and format
results = data["results"]
for i in range(len(results)-1):
column_texts = []
tabel_column_cells = data["results"][i+1]['table_row']['cells']
for j in range(len(tabel_column_cells)):
if tabel_column_cells[j]:
for table_column_cell_text in tabel_column_cells[j]:
column_text = table_column_cell_text["text"]["content"]
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
cur_result_text = "\n".join(column_texts)
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
return result_lines_arr
result_lines = "\n".join(result_lines_arr)
return result_lines
def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]:
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block."""
done = False
result_lines_arr = []
cur_block_id = block_id
while not done:
while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict
"GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
# current block's heading
if 'results' not in data or data["results"] is None:
break
heading = ''
for result in data["results"]:
result_type = result["type"]
......@@ -171,7 +233,6 @@ class NotionPageReader(BaseReader):
if result_type == 'table':
result_block_id = result["id"]
text = self._read_table_rows(result_block_id)
text += "\n\n"
result_lines_arr.append(text)
else:
if "rich_text" in result_obj:
......@@ -179,10 +240,10 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object
if "text" in rich_text:
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
if result_type in HEADING_TYPE:
heading = text
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
......@@ -193,177 +254,121 @@ class NotionPageReader(BaseReader):
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text)
else:
result_lines_arr.append(f'{heading}\n{cur_result_text}')
if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
return result_lines_arr
def read_page(self, page_id: str) -> str:
"""Read a page."""
return self._read_block(page_id)
def read_page_as_documents(self, page_id: str) -> List[str]:
"""Read a page as documents."""
return self._read_parent_blocks(page_id)
def query_database_data(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> str:
"""Get all the pages from a Notion database."""
res = requests.post\
(
DATABASE_URL_TMPL.format(database_id=database_id),
headers=self.headers,
json=query_dict,
)
data = res.json()
database_content_list = []
if 'results' not in data or data["results"] is None:
return ""
for result in data["results"]:
properties = result['properties']
data = {}
for property_name, property_value in properties.items():
type = property_value['type']
if type == 'multi_select':
value = []
multi_select_list = property_value[type]
for multi_select in multi_select_list:
value.append(multi_select['name'])
elif type == 'rich_text' or type == 'title':
if len(property_value[type]) > 0:
value = property_value[type][0]['plain_text']
else:
value = ''
elif type == 'select' or type == 'status':
if property_value[type]:
value = property_value[type]['name']
else:
value = ''
else:
value = property_value[type]
data[property_name] = value
database_content_list.append(json.dumps(data, ensure_ascii=False))
return "\n\n".join(database_content_list)
def query_database(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> List[str]:
"""Get all the pages from a Notion database."""
res = requests.post\
(
DATABASE_URL_TMPL.format(database_id=database_id),
headers=self.headers,
json=query_dict,
)
data = res.json()
page_ids = []
for result in data["results"]:
page_id = result["id"]
page_ids.append(page_id)
return page_ids
result_lines = "\n".join(result_lines_arr)
return result_lines
def search(self, query: str) -> List[str]:
"""Search Notion page given a text query."""
def _read_table_rows(self, block_id: str) -> str:
"""Read table rows."""
done = False
next_cursor: Optional[str] = None
page_ids = []
result_lines_arr = []
cur_block_id = block_id
while not done:
query_dict = {
"query": query,
}
if next_cursor is not None:
query_dict["start_cursor"] = next_cursor
res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict)
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
for result in data["results"]:
page_id = result["id"]
page_ids.append(page_id)
# get table headers text
table_header_cell_texts = []
tabel_header_cells = data["results"][0]['table_row']['cells']
for tabel_header_cell in tabel_header_cells:
if tabel_header_cell:
for table_header_cell_text in tabel_header_cell:
text = table_header_cell_text["text"]["content"]
table_header_cell_texts.append(text)
# get table columns text and format
results = data["results"]
for i in range(len(results) - 1):
column_texts = []
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
for j in range(len(tabel_column_cells)):
if tabel_column_cells[j]:
for table_column_cell_text in tabel_column_cells[j]:
column_text = table_column_cell_text["text"]["content"]
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
cur_result_text = "\n".join(column_texts)
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None:
done = True
break
else:
next_cursor = data["next_cursor"]
return page_ids
cur_block_id = data["next_cursor"]
def load_data(
self, page_ids: List[str] = [], database_id: Optional[str] = None
) -> List[Document]:
"""Load data from the input directory.
result_lines = "\n".join(result_lines_arr)
return result_lines
Args:
page_ids (List[str]): List of page ids to load.
def update_last_edited_time(self, document_model: DocumentModel):
if not document_model:
return
Returns:
List[Document]: List of documents.
last_edited_time = self.get_notion_last_edited_time()
data_source_info = document_model.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
DocumentModel.data_source_info: json.dumps(data_source_info)
}
"""
if not page_ids and not database_id:
raise ValueError("Must specify either `page_ids` or `database_id`.")
docs = []
if database_id is not None:
# get all the pages in the database
page_ids = self.query_database(database_id)
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text))
else:
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text))
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
db.session.commit()
return docs
def load_data_as_documents(
self, page_ids: List[str] = [], database_id: Optional[str] = None
) -> List[Document]:
if not page_ids and not database_id:
raise ValueError("Must specify either `page_ids` or `database_id`.")
docs = []
if database_id is not None:
# get all the pages in the database
page_text = self.query_database_data(database_id)
docs.append(Document(page_text))
def get_notion_last_edited_time(self) -> str:
obj_id = self._notion_obj_id
page_type = self._notion_page_type
if page_type == 'database':
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
else:
for page_id in page_ids:
page_text_list = self.read_page_as_documents(page_id)
for page_text in page_text_list:
docs.append(Document(page_text))
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
return docs
def get_page_last_edited_time(self, page_id: str) -> str:
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", retrieve_page_url, headers=self.headers, json=query_dict
"GET",
retrieve_page_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
return data["last_edited_time"]
def get_database_last_edited_time(self, database_id: str) -> str:
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", retrieve_page_url, headers=self.headers, json=query_dict
)
data = res.json()
return data["last_edited_time"]
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
)
).first()
if not data_source_binding:
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
f'and notion workspace {notion_workspace_id}')
if __name__ == "__main__":
reader = NotionPageReader()
logger.info(reader.search("What I"))
return data_source_binding.access_token
from pathlib import Path
from typing import Dict
import logging
from typing import List, Optional
from flask import current_app
from llama_index.readers.file.base_parser import BaseParser
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from pypdf import PdfReader
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PDFParser(BaseParser):
"""PDF parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
class PdfLoader(BaseLoader):
"""Load pdf files.
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
if not current_app.config.get('PDF_PREVIEW', True):
return ''
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> List[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
upload_file: UploadFile = self._parser_config['upload_file']
if upload_file.hash:
plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
......@@ -35,7 +43,7 @@ class PDFParser(BaseParser):
pass
text_list = []
with open(file, "rb") as fp:
with open(self._file_path, "rb") as fp:
# Create a PDF object
pdf = PdfReader(fp)
......@@ -53,4 +61,9 @@ class PDFParser(BaseParser):
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return text
metadata = {"source": self._file_path}
return [Document(page_content=text, metadata=metadata)]
def load_as_text(self) -> str:
documents = self.load()
return '\n'.join([document.page_content for document in documents])
from typing import Any, Dict, Optional, Sequence
import tiktoken
from llama_index.data_structs import Node
from llama_index.docstore.types import BaseDocumentStore
from llama_index.docstore.utils import json_to_doc
from llama_index.schema import BaseDocument
from langchain.schema import Document
from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator
......@@ -12,7 +8,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
class DatesetDocumentStore(BaseDocumentStore):
class DatesetDocumentStore:
def __init__(
self,
dataset: Dataset,
......@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
return self._embedding_model_name
@property
def docs(self) -> Dict[str, BaseDocument]:
def docs(self) -> Dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id
).all()
......@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
output = {}
for document_segment in document_segments:
doc_id = document_segment.index_node_id
result = self.segment_to_dict(document_segment)
output[doc_id] = json_to_doc(result)
output[doc_id] = Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
return output
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
self, docs: Sequence[Document], allow_update: bool = True
) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document == self._document_id
......@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
max_position = 0
for doc in docs:
if doc.is_doc_id_none:
raise ValueError("doc_id not set")
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
if not isinstance(doc, Node):
raise ValueError("doc must be a Node")
segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
# NOTE: doc could already exist in the store, but we overwrite it
if not allow_update and segment_document:
raise ValueError(
f"doc_id {doc.get_doc_id()} already exists. "
f"doc_id {doc.metadata['doc_id']} already exists. "
"Set allow_update to True to overwrite."
)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
if not segment_document:
max_position += 1
......@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
index_node_id=doc.get_doc_id(),
index_node_hash=doc.get_doc_hash(),
index_node_id=doc.metadata['doc_id'],
index_node_hash=doc.metadata['doc_hash'],
position=max_position,
content=doc.get_text(),
word_count=len(doc.get_text()),
content=doc.page_content,
word_count=len(doc.page_content),
tokens=tokens,
created_by=self._user_id,
)
db.session.add(segment_document)
else:
segment_document.content = doc.get_text()
segment_document.index_node_hash = doc.get_doc_hash()
segment_document.word_count = len(doc.get_text())
segment_document.content = doc.page_content
segment_document.index_node_hash = doc.metadata['doc_hash']
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
db.session.commit()
......@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
) -> Optional[Document]:
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
......@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
else:
return None
result = self.segment_to_dict(document_segment)
return json_to_doc(result)
return Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
document_segment = self.get_document_segment(doc_id)
......@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
return document_segment.index_node_hash
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
def get_document_segment(self, doc_id: str) -> DocumentSegment:
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
......@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
).first()
return document_segment
def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
return {
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"text": segment.content,
"__type__": Node.get_type()
}
import logging
from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
class CacheEmbedding(Embeddings):
def __init__(self, embeddings: Embeddings):
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
# use doc embedding cache or store if not exists
text_embeddings = []
embedding_queue_texts = []
for text in texts:
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
text_embeddings.append(embedding.embedding)
else:
embedding_queue_texts.append(text)
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
i += 1
embedding_queue_texts.extend(embedding_results)
return embedding_queue_texts
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
return embedding.embedding
embedding_results = self._embeddings.embed_query(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
except:
logging.exception('Failed to add embedding to db')
return embedding_results
from typing import Optional, Any, List
import openai
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
_TEXT_MODE_MODEL_DICT
from tenacity import wait_random_exponential, retry, stop_after_attempt
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(
text: str,
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[float]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
float]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
"embedding"
]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str],
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[List[float]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
) -> List[List[float]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
class OpenAIEmbedding(BaseEmbedding):
def __init__(
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Init params."""
new_kwargs = {}
if 'embed_batch_size' in kwargs:
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
if 'tokenizer' in kwargs:
new_kwargs['tokenizer'] = kwargs['tokenizer']
super().__init__(**new_kwargs)
self.mode = OpenAIEmbeddingMode(mode)
self.model = OpenAIEmbeddingModelType(model)
self.deployment_name = deployment_name
self.openai_api_key = openai_api_key
self.openai_api_type = kwargs.get('openai_api_type')
self.openai_api_version = kwargs.get('openai_api_version')
self.openai_api_base = kwargs.get('openai_api_base')
@handle_llm_exceptions
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _QUERY_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _QUERY_MODE_MODEL_DICT[key]
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Can be overriden for batch queries.
"""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(self._get_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(await self._aget_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
from __future__ import annotations
from abc import abstractmethod, ABC
from typing import List, Any
from langchain.schema import Document, BaseRetriever
class BaseIndex(ABC):
@abstractmethod
def create(self, texts: list[Document]) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document]):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError
@abstractmethod
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
raise NotImplementedError
@abstractmethod
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
from langchain.callbacks import CallbackManager
from llama_index import ServiceContext, PromptHelper, LLMPredictor
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.embedding.openai_embedding import OpenAIEmbedding
from core.llm.llm_builder import LLMBuilder
class IndexBuilder:
@classmethod
def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
# set number of output tokens
num_output = 512
# only for verbose
callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='text-davinci-003',
temperature=0,
max_tokens=num_output,
callback_manager=callback_manager,
)
llm_predictor = LLMPredictor(llm=llm)
# These parameters here will affect the logic of segmenting the final synthesized response.
# The number of refinement iterations in the synthesis process depends
# on whether the length of the segmented output exceeds the max_input_size.
prompt_helper = PromptHelper(
max_input_size=3500,
num_output=num_output,
max_chunk_overlap=20
)
provider = LLMBuilder.get_default_provider(tenant_id)
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=tenant_id,
model_provider=provider,
model_name='text-embedding-ada-002'
)
return ServiceContext.from_defaults(
llm_predictor=llm_predictor,
prompt_helper=prompt_helper,
embed_model=OpenAIEmbedding(**model_credentials),
)
@classmethod
def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext:
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='fake'
)
return ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
import re
from typing import (
Any,
Dict,
List,
Set,
Optional
)
import jieba.analyse
from core.index.keyword_table.stopwords import STOPWORDS
from llama_index.indices.query.base import IS
from llama_index import QueryMode
from llama_index.indices.base import QueryMap
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
def jieba_extract_keywords(
text_chunk: str,
max_keywords: Optional[int] = None,
expand_with_subtokens: bool = True,
) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text_chunk,
topK=max_keywords,
)
if expand_with_subtokens:
return set(expand_tokens_with_subtokens(keywords))
else:
return set(keywords)
def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
"""GPT JIEBA Keyword Table Index.
This index uses a JIEBA keyword extractor to extract keywords from the text.
"""
def _extract_keywords(self, text: str) -> Set[str]:
"""Extract keywords from text."""
return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
@classmethod
def get_query_map(self) -> QueryMap:
"""Get query map."""
super_map = super().get_query_map()
super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
return super_map
def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document."""
# get set of ids that correspond to node
node_idxs_to_delete = {doc_id}
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in self._index_struct.table.items():
if node_idxs_to_delete.intersection(node_idxs):
self._index_struct.table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not self._index_struct.table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del self._index_struct.table[keyword]
class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
"""GPT Keyword Table Index JIEBA Query.
Extracts keywords using JIEBA keyword extractor.
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
.. code-block:: python
response = index.query("<query_str>", mode="jieba")
See BaseGPTKeywordTableQuery for arguments.
"""
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
return list(
jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
)
import json
from typing import List, Optional
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
from llama_index.data_structs import KeywordTable, Node
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.registry import load_index_struct_from_dict
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.index_builder import IndexBuilder
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class KeywordTableIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
index_struct = KeywordTable()
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node in nodes:
keywords = index._extract_keywords(node.get_text())
self.update_segment_keywords(node.doc_id, list(keywords))
index._index_struct.add_node(list(keywords), node)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
def del_nodes(self, node_ids: List[str]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node_id in node_ids:
index.delete(node_id)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
@property
def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
docstore = DatesetDocumentStore(
dataset=self._dataset,
user_id=self._dataset.created_by,
embedding_model_name="text-embedding-ada-002"
)
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return None
index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
def get_keyword_table(self):
dataset_keyword_table = self._dataset.dataset_keyword_table
if dataset_keyword_table:
return dataset_keyword_table
return None
def update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
import re
from typing import Set
import jieba
from jieba.analyse import default_tfidf
from core.index.keyword_table_index.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
)
return set(self._expand_tokens_with_subtokens(keywords))
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
\ No newline at end of file
import json
from collections import defaultdict
from typing import Any, List, Optional, Dict
from langchain.schema import Document, BaseRetriever
from pydantic import BaseModel, Field
from core.index.base import BaseIndex
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
class KeywordTableIndex(BaseIndex):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
self._dataset = dataset
self._config = config
def create(self, texts: list[Document]) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(keyword_table)
)
db.session.add(dataset_keyword_table)
db.session.commit()
return self
def add_texts(self, texts: list[Document]):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table)
db.session.commit()
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table)
db.session.commit()
def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
DocumentSegment.document_id == document_id
).all()
ids = [segment.id for segment in segments]
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table)
db.session.commit()
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
if segment:
documents.append(Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
))
return documents
def _get_dataset_keyword_table(self) -> Optional[dict]:
keyword_table_dict = self._dataset.dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del keyword_table[keyword]
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[str, int] = defaultdict(int)
keywords = [k for k in keywords if k in set(keyword_table.keys())]
for k in keywords:
for node_id in keyword_table[k]:
chunk_indices_count[node_id] += 1
sorted_chunk_indices = sorted(
list(chunk_indices_count.keys()),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict)
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")
from typing import (
Any,
Dict,
Optional, Sequence,
)
from llama_index.indices.response.response_synthesis import ResponseSynthesizer
from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from llama_index.types import RESPONSE_TEXT_TYPE
class EnhanceResponseSynthesizer(ResponseSynthesizer):
@classmethod
def from_args(
cls,
service_context: ServiceContext,
streaming: bool = False,
use_async: bool = False,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_mode: ResponseMode = ResponseMode.DEFAULT,
response_kwargs: Optional[Dict] = None,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
) -> "ResponseSynthesizer":
response_builder: Optional[BaseResponseBuilder] = None
if response_mode != ResponseMode.NO_TEXT:
if response_mode == 'no_synthesizer':
response_builder = NoSynthesizer(
service_context=service_context,
simple_template=simple_template,
streaming=streaming,
)
else:
response_builder = get_response_builder(
service_context,
text_qa_template,
refine_template,
simple_template,
response_mode,
use_async=use_async,
streaming=streaming,
)
return cls(response_builder, response_mode, response_kwargs, optimizer)
class NoSynthesizer(BaseResponseBuilder):
def __init__(
self,
service_context: ServiceContext,
simple_template: Optional[SimpleInputPrompt] = None,
streaming: bool = False,
) -> None:
super().__init__(service_context, streaming)
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
\ No newline at end of file
from pathlib import Path
from typing import Dict
from bs4 import BeautifulSoup
from llama_index.readers.file.base_parser import BaseParser
class HTMLParser(BaseParser):
"""HTML parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
with open(file, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
from pathlib import Path
import json
from typing import Dict
from openpyxl import load_workbook
from llama_index.readers.file.base_parser import BaseParser
from flask import current_app
class XLSXParser(BaseParser):
"""XLSX parser."""
def _init_parser(self) -> Dict:
"""Init parser"""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
data = []
keys = []
with open(file, "r") as fp:
wb = load_workbook(filename=file, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
row_dict = dict(zip(keys, row))
row_dict = {k: v for k, v in row_dict.items() if v}
data.append(json.dumps(row_dict, ensure_ascii=False))
return '\n\n'.join(data)
import json
import logging
from typing import List, Optional
from llama_index.data_structs import Node
from requests import ReadTimeout
from sqlalchemy.exc import IntegrityError
from tenacity import retry, stop_after_attempt, retry_if_exception_type
from core.index.index_builder import IndexBuilder
from core.vector_store.base import BaseGPTVectorStoreIndex
from extensions.ext_vector_store import vector_store
from extensions.ext_database import db
from models.dataset import Dataset, Embedding
class VectorIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
if not self._dataset.index_struct_dict:
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
db.session.commit()
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
if duplicate_check:
nodes = self._filter_duplicate_nodes(index, nodes)
embedding_queue_nodes = []
embedded_nodes = []
for node in nodes:
node_hash = node.doc_hash
# if node hash in cached embedding tables, use cached embedding
embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
if embedding:
node.embedding = embedding.get_embedding()
embedded_nodes.append(node)
else:
embedding_queue_nodes.append(node)
if embedding_queue_nodes:
embedding_results = index._get_node_embedding_results(
embedding_queue_nodes,
set(),
)
# pre embed nodes for cached embedding
for embedding_result in embedding_results:
node = embedding_result.node
node.embedding = embedding_result.embedding
try:
embedding = Embedding(hash=node.doc_hash)
embedding.set_embedding(node.embedding)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
embedded_nodes.append(node)
self.index_insert_nodes(index, embedded_nodes)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
index.insert_nodes(nodes)
def del_nodes(self, node_ids: List[str]):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
for node_id in node_ids:
self.index_delete_node(index, node_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
index.delete_node(node_id)
def del_doc(self, doc_id: str):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
self.index_delete_doc(index, doc_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
index.delete(doc_id)
@property
def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
if not self._dataset.index_struct_dict:
return None
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
return vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
for node in nodes:
node_id = node.doc_id
exists_duplicate_node = index.exists_by_node_id(node_id)
if exists_duplicate_node:
nodes.remove(node)
return nodes
from abc import abstractmethod
from typing import List, Any, Tuple
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from index.base import BaseIndex
class BaseVectorIndex(BaseIndex):
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod
def get_index_name(self, dataset_id: str) -> str:
raise NotImplementedError
@abstractmethod
def to_index_struct(self) -> dict:
raise NotImplementedError
@abstractmethod
def _get_vector_store(self) -> VectorStore:
raise NotImplementedError
import os
from typing import Optional, Any, List, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
}
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
self._dataset = dataset
self._client_config = config
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset_id: str) -> str:
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self._dataset.get_id())}
}
def create(self, texts: list[Document]) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self._dataset.get_id()),
ids=uuids,
**self._client_config.to_qdrant_params()
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self._dataset.get_id()),
embeddings=self._embeddings
)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
return vector_store.as_retriever(**kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def add_texts(self, texts: list[Document]):
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(QdrantVectorStore, vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
import json
from flask import current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.base import BaseVectorIndex
from extensions.ext_database import db
from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings)
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError(f"Vector store must be specified.")
if vector_type == "weaviate":
from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_URL'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings,
attributes=['doc_id', 'dataset_id', 'document_id', 'source'],
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
return QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document]):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts)
def __getattr__(self, name):
if self._vector_index is not None:
method = getattr(self._vector_index, name)
if callable(method):
return method
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
from typing import Optional, Any, List, cast
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list[str]):
self._dataset = dataset
self._client = self._init_client(config)
self._embeddings = embeddings
self._attributes = attributes
self._vector_store = None
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset_id: str) -> str:
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self._dataset.get_id())}
}
def create(self, texts: list[Document]) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self._dataset.get_id()),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self._dataset.get_id()),
text_key='text',
embedding=self._embeddings,
attributes=self._attributes,
by_text=False
)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
return vector_store.as_retriever(**kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def add_texts(self, texts: list[Document]):
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(WeaviateVectorStore, vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
import datetime
import json
import re
import tempfile
import time
from pathlib import Path
from typing import Optional, List
import uuid
from typing import Optional, List, cast
from flask import current_app
from flask_login import current_user
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from llama_index import SimpleDirectoryReader
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.node_parser import SimpleNodeParser, NodeParser
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
from llama_index.readers.file.markdown_parser import MarkdownParser
from core.data_source.notion import NotionPageReader
from core.index.readers.xlsx_parser import XLSXParser
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.index.keyword_table_index import KeywordTableIndex
from core.index.readers.html_parser import HTMLParser
from core.index.readers.markdown_parser import MarkdownParser
from core.index.readers.pdf_parser import PDFParser
from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.index.vector_index import VectorIndex
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
from libs import helper
from models.dataset import Document as DatasetDocument
from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
from models.model import UploadFile
from models.source import DataSourceBinding
......@@ -40,49 +36,49 @@ class IndexingRunner:
self.storage = storage
self.embedding_model_name = embedding_model_name
def run(self, documents: List[Document]):
def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process."""
for document in documents:
for dataset_document in dataset_documents:
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# load file
text_docs = self._load_data(document)
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._step_split(
# split to documents
documents = self._step_split(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
dataset=dataset,
document=document,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
dataset_document=dataset_document,
documents=documents
)
def run_in_splitting_status(self, document: Document):
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
id=dataset_document.dataset_id
).first()
if not dataset:
......@@ -91,42 +87,44 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=document.id
document_id=dataset_document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file
text_docs = self._load_data(document)
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._step_split(
# split to documents
documents = self._step_split(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
dataset=dataset,
document=document,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
dataset_document=dataset_document,
documents=documents
)
def run_in_indexing_status(self, document: Document):
def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is indexing."""
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
id=dataset_document.dataset_id
).first()
if not dataset:
......@@ -135,39 +133,31 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=document.id
document_id=dataset_document.id
).all()
nodes = []
documents = []
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
relationships = {
DocumentRelationship.SOURCE: document_segment.document_id,
}
previous_segment = document_segment.previous_segment
if previous_segment:
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
next_segment = document_segment.next_segment
if next_segment:
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
node = Node(
doc_id=document_segment.index_node_id,
doc_hash=document_segment.index_node_hash,
text=document_segment.content,
extra_info=None,
node_info=None,
relationships=relationships
document = Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
nodes.append(node)
documents.append(document)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
dataset_document=dataset_document,
documents=documents
)
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
......@@ -179,28 +169,28 @@ class IndexingRunner:
total_segments = 0
for file_detail in file_details:
# load data from file
text_docs = self._load_data_from_file(file_detail)
text_docs = FileExtractor.load(file_detail)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
# split to documents
documents = self._split_to_documents(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(nodes)
for node in nodes:
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
return {
"total_segments": total_segments,
......@@ -230,35 +220,36 @@ class IndexingRunner:
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
for page in notion_info['pages']:
if page['type'] == 'page':
page_ids = [page['page_id']]
documents = reader.load_data_as_documents(page_ids=page_ids)
elif page['type'] == 'database':
documents = reader.load_data_as_documents(database_id=page['page_id'])
else:
documents = []
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page['page_id'],
notion_page_type=page['type']
)
documents = loader.load()
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
# split to documents
documents = self._split_to_documents(
text_docs=documents,
node_parser=node_parser,
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(nodes)
for node in nodes:
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
return {
"total_segments": total_segments,
......@@ -268,14 +259,14 @@ class IndexingRunner:
"preview": preview_texts
}
def _load_data(self, document: Document) -> List[Document]:
def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
# load file
if document.data_source_type not in ["upload_file", "notion_import"]:
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return []
data_source_info = document.data_source_info_dict
data_source_info = dataset_document.data_source_info_dict
text_docs = []
if document.data_source_type == 'upload_file':
if dataset_document.data_source_type == 'upload_file':
if not data_source_info or 'upload_file_id' not in data_source_info:
raise ValueError("no upload file found")
......@@ -283,47 +274,28 @@ class IndexingRunner:
filter(UploadFile.id == data_source_info['upload_file_id']). \
one_or_none()
text_docs = self._load_data_from_file(file_detail)
elif document.data_source_type == 'notion_import':
if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
workspace_id = data_source_info['notion_workspace_id']
page_id = data_source_info['notion_page_id']
page_type = data_source_info['type']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == document.tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
if page_type == 'page':
# add page last_edited_time to data_source_info
self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token)
elif page_type == 'database':
# add page last_edited_time to data_source_info
self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token)
text_docs = FileExtractor.load(file_detail)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
# update document status to splitting
self._update_document_index_status(
document_id=document.id,
document_id=dataset_document.id,
after_indexing_status="splitting",
extra_update_params={
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
Document.parsing_completed_at: datetime.datetime.utcnow()
DatasetDocument.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
}
)
# replace doc id to document model id
text_docs = cast(List[Document], text_docs)
for text_doc in text_docs:
# remove invalid symbol
text_doc.text = self.filter_string(text_doc.get_text())
text_doc.doc_id = document.id
text_doc.page_content = self.filter_string(text_doc.page_content)
text_doc.metadata['document_id'] = dataset_document.id
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
return text_docs
......@@ -331,61 +303,7 @@ class IndexingRunner:
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
return pattern.sub('', text)
def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
self.storage.download(upload_file.key, filepath)
file_extractor = DEFAULT_FILE_EXTRACTOR.copy()
file_extractor[".markdown"] = MarkdownParser()
file_extractor[".md"] = MarkdownParser()
file_extractor[".html"] = HTMLParser()
file_extractor[".htm"] = HTMLParser()
file_extractor[".pdf"] = PDFParser({'upload_file': upload_file})
file_extractor[".xlsx"] = XLSXParser()
loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor)
text_docs = loader.load_data()
return text_docs
def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
page_ids = [page_id]
reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(page_ids=page_ids)
return text_docs
def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]:
reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(database_id=database_id)
return text_docs
def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_page_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_database_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
......@@ -414,68 +332,83 @@ class IndexingRunner:
separators=["\n\n", "。", ".", " ", ""]
)
return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True)
return character_splitter
def _step_split(self, text_docs: List[Document], node_parser: NodeParser,
dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]:
def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
-> List[Document]:
"""
Split the text documents into nodes and save them to the document segment.
Split the text documents into documents and save them to the document segment.
"""
nodes = self._split_to_nodes(
documents = self._split_to_documents(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
processing_rule=processing_rule
)
# save node to document segment
doc_store = DatesetDocumentStore(
dataset=dataset,
user_id=document.created_by,
user_id=dataset_document.created_by,
embedding_model_name=self.embedding_model_name,
document_id=document.id
document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(nodes)
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.utcnow()
self._update_document_index_status(
document_id=document.id,
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
Document.cleaning_completed_at: cur_time,
Document.splitting_completed_at: cur_time,
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
}
)
# update segment status to indexing
self._update_segments_by_document(
document_id=document.id,
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
)
return nodes
return documents
def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser,
processing_rule: DatasetProcessRule) -> List[Node]:
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
"""
Split the text documents into nodes.
"""
all_nodes = []
all_documents = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.get_text(), processing_rule)
text_doc.text = document_text
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
nodes = node_parser.get_nodes_from_documents([text_doc])
nodes = [node for node in nodes if node.text is not None and node.text.strip()]
all_nodes.extend(nodes)
documents = splitter.split_documents([text_doc])
split_documents = []
for document in documents:
if document.page_content is None or not document.page_content.strip():
continue
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash
return all_nodes
split_documents.append(document)
all_documents.extend(split_documents)
return all_documents
def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
"""
......@@ -506,37 +439,58 @@ class IndexingRunner:
return text
def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None:
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
"""
Build the index for the document.
"""
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
keyword_table_index = KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 100
for i in range(0, len(nodes), chunk_size):
for i in range(0, len(documents), chunk_size):
# check document is paused
self._check_document_paused_status(document.id)
chunk_nodes = nodes[i:i + chunk_size]
self._check_document_paused_status(dataset_document.id)
chunk_documents = documents[i:i + chunk_size]
tokens += sum(
TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes
TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
for document in chunk_documents
)
# save vector index
if dataset.indexing_technique == "high_quality":
vector_index.add_nodes(chunk_nodes)
vector_index.add_texts(chunk_documents)
# save keyword index
keyword_table_index.add_nodes(chunk_nodes)
keyword_table_index.add_texts(chunk_documents)
node_ids = [node.doc_id for node in chunk_nodes]
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.index_node_id.in_(node_ids),
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing"
).update({
DocumentSegment.status: "completed",
......@@ -549,12 +503,12 @@ class IndexingRunner:
# update document status to completed
self._update_document_index_status(
document_id=document.id,
document_id=dataset_document.id,
after_indexing_status="completed",
extra_update_params={
Document.tokens: tokens,
Document.completed_at: datetime.datetime.utcnow(),
Document.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.utcnow(),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
}
)
......@@ -569,25 +523,25 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count = Document.query.filter_by(id=document_id, is_paused=True).count()
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedException()
update_params = {
Document.indexing_status: after_indexing_status
DatasetDocument.indexing_status: after_indexing_status
}
if extra_update_params:
update_params.update(extra_update_params)
Document.query.filter_by(id=document_id).update(update_params)
DatasetDocument.query.filter_by(id=document_id).update(update_params)
db.session.commit()
def _update_segments_by_document(self, document_id: str, update_params: dict) -> None:
def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
"""
Update the document segment by document id.
"""
DocumentSegment.query.filter_by(document_id=document_id).update(update_params)
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
......
......@@ -42,7 +42,7 @@ class AzureProvider(BaseProvider):
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['deployment_name'] = model_id.replace('.', '') if model_id else None
config['deployment'] = config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
def get_provider_name(self):
......
from abc import ABC, abstractmethod
from typing import Optional
from llama_index import ServiceContext, GPTVectorStoreIndex
from llama_index.data_structs import Node
from llama_index.vector_stores.types import VectorStore
class BaseVectorStoreClient(ABC):
@abstractmethod
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
raise NotImplementedError
@abstractmethod
def to_index_config(self, index_id: str) -> dict:
raise NotImplementedError
class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
def delete_node(self, node_id: str):
self._vector_store.delete_node(node_id)
def exists_by_node_id(self, node_id: str) -> bool:
return self._vector_store.exists_by_node_id(node_id)
class EnhanceVectorStore(ABC):
@abstractmethod
def delete_node(self, node_id: str):
pass
@abstractmethod
def exists_by_node_id(self, node_id: str) -> bool:
pass
from typing import cast
from langchain.vectorstores import Qdrant
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
from qdrant_client.local.qdrant_local import QdrantLocal
class QdrantVectorStore(Qdrant):
def del_texts(self, filter: Filter):
if not filter:
raise ValueError('filter must not be empty')
self._reload_if_needed()
self.client.delete(
collection_name=self.collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def del_text(self, uuid: str) -> None:
self._reload_if_needed()
self.client.delete(
collection_name=self.collection_name,
points_selector=PointIdsList(
points=[uuid],
),
)
def text_exists(self, uuid: str) -> bool:
self._reload_if_needed()
response = self.client.retrieve(
collection_name=self.collection_name,
ids=[uuid]
)
return len(response) > 0
def _reload_if_needed(self):
if isinstance(self.client, QdrantLocal):
self.client = cast(QdrantLocal, self.client)
self.client._load()
import os
from typing import cast, List
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
from qdrant_client.http.models import Payload, Filter
import qdrant_client
from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
from llama_index.vector_stores import QdrantVectorStore
from qdrant_client.local.qdrant_local import QdrantLocal
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
class QdrantVectorStoreClient(BaseVectorStoreClient):
def __init__(self, url: str, api_key: str, root_path: str):
self._client = self.init_from_config(url, api_key, root_path)
@classmethod
def init_from_config(cls, url: str, api_key: str, root_path: str):
if url and url.startswith('path:'):
path = url.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(root_path, path)
return qdrant_client.QdrantClient(
path=path
)
else:
return qdrant_client.QdrantClient(
url=url,
api_key=api_key,
)
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = QdrantIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"collection_name": "Gpt_index_xxx"}
collection_name = config.get('collection_name')
if not collection_name:
raise Exception("collection_name cannot be None.")
return GPTQdrantEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=QdrantEnhanceVectorStore(
client=self._client,
collection_name=collection_name
)
)
def to_index_config(self, index_id: str) -> dict:
return {"collection_name": index_id}
class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
pass
class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
from qdrant_client.http import models as rest
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=rest.Filter(
must=[
rest.FieldCondition(
key="id", match=rest.MatchValue(value=node_id)
)
]
),
)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
self._reload_if_needed()
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[node_id]
)
return len(response) > 0
def query(
self,
query: VectorStoreQuery,
) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): query
"""
query_embedding = cast(List[float], query.query_embedding)
self._reload_if_needed()
response = self._client.search(
collection_name=self._collection_name,
query_vector=query_embedding,
limit=cast(int, query.similarity_top_k),
query_filter=cast(Filter, self._build_query_filter(query)),
with_vectors=True
)
nodes = []
similarities = []
ids = []
for point in response:
payload = cast(Payload, point.payload)
node = Node(
doc_id=str(point.id),
text=payload.get("text"),
embedding=point.vector,
extra_info=payload.get("extra_info"),
relationships={
DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
},
)
nodes.append(node)
similarities.append(point.score)
ids.append(str(point.id))
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
def _reload_if_needed(self):
if isinstance(self._client._client, QdrantLocal):
self._client._client._load()
from flask import Flask
from llama_index import ServiceContext, GPTVectorStoreIndex
from requests import ReadTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
class VectorStore:
def __init__(self):
self._vector_store = None
self._client = None
def init_app(self, app: Flask):
if not app.config['VECTOR_STORE']:
return
self._vector_store = app.config['VECTOR_STORE']
if self._vector_store not in SUPPORTED_VECTOR_STORES:
raise ValueError(f"Vector store {self._vector_store} is not supported.")
if self._vector_store == 'weaviate':
self._client = WeaviateVectorStoreClient(
endpoint=app.config['WEAVIATE_ENDPOINT'],
api_key=app.config['WEAVIATE_API_KEY'],
grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'],
batch_size=app.config['WEAVIATE_BATCH_SIZE']
)
elif self._vector_store == 'qdrant':
self._client = QdrantVectorStoreClient(
url=app.config['QDRANT_URL'],
api_key=app.config['QDRANT_API_KEY'],
root_path=app.root_path
)
app.extensions['vector_store'] = self
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
vector_store_config: dict = index_struct.get('vector_store')
index = self.get_client().get_index(
service_context=service_context,
config=vector_store_config
)
return index
def to_index_struct(self, index_id: str) -> dict:
return {
"type": self._vector_store,
"vector_store": self.get_client().to_index_config(index_id)
}
def get_client(self):
if not self._client:
raise Exception("Vector store client is not initialized.")
return self._client
from llama_index.indices.query.base import IS
from typing import (
Any,
Dict,
List,
Optional
)
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)
from langchain.vectorstores import Weaviate
class WeaviateVectorStore(Weaviate):
def del_texts(self, where_filter: dict):
if not where_filter:
raise ValueError('where_filter must not be empty')
self._client.batch.delete_objects(
class_name=self._index_name,
where=where_filter,
output='minimal'
)
def del_text(self, uuid: str) -> None:
self._client.data_object.delete(
uuid,
class_name=self._index_name
)
def text_exists(self, uuid: str) -> bool:
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": uuid,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][self._index_name]
if len(entries) == 0:
return False
return True
import json
import weaviate
from dataclasses import field
from typing import List, Any, Dict, Optional
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
from llama_index.vector_stores import WeaviateVectorStore
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
from llama_index.readers.weaviate.utils import (
parse_get_response,
validate_client,
)
class WeaviateVectorStoreClient(BaseVectorStoreClient):
def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size)
def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
weaviate.connect.connection.has_grpc = grpc_enabled
client = weaviate.Client(
url=endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = WeaviateIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"class_prefix": "Gpt_index_xxx"}
class_prefix = config.get('class_prefix')
if not class_prefix:
raise Exception("class_prefix cannot be None.")
return GPTWeaviateEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=WeaviateWithSimilaritiesVectorStore(
weaviate_client=self._client,
class_prefix=class_prefix
)
)
def to_index_config(self, index_id: str) -> dict:
return {"class_prefix": index_id}
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
nodes = self.weaviate_query(
self._client,
self._class_prefix,
query,
)
nodes = nodes[: query.similarity_top_k]
node_idxs = [str(i) for i in range(len(nodes))]
similarities = []
for node in nodes:
similarities.append(node.extra_info['similarity'])
del node.extra_info['similarity']
return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
def weaviate_query(
self,
client: Any,
class_prefix: str,
query_spec: VectorStoreQuery,
) -> List[Node]:
"""Convert to LlamaIndex list."""
validate_client(client)
class_name = _class_name(class_prefix)
prop_names = [p["name"] for p in NODE_SCHEMA]
vector = query_spec.query_embedding
# build query
query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
if query_spec.mode == VectorStoreQueryMode.DEFAULT:
_logger.debug("Using vector search")
if vector is not None:
query = query.with_near_vector(
{
"vector": vector,
}
)
elif query_spec.mode == VectorStoreQueryMode.HYBRID:
_logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
query = query.with_hybrid(
query=query_spec.query_str,
alpha=query_spec.alpha,
vector=vector,
)
query = query.with_limit(query_spec.similarity_top_k)
_logger.debug(f"Using limit of {query_spec.similarity_top_k}")
# execute query
query_result = query.do()
# parse results
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
results = [self._to_node(entry) for entry in entries]
return results
def _to_node(self, entry: Dict) -> Node:
"""Convert to Node."""
extra_info_str = entry["extra_info"]
if extra_info_str == "":
extra_info = None
else:
extra_info = json.loads(extra_info_str)
if 'certainty' in entry['_additional']:
if extra_info:
extra_info['similarity'] = entry['_additional']['certainty']
else:
extra_info = {'similarity': entry['_additional']['certainty']}
node_info_str = entry["node_info"]
if node_info_str == "":
node_info = None
else:
node_info = json.loads(node_info_str)
relationships_str = entry["relationships"]
relationships: Dict[DocumentRelationship, str]
if relationships_str == "":
relationships = field(default_factory=dict)
else:
relationships = {
DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
}
return Node(
text=entry["text"],
doc_id=entry["doc_id"],
embedding=entry["_additional"]["vector"],
extra_info=extra_info,
node_info=node_info,
relationships=relationships,
)
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document.
Args:
doc_id (str): document id
"""
delete_document(self._client, doc_id, self._class_prefix)
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
delete_node(self._client, node_id, self._class_prefix)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
entry = get_by_node_id(self._client, node_id, self._class_prefix)
return True if entry else False
class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
pass
def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["ref_doc_id"],
"operator": "Equal",
"valueString": ref_doc_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
while len(entries) > 0:
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
if len(entries) == 0:
return None
return entries[0]
......@@ -3,6 +3,7 @@ import re
import subprocess
import uuid
from datetime import datetime
from hashlib import sha256
from zoneinfo import available_timezones
import random
import string
......@@ -147,3 +148,8 @@ def get_remote_ip(request):
return request.headers.getlist("X-Forwarded-For")[0]
else:
return request.remote_addr
def generate_text_hash(text: str) -> str:
hash_text = str(text) + 'None'
return sha256(hash_text.encode()).hexdigest()
......@@ -9,8 +9,7 @@ flask-session2==1.3.1
flask-cors==3.0.10
gunicorn~=20.1.0
gevent~=22.10.2
langchain==0.0.142
llama-index==0.5.27
langchain==0.0.201
openai~=0.27.5
psycopg2-binary~=2.9.6
pycryptodome==3.17
......@@ -31,4 +30,5 @@ celery==5.2.7
redis~=4.5.4
pypdf==3.8.1
openpyxl==3.1.2
chardet~=5.1.0
\ No newline at end of file
chardet~=5.1.0
docx2txt==0.8
\ No newline at end of file
......@@ -7,7 +7,6 @@ from typing import Optional, List
from extensions.ext_redis import redis_client
from flask_login import current_user
from core.index.index_builder import IndexBuilder
from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted
from extensions.ext_database import db
......@@ -386,8 +385,6 @@ class DocumentService:
dataset.indexing_technique = document_data["indexing_technique"]
if dataset.indexing_technique == 'high_quality':
IndexBuilder.get_default_service_context(dataset.tenant_id)
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
if 'original_document_id' in document_data and document_data["original_document_id"]:
......
......@@ -6,7 +6,7 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.data_source.notion import NotionPageReader
from core.data_loader.loader.notion import NotionLoader
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
......@@ -43,6 +43,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
raise ValueError("no notion page found")
workspace_id = data_source_info['notion_workspace_id']
page_id = data_source_info['notion_page_id']
page_type = data_source_info['type']
page_edited_time = data_source_info['last_edited_time']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
......@@ -54,8 +55,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
last_edited_time = reader.get_page_last_edited_time(page_id)
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type
)
last_edited_time = loader.get_notion_last_edited_time()
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = 'parsing'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment