Unverified Commit 6c4e6bf1 authored by Jyong's avatar Jyong Committed by GitHub

Feat/dify rag (#2528)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 97fe8171
...@@ -56,6 +56,7 @@ DEFAULTS = { ...@@ -56,6 +56,7 @@ DEFAULTS = {
'BILLING_ENABLED': 'False', 'BILLING_ENABLED': 'False',
'CAN_REPLACE_LOGO': 'False', 'CAN_REPLACE_LOGO': 'False',
'ETL_TYPE': 'dify', 'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20 'BATCH_UPLOAD_LIMIT': 20
} }
...@@ -183,7 +184,7 @@ class Config: ...@@ -183,7 +184,7 @@ class Config:
# Currently, only support: qdrant, milvus, zilliz, weaviate # Currently, only support: qdrant, milvus, zilliz, weaviate
# ------------------------ # ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE') self.VECTOR_STORE = get_env('VECTOR_STORE')
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
# qdrant settings # qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
......
...@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound ...@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required from libs.login import login_required
...@@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource): ...@@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource):
if not data_source_binding: if not data_source_binding:
raise NotFound('Data source binding not found.') raise NotFound('Data source binding not found.')
loader = NotionLoader( extractor = NotionExtractor(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type notion_page_type=page_type,
notion_access_token=data_source_binding.access_token
) )
text_docs = loader.load() text_docs = extractor.extract()
return { return {
'content': "\n".join([doc.page_content for doc in text_docs]) 'content': "\n".join([doc.page_content for doc in text_docs])
}, 200 }, 200
...@@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource): ...@@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
notion_info_list = args['notion_info_list']
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type']
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'])
return response, 200 return response, 200
......
...@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError ...@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
...@@ -178,9 +179,9 @@ class DatasetApi(Resource): ...@@ -178,9 +179,9 @@ class DatasetApi(Resource):
location='json', store_missing=False, location='json', store_missing=False,
type=_validate_description_length) type=_validate_description_length)
parser.add_argument('indexing_technique', type=str, location='json', parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST, choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, nullable=True,
help='Invalid indexing technique.') help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=( parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.') 'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
...@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True, parser.add_argument('indexing_technique', type=str, required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST, choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, location='json') nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
...@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
extract_settings = []
if args['info_list']['data_source_type'] == 'upload_file': if args['info_list']['data_source_type'] == 'upload_file':
file_ids = args['info_list']['file_info_list']['file_ids'] file_ids = args['info_list']['file_info_list']['file_ids']
file_details = db.session.query(UploadFile).filter( file_details = db.session.query(UploadFile).filter(
...@@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource):
if file_details is None: if file_details is None:
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() if file_details:
for file_detail in file_details:
try: extract_setting = ExtractSetting(
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, datasource_type="upload_file",
args['process_rule'], args['doc_form'], upload_file=file_detail,
args['doc_language'], args['dataset_id'], document_model=args['doc_form']
args['indexing_technique']) )
except LLMBadRequestError: extract_settings.append(extract_setting)
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif args['info_list']['data_source_type'] == 'notion_import': elif args['info_list']['data_source_type'] == 'notion_import':
notion_info_list = args['info_list']['notion_info_list']
indexing_runner = IndexingRunner() for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
try: for page in notion_info['pages']:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, extract_setting = ExtractSetting(
args['info_list']['notion_info_list'], datasource_type="notion_import",
args['process_rule'], args['doc_form'], notion_info={
args['doc_language'], args['dataset_id'], "notion_workspace_id": workspace_id,
args['indexing_technique']) "notion_obj_id": page['page_id'],
except LLMBadRequestError: "notion_page_type": page['type']
raise ProviderNotInitializeError( },
"No Embedding Model available. Please configure a valid provider " document_model=args['doc_form']
"in the Settings -> Model Provider.") )
except ProviderTokenNotInitError as ex: extract_settings.append(extract_setting)
raise ProviderNotInitializeError(ex.description)
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
return response, 200 return response, 200
...@@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') ...@@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
...@@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner ...@@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.document_fields import ( from fields.document_fields import (
...@@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource): ...@@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource):
req_data = request.args req_data = request.args
document_id = req_data.get('document_id') document_id = req_data.get('document_id')
# get default rules # get default rules
mode = DocumentService.DEFAULT_RULES['mode'] mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules'] rules = DocumentService.DEFAULT_RULES['rules']
...@@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource): ...@@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource):
if not file: if not file:
raise NotFound('File not found.') raise NotFound('File not found.')
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file,
document_model=document.doc_form
)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
data_process_rule_dict, None, data_process_rule_dict, document.doc_form,
'English', dataset_id) 'English', dataset_id)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
...@@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): ...@@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule = documents[0].dataset_process_rule data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict()
info_list = [] info_list = []
extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in ['completed', 'error']: if document.indexing_status in ['completed', 'error']:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
...@@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): ...@@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
} }
info_list.append(notion_info) info_list.append(notion_info)
if dataset.data_source_type == 'upload_file': if document.data_source_type == 'upload_file':
file_details = db.session.query(UploadFile).filter( file_id = data_source_info['upload_file_id']
UploadFile.tenant_id == current_user.current_tenant_id, file_detail = db.session.query(UploadFile).filter(
UploadFile.id.in_(info_list) UploadFile.tenant_id == current_user.current_tenant_id,
).all() UploadFile.id == file_id
).first()
if file_details is None: if file_detail is None:
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() extract_setting = ExtractSetting(
try: datasource_type="upload_file",
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, upload_file=file_detail,
data_process_rule_dict, None, document_model=document.doc_form
'English', dataset_id) )
except LLMBadRequestError: extract_settings.append(extract_setting)
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " elif document.data_source_type == 'notion_import':
"in the Settings -> Model Provider.") extract_setting = ExtractSetting(
except ProviderTokenNotInitError as ex: datasource_type="notion_import",
raise ProviderNotInitializeError(ex.description) notion_info={
elif dataset.data_source_type == 'notion_import': "notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['type']
},
document_model=document.doc_form
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
info_list, data_process_rule_dict, document.doc_form,
data_process_rule_dict, 'English', dataset_id)
None, 'English', dataset_id)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.") "in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response return response
......
import tempfile
from pathlib import Path
from typing import Optional, Union
import requests
from flask import current_app
from langchain.document_loaders import Docx2txtLoader, TextLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
return cls.load_from_file(file_path, return_text)
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None,
is_automatic: bool = False) -> Union[list[Document], str]:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
else MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
elif file_extension == '.eml':
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
elif file_extension == '.ppt':
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
elif file_extension == '.pptx':
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
elif file_extension == '.xml':
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
else:
# txt
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
else TextLoader(file_path, autodetect_encoding=True)
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
import logging
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class HTMLLoader(BaseLoader):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> list[Document]:
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
import logging
from typing import Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfLoader(BaseLoader):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> list[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.0625.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = PyPDFium2Loader(file_path=self._file_path).load()
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from langchain.schema import Document
from sqlalchemy import func from sqlalchemy import func
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
......
import logging import logging
from typing import Optional from typing import Optional
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.index.vector_index.vector_index import VectorIndex from core.rag.datasource.vdb.vector_factory import Vector
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
...@@ -45,17 +40,6 @@ class AnnotationReplyFeature: ...@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
embedding_provider_name = collection_binding_detail.provider_name embedding_provider_name = collection_binding_detail.provider_name
embedding_model_name = collection_binding_detail.model_name embedding_model_name = collection_binding_detail.model_name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=app_record.tenant_id,
provider=embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name
)
# get embedding model
embeddings = CacheEmbedding(model_instance)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_provider_name,
embedding_model_name, embedding_model_name,
...@@ -71,22 +55,14 @@ class AnnotationReplyFeature: ...@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
collection_binding_id=dataset_collection_binding.id collection_binding_id=dataset_collection_binding.id
) )
vector_index = VectorIndex( vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
documents = vector_index.search( documents = vector.search_by_vector(
query=query, query=query,
search_type='similarity_score_threshold', k=1,
search_kwargs={ score_threshold=score_threshold,
'k': 1, filter={
'score_threshold': score_threshold, 'group_id': [dataset.id]
'filter': {
'group_id': [dataset.id]
}
} }
) )
......
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
if indexing_technique == "high_quality":
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
provider=dataset.embedding_model_provider,
model=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif indexing_technique == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
@classmethod
def get_default_high_quality_index(cls, dataset: Dataset):
embeddings = OpenAIEmbeddings(openai_api_key=' ')
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
import json
import logging
from abc import abstractmethod
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores import VectorStore
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod
def get_index_name(self, dataset: Dataset) -> str:
raise NotImplementedError
@abstractmethod
def to_index_struct(self) -> dict:
raise NotImplementedError
@abstractmethod
def _get_vector_store(self) -> VectorStore:
raise NotImplementedError
@abstractmethod
def _get_vector_store_class(self) -> type:
raise NotImplementedError
@abstractmethod
def search_by_full_text_index(
self, query: str,
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
if search_type == 'similarity_score_threshold':
score_threshold = search_kwargs.get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
search_kwargs['score_threshold'] = .0
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
query, **search_kwargs
)
docs = []
for doc, similarity in docs_with_similarity:
doc.metadata['score'] = similarity
docs.append(doc)
return docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document], **kwargs):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if kwargs.get('duplicate_check', False):
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if self.dataset.collection_binding_id:
vector_store.delete_by_group_id(group_id)
else:
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
return False
def recreate_dataset(self, dataset: Dataset):
logging.info(f"Recreating dataset {dataset.id}")
try:
self.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
origin_index_struct = self.dataset.index_struct[:]
self.dataset.index_struct = None
if documents:
try:
self.create(documents)
except Exception as e:
self.dataset.index_struct = origin_index_struct
raise e
dataset.index_struct = json.dumps(self.to_index_struct())
db.session.commit()
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")
def create_qdrant_dataset(self, dataset: Dataset):
logging.info(f"create_qdrant_dataset {dataset.id}")
try:
self.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def update_qdrant_dataset(self, dataset: Dataset):
logging.info(f"update_qdrant_dataset {dataset.id}")
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).first()
if segment:
try:
exist = self.text_exists(segment.index_node_id)
if exist:
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.add_texts(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")
self.delete()
dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()
logging.info(f"Dataset {dataset.id} recreate successfully.")
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from models.dataset import Dataset
class MilvusConfig(BaseModel):
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'milvus'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
connection_args=self._client_config.to_milvus_params(),
index_params=index_params
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content'
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),
embedding_function=self._embeddings,
connection_args=self._client_config.to_milvus_params()
)
def _get_vector_store_class(self) -> type:
return MilvusVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_document_id(document_id)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_metadata_field(key, value)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_doc_ids(doc_ids)
vector_store.del_texts({
'filter': f' id in {ids}'
})
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
import os
from typing import Any, Optional, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
}
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
return dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
))
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
for node_id in ids:
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
))
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
return vector_store.similarity_search_by_bm25(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
],
), kwargs.get('top_k', 2))
import json
from flask import current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.base import BaseVectorIndex
from extensions.ext_database import db
from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
self._attributes = attributes
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError("Vector store must be specified.")
if vector_type == "weaviate":
from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings,
attributes=attributes
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
return QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
),
embeddings=embeddings
)
elif vector_type == "milvus":
from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex
return MilvusVectorIndex(
dataset=dataset,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name):
if self._vector_index is not None:
method = getattr(self._vector_index, name)
if callable(method):
return method
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
from typing import Any, Optional, cast
import requests
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = self._attributes
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": [key],
"valueText": value
})
def delete_by_group_id(self, group_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
...@@ -9,20 +9,20 @@ from typing import Optional, cast ...@@ -9,20 +9,20 @@ from typing import Optional, cast
from flask import Flask, current_app from flask import Flask, current_app
from flask_login import current_user from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatasetDocumentStore from core.docstore.dataset_docstore import DatasetDocumentStore
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from core.generator.llm_generator import LLMGenerator from core.generator.llm_generator import LLMGenerator
from core.index.index import IndexBuilder
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
...@@ -31,7 +31,6 @@ from libs import helper ...@@ -31,7 +31,6 @@ from libs import helper
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceBinding
from services.feature_service import FeatureService from services.feature_service import FeatureService
...@@ -57,38 +56,19 @@ class IndexingRunner: ...@@ -57,38 +56,19 @@ class IndexingRunner:
processing_rule = db.session.query(DatasetProcessRule). \ processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first() first()
index_type = dataset_document.doc_form
# load file index_processor = IndexProcessorFactory(index_type).init_index_processor()
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') # extract
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
# get embedding model instance
embedding_model_instance = None # transform
if dataset.indexing_technique == 'high_quality': documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
if dataset.embedding_model_provider: # save segment
embedding_model_instance = self.model_manager.get_model_instance( self._load_segments(dataset, dataset_document, documents)
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider, # load
model_type=ModelType.TEXT_EMBEDDING, self._load(
model=dataset.embedding_model index_processor=index_processor,
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
self._build_index(
dataset=dataset, dataset=dataset,
dataset_document=dataset_document, dataset_document=dataset_document,
documents=documents documents=documents
...@@ -134,39 +114,19 @@ class IndexingRunner: ...@@ -134,39 +114,19 @@ class IndexingRunner:
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first() first()
# load file index_type = dataset_document.doc_form
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
# get embedding model instance text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents # transform
documents = self._step_split( documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
text_docs=text_docs, # save segment
splitter=splitter, self._load_segments(dataset, dataset_document, documents)
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index # load
self._build_index( self._load(
index_processor=index_processor,
dataset=dataset, dataset=dataset,
dataset_document=dataset_document, dataset_document=dataset_document,
documents=documents documents=documents
...@@ -220,7 +180,15 @@ class IndexingRunner: ...@@ -220,7 +180,15 @@ class IndexingRunner:
documents.append(document) documents.append(document)
# build index # build index
self._build_index( # get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
self._load(
index_processor=index_processor,
dataset=dataset, dataset=dataset,
dataset_document=dataset_document, dataset_document=dataset_document,
documents=documents documents=documents
...@@ -239,16 +207,16 @@ class IndexingRunner: ...@@ -239,16 +207,16 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow() dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit() db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict, def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict: indexing_technique: str = 'economy') -> dict:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
# check document limit # check document limit
features = FeatureService.get_features(tenant_id) features = FeatureService.get_features(tenant_id)
if features.billing.enabled: if features.billing.enabled:
count = len(file_details) count = len(extract_settings)
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit: if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
...@@ -284,16 +252,18 @@ class IndexingRunner: ...@@ -284,16 +252,18 @@ class IndexingRunner:
total_segments = 0 total_segments = 0
total_price = 0 total_price = 0
currency = 'USD' currency = 'USD'
for file_detail in file_details: index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
for extract_setting in extract_settings:
# extract
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
all_text_docs.extend(text_docs)
processing_rule = DatasetProcessRule( processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"]) rules=json.dumps(tmp_processing_rule["rules"])
) )
# load data from file
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
# get splitter # get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance) splitter = self._get_splitter(processing_rule, embedding_model_instance)
...@@ -305,7 +275,6 @@ class IndexingRunner: ...@@ -305,7 +275,6 @@ class IndexingRunner:
) )
total_segments += len(documents) total_segments += len(documents)
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
...@@ -364,154 +333,8 @@ class IndexingRunner: ...@@ -364,154 +333,8 @@ class IndexingRunner:
"preview": preview_texts "preview": preview_texts
} }
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, -> list[Document]:
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
# check document limit
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
count = len(notion_info_list)
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
embedding_model_instance = None
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == 'high_quality':
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
# load data from notion
tokens = 0
preview_texts = []
total_segments = 0
total_price = 0
currency = 'USD'
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
for page in notion_info['pages']:
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page['page_id'],
notion_page_type=page['type']
)
documents = loader.load()
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._split_to_documents_for_estimate(
text_docs=documents,
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(documents)
embedding_model_type_instance = None
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' and embedding_model_type_instance:
tokens += embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
texts=[document.page_content]
)
if doc_form and doc_form == 'qa_model':
model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
price_info = model_type_instance.get_price(
model=model_instance.model,
credentials=model_instance.credentials,
price_type=PriceType.INPUT,
tokens=total_segments * 2000,
)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(price_info.total_amount),
"currency": price_info.currency,
"qa_preview": document_qa_list,
"preview": preview_texts
}
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
total_price = '{:f}'.format(embedding_price_info.total_amount)
currency = embedding_price_info.currency
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": total_price,
"currency": currency,
"preview": preview_texts
}
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
# load file # load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]: if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return [] return []
...@@ -527,11 +350,27 @@ class IndexingRunner: ...@@ -527,11 +350,27 @@ class IndexingRunner:
one_or_none() one_or_none()
if file_detail: if file_detail:
text_docs = FileExtractor.load(file_detail, is_automatic=automatic) extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
elif dataset_document.data_source_type == 'notion_import': elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document) if (not data_source_info or 'notion_workspace_id' not in data_source_info
text_docs = loader.load() or 'notion_page_id' not in data_source_info):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['notion_page_type'],
"document": dataset_document
},
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
# update document status to splitting # update document status to splitting
self._update_document_index_status( self._update_document_index_status(
document_id=dataset_document.id, document_id=dataset_document.id,
...@@ -545,8 +384,6 @@ class IndexingRunner: ...@@ -545,8 +384,6 @@ class IndexingRunner:
# replace doc id to document model id # replace doc id to document model id
text_docs = cast(list[Document], text_docs) text_docs = cast(list[Document], text_docs)
for text_doc in text_docs: for text_doc in text_docs:
# remove invalid symbol
text_doc.page_content = self.filter_string(text_doc.page_content)
text_doc.metadata['document_id'] = dataset_document.id text_doc.metadata['document_id'] = dataset_document.id
text_doc.metadata['dataset_id'] = dataset_document.dataset_id text_doc.metadata['dataset_id'] = dataset_document.dataset_id
...@@ -787,12 +624,12 @@ class IndexingRunner: ...@@ -787,12 +624,12 @@ class IndexingRunner:
for q, a in matches if q and a for q, a in matches if q and a
] ]
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None: def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
dataset_document: DatasetDocument, documents: list[Document]) -> None:
""" """
Build the index for the document. insert index and update document/segment status to completed
""" """
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
embedding_model_instance = None embedding_model_instance = None
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
embedding_model_instance = self.model_manager.get_model_instance( embedding_model_instance = self.model_manager.get_model_instance(
...@@ -825,13 +662,8 @@ class IndexingRunner: ...@@ -825,13 +662,8 @@ class IndexingRunner:
) )
for document in chunk_documents for document in chunk_documents
) )
# load index
# save vector index index_processor.load(dataset, chunk_documents)
if vector_index:
vector_index.add_texts(chunk_documents)
# save keyword index
keyword_table_index.add_texts(chunk_documents)
document_ids = [document.metadata['doc_id'] for document in chunk_documents] document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
...@@ -911,14 +743,64 @@ class IndexingRunner: ...@@ -911,14 +743,64 @@ class IndexingRunner:
) )
documents.append(document) documents.append(document)
# save vector index # save vector index
index = IndexBuilder.get_index(dataset, 'high_quality') index_type = dataset.doc_form
if index: index_processor = IndexProcessorFactory(index_type).init_index_processor()
index.add_texts(documents, duplicate_check=True) index_processor.load(dataset, documents)
# save keyword index def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
index = IndexBuilder.get_index(dataset, 'economy') text_docs: list[Document], process_rule: dict) -> list[Document]:
if index: # get embedding model instance
index.add_texts(documents) embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
process_rule=process_rule)
return documents
def _load_segments(self, dataset, dataset_document, documents):
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset,
user_id=dataset_document.created_by,
document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.utcnow()
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
}
)
# update segment status to indexing
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
)
pass
class DocumentIsPausedException(Exception): class DocumentIsPausedException(Exception):
......
import re
class CleanProcessor:
@classmethod
def clean(cls, text: str, process_rule: dict) -> str:
# default clean
# remove invalid symbol
text = re.sub(r'<\|', '<', text)
text = re.sub(r'\|>', '>', text)
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
# Unicode U+FFFE
text = re.sub('\uFFFE', '', text)
rules = process_rule['rules'] if process_rule else None
if 'pre_processing_rules' in rules:
pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules:
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
# Remove extra spaces
pattern = r'\n{3,}'
text = re.sub(pattern, '\n\n', text)
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
text = re.sub(pattern, ' ', text)
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
# Remove email
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
text = re.sub(pattern, '', text)
# Remove URL
pattern = r'https?://[^\s]+'
text = re.sub(pattern, '', text)
return text
def filter_string(self, text):
return text
"""Abstract interface for document cleaner implementations."""
from abc import ABC, abstractmethod
class BaseCleaner(ABC):
"""Interface for clean chunk content.
"""
@abstractmethod
def clean(self, content: str):
raise NotImplementedError
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_extra_whitespace
# Returns "ITEM 1A: RISK FACTORS"
return clean_extra_whitespace(content)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
import re
from unstructured.cleaners.core import group_broken_paragraphs
para_split_re = re.compile(r"(\s*\n\s*){3}")
return group_broken_paragraphs(content, paragraph_split=para_split_re)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_non_ascii_chars
# Returns "This text containsnon-ascii characters!"
return clean_non_ascii_chars(content)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""Replaces unicode quote characters, such as the \x91 character in a string."""
from unstructured.cleaners.core import replace_unicode_quotes
return replace_unicode_quotes(content)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredTranslateTextCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.translate import translate_text
return translate_text(content)
from typing import Optional
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.models.document import Document
from core.rerank.rerank import RerankRunner
class DataPostProcessor:
"""Interface for data post-processing document.
"""
def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)
return documents
def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
if reranking_model:
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return None
return RerankRunner(rerank_model_instance)
return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
if reorder_enabled:
return ReorderRunner()
return None
from langchain.schema import Document
class ReorderRunner:
def run(self, documents: list[Document]) -> list[Document]:
# Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
odd_elements = documents[::2]
# Retrieve elements from even indices (1, 3, 5, etc.) of the documents list
even_elements = documents[1::2]
# Reverse the list of elements from even indices
even_elements_reversed = even_elements[::-1]
new_documents = odd_elements + even_elements_reversed
return new_documents
from abc import ABC, abstractmethod
class Embeddings(ABC):
"""Interface for embedding models."""
@abstractmethod
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronous Embed query text."""
raise NotImplementedError
...@@ -2,11 +2,11 @@ import json ...@@ -2,11 +2,11 @@ import json
from collections import defaultdict from collections import defaultdict
from typing import Any, Optional from typing import Any, Optional
from langchain.schema import BaseRetriever, Document from pydantic import BaseModel
from pydantic import BaseModel, Extra, Field
from core.index.base import BaseIndex from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
...@@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel): ...@@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10 max_keywords_per_chunk: int = 10
class KeywordTableIndex(BaseIndex): class Jieba(BaseKeyword):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): def __init__(self, dataset: Dataset):
super().__init__(dataset) super().__init__(dataset)
self._config = config self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {} keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts: for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
return self return self
...@@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex): ...@@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex):
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
for text in texts: keywords_list = kwargs.get('keywords_list', None)
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
...@@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex): ...@@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
def delete_by_metadata_field(self, key: str, value: str):
pass
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
def search( def search(
self, query: str, self, query: str,
**kwargs: Any **kwargs: Any
) -> list[Document]: ) -> list[Document]:
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} k = kwargs.get('top_k', 4)
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
...@@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex): ...@@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex):
db.session.delete(dataset_keyword_table) db.session.delete(dataset_keyword_table)
db.session.commit() db.session.commit()
def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table): def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = { keyword_table_dict = {
'__type__': 'keyword_table', '__type__': 'keyword_table',
...@@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex): ...@@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex):
).first() ).first()
if document_segment: if document_segment:
document_segment.keywords = keywords document_segment.keywords = keywords
db.session.add(document_segment)
db.session.commit() db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: list[str]): def create_segment_keywords(self, node_id: str, keywords: list[str]):
...@@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex): ...@@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> list[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> list[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")
class SetEncoder(json.JSONEncoder): class SetEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, set): if isinstance(obj, set):
......
...@@ -3,7 +3,7 @@ import re ...@@ -3,7 +3,7 @@ import re
import jieba import jieba
from jieba.analyse import default_tfidf from jieba.analyse import default_tfidf
from core.index.keyword_table_index.stopwords import STOPWORDS from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler: class JiebaKeywordTableHandler:
......
...@@ -3,22 +3,17 @@ from __future__ import annotations ...@@ -3,22 +3,17 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from langchain.schema import BaseRetriever, Document from core.rag.models.document import Document
from models.dataset import Dataset from models.dataset import Dataset
class BaseIndex(ABC): class BaseKeyword(ABC):
def __init__(self, dataset: Dataset): def __init__(self, dataset: Dataset):
self.dataset = dataset self.dataset = dataset
@abstractmethod @abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
raise NotImplementedError
@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -34,31 +29,18 @@ class BaseIndex(ABC): ...@@ -34,31 +29,18 @@ class BaseIndex(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_document_id(self, document_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError raise NotImplementedError
@abstractmethod def delete(self) -> None:
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def search( def search(
self, query: str, self, query: str,
**kwargs: Any **kwargs: Any
) -> list[Document]: ) -> list[Document]:
raise NotImplementedError raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts: for text in texts:
doc_id = text.metadata['doc_id'] doc_id = text.metadata['doc_id']
......
from typing import Any, cast
from flask import current_app
from core.rag.datasource.keyword.jieba.jieba import Jieba
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from models.dataset import Dataset
class Keyword:
def __init__(self, dataset: Dataset):
self._dataset = dataset
self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword:
config = cast(dict, current_app.config)
keyword_type = config.get('KEYWORD_STORE')
if not keyword_type:
raise ValueError("Keyword store must be specified.")
if keyword_type == "jieba":
return Jieba(
dataset=self._dataset
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def create(self, texts: list[Document], **kwargs):
self._keyword_processor.create(texts, **kwargs)
def add_texts(self, texts: list[Document], **kwargs):
self._keyword_processor.add_texts(texts, **kwargs)
def text_exists(self, id: str) -> bool:
return self._keyword_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
self._keyword_processor.delete_by_ids(ids)
def delete_by_document_id(self, document_id: str) -> None:
self._keyword_processor.delete_by_document_id(document_id)
def delete(self) -> None:
self._keyword_processor.delete()
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
return self._keyword_processor.search(query, **kwargs)
def __getattr__(self, name):
if self._keyword_processor is not None:
method = getattr(self._keyword_processor, name)
if callable(method):
return method
raise AttributeError(f"'Keyword' object has no attribute '{name}'")
import threading
from typing import Optional from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from langchain.embeddings.base import Embeddings from flask_login import current_user
from core.index.vector_index.vector_index import VectorIndex from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.model_manager import ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword
from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.vdb.vector_factory import Vector
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
...@@ -25,48 +24,109 @@ default_retrieval_model = { ...@@ -25,48 +24,109 @@ default_retrieval_model = {
class RetrievalService: class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
all_documents = []
threads = []
# retrieval_model source with keyword
if retrival_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k
})
threads.append(keyword_thread)
keyword_thread.start()
# retrieval_model source with semantic
if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
'retrival_method': retrival_method
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'retrival_method': retrival_method,
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrival_method == 'hybrid_search':
data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
)
return all_documents
@classmethod
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
keyword = Keyword(
dataset=dataset
)
documents = keyword.search(
query,
k=top_k
)
all_documents.extend(documents)
@classmethod @classmethod
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, retrival_method: str):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter( dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id Dataset.id == dataset_id
).first() ).first()
vector_index = VectorIndex( vector = Vector(
dataset=dataset, dataset=dataset
config=current_app.config,
embeddings=embeddings
) )
documents = vector_index.search( documents = vector.search_by_vector(
query, query,
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
search_kwargs={ k=top_k,
'k': top_k, score_threshold=score_threshold,
'score_threshold': score_threshold, filter={
'filter': { 'group_id': [dataset.id]
'group_id': [dataset.id]
}
} }
) )
if documents: if documents:
if reranking_model and search_method == 'semantic_search': if reranking_model and retrival_method == 'semantic_search':
try: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
model_manager = ModelManager() all_documents.extend(data_post_processor.invoke(
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query, query=query,
documents=documents, documents=documents,
score_threshold=score_threshold, score_threshold=score_threshold,
...@@ -78,38 +138,24 @@ class RetrievalService: ...@@ -78,38 +138,24 @@ class RetrievalService:
@classmethod @classmethod
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, retrival_method: str):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter( dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id Dataset.id == dataset_id
).first() ).first()
vector_index = VectorIndex( vector_processor = Vector(
dataset=dataset, dataset=dataset,
config=current_app.config,
embeddings=embeddings
) )
documents = vector_index.search_by_full_text_index( documents = vector_processor.search_by_full_text(
query, query,
search_type='similarity_score_threshold',
top_k=top_k top_k=top_k
) )
if documents: if documents:
if reranking_model and search_method == 'full_text_search': if reranking_model and retrival_method == 'full_text_search':
try: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
model_manager = ModelManager() all_documents.extend(data_post_processor.invoke(
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query, query=query,
documents=documents, documents=documents,
score_threshold=score_threshold, score_threshold=score_threshold,
......
from enum import Enum
class Field(Enum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
TEXT_KEY = "text"
PRIMARY_KEY = " id"
import logging
from typing import Any, Optional
from uuid import uuid4
from pydantic import BaseModel, root_validator
from pymilvus import MilvusClient, MilvusException, connections
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVector(BaseVector):
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = 'Session'
self._fields = []
def get_type(self) -> str:
return 'milvus'
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
metadatas = [d.metadata for d in texts]
# Grab the existing collection if it exists
from pymilvus import utility
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
if not utility.has_collection(self._collection_name, using=alias):
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata
}
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i:i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def get_ids_by_metadata_field(self, key: str, value: str):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=["id"])
if result:
return [item["id"] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, doc_ids: list[str]) -> None:
self._client.delete(collection_name=self._collection_name, pks=doc_ids)
def delete(self) -> None:
from pymilvus import utility
utility.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result['entity'].get(Field.METADATA_KEY.value)
metadata['score'] = result['distance']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if result['distance'] > score_threshold:
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
) -> str:
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
fields.append(
FieldSchema(
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(fields)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
self._fields.remove(Field.PRIMARY_KEY.value)
# Create the collection
collection_name = self._collection_name
self._client.create_collection_with_schema(collection_name=collection_name,
schema=schema, index_param=index_params,
consistency_level=self._consistency_level)
return collection_name
def _init_client(self, config) -> MilvusClient:
if config.secure:
uri = "https://" + str(config.host) + ":" + str(config.port)
else:
uri = "http://" + str(config.host) + ":" + str(config.port)
client = MilvusClient(uri=uri, user=config.user, password=config.password)
return client
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import qdrant_client
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
FilterSelector,
HnswConfigDiff,
PayloadSchemaType,
TextIndexParams,
TextIndexType,
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
}
class QdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
self._distance_func = distance_func.upper()
self._group_id = group_id
def get_type(self) -> str:
return 'qdrant'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# get embedding vector size
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
)
# create payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
field_schema=text_index_params)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, metadatas, uuids, 64, self._group_id
):
self._client.upsert(
collection_name=self._collection_name, points=points
)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = list(islice(embeddings_iterator, batch_size))
points = [
rest.PointStruct(
id=point_id,
vector=vector,
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
Field.GROUP_KEY.value,
),
)
]
yield batch_ids, points
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def delete(self):
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
for node_id in ids:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def text_exists(self, id: str) -> bool:
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[id]
)
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
query_filter=filter,
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", .0)
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
if result.score > score_threshold:
metadata['score'] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
metadata=metadata,
)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar by bm25.
Returns:
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get('top_k', 2),
with_payload=True,
with_vectors=True
)
results = response[0]
documents = []
for result in results:
if result:
documents.append(self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
))
return documents
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name
@abstractmethod
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
def search_by_vector(
self,
query_vector: list[float],
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
from typing import Any, cast
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class Vector:
def __init__(self, dataset: Dataset, attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector:
config = cast(dict, current_app.config)
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError("Vector store must be specified.")
if vector_type == "weaviate":
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
attributes=self._attributes
)
elif vector_type == "qdrant":
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
if self._dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self._dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return QdrantVector(
collection_name=collection_name,
group_id=self._dataset.id,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
)
)
elif vector_type == "milvus":
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
)
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def create(self, texts: list = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(
texts=texts,
embeddings=embeddings,
**kwargs
)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get('duplicate_check', False):
documents = self._filter_duplicate_texts(documents)
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.add_texts(
documents=documents,
embeddings=embeddings,
**kwargs
)
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
self._vector_processor.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(
self, query: str,
**kwargs: Any
) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self) -> None:
self._vector_processor.delete()
def _get_embeddings(self) -> Embeddings:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
)
return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def __getattr__(self, name):
if self._vector_processor is not None:
method = getattr(self._vector_processor, name)
if callable(method):
return method
raise AttributeError(f"'vector_processor' object has no attribute '{name}'")
import datetime
from typing import Any, Optional
import requests
import weaviate
from pydantic import BaseModel, root_validator
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
super().__init__(collection_name)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_collection_name(self, dataset: Dataset) -> str:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
schema = self._default_schema(self._collection_name)
# check whether the index already exists
if not self._client.schema.contains(schema):
# create collection
self._client.schema.create_class(schema)
# create vector
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
ids = []
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
if metadatas is not None:
for key, val in metadatas[i].items():
data_properties[key] = self._json_serializable(val)
batch.add_data_object(
data_object=data_properties,
class_name=self._collection_name,
uuid=uuids[i],
vector=embeddings[i] if embeddings else None,
)
ids.append(uuids[i])
return ids
def delete_by_metadata_field(self, key: str, value: str):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
def delete(self):
self._client.schema.delete_class(self._collection_name)
def text_exists(self, id: str) -> bool:
collection_name = self._collection_name
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][collection_name]
if len(entries) == 0:
return False
return True
def delete_by_ids(self, ids: list[str]) -> None:
self._client.data_object.delete(
ids,
class_name=self._collection_name
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
collection_name = self._collection_name
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
result = (
query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4))
.with_additional(["vector", "distance"])
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
# check score threshold
if score > score_threshold:
doc.metadata['score'] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
docs.append(Document(page_content=text, metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:
return {
"class": index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
def _json_serializable(self, value: Any) -> Any:
if isinstance(value, datetime.datetime):
return value.isoformat()
return value
"""Schema for Blobs and Blob Loaders.
The goal is to facilitate decoupling of content loading from content parsing code.
In addition, content loading code should provide a lazy loading interface by default.
"""
from __future__ import annotations
import contextlib
import mimetypes
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Mapping
from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Optional, Union
from pydantic import BaseModel, root_validator
PathLike = Union[str, PurePath]
class Blob(BaseModel):
"""A blob is used to represent raw data by either reference or value.
Provides an interface to materialize the blob in different representations, and
help to decouple the development of data loaders from the downstream parsing of
the raw data.
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
"""
data: Union[bytes, str, None] # Raw data
mimetype: Optional[str] = None # Not to be confused with a file extension
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
# Location where the original content was found
# Represent location on the local file system
# Useful for situations where downstream code assumes it must work with file paths
# rather than in-memory content.
path: Optional[PathLike] = None
class Config:
arbitrary_types_allowed = True
frozen = True
@property
def source(self) -> Optional[str]:
"""The source location of the blob as string if known otherwise none."""
return str(self.path) if self.path else None
@root_validator(pre=True)
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
"""Verify that either data or path is provided."""
if "data" not in values and "path" not in values:
raise ValueError("Either data or path must be provided")
return values
def as_string(self) -> str:
"""Read data as a string."""
if self.data is None and self.path:
with open(str(self.path), encoding=self.encoding) as f:
return f.read()
elif isinstance(self.data, bytes):
return self.data.decode(self.encoding)
elif isinstance(self.data, str):
return self.data
else:
raise ValueError(f"Unable to get string for blob {self}")
def as_bytes(self) -> bytes:
"""Read data as bytes."""
if isinstance(self.data, bytes):
return self.data
elif isinstance(self.data, str):
return self.data.encode(self.encoding)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
return f.read()
else:
raise ValueError(f"Unable to get bytes for blob {self}")
@contextlib.contextmanager
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
"""Read data as a byte stream."""
if isinstance(self.data, bytes):
yield BytesIO(self.data)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
yield f
else:
raise NotImplementedError(f"Unable to convert blob {self}")
@classmethod
def from_path(
cls,
path: PathLike,
*,
encoding: str = "utf-8",
mime_type: Optional[str] = None,
guess_type: bool = True,
) -> Blob:
"""Load the blob from a path like object.
Args:
path: path like object to file to be read
encoding: Encoding to use if decoding the bytes into a string
mime_type: if provided, will be set as the mime-type of the data
guess_type: If True, the mimetype will be guessed from the file extension,
if a mime-type was not provided
Returns:
Blob instance
"""
if mime_type is None and guess_type:
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
else:
_mimetype = mime_type
# We do not load the data immediately, instead we treat the blob as a
# reference to the underlying data.
return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path)
@classmethod
def from_data(
cls,
data: Union[str, bytes],
*,
encoding: str = "utf-8",
mime_type: Optional[str] = None,
path: Optional[str] = None,
) -> Blob:
"""Initialize the blob from in-memory data.
Args:
data: the in-memory data associated with the blob
encoding: Encoding to use if decoding the bytes into a string
mime_type: if provided, will be set as the mime-type of the data
path: if provided, will be set as the source from which the data came
Returns:
Blob instance
"""
return cls(data=data, mimetype=mime_type, encoding=encoding, path=path)
def __repr__(self) -> str:
"""Define the blob representation."""
str_repr = f"Blob {id(self)}"
if self.source:
str_repr += f" {self.source}"
return str_repr
class BlobLoader(ABC):
"""Abstract interface for blob loaders implementation.
Implementer should be able to load raw content from a datasource system according
to some criteria and return the raw content lazily as a stream of blobs.
"""
@abstractmethod
def yield_blobs(
self,
) -> Iterable[Blob]:
"""A lazy loader for raw data represented by LangChain's Blob object.
Returns:
A generator over blobs
"""
"""Abstract interface for document loader implementations."""
import csv
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class CSVExtractor(BaseExtractor):
"""Load CSV files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None,
csv_args: Optional[dict] = None,
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
self.source_column = source_column
self.csv_args = csv_args or {}
def extract(self) -> list[Document]:
"""Load data into document objects."""
try:
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_filze_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
return docs
def _read_from_file(self, csvfile) -> list[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs
from enum import Enum
class DatasourceType(Enum):
FILE = "upload_file"
NOTION = "notion_import"
from pydantic import BaseModel
from models.dataset import Document
from models.model import UploadFile
class NotionInfo(BaseModel):
"""
Notion import info.
"""
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str
document: Document = None
class Config:
arbitrary_types_allowed = True
def __init__(self, **data) -> None:
super().__init__(**data)
class ExtractSetting(BaseModel):
"""
Model class for provider response.
"""
datasource_type: str
upload_file: UploadFile = None
notion_info: NotionInfo = None
document_model: str = None
class Config:
arbitrary_types_allowed = True
def __init__(self, **data) -> None:
super().__init__(**data)
import logging """Abstract interface for document loader implementations."""
from typing import Optional
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from openpyxl.reader.excel import load_workbook from openpyxl.reader.excel import load_workbook
logger = logging.getLogger(__name__) from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class ExcelLoader(BaseLoader): class ExcelExtractor(BaseExtractor):
"""Load xlxs files. """Load Excel files.
Args: Args:
...@@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader): ...@@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader):
""" """
def __init__( def __init__(
self, self,
file_path: str file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def load(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load from file path."""
data = [] data = []
keys = [] keys = []
wb = load_workbook(filename=self._file_path, read_only=True) wb = load_workbook(filename=self._file_path, read_only=True)
......
import tempfile
from pathlib import Path
from typing import Union
import requests
from flask import current_app
from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class ExtractProcessor:
@classmethod
def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
-> Union[list[Document], str]:
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=upload_file,
document_model='text_model'
)
if return_text:
delimiter = '\n'
return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
else:
return cls.extract(extract_setting, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
extract_setting = ExtractSetting(
datasource_type="upload_file",
document_model='text_model'
)
if return_text:
delimiter = '\n'
return delimiter.join([document.page_content for document in cls.extract(
extract_setting=extract_setting, file_path=file_path)])
else:
return cls.extract(extract_setting=extract_setting, file_path=file_path)
@classmethod
def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
file_path: str = None) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE.value:
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
upload_file: UploadFile = extract_setting.upload_file
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
extractor = ExcelExtractor(file_path)
elif file_extension == '.pdf':
extractor = PdfExtractor(file_path)
elif file_extension in ['.md', '.markdown']:
extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
else MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
elif file_extension == '.eml':
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
elif file_extension == '.ppt':
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url)
elif file_extension == '.pptx':
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
elif file_extension == '.xml':
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
else:
# txt
extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
else TextExtractor(file_path, autodetect_encoding=True)
else:
if file_extension == '.xlsx':
extractor = ExcelExtractor(file_path)
elif file_extension == '.pdf':
extractor = PdfExtractor(file_path)
elif file_extension in ['.md', '.markdown']:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
extractor = WordExtractor(file_path)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
else:
# txt
extractor = TextExtractor(file_path, autodetect_encoding=True)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document
)
return extractor.extract()
else:
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
class BaseExtractor(ABC):
"""Interface for extract files.
"""
@abstractmethod
def extract(self):
raise NotImplementedError
"""Document loader helpers."""
import concurrent.futures
from typing import NamedTuple, Optional, cast
class FileEncoding(NamedTuple):
"""A file encoding as the NamedTuple."""
encoding: Optional[str]
"""The encoding of the file."""
confidence: float
"""The confidence of the encoding."""
language: Optional[str]
"""The language of the file."""
def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
"""Try to detect the file encoding.
Returns a list of `FileEncoding` tuples with the detected encodings ordered
by confidence.
Args:
file_path: The path to the file to detect the encoding for.
timeout: The timeout in seconds for the encoding detection.
"""
import chardet
def read_and_detect(file_path: str) -> list[dict]:
with open(file_path, "rb") as f:
rawdata = f.read()
return cast(list[dict], chardet.detect_all(rawdata))
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(read_and_detect, file_path)
try:
encodings = future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
raise TimeoutError(
f"Timeout reached while detecting encoding for {file_path}"
)
if all(encoding["encoding"] is None for encoding in encodings):
raise RuntimeError(f"Could not detect encoding for {file_path}")
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
import csv """Abstract interface for document loader implementations."""
import logging
from typing import Optional from typing import Optional
from langchain.document_loaders import CSVLoader as LCCSVLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.document_loaders.helpers import detect_file_encodings from core.rag.extractor.helpers import detect_file_encodings
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class HtmlExtractor(BaseExtractor):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
class CSVLoader(LCCSVLoader):
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None, source_column: Optional[str] = None,
csv_args: Optional[dict] = None, csv_args: Optional[dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
): ):
self.file_path = file_path """Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
self.source_column = source_column self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {} self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
def load(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load data into document objects.""" """Load data into document objects."""
try: try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile: with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile) docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
if self.autodetect_encoding: if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path) detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings: for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try: try:
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile) docs = self._read_from_file(csvfile)
break break
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue
else: else:
raise RuntimeError(f"Error loading {self.file_path}") from e raise RuntimeError(f"Error loading {self._file_path}") from e
return docs return docs
def _read_from_file(self, csvfile): def _read_from_file(self, csvfile) -> list[Document]:
docs = [] docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader): for i, row in enumerate(csv_reader):
......
import logging """Abstract interface for document loader implementations."""
import re import re
from typing import Optional, cast from typing import Optional, cast
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.document_loaders.helpers import detect_file_encodings from core.rag.extractor.helpers import detect_file_encodings
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class MarkdownExtractor(BaseExtractor):
class MarkdownLoader(BaseLoader): """Load Markdown files.
"""Load md files.
Args: Args:
file_path: Path to the file to load. file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
remove_images: Whether to remove images from the text.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
""" """
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
remove_hyperlinks: bool = True, remove_hyperlinks: bool = True,
remove_images: bool = True, remove_images: bool = True,
encoding: Optional[str] = None, encoding: Optional[str] = None,
autodetect_encoding: bool = True, autodetect_encoding: bool = True,
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
...@@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader): ...@@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader):
self._encoding = encoding self._encoding = encoding
self._autodetect_encoding = autodetect_encoding self._autodetect_encoding = autodetect_encoding
def load(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load from file path."""
tups = self.parse_tups(self._file_path) tups = self.parse_tups(self._file_path)
documents = [] documents = []
for header, value in tups: for header, value in tups:
...@@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader): ...@@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader):
if self._autodetect_encoding: if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath) detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings: for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try: try:
with open(filepath, encoding=encoding.encoding) as f: with open(filepath, encoding=encoding.encoding) as f:
content = f.read() content = f.read()
......
...@@ -4,9 +4,10 @@ from typing import Any, Optional ...@@ -4,9 +4,10 @@ from typing import Any, Optional
import requests import requests
from flask import current_app from flask import current_app
from langchain.document_loaders.base import BaseLoader from flask_login import current_user
from langchain.schema import Document from langchain.schema import Document
from core.rag.extractor.extractor_base import BaseExtractor
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document as DocumentModel from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding from models.source import DataSourceBinding
...@@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" ...@@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
class NotionLoader(BaseLoader): class NotionExtractor(BaseExtractor):
def __init__( def __init__(
self, self,
notion_access_token: str,
notion_workspace_id: str, notion_workspace_id: str,
notion_obj_id: str, notion_obj_id: str,
notion_page_type: str, notion_page_type: str,
document_model: Optional[DocumentModel] = None document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None
): ):
self._notion_access_token = None
self._document_model = document_model self._document_model = document_model
self._notion_workspace_id = notion_workspace_id self._notion_workspace_id = notion_workspace_id
self._notion_obj_id = notion_obj_id self._notion_obj_id = notion_obj_id
self._notion_page_type = notion_page_type self._notion_page_type = notion_page_type
self._notion_access_token = notion_access_token if notion_access_token:
self._notion_access_token = notion_access_token
if not self._notion_access_token: else:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
if integration_token is None: self._notion_workspace_id)
raise ValueError( if not self._notion_access_token:
"Must specify `integration_token` or set environment " integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
"variable `NOTION_INTEGRATION_TOKEN`." if integration_token is None:
) raise ValueError(
"Must specify `integration_token` or set environment "
self._notion_access_token = integration_token "variable `NOTION_INTEGRATION_TOKEN`."
)
@classmethod
def from_document(cls, document_model: DocumentModel): self._notion_access_token = integration_token
data_source_info = document_model.data_source_info_dict
if not data_source_info or 'notion_page_id' not in data_source_info \ def extract(self) -> list[Document]:
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
notion_workspace_id = data_source_info['notion_workspace_id']
notion_obj_id = data_source_info['notion_page_id']
notion_page_type = data_source_info['type']
notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
return cls(
notion_access_token=notion_access_token,
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_obj_id,
notion_page_type=notion_page_type,
document_model=document_model
)
def load(self) -> list[Document]:
self.update_last_edited_time( self.update_last_edited_time(
self._document_model self._document_model
) )
......
"""Abstract interface for document loader implementations."""
from collections.abc import Iterator
from typing import Optional
from core.rag.extractor.blod.blod import Blob
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
class PdfExtractor(BaseExtractor):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
file_cache_key: Optional[str] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._file_cache_key:
try:
text = storage.load(self._file_cache_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = list(self.load())
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents
def load(
self,
) -> Iterator[Document]:
"""Lazy load given path as pages."""
blob = Blob.from_path(self._file_path)
yield from self.parse(blob)
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import pypdfium2
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
try:
for page_number, page in enumerate(pdf_reader):
text_page = page.get_textpage()
content = text_page.get_text_range()
text_page.close()
page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata)
finally:
pdf_reader.close()
"""Abstract interface for document loader implementations."""
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
from core.rag.models.document import Document
class TextExtractor(BaseExtractor):
"""Load text files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def extract(self) -> list[Document]:
"""Load from file path."""
text = ""
try:
with open(self._file_path, encoding=self._encoding) as f:
text = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, encoding=encoding.encoding) as f:
text = f.read()
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
except Exception as e:
raise RuntimeError(f"Error loading {self._file_path}") from e
metadata = {"source": self._file_path}
return [Document(page_content=text, metadata=metadata)]
import logging
import os
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredWordExtractor(BaseExtractor):
"""Loader that uses unstructured to load word documents.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.__version__ import __version__ as __unstructured_version__
from unstructured.file_utils.filetype import FileType, detect_filetype
unstructured_version = tuple(
[int(x) for x in __unstructured_version__.split(".")]
)
# check the file extension
try:
import magic # noqa: F401
is_doc = detect_filetype(self._file_path) == FileType.DOC
except ImportError:
_, extension = os.path.splitext(str(self._file_path))
is_doc = extension == ".doc"
if is_doc and unstructured_version < (0, 4, 11):
raise ValueError(
f"You are on unstructured version {__unstructured_version__}. "
"Partitioning .doc files is only supported in unstructured>=0.4.11. "
"Please upgrade the unstructured package and try again."
)
if is_doc:
from unstructured.partition.doc import partition_doc
elements = partition_doc(filename=self._file_path)
else:
from unstructured.partition.docx import partition_docx
elements = partition_docx(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents
...@@ -2,13 +2,14 @@ import base64 ...@@ -2,13 +2,14 @@ import base64
import logging import logging
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredEmailLoader(BaseLoader): class UnstructuredEmailExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
Args: Args:
file_path: Path to the file to load. file_path: Path to the file to load.
...@@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader): ...@@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.email import partition_email from unstructured.partition.email import partition_email
elements = partition_email(filename=self._file_path, api_url=self._api_url) elements = partition_email(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredMarkdownLoader(BaseLoader): class UnstructuredMarkdownExtractor(BaseExtractor):
"""Load md files. """Load md files.
...@@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader): ...@@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.md import partition_md from unstructured.partition.md import partition_md
elements = partition_md(filename=self._file_path, api_url=self._api_url) elements = partition_md(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredMsgLoader(BaseLoader): class UnstructuredMsgExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader): ...@@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.msg import partition_msg from unstructured.partition.msg import partition_msg
elements = partition_msg(filename=self._file_path, api_url=self._api_url) elements = partition_msg(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredPPTLoader(BaseLoader):
class UnstructuredPPTExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader): ...@@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader):
""" """
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
api_url: str api_url: str
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.ppt import partition_ppt from unstructured.partition.ppt import partition_ppt
elements = partition_ppt(filename=self._file_path, api_url=self._api_url) elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredPPTXLoader(BaseLoader):
class UnstructuredPPTXExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader): ...@@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader):
""" """
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
api_url: str api_url: str
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.pptx import partition_pptx from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path, api_url=self._api_url) elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredTextLoader(BaseLoader): class UnstructuredTextExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader): ...@@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.text import partition_text from unstructured.partition.text import partition_text
elements = partition_text(filename=self._file_path, api_url=self._api_url) elements = partition_text(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredXmlLoader(BaseLoader): class UnstructuredXmlExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader): ...@@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.xml import partition_xml from unstructured.partition.xml import partition_xml
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
......
"""Abstract interface for document loader implementations."""
import os
import tempfile
from urllib.parse import urlparse
import requests
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class WordExtractor(BaseExtractor):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(self, file_path: str):
"""Initialize with file path."""
self.file_path = file_path
if "~" in self.file_path:
self.file_path = os.path.expanduser(self.file_path)
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
r = requests.get(self.file_path)
if r.status_code != 200:
raise ValueError(
"Check the url of your file; returned status code %s"
% r.status_code
)
self.web_path = self.file_path
self.temp_file = tempfile.NamedTemporaryFile()
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):
raise ValueError("File path %s is not a valid file or url" % self.file_path)
def __del__(self) -> None:
if hasattr(self, "temp_file"):
self.temp_file.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""
import docx2txt
return [
Document(
page_content=docx2txt.process(self.file_path),
metadata={"source": self.file_path},
)
]
@staticmethod
def _is_valid_url(url: str) -> bool:
"""Check if the url is valid."""
parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme)
from enum import Enum
class IndexType(Enum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "parent_child_index"
SUMMARY_INDEX = "summary_index"
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
from typing import Optional
from langchain.text_splitter import TextSplitter
from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from models.dataset import Dataset, DatasetProcessRule
class BaseIndexProcessor(ABC):
"""Interface for extract files.
"""
@abstractmethod
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
raise NotImplementedError
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
raise NotImplementedError
@abstractmethod
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
raise NotImplementedError
def _get_splitter(self, processing_rule: dict,
embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
if processing_rule['mode'] == "custom":
# The user-defined segmentation rule
rules = processing_rule['rules']
segmentation = rules["segmentation"]
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
raise ValueError("Custom segment length should be between 50 and 1000.")
separator = segmentation["separator"]
if separator:
separator = separator.replace('\\n', '\n')
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=0,
fixed_separator=separator,
separators=["\n\n", "。", ".", " ", ""],
embedding_model_instance=embedding_model_instance
)
else:
# Automatic segmentation
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=0,
separators=["\n\n", "。", ".", " ", ""],
embedding_model_instance=embedding_model_instance
)
return character_splitter
"""Abstract interface for document loader implementations."""
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
class IndexProcessorFactory:
"""IndexProcessorInit.
"""
def __init__(self, index_type: str):
self._index_type = index_type
def init_index_processor(self) -> BaseIndexProcessor:
"""Init index processor."""
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX.value:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX.value:
return QAIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")
"""Paragraph index processor."""
import uuid
from typing import Optional
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from libs import helper
from models.dataset import Dataset
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=kwargs.get('process_rule_mode') == "automatic")
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=kwargs.get('embedding_model_instance'))
all_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith("。"):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
keyword = Keyword(dataset)
keyword.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
else:
vector.delete()
if with_keywords:
keyword = Keyword(dataset)
if node_ids:
keyword.delete_by_ids(node_ids)
else:
keyword.delete()
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
"""Paragraph index processor."""
import logging
import re
import threading
import uuid
from typing import Optional
import pandas as pd
from flask import Flask, current_app
from flask_login import current_user
from werkzeug.datastructures import FileStorage
from core.generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from libs import helper
from models.dataset import Dataset
class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=kwargs.get('process_rule_mode') == "automatic")
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=None)
# Split the text documents into nodes.
all_documents = []
all_qa_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith("。"):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i:i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={
'flask_app': current_app._get_current_object(),
'tenant_id': current_user.current_tenant.id,
'document_node': doc,
'all_qa_documents': all_qa_documents,
'document_language': kwargs.get('document_language', 'English')})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
# Skip the first row
df = pd.read_csv(file)
text_docs = []
for index, row in df.iterrows():
data = Document(page_content=row[0], metadata={'answer': row[1]})
text_docs.append(data)
if len(text_docs) == 0:
raise ValueError("The CSV file is empty.")
except Exception as e:
raise ValueError(str(e))
return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
vector.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
else:
vector.delete()
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict):
# Set search parameters.
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
with flask_app.app_context():
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
document_qa_list = self._format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception(e)
all_qa_documents.extend(format_documents)
def _format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE)
return [
{
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
}
for q, a in matches if q and a
]
from typing import Optional
from pydantic import BaseModel, Field
class Document(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: Optional[dict] = Field(default_factory=dict)
...@@ -21,9 +21,9 @@ from pydantic import BaseModel, Field ...@@ -21,9 +21,9 @@ from pydantic import BaseModel, Field
from regex import regex from regex import regex
from core.chain.llm_chain import LLMChain from core.chain.llm_chain import LLMChain
from core.data_loader import file_extractor
from core.data_loader.file_extractor import FileExtractor
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
FULL_TEMPLATE = """ FULL_TEMPLATE = """
TITLE: {title} TITLE: {title}
...@@ -146,7 +146,7 @@ def get_url(url: str) -> str: ...@@ -146,7 +146,7 @@ def get_url(url: str) -> str:
headers = { headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
} }
supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
...@@ -158,8 +158,8 @@ def get_url(url: str) -> str: ...@@ -158,8 +158,8 @@ def get_url(url: str) -> str:
if main_content_type not in supported_content_types: if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type) return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return FileExtractor.load_from_url(url, return_text=True) return ExtractProcessor.load_from_url(url, return_text=True)
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
a = extract_using_readabilipy(response.text) a = extract_using_readabilipy(response.text)
......
...@@ -6,15 +6,12 @@ from langchain.tools import BaseTool ...@@ -6,15 +6,12 @@ from langchain.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rerank.rerank import RerankRunner from core.rerank.rerank import RerankRunner
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.retrieval_service import RetrievalService
default_retrieval_model = { default_retrieval_model = {
'search_method': 'semantic_search', 'search_method': 'semantic_search',
...@@ -174,76 +171,24 @@ class DatasetMultiRetrieverTool(BaseTool): ...@@ -174,76 +171,24 @@ class DatasetMultiRetrieverTool(BaseTool):
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
kw_table_index = KeywordTableIndex( documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset=dataset, dataset_id=dataset.id,
config=KeywordTableConfig( query=query,
max_keywords_per_chunk=5 top_k=self.top_k
) )
)
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
if documents: if documents:
all_documents.extend(documents) all_documents.extend(documents)
else: else:
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
return []
except ProviderTokenNotInitError:
return []
embeddings = CacheEmbedding(embedding_model)
documents = []
threads = []
if self.top_k > 0: if self.top_k > 0:
# retrieval_model source with semantic # retrieval source
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[ documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
'search_method'] == 'hybrid_search': dataset_id=dataset.id,
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ query=query,
'flask_app': current_app._get_current_object(), top_k=self.top_k,
'dataset_id': str(dataset.id), score_threshold=retrieval_model['score_threshold']
'query': query, if retrieval_model['score_threshold_enabled'] else None,
'top_k': self.top_k, reranking_model=retrieval_model['reranking_model']
'score_threshold': self.score_threshold, if retrieval_model['reranking_enable'] else None
'reranking_model': None, )
'all_documents': documents,
'search_method': 'hybrid_search',
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval_model source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[
'search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': str(dataset.id),
'query': query,
'search_method': 'hybrid_search',
'embeddings': embeddings,
'score_threshold': retrieval_model[
'score_threshold'] if retrieval_model[
'score_threshold_enabled'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model[
'reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
all_documents.extend(documents) all_documents.extend(documents)
\ No newline at end of file
import threading
from typing import Optional from typing import Optional
from flask import current_app
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding from core.rag.datasource.retrieval_service import RetrievalService
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.retrieval_service import RetrievalService
default_retrieval_model = { default_retrieval_model = {
'search_method': 'semantic_search', 'search_method': 'semantic_search',
...@@ -77,94 +69,24 @@ class DatasetRetrieverTool(BaseTool): ...@@ -77,94 +69,24 @@ class DatasetRetrieverTool(BaseTool):
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
kw_table_index = KeywordTableIndex( documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset=dataset, dataset_id=dataset.id,
config=KeywordTableConfig( query=query,
max_keywords_per_chunk=5 top_k=self.top_k
) )
)
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
return str("\n".join([document.page_content for document in documents])) return str("\n".join([document.page_content for document in documents]))
else: else:
# get embedding model instance
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except InvokeAuthorizationError:
return ''
embeddings = CacheEmbedding(embedding_model)
documents = []
threads = []
if self.top_k > 0: if self.top_k > 0:
# retrieval source with semantic # retrieval source
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ dataset_id=dataset.id,
'flask_app': current_app._get_current_object(), query=query,
'dataset_id': str(dataset.id), top_k=self.top_k,
'query': query, score_threshold=retrieval_model['score_threshold']
'top_k': self.top_k, if retrieval_model['score_threshold_enabled'] else None,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ reranking_model=retrieval_model['reranking_model']
'score_threshold_enabled'] else None, if retrieval_model['reranking_enable'] else None
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ )
'reranking_enable'] else None,
'all_documents': documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval_model source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': str(dataset.id),
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enabled'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
# hybrid search: rerank after all documents have been searched
if retrieval_model['search_method'] == 'hybrid_search':
# get rerank model instance
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=retrieval_model['reranking_model']['reranking_provider_name'],
model_type=ModelType.RERANK,
model=retrieval_model['reranking_model']['reranking_model_name']
)
except InvokeAuthorizationError:
return ''
rerank_runner = RerankRunner(rerank_model_instance)
documents = rerank_runner.run(
query=query,
documents=documents,
score_threshold=retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enabled'] else None,
top_n=self.top_k
)
else: else:
documents = [] documents = []
...@@ -234,4 +156,4 @@ class DatasetRetrieverTool(BaseTool): ...@@ -234,4 +156,4 @@ class DatasetRetrieverTool(BaseTool):
return str("\n".join(document_context_list)) return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str: async def _arun(self, tool_input: str) -> str:
raise NotImplementedError() raise NotImplementedError()
\ No newline at end of file
...@@ -21,9 +21,9 @@ from pydantic import BaseModel, Field ...@@ -21,9 +21,9 @@ from pydantic import BaseModel, Field
from regex import regex from regex import regex
from core.chain.llm_chain import LLMChain from core.chain.llm_chain import LLMChain
from core.data_loader import file_extractor
from core.data_loader.file_extractor import FileExtractor
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
FULL_TEMPLATE = """ FULL_TEMPLATE = """
TITLE: {title} TITLE: {title}
...@@ -149,7 +149,7 @@ def get_url(url: str, user_agent: str = None) -> str: ...@@ -149,7 +149,7 @@ def get_url(url: str, user_agent: str = None) -> str:
if user_agent: if user_agent:
headers["User-Agent"] = user_agent headers["User-Agent"] = user_agent
supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
...@@ -161,8 +161,8 @@ def get_url(url: str, user_agent: str = None) -> str: ...@@ -161,8 +161,8 @@ def get_url(url: str, user_agent: str = None) -> str:
if main_content_type not in supported_content_types: if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type) return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return FileExtractor.load_from_url(url, return_text=True) return ExtractProcessor.load_from_url(url, return_text=True)
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
a = extract_using_readabilipy(response.text) a = extract_using_readabilipy(response.text)
......
from core.vector_store.vector.milvus import Milvus
class MilvusVectorStore(Milvus):
def del_texts(self, where_filter: dict):
if not where_filter:
raise ValueError('where_filter must not be empty')
self.col.delete(where_filter.get('filter'))
def del_text(self, uuid: str) -> None:
expr = f"id == {uuid}"
self.col.delete(expr)
def text_exists(self, uuid: str) -> bool:
result = self.col.query(
expr=f'metadata["doc_id"] == "{uuid}"',
output_fields=["id"]
)
return len(result) > 0
def get_ids_by_document_id(self, document_id: str):
result = self.col.query(
expr=f'metadata["document_id"] == "{document_id}"',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def get_ids_by_metadata_field(self, key: str, value: str):
result = self.col.query(
expr=f'metadata["{key}"] == "{value}"',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def get_ids_by_doc_ids(self, doc_ids: list):
result = self.col.query(
expr=f'metadata["doc_id"] in {doc_ids}',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def delete(self):
from pymilvus import utility
utility.drop_collection(self.collection_name, None, self.alias)
from typing import Any, cast
from langchain.schema import Document
from qdrant_client.http.models import Filter, FilterSelector, PointIdsList
from qdrant_client.local.qdrant_local import QdrantLocal
from core.vector_store.vector.qdrant import Qdrant
class QdrantVectorStore(Qdrant):
def del_texts(self, filter: Filter):
if not filter:
raise ValueError('filter must not be empty')
self._reload_if_needed()
self.client.delete(
collection_name=self.collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def del_text(self, uuid: str) -> None:
self._reload_if_needed()
self.client.delete(
collection_name=self.collection_name,
points_selector=PointIdsList(
points=[uuid],
),
)
def text_exists(self, uuid: str) -> bool:
self._reload_if_needed()
response = self.client.retrieve(
collection_name=self.collection_name,
ids=[uuid]
)
return len(response) > 0
def delete(self):
self._reload_if_needed()
self.client.delete_collection(collection_name=self.collection_name)
def delete_group(self):
self._reload_if_needed()
self.client.delete_collection(collection_name=self.collection_name)
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
if scored_point.payload.get('doc_id'):
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata={'doc_id': scored_point.id}
)
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
def _reload_if_needed(self):
if isinstance(self.client, QdrantLocal):
self.client = cast(QdrantLocal, self.client)
self.client._load()
"""Wrapper around the Milvus vector database."""
from __future__ import annotations
import logging
from collections.abc import Iterable, Sequence
from typing import Any, Optional, Union
from uuid import uuid4
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
DEFAULT_MILVUS_CONNECTION = {
"host": "localhost",
"port": "19530",
"user": "",
"password": "",
"secure": False,
}
class Milvus(VectorStore):
"""Initialize wrapper around the milvus vector database.
In order to use this you need to have `pymilvus` installed and a
running Milvus
See the following documentation for how to run a Milvus instance:
https://milvus.io/docs/install_standalone-docker.md
If looking for a hosted Milvus, take a look at this documentation:
https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
this project,
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
Args:
embedding_function (Embeddings): Function used to embed the text.
collection_name (str): Which Milvus collection to use. Defaults to
"LangChainCollection".
connection_args (Optional[dict[str, any]]): The connection args used for
this class comes in the form of a dict.
consistency_level (str): The consistency level to use for a collection.
Defaults to "Session".
index_params (Optional[dict]): Which index params to use. Defaults to
HNSW/AUTOINDEX depending on service.
search_params (Optional[dict]): Which search params to use. Defaults to
default of index.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
The connection args used for this class comes in the form of a dict,
here are a few of the options:
address (str): The actual address of Milvus
instance. Example address: "localhost:19530"
uri (str): The uri of Milvus instance. Example uri:
"http://randomwebsite:19530",
"tcp:foobarsite:19530",
"https://ok.s3.south.com:19530".
host (str): The host of Milvus instance. Default at "localhost",
PyMilvus will fill in the default host if only port is provided.
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
will fill in the default port if only host is provided.
user (str): Use which user to connect to Milvus instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
secure (bool): Default is false. If set to true, tls will be enabled.
client_key_path (str): If use tls two-way authentication, need to
write the client.key path.
client_pem_path (str): If use tls two-way authentication, need to
write the client.pem path.
ca_pem_path (str): If use tls two-way authentication, need to write
the ca.pem path.
server_pem_path (str): If use tls one-way authentication, need to
write the server.pem path.
server_name (str): If use tls, need to write the common name.
Example:
.. code-block:: python
from langchain import Milvus
from langchain.embeddings import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
# Connect to a milvus instance on localhost
milvus_store = Milvus(
embedding_function = Embeddings,
collection_name = "LangChainCollection",
drop_old = True,
)
Raises:
ValueError: If the pymilvus python package is not installed.
"""
def __init__(
self,
embedding_function: Embeddings,
collection_name: str = "LangChainCollection",
connection_args: Optional[dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: Optional[bool] = False,
):
"""Initialize the Milvus vector store."""
try:
from pymilvus import Collection, utility
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
# Default search params when one is not provided.
self.default_search_params = {
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"HNSW": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
"AUTOINDEX": {"metric_type": "L2", "params": {}},
}
self.embedding_func = embedding_function
self.collection_name = collection_name
self.index_params = index_params
self.search_params = search_params
self.consistency_level = consistency_level
# In order for a collection to be compatible, pk needs to be auto'id and int
self._primary_field = "id"
# In order for compatibility, the text field will need to be called "text"
self._text_field = "page_content"
# In order for compatibility, the vector field needs to be called "vector"
self._vector_field = "vectors"
# In order for compatibility, the metadata field will need to be called "metadata"
self._metadata_field = "metadata"
self.fields: list[str] = []
# Create the connection to the server
if connection_args is None:
connection_args = DEFAULT_MILVUS_CONNECTION
self.alias = self._create_connection_alias(connection_args)
self.col: Optional[Collection] = None
# Grab the existing collection if it exists
if utility.has_collection(self.collection_name, using=self.alias):
self.col = Collection(
self.collection_name,
using=self.alias,
)
# If need to drop old, drop it
if drop_old and isinstance(self.col, Collection):
self.col.drop()
self.col = None
# Initialize the vector store
self._init()
@property
def embeddings(self) -> Embeddings:
return self.embedding_func
def _create_connection_alias(self, connection_args: dict) -> str:
"""Create the connection to the Milvus server."""
from pymilvus import MilvusException, connections
# Grab the connection arguments that are used for checking existing connection
host: str = connection_args.get("host", None)
port: Union[str, int] = connection_args.get("port", None)
address: str = connection_args.get("address", None)
uri: str = connection_args.get("uri", None)
user = connection_args.get("user", None)
# Order of use is host/port, uri, address
if host is not None and port is not None:
given_address = str(host) + ":" + str(port)
elif uri is not None:
given_address = uri.split("https://")[1]
elif address is not None:
given_address = address
else:
given_address = None
logger.debug("Missing standard address type for reuse atttempt")
# User defaults to empty string when getting connection info
if user is not None:
tmp_user = user
else:
tmp_user = ""
# If a valid address was given, then check if a connection exists
if given_address is not None:
for con in connections.list_connections():
addr = connections.get_connection_addr(con[0])
if (
con[1]
and ("address" in addr)
and (addr["address"] == given_address)
and ("user" in addr)
and (addr["user"] == tmp_user)
):
logger.debug("Using previous connection: %s", con[0])
return con[0]
# Generate a new connection if one doesn't exist
alias = uuid4().hex
try:
connections.connect(alias=alias, **connection_args)
logger.debug("Created new connection using: %s", alias)
return alias
except MilvusException as e:
logger.error("Failed to create new connection using: %s", alias)
raise e
def _init(
self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
) -> None:
if embeddings is not None:
self._create_collection(embeddings, metadatas)
self._extract_fields()
self._create_index()
self._create_search_params()
self._load()
def _create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None
) -> None:
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
# Determine metadata schema
# if metadatas:
# # Create FieldSchema for each entry in metadata.
# for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
# dtype = infer_dtype_bydata(value)
# # Datatype isn't compatible
# if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
# logger.error(
# "Failure to create collection, unrecognized dtype for key: %s",
# key,
# )
# raise ValueError(f"Unrecognized datatype for {key}.")
# # Dataype is a string/varchar equivalent
# elif dtype == DataType.VARCHAR:
# fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
# else:
# fields.append(FieldSchema(key, dtype))
if metadatas:
fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
fields.append(
FieldSchema(
self._primary_field, DataType.INT64, is_primary=True, auto_id=True
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(fields)
# Create the collection
try:
self.col = Collection(
name=self.collection_name,
schema=schema,
consistency_level=self.consistency_level,
using=self.alias,
)
except MilvusException as e:
logger.error(
"Failed to create collection: %s error: %s", self.collection_name, e
)
raise e
def _extract_fields(self) -> None:
"""Grab the existing fields from the Collection"""
from pymilvus import Collection
if isinstance(self.col, Collection):
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
# Since primary field is auto-id, no need to track it
self.fields.remove(self._primary_field)
def _get_index(self) -> Optional[dict[str, Any]]:
"""Return the vector index information if it exists"""
from pymilvus import Collection
if isinstance(self.col, Collection):
for x in self.col.indexes:
if x.field_name == self._vector_field:
return x.to_dict()
return None
def _create_index(self) -> None:
"""Create a index on the collection"""
from pymilvus import Collection, MilvusException
if isinstance(self.col, Collection) and self._get_index() is None:
try:
# If no index params, use a default HNSW based one
if self.index_params is None:
self.index_params = {
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
try:
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
# If default did not work, most likely on Zilliz Cloud
except MilvusException:
# Use AUTOINDEX based index
self.index_params = {
"metric_type": "L2",
"index_type": "AUTOINDEX",
"params": {},
}
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
logger.debug(
"Successfully created an index on collection: %s",
self.collection_name,
)
except MilvusException as e:
logger.error(
"Failed to create an index on collection: %s", self.collection_name
)
raise e
def _create_search_params(self) -> None:
"""Generate search params based on the current index type"""
from pymilvus import Collection
if isinstance(self.col, Collection) and self.search_params is None:
index = self._get_index()
if index is not None:
index_type: str = index["index_param"]["index_type"]
metric_type: str = index["index_param"]["metric_type"]
self.search_params = self.default_search_params[index_type]
self.search_params["metric_type"] = metric_type
def _load(self) -> None:
"""Load the collection if available."""
from pymilvus import Collection
if isinstance(self.col, Collection) and self._get_index() is not None:
self.col.load()
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[list[dict]] = None,
timeout: Optional[int] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> list[str]:
"""Insert text data into Milvus.
Inserting data when the collection has not be made yet will result
in creating a new Collection. The data of the first entity decides
the schema of the new collection, the dim is extracted from the first
embedding and the columns are decided by the first metadata dict.
Metada keys will need to be present for all inserted values. At
the moment there is no None equivalent in Milvus.
Args:
texts (Iterable[str]): The texts to embed, it is assumed
that they all fit in memory.
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
the texts. Defaults to None.
timeout (Optional[int]): Timeout for each batch insert. Defaults
to None.
batch_size (int, optional): Batch size to use for insertion.
Defaults to 1000.
Raises:
MilvusException: Failure to add texts
Returns:
List[str]: The resulting keys for each inserted element.
"""
from pymilvus import Collection, MilvusException
texts = list(texts)
try:
embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
# If the collection hasn't been initialized yet, perform all steps to do so
if not isinstance(self.col, Collection):
self._init(embeddings, metadatas)
# Dict to hold all insert columns
insert_dict: dict[str, list] = {
self._text_field: texts,
self._vector_field: embeddings,
}
# Collect the metadata into the insert dict.
# if metadatas is not None:
# for d in metadatas:
# for key, value in d.items():
# if key in self.fields:
# insert_dict.setdefault(key, []).append(value)
if metadatas is not None:
for d in metadatas:
insert_dict.setdefault(self._metadata_field, []).append(d)
# Total insert count
vectors: list = insert_dict[self._vector_field]
total_count = len(vectors)
pks: list[str] = []
assert isinstance(self.col, Collection)
for i in range(0, total_count, batch_size):
# Grab end index
end = min(i + batch_size, total_count)
# Convert dict to list of lists batch for insertion
insert_list = [insert_dict[x][i:end] for x in self.fields]
# Insert into the collection.
try:
res: Collection
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
pks.extend(res.primary_keys)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
def similarity_search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> list[Document]:
"""Perform a similarity search against the query string.
Args:
query (str): The text to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score(
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_by_vector(
self,
embedding: list[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> list[Document]:
"""Perform a similarity search against the query string.
Args:
embedding (List[float]): The embedding vector to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
query (str): The text being searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[float], List[Tuple[Document, any, any]]:
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
# Embed the query text.
embedding = self.embedding_func.embed_query(query)
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return res
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
return self.similarity_search_with_score(query, k, **kwargs)
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ret = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
pair = (doc, result.score)
ret.append(pair)
return ret
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> list[Document]:
"""Perform a search and return results that are reordered by MMR.
Args:
query (str): The text being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
embedding = self.embedding_func.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding=embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
param=param,
expr=expr,
timeout=timeout,
**kwargs,
)
def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> list[Document]:
"""Perform a search and return results that are reordered by MMR.
Args:
embedding (str): The embedding vector being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=fetch_k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ids = []
documents = []
scores = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
documents.append(doc)
scores.append(result.score)
ids.append(result.id)
vectors = self.col.query(
expr=f"{self._primary_field} in {ids}",
output_fields=[self._primary_field, self._vector_field],
timeout=timeout,
)
# Reorganize the results from query to match search order.
vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
ordered_result_embeddings = [vectors[x] for x in ids]
# Get the new order of results.
new_ordering = maximal_marginal_relevance(
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
)
# Reorder the values and return.
ret = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
else:
ret.append(documents[x])
return ret
@classmethod
def from_texts(
cls,
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: bool = False,
batch_size: int = 100,
ids: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Milvus:
"""Create a Milvus collection, indexes it with HNSW, and insert data.
Args:
texts (List[str]): Text data.
embedding (Embeddings): Embedding function.
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
Defaults to None.
collection_name (str, optional): Collection name to use. Defaults to
"LangChainCollection".
connection_args (dict[str, Any], optional): Connection args to use. Defaults
to DEFAULT_MILVUS_CONNECTION.
consistency_level (str, optional): Which consistency level to use. Defaults
to "Session".
index_params (Optional[dict], optional): Which index_params to use. Defaults
to None.
search_params (Optional[dict], optional): Which search params to use.
Defaults to None.
drop_old (Optional[bool], optional): Whether to drop the collection with
that name if it exists. Defaults to False.
batch_size:
How many vectors upload per-request.
Default: 100
ids: Optional[Sequence[str]] = None,
Returns:
Milvus: Milvus Vector Store
"""
vector_db = cls(
embedding_function=embedding,
collection_name=collection_name,
connection_args=connection_args,
consistency_level=consistency_level,
index_params=index_params,
search_params=search_params,
drop_old=drop_old,
**kwargs,
)
vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
return vector_db
"""Wrapper around Qdrant vector database."""
from __future__ import annotations
import asyncio
import functools
import uuid
import warnings
from collections.abc import Callable, Generator, Iterable, Sequence
from itertools import islice
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType, TextIndexParams, TextIndexType, TokenizerType
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class QdrantException(Exception):
"""Base class for all the Qdrant related exceptions"""
def sync_call_fallback(method: Callable) -> Callable:
"""
Decorator to call the synchronous method of the class if the async method is not
implemented. This decorator might be only used for the methods that are defined
as async in the class.
"""
@functools.wraps(method)
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
return await method(self, *args, **kwargs)
except NotImplementedError:
# If the async method is not implemented, call the synchronous method
# by removing the first letter from the method name. For example,
# if the async method is called ``aaad_texts``, the synchronous method
# will be called ``aad_texts``.
sync_method = functools.partial(
getattr(self, method.__name__[1:]), *args, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, sync_method)
return wrapper
class Qdrant(VectorStore):
"""Wrapper around Qdrant vector database.
To use you should have the ``qdrant-client`` package installed.
Example:
.. code-block:: python
from qdrant_client import QdrantClient
from langchain import Qdrant
client = QdrantClient()
collection_name = "MyCollection"
qdrant = Qdrant(client, collection_name, embedding_function)
"""
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR_NAME = None
def __init__(
self,
client: Any,
collection_name: str,
embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME,
embedding_function: Optional[Callable] = None, # deprecated
is_new_collection: bool = False
):
"""Initialize with necessary components."""
try:
import qdrant_client
except ImportError:
raise ValueError(
"Could not import qdrant-client python package. "
"Please install it with `pip install qdrant-client`."
)
if not isinstance(client, qdrant_client.QdrantClient):
raise ValueError(
f"client should be an instance of qdrant_client.QdrantClient, "
f"got {type(client)}"
)
if embeddings is None and embedding_function is None:
raise ValueError(
"`embeddings` value can't be None. Pass `Embeddings` instance."
)
if embeddings is not None and embedding_function is not None:
raise ValueError(
"Both `embeddings` and `embedding_function` are passed. "
"Use `embeddings` only."
)
self._embeddings = embeddings
self._embeddings_function = embedding_function
self.client: qdrant_client.QdrantClient = client
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.group_payload_key = group_payload_key or self.GROUP_KEY
self.vector_name = vector_name or self.VECTOR_NAME
self.group_id = group_id
self.is_new_collection= is_new_collection
if embedding_function is not None:
warnings.warn(
"Using `embedding_function` is deprecated. "
"Pass `Embeddings` instance to `embeddings` instead."
)
if not isinstance(embeddings, Embeddings):
warnings.warn(
"`embeddings` should be an instance of `Embeddings`."
"Using `embeddings` as `embedding_function` which is deprecated"
)
self._embeddings_function = embeddings
self._embeddings = None
self.distance_strategy = distance_strategy.upper()
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embeddings
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids:
Optional list of ids to associate with the texts. Ids have to be
uuid-like strings.
batch_size:
How many vectors upload per-request.
Default: 64
group_id:
collection group
Returns:
List of ids from adding the texts into the vectorstore.
"""
added_ids = []
for batch_ids, points in self._generate_rest_batches(
texts, metadatas, ids, batch_size
):
self.client.upsert(
collection_name=self.collection_name, points=points
)
added_ids.extend(batch_ids)
# if is new collection, create payload index on group_id
if self.is_new_collection:
# create payload index
self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
self.client.create_payload_index(self.collection_name, self.content_payload_key,
field_schema=text_index_params)
return added_ids
@sync_call_fallback
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids:
Optional list of ids to associate with the texts. Ids have to be
uuid-like strings.
batch_size:
How many vectors upload per-request.
Default: 64
Returns:
List of ids from adding the texts into the vectorstore.
"""
from qdrant_client import grpc # noqa
from qdrant_client.conversions.conversion import RestToGrpc
added_ids = []
for batch_ids, points in self._generate_rest_batches(
texts, metadatas, ids, batch_size
):
await self.client.async_grpc_points.Upsert(
grpc.UpsertPoints(
collection_name=self.collection_name,
points=[RestToGrpc.convert_point_struct(point) for point in points],
)
)
added_ids.extend(batch_ids)
return added_ids
def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
offset:
Offset of the first result to return.
May be used to paginate results.
Note: large offset values may cause performance issues.
score_threshold:
Define a minimal score threshold for the result.
If defined, less similar results will not be returned.
Score of the returned result might be higher or smaller than the
threshold depending on the Distance function used.
E.g. for cosine similarity only higher scores will be returned.
consistency:
Read consistency of the search. Defines how many replicas should be
queried before returning the result.
Values:
- int - number of replicas to query, values should present in all
queried replicas
- 'majority' - query all replicas, but return values present in the
majority of replicas
- 'quorum' - query the majority of replicas, return values present in
all of them
- 'all' - query all replicas, and return values present in all replicas
Returns:
List of Documents most similar to the query.
"""
results = self.similarity_search_with_score(
query,
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return list(map(itemgetter(0), results))
@sync_call_fallback
async def asimilarity_search(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilter] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
Returns:
List of Documents most similar to the query.
"""
results = await self.asimilarity_search_with_score(query, k, filter, **kwargs)
return list(map(itemgetter(0), results))
def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
offset:
Offset of the first result to return.
May be used to paginate results.
Note: large offset values may cause performance issues.
score_threshold:
Define a minimal score threshold for the result.
If defined, less similar results will not be returned.
Score of the returned result might be higher or smaller than the
threshold depending on the Distance function used.
E.g. for cosine similarity only higher scores will be returned.
consistency:
Read consistency of the search. Defines how many replicas should be
queried before returning the result.
Values:
- int - number of replicas to query, values should present in all
queried replicas
- 'majority' - query all replicas, but return values present in the
majority of replicas
- 'quorum' - query the majority of replicas, return values present in
all of them
- 'all' - query all replicas, and return values present in all replicas
Returns:
List of documents most similar to the query text and distance for each.
"""
return self.similarity_search_with_score_by_vector(
self._embed_query(query),
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
@sync_call_fallback
async def asimilarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
offset:
Offset of the first result to return.
May be used to paginate results.
Note: large offset values may cause performance issues.
score_threshold:
Define a minimal score threshold for the result.
If defined, less similar results will not be returned.
Score of the returned result might be higher or smaller than the
threshold depending on the Distance function used.
E.g. for cosine similarity only higher scores will be returned.
consistency:
Read consistency of the search. Defines how many replicas should be
queried before returning the result.
Values:
- int - number of replicas to query, values should present in all
queried replicas
- 'majority' - query all replicas, but return values present in the
majority of replicas
- 'quorum' - query the majority of replicas, return values present in
all of them
- 'all' - query all replicas, and return values present in all replicas
Returns:
List of documents most similar to the query text and distance for each.
"""
return await self.asimilarity_search_with_score_by_vector(
self._embed_query(query),
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
def similarity_search_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
offset:
Offset of the first result to return.
May be used to paginate results.
Note: large offset values may cause performance issues.
score_threshold:
Define a minimal score threshold for the result.
If defined, less similar results will not be returned.
Score of the returned result might be higher or smaller than the
threshold depending on the Distance function used.
E.g. for cosine similarity only higher scores will be returned.
consistency:
Read consistency of the search. Defines how many replicas should be
queried before returning the result.
Values:
- int - number of replicas to query, values should present in all
queried replicas
- 'majority' - query all replicas, but return values present in the
majority of replicas
- 'quorum' - query the majority of replicas, return values present in
all of them
- 'all' - query all replicas, and return values present in all replicas
Returns:
List of Documents most similar to the query.
"""
results = self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return list(map(itemgetter(0), results))
@sync_call_fallback
async def asimilarity_search_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
offset:
Offset of the first result to return.
May be used to paginate results.
Note: large offset values may cause performance issues.
score_threshold:
Define a minimal score threshold for the result.
If defined, less similar results will not be returned.
Score of the returned result might be higher or smaller than the
threshold depending on the Distance function used.
E.g. for cosine similarity only higher scores will be returned.
consistency:
Read consistency of the search. Defines how many replicas should be
queried before returning the result.
Values:
- int - number of replicas to query, values should present in all
queried replicas
- 'majority' - query all replicas, but return values present in the
majority of replicas
- 'quorum' - query the majority of replicas, return values present in
all of them
- 'all' - query all replicas, and return values present in all replicas
Returns:
List of Documents most similar to the query.
"""
results = await self.asimilarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return list(map(itemgetter(0), results))
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
offset:
Offset of the first result to return.
May be used to paginate results.
Note: large offset values may cause performance issues.
score_threshold:
Define a minimal score threshold for the result.
If defined, less similar results will not be returned.
Score of the returned result might be higher or smaller than the
threshold depending on the Distance function used.
E.g. for cosine similarity only higher scores will be returned.
consistency:
Read consistency of the search. Defines how many replicas should be
queried before returning the result.
Values:
- int - number of replicas to query, values should present in all
queried replicas
- 'majority' - query all replicas, but return values present in the
majority of replicas
- 'quorum' - query the majority of replicas, return values present in
all of them
- 'all' - query all replicas, and return values present in all replicas
Returns:
List of documents most similar to the query text and distance for each.
"""
if filter is not None and isinstance(filter, dict):
warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning,
)
qdrant_filter = self._qdrant_filter_from_dict(filter)
else:
qdrant_filter = filter
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, embedding) # type: ignore[assignment]
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=qdrant_filter,
search_params=search_params,
limit=k,
offset=offset,
with_payload=True,
with_vectors=True,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return [
(
self._document_from_scored_point(
result, self.content_payload_key, self.metadata_payload_key
),
result.score,
)
for result in results
]
def similarity_search_by_bm25(
self,
filter: Optional[MetadataFilter] = None,
k: int = 4
) -> list[Document]:
"""Return docs most similar by bm25.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
Returns:
List of documents most similar to the query text and distance for each.
"""
response = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=filter,
limit=k,
with_payload=True,
with_vectors=True
)
results = response[0]
documents = []
for result in results:
if result:
documents.append(self._document_from_scored_point(
result, self.content_payload_key, self.metadata_payload_key
))
return documents
@sync_call_fallback
async def asimilarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
offset:
Offset of the first result to return.
May be used to paginate results.
Note: large offset values may cause performance issues.
score_threshold:
Define a minimal score threshold for the result.
If defined, less similar results will not be returned.
Score of the returned result might be higher or smaller than the
threshold depending on the Distance function used.
E.g. for cosine similarity only higher scores will be returned.
consistency:
Read consistency of the search. Defines how many replicas should be
queried before returning the result.
Values:
- int - number of replicas to query, values should present in all
queried replicas
- 'majority' - query all replicas, but return values present in the
majority of replicas
- 'quorum' - query the majority of replicas, return values present in
all of them
- 'all' - query all replicas, and return values present in all replicas
Returns:
List of documents most similar to the query text and distance for each.
"""
from qdrant_client import grpc # noqa
from qdrant_client.conversions.conversion import RestToGrpc
from qdrant_client.http import models as rest
if filter is not None and isinstance(filter, dict):
warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning,
)
qdrant_filter = self._qdrant_filter_from_dict(filter)
else:
qdrant_filter = filter
if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
response = await self.client.async_grpc_points.Search(
grpc.SearchPoints(
collection_name=self.collection_name,
vector_name=self.vector_name,
vector=embedding,
filter=qdrant_filter,
params=search_params,
limit=k,
offset=offset,
with_payload=grpc.WithPayloadSelector(enable=True),
with_vectors=grpc.WithVectorsSelector(enable=False),
score_threshold=score_threshold,
read_consistency=consistency,
**kwargs,
)
)
return [
(
self._document_from_scored_point_grpc(
result, self.content_payload_key, self.metadata_payload_key
),
result.score,
)
for result in response.result
]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_embedding = self._embed_query(query)
return self.max_marginal_relevance_search_by_vector(
query_embedding, k, fetch_k, lambda_mult, **kwargs
)
@sync_call_fallback
async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_embedding = self._embed_query(query)
return await self.amax_marginal_relevance_search_by_vector(
query_embedding, k, fetch_k, lambda_mult, **kwargs
)
def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
results = self.max_marginal_relevance_search_with_score_by_vector(
embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
)
return list(map(itemgetter(0), results))
@sync_call_fallback
async def amax_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance and distance for
each.
"""
results = await self.amax_marginal_relevance_search_with_score_by_vector(
embedding, k, fetch_k, lambda_mult, **kwargs
)
return list(map(itemgetter(0), results))
def max_marginal_relevance_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance and distance for
each.
"""
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
with_payload=True,
with_vectors=True,
limit=fetch_k,
)
embeddings = [
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
(
self._document_from_scored_point(
results[i], self.content_payload_key, self.metadata_payload_key
),
results[i].score,
)
for i in mmr_selected
]
@sync_call_fallback
async def amax_marginal_relevance_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance and distance for
each.
"""
from qdrant_client import grpc # noqa
from qdrant_client.conversions.conversion import GrpcToRest
response = await self.client.async_grpc_points.Search(
grpc.SearchPoints(
collection_name=self.collection_name,
vector_name=self.vector_name,
vector=embedding,
with_payload=grpc.WithPayloadSelector(enable=True),
with_vectors=grpc.WithVectorsSelector(enable=True),
limit=fetch_k,
)
)
results = [
GrpcToRest.convert_vectors(result.vectors) for result in response.result
]
embeddings: list[list[float]] = [
result.get(self.vector_name) # type: ignore
if isinstance(result, dict)
else result
for result in results
]
mmr_selected: list[int] = maximal_marginal_relevance(
np.array(embedding),
embeddings,
k=k,
lambda_mult=lambda_mult,
)
return [
(
self._document_from_scored_point_grpc(
response.result[i],
self.content_payload_key,
self.metadata_payload_key,
),
response.result[i].score,
)
for i in mmr_selected
]
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector ID or other criteria.
Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
from qdrant_client.http import models as rest
result = self.client.delete(
collection_name=self.collection_name,
points_selector=ids,
)
return result.status == rest.UpdateStatus.COMPLETED
@classmethod
def from_texts(
cls: type[Qdrant],
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None,
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
"""Construct Qdrant wrapper from a list of texts.
Args:
texts: A list of texts to be indexed in Qdrant.
embedding: A subclass of `Embeddings`, responsible for text vectorization.
metadatas:
An optional list of metadata. If provided it has to be of the same
length as a list of texts.
ids:
Optional list of ids to associate with the texts. Ids have to be
uuid-like strings.
location:
If `:memory:` - use in-memory Qdrant instance.
If `str` - use it as a `url` parameter.
If `None` - fallback to relying on `host` and `port` parameters.
url: either host or str of "Optional[scheme], host, Optional[port],
Optional[prefix]". Default: `None`
port: Port of the REST API interface. Default: 6333
grpc_port: Port of the gRPC interface. Default: 6334
prefer_grpc:
If true - use gPRC interface whenever possible in custom methods.
Default: False
https: If true - use HTTPS(SSL) protocol. Default: None
api_key: API key for authentication in Qdrant Cloud. Default: None
prefix:
If not None - add prefix to the REST URL path.
Example: service/v1 will result in
http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
Default: None
timeout:
Timeout for REST and gRPC API requests.
Default: 5.0 seconds for REST and unlimited for gRPC
host:
Host name of Qdrant service. If url and host are None, set to
'localhost'. Default: None
path:
Path in which the vectors will be stored while using local mode.
Default: None
collection_name:
Name of the Qdrant collection to be used. If not provided,
it will be created randomly. Default: None
distance_func:
Distance function. One of: "Cosine" / "Euclid" / "Dot".
Default: "Cosine"
content_payload_key:
A payload key used to store the content of the document.
Default: "page_content"
metadata_payload_key:
A payload key used to store the metadata of the document.
Default: "metadata"
group_payload_key:
A payload key used to store the content of the document.
Default: "group_id"
group_id:
collection group id
vector_name:
Name of the vector to be used internally in Qdrant.
Default: None
batch_size:
How many vectors upload per-request.
Default: 64
shard_number: Number of shards in collection. Default is 1, minimum is 1.
replication_factor:
Replication factor for collection. Default is 1, minimum is 1.
Defines how many copies of each shard will be created.
Have effect only in distributed mode.
write_consistency_factor:
Write consistency factor for collection. Default is 1, minimum is 1.
Defines how many replicas should apply the operation for us to consider
it successful. Increasing this number will make the collection more
resilient to inconsistencies, but will also make it fail if not enough
replicas are available.
Does not have any performance impact.
Have effect only in distributed mode.
on_disk_payload:
If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are
indexed - remain in RAM.
hnsw_config: Params for HNSW index
optimizers_config: Params for optimizer
wal_config: Params for Write-Ahead-Log
quantization_config:
Params for quantization, if None - quantization will be disabled
init_from:
Use data stored in another collection to initialize this collection
force_recreate:
Force recreating the collection
**kwargs:
Additional arguments passed directly into REST client initialization
This is a user-friendly interface that:
1. Creates embeddings, one for each text
2. Initializes the Qdrant database as an in-memory docstore by default
(and overridable to a remote docstore)
3. Adds the text embeddings to the Qdrant database
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain import Qdrant
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
qdrant = Qdrant.from_texts(texts, embeddings, "localhost")
"""
qdrant = cls._construct_instance(
texts,
embedding,
metadatas,
ids,
location,
url,
port,
grpc_port,
prefer_grpc,
https,
api_key,
prefix,
timeout,
host,
path,
collection_name,
distance_func,
content_payload_key,
metadata_payload_key,
group_payload_key,
group_id,
vector_name,
shard_number,
replication_factor,
write_consistency_factor,
on_disk_payload,
hnsw_config,
optimizers_config,
wal_config,
quantization_config,
init_from,
force_recreate,
**kwargs,
)
qdrant.add_texts(texts, metadatas, ids, batch_size)
return qdrant
@classmethod
@sync_call_fallback
async def afrom_texts(
cls: type[Qdrant],
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None,
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
"""Construct Qdrant wrapper from a list of texts.
Args:
texts: A list of texts to be indexed in Qdrant.
embedding: A subclass of `Embeddings`, responsible for text vectorization.
metadatas:
An optional list of metadata. If provided it has to be of the same
length as a list of texts.
ids:
Optional list of ids to associate with the texts. Ids have to be
uuid-like strings.
location:
If `:memory:` - use in-memory Qdrant instance.
If `str` - use it as a `url` parameter.
If `None` - fallback to relying on `host` and `port` parameters.
url: either host or str of "Optional[scheme], host, Optional[port],
Optional[prefix]". Default: `None`
port: Port of the REST API interface. Default: 6333
grpc_port: Port of the gRPC interface. Default: 6334
prefer_grpc:
If true - use gPRC interface whenever possible in custom methods.
Default: False
https: If true - use HTTPS(SSL) protocol. Default: None
api_key: API key for authentication in Qdrant Cloud. Default: None
prefix:
If not None - add prefix to the REST URL path.
Example: service/v1 will result in
http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
Default: None
timeout:
Timeout for REST and gRPC API requests.
Default: 5.0 seconds for REST and unlimited for gRPC
host:
Host name of Qdrant service. If url and host are None, set to
'localhost'. Default: None
path:
Path in which the vectors will be stored while using local mode.
Default: None
collection_name:
Name of the Qdrant collection to be used. If not provided,
it will be created randomly. Default: None
distance_func:
Distance function. One of: "Cosine" / "Euclid" / "Dot".
Default: "Cosine"
content_payload_key:
A payload key used to store the content of the document.
Default: "page_content"
metadata_payload_key:
A payload key used to store the metadata of the document.
Default: "metadata"
vector_name:
Name of the vector to be used internally in Qdrant.
Default: None
batch_size:
How many vectors upload per-request.
Default: 64
shard_number: Number of shards in collection. Default is 1, minimum is 1.
replication_factor:
Replication factor for collection. Default is 1, minimum is 1.
Defines how many copies of each shard will be created.
Have effect only in distributed mode.
write_consistency_factor:
Write consistency factor for collection. Default is 1, minimum is 1.
Defines how many replicas should apply the operation for us to consider
it successful. Increasing this number will make the collection more
resilient to inconsistencies, but will also make it fail if not enough
replicas are available.
Does not have any performance impact.
Have effect only in distributed mode.
on_disk_payload:
If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are
indexed - remain in RAM.
hnsw_config: Params for HNSW index
optimizers_config: Params for optimizer
wal_config: Params for Write-Ahead-Log
quantization_config:
Params for quantization, if None - quantization will be disabled
init_from:
Use data stored in another collection to initialize this collection
force_recreate:
Force recreating the collection
**kwargs:
Additional arguments passed directly into REST client initialization
This is a user-friendly interface that:
1. Creates embeddings, one for each text
2. Initializes the Qdrant database as an in-memory docstore by default
(and overridable to a remote docstore)
3. Adds the text embeddings to the Qdrant database
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain import Qdrant
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost")
"""
qdrant = cls._construct_instance(
texts,
embedding,
metadatas,
ids,
location,
url,
port,
grpc_port,
prefer_grpc,
https,
api_key,
prefix,
timeout,
host,
path,
collection_name,
distance_func,
content_payload_key,
metadata_payload_key,
vector_name,
shard_number,
replication_factor,
write_consistency_factor,
on_disk_payload,
hnsw_config,
optimizers_config,
wal_config,
quantization_config,
init_from,
force_recreate,
**kwargs,
)
await qdrant.aadd_texts(texts, metadatas, ids, batch_size)
return qdrant
@classmethod
def _construct_instance(
cls: type[Qdrant],
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None,
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
try:
import qdrant_client
except ImportError:
raise ValueError(
"Could not import qdrant-client python package. "
"Please install it with `pip install qdrant-client`."
)
from qdrant_client.http import models as rest
# Just do a single quick embedding to get vector size
partial_embeddings = embedding.embed_documents(texts[:1])
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
is_new_collection = False
client = qdrant_client.QdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
all_collection_name = []
collections_response = client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
)
# If vector name was provided, we're going to use the named vectors feature
# with just a single vector.
if vector_name is not None:
vectors_config = { # type: ignore[assignment]
vector_name: vectors_config,
}
client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
shard_number=shard_number,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=int(timeout), # type: ignore[arg-type]
)
is_new_collection = True
if force_recreate:
raise ValueError
# Get the vector configuration of the existing collection and vector, if it
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
qdrant = cls(
client=client,
collection_name=collection_name,
embeddings=embedding,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func,
vector_name=vector_name,
group_id=group_id,
group_payload_key=group_payload_key,
is_new_collection=is_new_collection
)
return qdrant
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
The 'correct' relevance function
may differ depending on a few things, including:
- the distance / similarity metric used by the VectorStore
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
- embedding dimensionality
- etc.
"""
if self.distance_strategy == "COSINE":
return self._cosine_relevance_score_fn
elif self.distance_strategy == "DOT":
return self._max_inner_product_relevance_score_fn
elif self.distance_strategy == "EUCLID":
return self._euclidean_relevance_score_fn
else:
raise ValueError(
"Unknown distance strategy, must be cosine, "
"max_inner_product, or euclidean"
)
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
return self.similarity_search_with_score(query, k, **kwargs)
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
return payloads
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
@classmethod
def _document_from_scored_point_grpc(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
from qdrant_client.conversions.conversion import grpc_to_payload
payload = grpc_to_payload(scored_point.payload)
return Document(
page_content=payload[content_payload_key],
metadata=payload.get(metadata_payload_key) or {},
)
def _build_condition(self, key: str, value: Any) -> list[rest.FieldCondition]:
from qdrant_client.http import models as rest
out = []
if isinstance(value, dict):
for _key, value in value.items():
out.extend(self._build_condition(f"{key}.{_key}", value))
elif isinstance(value, list):
for _value in value:
if isinstance(_value, dict):
out.extend(self._build_condition(f"{key}[]", _value))
else:
out.extend(self._build_condition(f"{key}", _value))
else:
out.append(
rest.FieldCondition(
key=key,
match=rest.MatchValue(value=value),
)
)
return out
def _qdrant_filter_from_dict(
self, filter: Optional[DictFilter]
) -> Optional[rest.Filter]:
from qdrant_client.http import models as rest
if not filter:
return None
return rest.Filter(
must=[
condition
for key, value in filter.items()
for condition in self._build_condition(key, value)
]
)
def _embed_query(self, query: str) -> list[float]:
"""Embed query text.
Used to provide backward compatibility with `embedding_function` argument.
Args:
query: Query text.
Returns:
List of floats representing the query embedding.
"""
if self.embeddings is not None:
embedding = self.embeddings.embed_query(query)
else:
if self._embeddings_function is not None:
embedding = self._embeddings_function(query)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embedding.tolist() if hasattr(embedding, "tolist") else embedding
def _embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
"""Embed search texts.
Used to provide backward compatibility with `embedding_function` argument.
Args:
texts: Iterable of texts to embed.
Returns:
List of floats representing the texts embedding.
"""
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(list(texts))
if hasattr(embeddings, "tolist"):
embeddings = embeddings.tolist()
elif self._embeddings_function is not None:
embeddings = []
for text in texts:
embedding = self._embeddings_function(text)
if hasattr(embeddings, "tolist"):
embedding = embedding.tolist()
embeddings.append(embedding)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embeddings
def _generate_rest_batches(
self,
texts: Iterable[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = self._embed_texts(batch_texts)
points = [
rest.PointStruct(
id=point_id,
vector=vector
if self.vector_name is None
else {self.vector_name: vector},
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
self.group_id,
self.group_payload_key
),
)
]
yield batch_ids, points
"""Wrapper around weaviate vector database."""
from __future__ import annotations
import datetime
from collections.abc import Callable, Iterable
from typing import Any, Optional
from uuid import uuid4
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
def _default_schema(index_name: str) -> dict:
return {
"class": index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
def _create_weaviate_client(**kwargs: Any) -> Any:
client = kwargs.get("client")
if client is not None:
return client
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
try:
# the weaviate api key param should not be mandatory
weaviate_api_key = get_from_dict_or_env(
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
)
except ValueError:
weaviate_api_key = None
try:
import weaviate
except ImportError:
raise ValueError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`"
)
auth = (
weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
if weaviate_api_key is not None
else None
)
client = weaviate.Client(weaviate_url, auth_client_secret=auth)
return client
def _default_score_normalizer(val: float) -> float:
return 1 - val
def _json_serializable(value: Any) -> Any:
if isinstance(value, datetime.datetime):
return value.isoformat()
return value
class Weaviate(VectorStore):
"""Wrapper around Weaviate vector database.
To use, you should have the ``weaviate-client`` python package installed.
Example:
.. code-block:: python
import weaviate
from langchain.vectorstores import Weaviate
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
weaviate = Weaviate(client, index_name, text_key)
"""
def __init__(
self,
client: Any,
index_name: str,
text_key: str,
embedding: Optional[Embeddings] = None,
attributes: Optional[list[str]] = None,
relevance_score_fn: Optional[
Callable[[float], float]
] = _default_score_normalizer,
by_text: bool = True,
):
"""Initialize with Weaviate client."""
try:
import weaviate
except ImportError:
raise ValueError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
if not isinstance(client, weaviate.Client):
raise ValueError(
f"client should be an instance of weaviate.Client, got {type(client)}"
)
self._client = client
self._index_name = index_name
self._embedding = embedding
self._text_key = text_key
self._query_attrs = [self._text_key]
self.relevance_score_fn = relevance_score_fn
self._by_text = by_text
if attributes is not None:
self._query_attrs.extend(attributes)
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return (
self.relevance_score_fn
if self.relevance_score_fn
else _default_score_normalizer
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> list[str]:
"""Upload texts with metadata (properties) to Weaviate."""
from weaviate.util import get_valid_uuid
ids = []
embeddings: Optional[list[list[float]]] = None
if self._embedding:
if not isinstance(texts, list):
texts = list(texts)
embeddings = self._embedding.embed_documents(texts)
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {self._text_key: text}
if metadatas is not None:
for key, val in metadatas[i].items():
data_properties[key] = _json_serializable(val)
# Allow for ids (consistent w/ other methods)
# # Or uuids (backwards compatble w/ existing arg)
# If the UUID of one of the objects already exists
# then the existing object will be replaced by the new object.
_id = get_valid_uuid(uuid4())
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
elif "ids" in kwargs:
_id = kwargs["ids"][i]
batch.add_data_object(
data_object=data_properties,
class_name=self._index_name,
uuid=_id,
vector=embeddings[i] if embeddings else None,
)
ids.append(_id)
return ids
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
if self._by_text:
return self.similarity_search_by_text(query, k, **kwargs)
else:
if self._embedding is None:
raise ValueError(
"_embedding cannot be None for similarity_search when "
"_by_text=False"
)
embedding = self._embedding.embed_query(query)
return self.similarity_search_by_vector(embedding, k, **kwargs)
def similarity_search_by_text(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
content: dict[str, Any] = {"concepts": [query]}
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
result = query_obj.with_near_text(content).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
docs.append(Document(page_content=text, metadata=res))
return docs
def similarity_search_by_bm25(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
content: dict[str, Any] = {"concepts": [query]}
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
docs.append(Document(page_content=text, metadata=res))
return docs
def similarity_search_by_vector(
self, embedding: list[float], k: int = 4, **kwargs: Any
) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
vector = {"vector": embedding}
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
result = query_obj.with_near_vector(vector).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
docs.append(Document(page_content=text, metadata=res))
return docs
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._embedding is not None:
embedding = self._embedding.embed_query(query)
else:
raise ValueError(
"max_marginal_relevance_search requires a suitable Embeddings object"
)
return self.max_marginal_relevance_search_by_vector(
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
)
def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
vector = {"vector": embedding}
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
results = (
query_obj.with_additional("vector")
.with_near_vector(vector)
.with_limit(fetch_k)
.do()
)
payload = results["data"]["Get"][self._index_name]
embeddings = [result["_additional"]["vector"] for result in payload]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
docs = []
for idx in mmr_selected:
text = payload[idx].pop(self._text_key)
payload[idx].pop("_additional")
meta = payload[idx]
docs.append(Document(page_content=text, metadata=meta))
return docs
def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> list[tuple[Document, float]]:
"""
Return list of documents most similar to the query
text and cosine distance in float for each.
Lower score represents more similarity.
"""
if self._embedding is None:
raise ValueError(
"_embedding cannot be None for similarity_search_with_score"
)
content: dict[str, Any] = {"concepts": [query]}
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(self._index_name, self._query_attrs)
embedded_query = self._embedding.embed_query(query)
if not self._by_text:
vector = {"vector": embedded_query}
result = (
query_obj.with_near_vector(vector)
.with_limit(k)
.with_additional(["vector", "distance"])
.do()
)
else:
result = (
query_obj.with_near_text(content)
.with_limit(k)
.with_additional(["vector", "distance"])
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
score = res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
return docs_and_scores
@classmethod
def from_texts(
cls: type[Weaviate],
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> Weaviate:
"""Construct Weaviate wrapper from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
2. Creates a new index for the embeddings in the Weaviate instance.
3. Adds the documents to the newly created Weaviate index.
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain.vectorstores.weaviate import Weaviate
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
weaviate = Weaviate.from_texts(
texts,
embeddings,
weaviate_url="http://localhost:8080"
)
"""
client = _create_weaviate_client(**kwargs)
from weaviate.util import get_valid_uuid
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
embeddings = embedding.embed_documents(texts) if embedding else None
text_key = "text"
schema = _default_schema(index_name)
attributes = list(metadatas[0].keys()) if metadatas else None
# check whether the index already exists
if not client.schema.contains(schema):
client.schema.create_class(schema)
with client.batch as batch:
for i, text in enumerate(texts):
data_properties = {
text_key: text,
}
if metadatas is not None:
for key in metadatas[i].keys():
data_properties[key] = metadatas[i][key]
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
# if an embedding strategy is not provided, we let
# weaviate create the embedding. Note that this will only
# work if weaviate has been installed with a vectorizer module
# like text2vec-contextionary for example
params = {
"uuid": _id,
"data_object": data_properties,
"class_name": index_name,
}
if embeddings is not None:
params["vector"] = embeddings[i]
batch.add_data_object(**params)
batch.flush()
relevance_score_fn = kwargs.get("relevance_score_fn")
by_text: bool = kwargs.get("by_text", False)
return cls(
client,
index_name,
text_key,
embedding=embedding,
attributes=attributes,
relevance_score_fn=relevance_score_fn,
by_text=by_text,
)
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None:
"""Delete by vector IDs.
Args:
ids: List of ids to delete.
"""
if ids is None:
raise ValueError("No ids provided to delete.")
# TODO: Check if this can be done in bulk
for id in ids:
self._client.data_object.delete(uuid=id)
from core.vector_store.vector.weaviate import Weaviate
class WeaviateVectorStore(Weaviate):
def del_texts(self, where_filter: dict):
if not where_filter:
raise ValueError('where_filter must not be empty')
self._client.batch.delete_objects(
class_name=self._index_name,
where=where_filter,
output='minimal'
)
def del_text(self, uuid: str) -> None:
self._client.data_object.delete(
uuid,
class_name=self._index_name
)
def text_exists(self, uuid: str) -> bool:
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": uuid,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][self._index_name]
if len(entries) == 0:
return False
return True
def delete(self):
self._client.schema.delete_class(self._index_name)
...@@ -6,4 +6,4 @@ from tasks.clean_dataset_task import clean_dataset_task ...@@ -6,4 +6,4 @@ from tasks.clean_dataset_task import clean_dataset_task
def handle(sender, **kwargs): def handle(sender, **kwargs):
dataset = sender dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
dataset.index_struct, dataset.collection_binding_id) dataset.index_struct, dataset.collection_binding_id, dataset.doc_form)
...@@ -6,4 +6,5 @@ from tasks.clean_document_task import clean_document_task ...@@ -6,4 +6,5 @@ from tasks.clean_document_task import clean_document_task
def handle(sender, **kwargs): def handle(sender, **kwargs):
document_id = sender document_id = sender
dataset_id = kwargs.get('dataset_id') dataset_id = kwargs.get('dataset_id')
clean_document_task.delay(document_id, dataset_id) doc_form = kwargs.get('doc_form')
clean_document_task.delay(document_id, dataset_id, doc_form)
...@@ -94,6 +94,14 @@ class Dataset(db.Model): ...@@ -94,6 +94,14 @@ class Dataset(db.Model):
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
.filter(Document.dataset_id == self.id).scalar() .filter(Document.dataset_id == self.id).scalar()
@property
def doc_form(self):
document = db.session.query(Document).filter(
Document.dataset_id == self.id).first()
if document:
return document.doc_form
return None
@property @property
def retrieval_model_dict(self): def retrieval_model_dict(self):
default_retrieval_model = { default_retrieval_model = {
......
...@@ -6,7 +6,7 @@ from flask import current_app ...@@ -6,7 +6,7 @@ from flask import current_app
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import app import app
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, Document from models.dataset import Dataset, DatasetQuery, Document
...@@ -41,18 +41,9 @@ def clean_unused_datasets_task(): ...@@ -41,18 +41,9 @@ def clean_unused_datasets_task():
if not documents or len(documents) == 0: if not documents or len(documents) == 0:
try: try:
# remove index # remove index
vector_index = IndexBuilder.get_index(dataset, 'high_quality') index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
kw_index = IndexBuilder.get_index(dataset, 'economy') index_processor.clean(dataset, None)
# delete from vector index
if vector_index:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete()
# update document # update document
update_params = { update_params = {
Document.enabled: False Document.enabled: False
......
...@@ -11,10 +11,11 @@ from flask_login import current_user ...@@ -11,10 +11,11 @@ from flask_login import current_user
from sqlalchemy import func from sqlalchemy import func
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.index.index import IndexBuilder
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.models.document import Document as RAGDocument
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted from events.document_event import document_was_deleted
from extensions.ext_database import db from extensions.ext_database import db
...@@ -402,7 +403,7 @@ class DocumentService: ...@@ -402,7 +403,7 @@ class DocumentService:
@staticmethod @staticmethod
def delete_document(document): def delete_document(document):
# trigger document_was_deleted signal # trigger document_was_deleted signal
document_was_deleted.send(document.id, dataset_id=document.dataset_id) document_was_deleted.send(document.id, dataset_id=document.dataset_id, doc_form=document.doc_form)
db.session.delete(document) db.session.delete(document)
db.session.commit() db.session.commit()
...@@ -1060,7 +1061,7 @@ class SegmentService: ...@@ -1060,7 +1061,7 @@ class SegmentService:
# save vector index # save vector index
try: try:
VectorService.create_segment_vector(args['keywords'], segment_document, dataset) VectorService.create_segments_vector([args['keywords']], [segment_document], dataset)
except Exception as e: except Exception as e:
logging.exception("create segment index failed") logging.exception("create segment index failed")
segment_document.enabled = False segment_document.enabled = False
...@@ -1087,6 +1088,7 @@ class SegmentService: ...@@ -1087,6 +1088,7 @@ class SegmentService:
).scalar() ).scalar()
pre_segment_data_list = [] pre_segment_data_list = []
segment_data_list = [] segment_data_list = []
keywords_list = []
for segment_item in segments: for segment_item in segments:
content = segment_item['content'] content = segment_item['content']
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
...@@ -1119,15 +1121,13 @@ class SegmentService: ...@@ -1119,15 +1121,13 @@ class SegmentService:
segment_document.answer = segment_item['answer'] segment_document.answer = segment_item['answer']
db.session.add(segment_document) db.session.add(segment_document)
segment_data_list.append(segment_document) segment_data_list.append(segment_document)
pre_segment_data = {
'segment': segment_document, pre_segment_data_list.append(segment_document)
'keywords': segment_item['keywords'] keywords_list.append(segment_item['keywords'])
}
pre_segment_data_list.append(pre_segment_data)
try: try:
# save vector index # save vector index
VectorService.multi_create_segment_vector(pre_segment_data_list, dataset) VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
except Exception as e: except Exception as e:
logging.exception("create segment index failed") logging.exception("create segment index failed")
for segment_document in segment_data_list: for segment_document in segment_data_list:
...@@ -1157,11 +1157,18 @@ class SegmentService: ...@@ -1157,11 +1157,18 @@ class SegmentService:
db.session.commit() db.session.commit()
# update segment index task # update segment index task
if args['keywords']: if args['keywords']:
kw_index = IndexBuilder.get_index(dataset, 'economy') keyword = Keyword(dataset)
# delete from keyword index keyword.delete_by_ids([segment.index_node_id])
kw_index.delete_by_ids([segment.index_node_id]) document = RAGDocument(
# save keyword index page_content=segment.content,
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords) metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
keyword.add_texts([document], keywords_list=[args['keywords']])
else: else:
segment_hash = helper.generate_text_hash(content) segment_hash = helper.generate_text_hash(content)
tokens = 0 tokens = 0
......
...@@ -9,8 +9,8 @@ from flask_login import current_user ...@@ -9,8 +9,8 @@ from flask_login import current_user
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.data_loader.file_extractor import FileExtractor
from core.file.upload_file_parser import UploadFileParser from core.file.upload_file_parser import UploadFileParser
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.account import Account from models.account import Account
...@@ -32,7 +32,8 @@ class FileService: ...@@ -32,7 +32,8 @@ class FileService:
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
extension = file.filename.split('.')[-1] extension = file.filename.split('.')[-1]
etl_type = current_app.config['ETL_TYPE'] etl_type = current_app.config['ETL_TYPE']
allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
if extension.lower() not in allowed_extensions: if extension.lower() not in allowed_extensions:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
elif only_image and extension.lower() not in IMAGE_EXTENSIONS: elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
...@@ -136,7 +137,7 @@ class FileService: ...@@ -136,7 +137,7 @@ class FileService:
if extension.lower() not in allowed_extensions: if extension.lower() not in allowed_extensions:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
text = FileExtractor.load(upload_file, return_text=True) text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else '' text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return text return text
...@@ -164,7 +165,7 @@ class FileService: ...@@ -164,7 +165,7 @@ class FileService:
return generator, upload_file.mime_type return generator, upload_file.mime_type
@staticmethod @staticmethod
def get_public_image_preview(file_id: str) -> str: def get_public_image_preview(file_id: str) -> tuple[Generator, str]:
upload_file = db.session.query(UploadFile) \ upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \ .filter(UploadFile.id == file_id) \
.first() .first()
......
import logging import logging
import threading
import time import time
import numpy as np import numpy as np
from flask import current_app
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rerank.rerank import RerankRunner from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Dataset, DatasetQuery, DocumentSegment
from services.retrieval_service import RetrievalService
default_retrieval_model = { default_retrieval_model = {
'search_method': 'semantic_search', 'search_method': 'semantic_search',
...@@ -28,6 +25,7 @@ default_retrieval_model = { ...@@ -28,6 +25,7 @@ default_retrieval_model = {
'score_threshold_enabled': False 'score_threshold_enabled': False
} }
class HitTestingService: class HitTestingService:
@classmethod @classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
...@@ -57,61 +55,15 @@ class HitTestingService: ...@@ -57,61 +55,15 @@ class HitTestingService:
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
all_documents = [] all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
threads = [] dataset_id=dataset.id,
query=query,
# retrieval_model source with semantic top_k=retrieval_model['top_k'],
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': score_threshold=retrieval_model['score_threshold']
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ if retrieval_model['score_threshold_enabled'] else None,
'flask_app': current_app._get_current_object(), reranking_model=retrieval_model['reranking_model']
'dataset_id': str(dataset.id), if retrieval_model['reranking_enable'] else None
'query': query, )
'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': str(dataset.id),
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
'top_k': retrieval_model['top_k'],
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrieval_model['search_method'] == 'hybrid_search':
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=retrieval_model['reranking_model']['reranking_provider_name'],
model_type=ModelType.RERANK,
model=retrieval_model['reranking_model']['reranking_model_name']
)
rerank_runner = RerankRunner(rerank_model_instance)
all_documents = rerank_runner.run(
query=query,
documents=all_documents,
score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
top_n=retrieval_model['top_k'],
user=f"account-{account.id}"
)
end = time.perf_counter() end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
...@@ -203,4 +155,3 @@ class HitTestingService: ...@@ -203,4 +155,3 @@ class HitTestingService:
if not query or len(query) > 250: if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters') raise ValueError('Query is required and cannot exceed 250 characters')
from typing import Optional from typing import Optional
from langchain.schema import Document from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.index.index import IndexBuilder from core.rag.models.document import Document
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
class VectorService: class VectorService:
@classmethod @classmethod
def create_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): def create_segments_vector(cls, keywords_list: Optional[list[list[str]]],
document = Document( segments: list[DocumentSegment], dataset: Dataset):
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document], duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
if keywords and len(keywords) > 0:
index.create_segment_keywords(segment.index_node_id, keywords)
else:
index.add_texts([document])
@classmethod
def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
documents = [] documents = []
for pre_segment_data in pre_segment_data_list: for segment in segments:
segment = pre_segment_data['segment']
document = Document( document = Document(
page_content=segment.content, page_content=segment.content,
metadata={ metadata={
...@@ -49,30 +23,26 @@ class VectorService: ...@@ -49,30 +23,26 @@ class VectorService:
} }
) )
documents.append(document) documents.append(document)
if dataset.indexing_technique == 'high_quality':
# save vector index # save vector index
index = IndexBuilder.get_index(dataset, 'high_quality') vector = Vector(
if index: dataset=dataset
index.add_texts(documents, duplicate_check=True) )
vector.add_texts(documents, duplicate_check=True)
# save keyword index # save keyword index
keyword_index = IndexBuilder.get_index(dataset, 'economy') keyword = Keyword(dataset)
if keyword_index:
keyword_index.multi_create_segment_keywords(pre_segment_data_list) if keywords_list and len(keywords_list) > 0:
keyword.add_texts(documents, keyword_list=keywords_list)
else:
keyword.add_texts(documents)
@classmethod @classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
# update segment index task # update segment index task
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete_by_ids([segment.index_node_id])
# delete from keyword index # format new index
kw_index.delete_by_ids([segment.index_node_id])
# add new index
document = Document( document = Document(
page_content=segment.content, page_content=segment.content,
metadata={ metadata={
...@@ -82,13 +52,20 @@ class VectorService: ...@@ -82,13 +52,20 @@ class VectorService:
"dataset_id": segment.dataset_id, "dataset_id": segment.dataset_id,
} }
) )
if dataset.indexing_technique == 'high_quality':
# update vector index
vector = Vector(
dataset=dataset
)
vector.delete_by_ids([segment.index_node_id])
vector.add_texts([document], duplicate_check=True)
# save vector index # update keyword index
if vector_index: keyword = Keyword(dataset)
vector_index.add_texts([document], duplicate_check=True) keyword.delete_by_ids([segment.index_node_id])
# save keyword index # save keyword index
if keywords and len(keywords) > 0: if keywords and len(keywords) > 0:
kw_index.create_segment_keywords(segment.index_node_id, keywords) keyword.add_texts([document], keywords_list=[keywords])
else: else:
kw_index.add_texts([document]) keyword.add_texts([document])
...@@ -4,10 +4,10 @@ import time ...@@ -4,10 +4,10 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
...@@ -60,15 +60,9 @@ def add_document_to_index_task(dataset_document_id: str): ...@@ -60,15 +60,9 @@ def add_document_to_index_task(dataset_document_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
# save vector index index_type = dataset.doc_form
index = IndexBuilder.get_index(dataset, 'high_quality') index_processor = IndexProcessorFactory(index_type).init_index_processor()
if index: index_processor.load(dataset, documents)
index.add_texts(documents)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
......
...@@ -5,7 +5,7 @@ import click ...@@ -5,7 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document from langchain.schema import Document
from core.index.index import IndexBuilder from core.rag.datasource.vdb.vector_factory import Vector
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
...@@ -48,9 +48,9 @@ def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: s ...@@ -48,9 +48,9 @@ def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: s
"doc_id": annotation_id "doc_id": annotation_id
} }
) )
index = IndexBuilder.get_index(dataset, 'high_quality') vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
if index: vector.create([document], duplicate_check=True)
index.add_texts([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style( click.style(
......
...@@ -6,7 +6,7 @@ from celery import shared_task ...@@ -6,7 +6,7 @@ from celery import shared_task
from langchain.schema import Document from langchain.schema import Document
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
...@@ -79,9 +79,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: ...@@ -79,9 +79,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
collection_binding_id=dataset_collection_binding.id collection_binding_id=dataset_collection_binding.id
) )
index = IndexBuilder.get_index(dataset, 'high_quality') vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
if index: vector.create(documents, duplicate_check=True)
index.add_texts(documents)
db.session.commit() db.session.commit()
redis_client.setex(indexing_cache_key, 600, 'completed') redis_client.setex(indexing_cache_key, 600, 'completed')
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from core.index.index import IndexBuilder from core.rag.datasource.vdb.vector_factory import Vector
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
...@@ -30,12 +30,11 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str ...@@ -30,12 +30,11 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
collection_binding_id=dataset_collection_binding.id collection_binding_id=dataset_collection_binding.id
) )
vector_index = IndexBuilder.get_default_high_quality_index(dataset) try:
if vector_index: vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
try: vector.delete_by_metadata_field('annotation_id', annotation_id)
vector_index.delete_by_metadata_field('annotation_id', annotation_id) except Exception:
except Exception: logging.exception("Delete annotation index failed when annotation deleted.")
logging.exception("Delete annotation index failed when annotation deleted.")
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
......
...@@ -5,7 +5,7 @@ import click ...@@ -5,7 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
...@@ -48,12 +48,11 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): ...@@ -48,12 +48,11 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
collection_binding_id=app_annotation_setting.collection_binding_id collection_binding_id=app_annotation_setting.collection_binding_id
) )
vector_index = IndexBuilder.get_default_high_quality_index(dataset) try:
if vector_index: vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
try: vector.delete_by_metadata_field('app_id', app_id)
vector_index.delete_by_metadata_field('app_id', app_id) except Exception:
except Exception: logging.exception("Delete annotation index failed when annotation deleted.")
logging.exception("Delete doc index failed when dataset deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, 'completed') redis_client.setex(disable_app_annotation_job_key, 600, 'completed')
# delete annotation setting # delete annotation setting
......
...@@ -7,7 +7,7 @@ from celery import shared_task ...@@ -7,7 +7,7 @@ from celery import shared_task
from langchain.schema import Document from langchain.schema import Document
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
...@@ -81,15 +81,15 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_ ...@@ -81,15 +81,15 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_
} }
) )
documents.append(document) documents.append(document)
index = IndexBuilder.get_index(dataset, 'high_quality')
if index: vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
try: try:
index.delete_by_metadata_field('app_id', app_id) vector.delete_by_metadata_field('app_id', app_id)
except Exception as e: except Exception as e:
logging.info( logging.info(
click.style('Delete annotation index error: {}'.format(str(e)), click.style('Delete annotation index error: {}'.format(str(e)),
fg='red')) fg='red'))
index.add_texts(documents) vector.add_texts(documents)
db.session.commit() db.session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, 'completed') redis_client.setex(enable_app_annotation_job_key, 600, 'completed')
end_at = time.perf_counter() end_at = time.perf_counter()
......
...@@ -5,7 +5,7 @@ import click ...@@ -5,7 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document from langchain.schema import Document
from core.index.index import IndexBuilder from core.rag.datasource.vdb.vector_factory import Vector
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
...@@ -49,10 +49,9 @@ def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id ...@@ -49,10 +49,9 @@ def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id
"doc_id": annotation_id "doc_id": annotation_id
} }
) )
index = IndexBuilder.get_index(dataset, 'high_quality') vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
if index: vector.delete_by_metadata_field('annotation_id', annotation_id)
index.delete_by_metadata_field('annotation_id', annotation_id) vector.add_texts([document])
index.add_texts([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style( click.style(
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import ( from models.dataset import (
AppDatasetJoin, AppDatasetJoin,
...@@ -18,7 +18,7 @@ from models.dataset import ( ...@@ -18,7 +18,7 @@ from models.dataset import (
@shared_task(queue='dataset') @shared_task(queue='dataset')
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct: str, collection_binding_id: str): index_struct: str, collection_binding_id: str, doc_form: str):
""" """
Clean dataset when dataset deleted. Clean dataset when dataset deleted.
:param dataset_id: dataset id :param dataset_id: dataset id
...@@ -26,6 +26,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -26,6 +26,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
:param indexing_technique: indexing technique :param indexing_technique: indexing technique
:param index_struct: index struct dict :param index_struct: index struct dict
:param collection_binding_id: collection binding id :param collection_binding_id: collection binding id
:param doc_form: dataset form
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
""" """
...@@ -38,26 +39,14 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -38,26 +39,14 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
tenant_id=tenant_id, tenant_id=tenant_id,
indexing_technique=indexing_technique, indexing_technique=indexing_technique,
index_struct=index_struct, index_struct=index_struct,
collection_binding_id=collection_binding_id collection_binding_id=collection_binding_id,
doc_form=doc_form
) )
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
kw_index = IndexBuilder.get_index(dataset, 'economy') index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None)
# delete from vector index
if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try:
vector_index.delete_by_group_id(dataset.id)
except Exception:
logging.exception("Delete doc index failed when dataset deleted.")
# delete from keyword index
try:
kw_index.delete()
except Exception:
logging.exception("Delete nodes index failed when dataset deleted.")
for document in documents: for document in documents:
db.session.delete(document) db.session.delete(document)
......
...@@ -4,17 +4,18 @@ import time ...@@ -4,17 +4,18 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
@shared_task(queue='dataset') @shared_task(queue='dataset')
def clean_document_task(document_id: str, dataset_id: str): def clean_document_task(document_id: str, dataset_id: str, doc_form: str):
""" """
Clean document when document deleted. Clean document when document deleted.
:param document_id: document id :param document_id: document id
:param dataset_id: dataset id :param dataset_id: dataset id
:param doc_form: doc_form
Usage: clean_document_task.delay(document_id, dataset_id) Usage: clean_document_task.delay(document_id, dataset_id)
""" """
...@@ -27,21 +28,12 @@ def clean_document_task(document_id: str, dataset_id: str): ...@@ -27,21 +28,12 @@ def clean_document_task(document_id: str, dataset_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
# check segment is exist # check segment is exist
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
# delete from vector index index_processor.clean(dataset, index_node_ids)
if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index
if index_node_ids:
kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
...@@ -26,9 +26,8 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): ...@@ -26,9 +26,8 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
index_type = dataset.doc_form
vector_index = IndexBuilder.get_index(dataset, 'high_quality') index_processor = IndexProcessorFactory(index_type).init_index_processor()
kw_index = IndexBuilder.get_index(dataset, 'economy')
for document_id in document_ids: for document_id in document_ids:
document = db.session.query(Document).filter( document = db.session.query(Document).filter(
Document.id == document_id Document.id == document_id
...@@ -38,13 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): ...@@ -38,13 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index index_processor.clean(dataset, index_node_ids)
if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index
if index_node_ids:
kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
......
...@@ -5,10 +5,10 @@ from typing import Optional ...@@ -5,10 +5,10 @@ from typing import Optional
import click import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
...@@ -68,18 +68,9 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] ...@@ -68,18 +68,9 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return return
# save vector index index_type = dataset.doc_form
index = IndexBuilder.get_index(dataset, 'high_quality') index_processor = IndexProcessorFactory(index_type).init_index_processor()
if index: index_processor.load(dataset, [document])
index.add_texts([document], duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
if keywords and len(keywords) > 0:
index.create_segment_keywords(segment.index_node_id, keywords)
else:
index.add_texts([document])
# update segment to completed # update segment to completed
update_params = { update_params = {
......
...@@ -3,9 +3,9 @@ import time ...@@ -3,9 +3,9 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
...@@ -29,10 +29,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -29,10 +29,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if not dataset: if not dataset:
raise Exception('Dataset not found') raise Exception('Dataset not found')
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove": if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) index_processor.clean(dataset, None, with_keywords=False)
index.delete_by_group_id(dataset.id)
elif action == "add": elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter( dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,
...@@ -42,8 +42,6 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -42,8 +42,6 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
).all() ).all()
if dataset_documents: if dataset_documents:
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
documents = [] documents = []
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
# delete from vector index # delete from vector index
...@@ -65,7 +63,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -65,7 +63,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document) documents.append(document)
# save vector index # save vector index
index.create(documents) index_processor.load(dataset, documents, with_keywords=False)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document from models.dataset import Dataset, Document
...@@ -39,15 +39,9 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ ...@@ -39,15 +39,9 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan')) logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan'))
return return
vector_index = IndexBuilder.get_index(dataset, 'high_quality') index_type = dataset_document.doc_form
kw_index = IndexBuilder.get_index(dataset, 'economy') index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [index_node_id])
# delete from vector index
if vector_index:
vector_index.delete_by_ids([index_node_id])
# delete from keyword index
kw_index.delete_by_ids([index_node_id])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green')) logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green'))
......
...@@ -5,7 +5,7 @@ import click ...@@ -5,7 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
...@@ -48,15 +48,9 @@ def disable_segment_from_index_task(segment_id: str): ...@@ -48,15 +48,9 @@ def disable_segment_from_index_task(segment_id: str):
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return return
vector_index = IndexBuilder.get_index(dataset, 'high_quality') index_type = dataset_document.doc_form
kw_index = IndexBuilder.get_index(dataset, 'economy') index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [segment.index_node_id])
# delete from vector index
if vector_index:
vector_index.delete_by_ids([segment.index_node_id])
# delete from keyword index
kw_index.delete_by_ids([segment.index_node_id])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
......
...@@ -6,9 +6,9 @@ import click ...@@ -6,9 +6,9 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.data_loader.loader.notion import NotionLoader
from core.index.index import IndexBuilder
from core.indexing_runner import DocumentIsPausedException, IndexingRunner from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from models.source import DataSourceBinding from models.source import DataSourceBinding
...@@ -54,11 +54,11 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ...@@ -54,11 +54,11 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
if not data_source_binding: if not data_source_binding:
raise ValueError('Data source binding not found.') raise ValueError('Data source binding not found.')
loader = NotionLoader( loader = NotionExtractor(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type notion_page_type=page_type,
notion_access_token=data_source_binding.access_token
) )
last_edited_time = loader.get_notion_last_edited_time() last_edited_time = loader.get_notion_last_edited_time()
...@@ -74,20 +74,14 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ...@@ -74,20 +74,14 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise Exception('Dataset not found') raise Exception('Dataset not found')
index_type = document.doc_form
vector_index = IndexBuilder.get_index(dataset, 'high_quality') index_processor = IndexProcessorFactory(index_type).init_index_processor()
kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
if vector_index: index_processor.clean(dataset, index_node_ids)
vector_index.delete_by_document_id(document_id)
# delete from keyword index
if index_node_ids:
kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
......
...@@ -6,8 +6,8 @@ import click ...@@ -6,8 +6,8 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from core.indexing_runner import DocumentIsPausedException, IndexingRunner from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
...@@ -42,19 +42,14 @@ def document_indexing_update_task(dataset_id: str, document_id: str): ...@@ -42,19 +42,14 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
if not dataset: if not dataset:
raise Exception('Dataset not found') raise Exception('Dataset not found')
vector_index = IndexBuilder.get_index(dataset, 'high_quality') index_type = document.doc_form
kw_index = IndexBuilder.get_index(dataset, 'economy') index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
if vector_index: index_processor.clean(dataset, index_node_ids)
vector_index.delete_by_ids(index_node_ids)
# delete from keyword index
if index_node_ids:
kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
......
...@@ -4,10 +4,10 @@ import time ...@@ -4,10 +4,10 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
...@@ -60,15 +60,9 @@ def enable_segment_to_index_task(segment_id: str): ...@@ -60,15 +60,9 @@ def enable_segment_to_index_task(segment_id: str):
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return return
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
# save vector index # save vector index
index = IndexBuilder.get_index(dataset, 'high_quality') index_processor.load(dataset, [document])
if index:
index.add_texts([document], duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
......
...@@ -5,7 +5,7 @@ import click ...@@ -5,7 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Document, DocumentSegment from models.dataset import Document, DocumentSegment
...@@ -37,18 +37,12 @@ def remove_document_from_index_task(document_id: str): ...@@ -37,18 +37,12 @@ def remove_document_from_index_task(document_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = IndexBuilder.get_index(dataset, 'high_quality') index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete_by_document_id(document.id)
# delete from keyword index
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids: if index_node_ids:
kw_index.delete_by_ids(index_node_ids) index_processor.clean(dataset, index_node_ids)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
......
import datetime
import logging
import time
from typing import Optional
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@shared_task(queue='dataset')
def update_segment_index_task(segment_id: str, keywords: Optional[list[str]] = None):
"""
Async update segment index
:param segment_id:
:param keywords:
Usage: update_segment_index_task.delay(segment_id)
"""
logging.info(click.style('Start update segment index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
if not segment:
raise NotFound('Segment not found')
if segment.status != 'updating':
return
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
try:
dataset = segment.dataset
if not dataset:
logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
return
dataset_document = segment.document
if not dataset_document:
logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return
# update segment status to indexing
update_params = {
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.commit()
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete_by_ids([segment.index_node_id])
# delete from keyword index
kw_index.delete_by_ids([segment.index_node_id])
# add new index
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality')
if index:
index.add_texts([document], duplicate_check=True)
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
if keywords and len(keywords) > 0:
index.create_segment_keywords(segment.index_node_id, keywords)
else:
index.add_texts([document])
# update segment to completed
update_params = {
DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.utcnow()
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.commit()
end_at = time.perf_counter()
logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
except Exception as e:
logging.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.utcnow()
segment.status = 'error'
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
import datetime
import logging
import time
import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@shared_task(queue='dataset')
def update_segment_keyword_index_task(segment_id: str):
"""
Async update segment index
:param segment_id:
Usage: update_segment_keyword_index_task.delay(segment_id)
"""
logging.info(click.style('Start update segment keyword index: {}'.format(segment_id), fg='green'))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
if not segment:
raise NotFound('Segment not found')
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
try:
dataset = segment.dataset
if not dataset:
logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
return
dataset_document = segment.document
if not dataset_document:
logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from keyword index
kw_index.delete_by_ids([segment.index_node_id])
# save keyword index
index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
end_at = time.perf_counter()
logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
except Exception as e:
logging.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = datetime.datetime.utcnow()
segment.status = 'error'
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment