Unverified Commit 6c4e6bf1 authored by Jyong's avatar Jyong Committed by GitHub

Feat/dify rag (#2528)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 97fe8171
...@@ -56,6 +56,7 @@ DEFAULTS = { ...@@ -56,6 +56,7 @@ DEFAULTS = {
'BILLING_ENABLED': 'False', 'BILLING_ENABLED': 'False',
'CAN_REPLACE_LOGO': 'False', 'CAN_REPLACE_LOGO': 'False',
'ETL_TYPE': 'dify', 'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20 'BATCH_UPLOAD_LIMIT': 20
} }
...@@ -183,7 +184,7 @@ class Config: ...@@ -183,7 +184,7 @@ class Config:
# Currently, only support: qdrant, milvus, zilliz, weaviate # Currently, only support: qdrant, milvus, zilliz, weaviate
# ------------------------ # ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE') self.VECTOR_STORE = get_env('VECTOR_STORE')
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
# qdrant settings # qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
......
...@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound ...@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required from libs.login import login_required
...@@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource): ...@@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource):
if not data_source_binding: if not data_source_binding:
raise NotFound('Data source binding not found.') raise NotFound('Data source binding not found.')
loader = NotionLoader( extractor = NotionExtractor(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type notion_page_type=page_type,
notion_access_token=data_source_binding.access_token
) )
text_docs = loader.load() text_docs = extractor.extract()
return { return {
'content': "\n".join([doc.page_content for doc in text_docs]) 'content': "\n".join([doc.page_content for doc in text_docs])
}, 200 }, 200
...@@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource): ...@@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
notion_info_list = args['notion_info_list']
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type']
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'])
return response, 200 return response, 200
......
...@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError ...@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
...@@ -178,9 +179,9 @@ class DatasetApi(Resource): ...@@ -178,9 +179,9 @@ class DatasetApi(Resource):
location='json', store_missing=False, location='json', store_missing=False,
type=_validate_description_length) type=_validate_description_length)
parser.add_argument('indexing_technique', type=str, location='json', parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST, choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, nullable=True,
help='Invalid indexing technique.') help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=( parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.') 'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
...@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True, parser.add_argument('indexing_technique', type=str, required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST, choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, location='json') nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
...@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
extract_settings = []
if args['info_list']['data_source_type'] == 'upload_file': if args['info_list']['data_source_type'] == 'upload_file':
file_ids = args['info_list']['file_info_list']['file_ids'] file_ids = args['info_list']['file_info_list']['file_ids']
file_details = db.session.query(UploadFile).filter( file_details = db.session.query(UploadFile).filter(
...@@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource):
if file_details is None: if file_details is None:
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() if file_details:
for file_detail in file_details:
try: extract_setting = ExtractSetting(
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, datasource_type="upload_file",
args['process_rule'], args['doc_form'], upload_file=file_detail,
args['doc_language'], args['dataset_id'], document_model=args['doc_form']
args['indexing_technique']) )
except LLMBadRequestError: extract_settings.append(extract_setting)
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif args['info_list']['data_source_type'] == 'notion_import': elif args['info_list']['data_source_type'] == 'notion_import':
notion_info_list = args['info_list']['notion_info_list']
indexing_runner = IndexingRunner() for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
try: for page in notion_info['pages']:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, extract_setting = ExtractSetting(
args['info_list']['notion_info_list'], datasource_type="notion_import",
args['process_rule'], args['doc_form'], notion_info={
args['doc_language'], args['dataset_id'], "notion_workspace_id": workspace_id,
args['indexing_technique']) "notion_obj_id": page['page_id'],
except LLMBadRequestError: "notion_page_type": page['type']
raise ProviderNotInitializeError( },
"No Embedding Model available. Please configure a valid provider " document_model=args['doc_form']
"in the Settings -> Model Provider.") )
except ProviderTokenNotInitError as ex: extract_settings.append(extract_setting)
raise ProviderNotInitializeError(ex.description)
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
return response, 200 return response, 200
...@@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') ...@@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
...@@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner ...@@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.document_fields import ( from fields.document_fields import (
...@@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource): ...@@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource):
req_data = request.args req_data = request.args
document_id = req_data.get('document_id') document_id = req_data.get('document_id')
# get default rules # get default rules
mode = DocumentService.DEFAULT_RULES['mode'] mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules'] rules = DocumentService.DEFAULT_RULES['rules']
...@@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource): ...@@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource):
if not file: if not file:
raise NotFound('File not found.') raise NotFound('File not found.')
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file,
document_model=document.doc_form
)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
data_process_rule_dict, None, data_process_rule_dict, document.doc_form,
'English', dataset_id) 'English', dataset_id)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
...@@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): ...@@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule = documents[0].dataset_process_rule data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict()
info_list = [] info_list = []
extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in ['completed', 'error']: if document.indexing_status in ['completed', 'error']:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
...@@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): ...@@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
} }
info_list.append(notion_info) info_list.append(notion_info)
if dataset.data_source_type == 'upload_file': if document.data_source_type == 'upload_file':
file_details = db.session.query(UploadFile).filter( file_id = data_source_info['upload_file_id']
UploadFile.tenant_id == current_user.current_tenant_id, file_detail = db.session.query(UploadFile).filter(
UploadFile.id.in_(info_list) UploadFile.tenant_id == current_user.current_tenant_id,
).all() UploadFile.id == file_id
).first()
if file_details is None: if file_detail is None:
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() extract_setting = ExtractSetting(
try: datasource_type="upload_file",
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, upload_file=file_detail,
data_process_rule_dict, None, document_model=document.doc_form
'English', dataset_id) )
except LLMBadRequestError: extract_settings.append(extract_setting)
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " elif document.data_source_type == 'notion_import':
"in the Settings -> Model Provider.") extract_setting = ExtractSetting(
except ProviderTokenNotInitError as ex: datasource_type="notion_import",
raise ProviderNotInitializeError(ex.description) notion_info={
elif dataset.data_source_type == 'notion_import': "notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['type']
},
document_model=document.doc_form
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
info_list, data_process_rule_dict, document.doc_form,
data_process_rule_dict, 'English', dataset_id)
None, 'English', dataset_id)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.") "in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response return response
......
import tempfile
from pathlib import Path
from typing import Optional, Union
import requests
from flask import current_app
from langchain.document_loaders import Docx2txtLoader, TextLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
return cls.load_from_file(file_path, return_text)
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None,
is_automatic: bool = False) -> Union[list[Document], str]:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
else MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
elif file_extension == '.eml':
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
elif file_extension == '.ppt':
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
elif file_extension == '.pptx':
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
elif file_extension == '.xml':
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
else:
# txt
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
else TextLoader(file_path, autodetect_encoding=True)
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
import logging
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class HTMLLoader(BaseLoader):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> list[Document]:
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
import logging
from typing import Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfLoader(BaseLoader):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> list[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.0625.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = PyPDFium2Loader(file_path=self._file_path).load()
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from langchain.schema import Document
from sqlalchemy import func from sqlalchemy import func
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
......
import logging import logging
from typing import Optional from typing import Optional
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.index.vector_index.vector_index import VectorIndex from core.rag.datasource.vdb.vector_factory import Vector
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
...@@ -45,17 +40,6 @@ class AnnotationReplyFeature: ...@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
embedding_provider_name = collection_binding_detail.provider_name embedding_provider_name = collection_binding_detail.provider_name
embedding_model_name = collection_binding_detail.model_name embedding_model_name = collection_binding_detail.model_name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=app_record.tenant_id,
provider=embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name
)
# get embedding model
embeddings = CacheEmbedding(model_instance)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_provider_name,
embedding_model_name, embedding_model_name,
...@@ -71,22 +55,14 @@ class AnnotationReplyFeature: ...@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
collection_binding_id=dataset_collection_binding.id collection_binding_id=dataset_collection_binding.id
) )
vector_index = VectorIndex( vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
documents = vector_index.search( documents = vector.search_by_vector(
query=query, query=query,
search_type='similarity_score_threshold', k=1,
search_kwargs={ score_threshold=score_threshold,
'k': 1, filter={
'score_threshold': score_threshold, 'group_id': [dataset.id]
'filter': {
'group_id': [dataset.id]
}
} }
) )
......
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
if indexing_technique == "high_quality":
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
provider=dataset.embedding_model_provider,
model=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif indexing_technique == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
@classmethod
def get_default_high_quality_index(cls, dataset: Dataset):
embeddings = OpenAIEmbeddings(openai_api_key=' ')
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
This diff is collapsed.
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from models.dataset import Dataset
class MilvusConfig(BaseModel):
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'milvus'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
connection_args=self._client_config.to_milvus_params(),
index_params=index_params
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content'
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),
embedding_function=self._embeddings,
connection_args=self._client_config.to_milvus_params()
)
def _get_vector_store_class(self) -> type:
return MilvusVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_document_id(document_id)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_metadata_field(key, value)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_doc_ids(doc_ids)
vector_store.del_texts({
'filter': f' id in {ids}'
})
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
import os
from typing import Any, Optional, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
}
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
return dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
))
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
for node_id in ids:
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
))
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
return vector_store.similarity_search_by_bm25(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
],
), kwargs.get('top_k', 2))
import json
from flask import current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.base import BaseVectorIndex
from extensions.ext_database import db
from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
self._attributes = attributes
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError("Vector store must be specified.")
if vector_type == "weaviate":
from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings,
attributes=attributes
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
return QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
),
embeddings=embeddings
)
elif vector_type == "milvus":
from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex
return MilvusVectorIndex(
dataset=dataset,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name):
if self._vector_index is not None:
method = getattr(self._vector_index, name)
if callable(method):
return method
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
from typing import Any, Optional, cast
import requests
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = self._attributes
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": [key],
"valueText": value
})
def delete_by_group_id(self, group_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
This diff is collapsed.
import re
class CleanProcessor:
@classmethod
def clean(cls, text: str, process_rule: dict) -> str:
# default clean
# remove invalid symbol
text = re.sub(r'<\|', '<', text)
text = re.sub(r'\|>', '>', text)
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
# Unicode U+FFFE
text = re.sub('\uFFFE', '', text)
rules = process_rule['rules'] if process_rule else None
if 'pre_processing_rules' in rules:
pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules:
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
# Remove extra spaces
pattern = r'\n{3,}'
text = re.sub(pattern, '\n\n', text)
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
text = re.sub(pattern, ' ', text)
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
# Remove email
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
text = re.sub(pattern, '', text)
# Remove URL
pattern = r'https?://[^\s]+'
text = re.sub(pattern, '', text)
return text
def filter_string(self, text):
return text
"""Abstract interface for document cleaner implementations."""
from abc import ABC, abstractmethod
class BaseCleaner(ABC):
"""Interface for clean chunk content.
"""
@abstractmethod
def clean(self, content: str):
raise NotImplementedError
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_extra_whitespace
# Returns "ITEM 1A: RISK FACTORS"
return clean_extra_whitespace(content)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
import re
from unstructured.cleaners.core import group_broken_paragraphs
para_split_re = re.compile(r"(\s*\n\s*){3}")
return group_broken_paragraphs(content, paragraph_split=para_split_re)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_non_ascii_chars
# Returns "This text containsnon-ascii characters!"
return clean_non_ascii_chars(content)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""Replaces unicode quote characters, such as the \x91 character in a string."""
from unstructured.cleaners.core import replace_unicode_quotes
return replace_unicode_quotes(content)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredTranslateTextCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.translate import translate_text
return translate_text(content)
from typing import Optional
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.models.document import Document
from core.rerank.rerank import RerankRunner
class DataPostProcessor:
"""Interface for data post-processing document.
"""
def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)
return documents
def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
if reranking_model:
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return None
return RerankRunner(rerank_model_instance)
return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
if reorder_enabled:
return ReorderRunner()
return None
from langchain.schema import Document
class ReorderRunner:
def run(self, documents: list[Document]) -> list[Document]:
# Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
odd_elements = documents[::2]
# Retrieve elements from even indices (1, 3, 5, etc.) of the documents list
even_elements = documents[1::2]
# Reverse the list of elements from even indices
even_elements_reversed = even_elements[::-1]
new_documents = odd_elements + even_elements_reversed
return new_documents
from abc import ABC, abstractmethod
class Embeddings(ABC):
"""Interface for embedding models."""
@abstractmethod
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronous Embed query text."""
raise NotImplementedError
...@@ -2,11 +2,11 @@ import json ...@@ -2,11 +2,11 @@ import json
from collections import defaultdict from collections import defaultdict
from typing import Any, Optional from typing import Any, Optional
from langchain.schema import BaseRetriever, Document from pydantic import BaseModel
from pydantic import BaseModel, Extra, Field
from core.index.base import BaseIndex from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
...@@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel): ...@@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10 max_keywords_per_chunk: int = 10
class KeywordTableIndex(BaseIndex): class Jieba(BaseKeyword):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): def __init__(self, dataset: Dataset):
super().__init__(dataset) super().__init__(dataset)
self._config = config self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {} keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts: for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
return self return self
...@@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex): ...@@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex):
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
for text in texts: keywords_list = kwargs.get('keywords_list', None)
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
...@@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex): ...@@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
def delete_by_metadata_field(self, key: str, value: str):
pass
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
def search( def search(
self, query: str, self, query: str,
**kwargs: Any **kwargs: Any
) -> list[Document]: ) -> list[Document]:
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} k = kwargs.get('top_k', 4)
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
...@@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex): ...@@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex):
db.session.delete(dataset_keyword_table) db.session.delete(dataset_keyword_table)
db.session.commit() db.session.commit()
def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table): def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = { keyword_table_dict = {
'__type__': 'keyword_table', '__type__': 'keyword_table',
...@@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex): ...@@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex):
).first() ).first()
if document_segment: if document_segment:
document_segment.keywords = keywords document_segment.keywords = keywords
db.session.add(document_segment)
db.session.commit() db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: list[str]): def create_segment_keywords(self, node_id: str, keywords: list[str]):
...@@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex): ...@@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> list[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> list[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")
class SetEncoder(json.JSONEncoder): class SetEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, set): if isinstance(obj, set):
......
...@@ -3,7 +3,7 @@ import re ...@@ -3,7 +3,7 @@ import re
import jieba import jieba
from jieba.analyse import default_tfidf from jieba.analyse import default_tfidf
from core.index.keyword_table_index.stopwords import STOPWORDS from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler: class JiebaKeywordTableHandler:
......
...@@ -3,22 +3,17 @@ from __future__ import annotations ...@@ -3,22 +3,17 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from langchain.schema import BaseRetriever, Document from core.rag.models.document import Document
from models.dataset import Dataset from models.dataset import Dataset
class BaseIndex(ABC): class BaseKeyword(ABC):
def __init__(self, dataset: Dataset): def __init__(self, dataset: Dataset):
self.dataset = dataset self.dataset = dataset
@abstractmethod @abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
raise NotImplementedError
@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -34,31 +29,18 @@ class BaseIndex(ABC): ...@@ -34,31 +29,18 @@ class BaseIndex(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_document_id(self, document_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError raise NotImplementedError
@abstractmethod def delete(self) -> None:
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def search( def search(
self, query: str, self, query: str,
**kwargs: Any **kwargs: Any
) -> list[Document]: ) -> list[Document]:
raise NotImplementedError raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts: for text in texts:
doc_id = text.metadata['doc_id'] doc_id = text.metadata['doc_id']
......
from typing import Any, cast
from flask import current_app
from core.rag.datasource.keyword.jieba.jieba import Jieba
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from models.dataset import Dataset
class Keyword:
def __init__(self, dataset: Dataset):
self._dataset = dataset
self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword:
config = cast(dict, current_app.config)
keyword_type = config.get('KEYWORD_STORE')
if not keyword_type:
raise ValueError("Keyword store must be specified.")
if keyword_type == "jieba":
return Jieba(
dataset=self._dataset
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def create(self, texts: list[Document], **kwargs):
self._keyword_processor.create(texts, **kwargs)
def add_texts(self, texts: list[Document], **kwargs):
self._keyword_processor.add_texts(texts, **kwargs)
def text_exists(self, id: str) -> bool:
return self._keyword_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
self._keyword_processor.delete_by_ids(ids)
def delete_by_document_id(self, document_id: str) -> None:
self._keyword_processor.delete_by_document_id(document_id)
def delete(self) -> None:
self._keyword_processor.delete()
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
return self._keyword_processor.search(query, **kwargs)
def __getattr__(self, name):
if self._keyword_processor is not None:
method = getattr(self._keyword_processor, name)
if callable(method):
return method
raise AttributeError(f"'Keyword' object has no attribute '{name}'")
import threading
from typing import Optional from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from langchain.embeddings.base import Embeddings from flask_login import current_user
from core.index.vector_index.vector_index import VectorIndex from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.model_manager import ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword
from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.vdb.vector_factory import Vector
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
...@@ -25,48 +24,109 @@ default_retrieval_model = { ...@@ -25,48 +24,109 @@ default_retrieval_model = {
class RetrievalService: class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
all_documents = []
threads = []
# retrieval_model source with keyword
if retrival_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k
})
threads.append(keyword_thread)
keyword_thread.start()
# retrieval_model source with semantic
if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
'retrival_method': retrival_method
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'retrival_method': retrival_method,
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrival_method == 'hybrid_search':
data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
)
return all_documents
@classmethod
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
keyword = Keyword(
dataset=dataset
)
documents = keyword.search(
query,
k=top_k
)
all_documents.extend(documents)
@classmethod @classmethod
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, retrival_method: str):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter( dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id Dataset.id == dataset_id
).first() ).first()
vector_index = VectorIndex( vector = Vector(
dataset=dataset, dataset=dataset
config=current_app.config,
embeddings=embeddings
) )
documents = vector_index.search( documents = vector.search_by_vector(
query, query,
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
search_kwargs={ k=top_k,
'k': top_k, score_threshold=score_threshold,
'score_threshold': score_threshold, filter={
'filter': { 'group_id': [dataset.id]
'group_id': [dataset.id]
}
} }
) )
if documents: if documents:
if reranking_model and search_method == 'semantic_search': if reranking_model and retrival_method == 'semantic_search':
try: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
model_manager = ModelManager() all_documents.extend(data_post_processor.invoke(
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query, query=query,
documents=documents, documents=documents,
score_threshold=score_threshold, score_threshold=score_threshold,
...@@ -78,38 +138,24 @@ class RetrievalService: ...@@ -78,38 +138,24 @@ class RetrievalService:
@classmethod @classmethod
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, retrival_method: str):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter( dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id Dataset.id == dataset_id
).first() ).first()
vector_index = VectorIndex( vector_processor = Vector(
dataset=dataset, dataset=dataset,
config=current_app.config,
embeddings=embeddings
) )
documents = vector_index.search_by_full_text_index( documents = vector_processor.search_by_full_text(
query, query,
search_type='similarity_score_threshold',
top_k=top_k top_k=top_k
) )
if documents: if documents:
if reranking_model and search_method == 'full_text_search': if reranking_model and retrival_method == 'full_text_search':
try: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
model_manager = ModelManager() all_documents.extend(data_post_processor.invoke(
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query, query=query,
documents=documents, documents=documents,
score_threshold=score_threshold, score_threshold=score_threshold,
......
from enum import Enum
class Field(Enum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
TEXT_KEY = "text"
PRIMARY_KEY = " id"
import logging
from typing import Any, Optional
from uuid import uuid4
from pydantic import BaseModel, root_validator
from pymilvus import MilvusClient, MilvusException, connections
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVector(BaseVector):
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = 'Session'
self._fields = []
def get_type(self) -> str:
return 'milvus'
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
metadatas = [d.metadata for d in texts]
# Grab the existing collection if it exists
from pymilvus import utility
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
if not utility.has_collection(self._collection_name, using=alias):
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata
}
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i:i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def get_ids_by_metadata_field(self, key: str, value: str):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=["id"])
if result:
return [item["id"] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, doc_ids: list[str]) -> None:
self._client.delete(collection_name=self._collection_name, pks=doc_ids)
def delete(self) -> None:
from pymilvus import utility
utility.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result['entity'].get(Field.METADATA_KEY.value)
metadata['score'] = result['distance']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if result['distance'] > score_threshold:
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
) -> str:
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
fields.append(
FieldSchema(
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(fields)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
self._fields.remove(Field.PRIMARY_KEY.value)
# Create the collection
collection_name = self._collection_name
self._client.create_collection_with_schema(collection_name=collection_name,
schema=schema, index_param=index_params,
consistency_level=self._consistency_level)
return collection_name
def _init_client(self, config) -> MilvusClient:
if config.secure:
uri = "https://" + str(config.host) + ":" + str(config.port)
else:
uri = "http://" + str(config.host) + ":" + str(config.port)
client = MilvusClient(uri=uri, user=config.user, password=config.password)
return client
This diff is collapsed.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name
@abstractmethod
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
def search_by_vector(
self,
query_vector: list[float],
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
from typing import Any, cast
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class Vector:
def __init__(self, dataset: Dataset, attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector:
config = cast(dict, current_app.config)
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError("Vector store must be specified.")
if vector_type == "weaviate":
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
attributes=self._attributes
)
elif vector_type == "qdrant":
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
if self._dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self._dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return QdrantVector(
collection_name=collection_name,
group_id=self._dataset.id,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
)
)
elif vector_type == "milvus":
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
)
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def create(self, texts: list = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(
texts=texts,
embeddings=embeddings,
**kwargs
)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get('duplicate_check', False):
documents = self._filter_duplicate_texts(documents)
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.add_texts(
documents=documents,
embeddings=embeddings,
**kwargs
)
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
self._vector_processor.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(
self, query: str,
**kwargs: Any
) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self) -> None:
self._vector_processor.delete()
def _get_embeddings(self) -> Embeddings:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
)
return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def __getattr__(self, name):
if self._vector_processor is not None:
method = getattr(self._vector_processor, name)
if callable(method):
return method
raise AttributeError(f"'vector_processor' object has no attribute '{name}'")
import datetime
from typing import Any, Optional
import requests
import weaviate
from pydantic import BaseModel, root_validator
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
super().__init__(collection_name)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_collection_name(self, dataset: Dataset) -> str:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
schema = self._default_schema(self._collection_name)
# check whether the index already exists
if not self._client.schema.contains(schema):
# create collection
self._client.schema.create_class(schema)
# create vector
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
ids = []
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
if metadatas is not None:
for key, val in metadatas[i].items():
data_properties[key] = self._json_serializable(val)
batch.add_data_object(
data_object=data_properties,
class_name=self._collection_name,
uuid=uuids[i],
vector=embeddings[i] if embeddings else None,
)
ids.append(uuids[i])
return ids
def delete_by_metadata_field(self, key: str, value: str):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
def delete(self):
self._client.schema.delete_class(self._collection_name)
def text_exists(self, id: str) -> bool:
collection_name = self._collection_name
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][collection_name]
if len(entries) == 0:
return False
return True
def delete_by_ids(self, ids: list[str]) -> None:
self._client.data_object.delete(
ids,
class_name=self._collection_name
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
collection_name = self._collection_name
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
result = (
query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4))
.with_additional(["vector", "distance"])
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
# check score threshold
if score > score_threshold:
doc.metadata['score'] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
docs.append(Document(page_content=text, metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:
return {
"class": index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
def _json_serializable(self, value: Any) -> Any:
if isinstance(value, datetime.datetime):
return value.isoformat()
return value
"""Schema for Blobs and Blob Loaders.
The goal is to facilitate decoupling of content loading from content parsing code.
In addition, content loading code should provide a lazy loading interface by default.
"""
from __future__ import annotations
import contextlib
import mimetypes
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Mapping
from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Optional, Union
from pydantic import BaseModel, root_validator
PathLike = Union[str, PurePath]
class Blob(BaseModel):
"""A blob is used to represent raw data by either reference or value.
Provides an interface to materialize the blob in different representations, and
help to decouple the development of data loaders from the downstream parsing of
the raw data.
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
"""
data: Union[bytes, str, None] # Raw data
mimetype: Optional[str] = None # Not to be confused with a file extension
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
# Location where the original content was found
# Represent location on the local file system
# Useful for situations where downstream code assumes it must work with file paths
# rather than in-memory content.
path: Optional[PathLike] = None
class Config:
arbitrary_types_allowed = True
frozen = True
@property
def source(self) -> Optional[str]:
"""The source location of the blob as string if known otherwise none."""
return str(self.path) if self.path else None
@root_validator(pre=True)
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
"""Verify that either data or path is provided."""
if "data" not in values and "path" not in values:
raise ValueError("Either data or path must be provided")
return values
def as_string(self) -> str:
"""Read data as a string."""
if self.data is None and self.path:
with open(str(self.path), encoding=self.encoding) as f:
return f.read()
elif isinstance(self.data, bytes):
return self.data.decode(self.encoding)
elif isinstance(self.data, str):
return self.data
else:
raise ValueError(f"Unable to get string for blob {self}")
def as_bytes(self) -> bytes:
"""Read data as bytes."""
if isinstance(self.data, bytes):
return self.data
elif isinstance(self.data, str):
return self.data.encode(self.encoding)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
return f.read()
else:
raise ValueError(f"Unable to get bytes for blob {self}")
@contextlib.contextmanager
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
"""Read data as a byte stream."""
if isinstance(self.data, bytes):
yield BytesIO(self.data)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
yield f
else:
raise NotImplementedError(f"Unable to convert blob {self}")
@classmethod
def from_path(
cls,
path: PathLike,
*,
encoding: str = "utf-8",
mime_type: Optional[str] = None,
guess_type: bool = True,
) -> Blob:
"""Load the blob from a path like object.
Args:
path: path like object to file to be read
encoding: Encoding to use if decoding the bytes into a string
mime_type: if provided, will be set as the mime-type of the data
guess_type: If True, the mimetype will be guessed from the file extension,
if a mime-type was not provided
Returns:
Blob instance
"""
if mime_type is None and guess_type:
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
else:
_mimetype = mime_type
# We do not load the data immediately, instead we treat the blob as a
# reference to the underlying data.
return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path)
@classmethod
def from_data(
cls,
data: Union[str, bytes],
*,
encoding: str = "utf-8",
mime_type: Optional[str] = None,
path: Optional[str] = None,
) -> Blob:
"""Initialize the blob from in-memory data.
Args:
data: the in-memory data associated with the blob
encoding: Encoding to use if decoding the bytes into a string
mime_type: if provided, will be set as the mime-type of the data
path: if provided, will be set as the source from which the data came
Returns:
Blob instance
"""
return cls(data=data, mimetype=mime_type, encoding=encoding, path=path)
def __repr__(self) -> str:
"""Define the blob representation."""
str_repr = f"Blob {id(self)}"
if self.source:
str_repr += f" {self.source}"
return str_repr
class BlobLoader(ABC):
"""Abstract interface for blob loaders implementation.
Implementer should be able to load raw content from a datasource system according
to some criteria and return the raw content lazily as a stream of blobs.
"""
@abstractmethod
def yield_blobs(
self,
) -> Iterable[Blob]:
"""A lazy loader for raw data represented by LangChain's Blob object.
Returns:
A generator over blobs
"""
"""Abstract interface for document loader implementations."""
import csv
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class CSVExtractor(BaseExtractor):
"""Load CSV files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None,
csv_args: Optional[dict] = None,
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
self.source_column = source_column
self.csv_args = csv_args or {}
def extract(self) -> list[Document]:
"""Load data into document objects."""
try:
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_filze_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
return docs
def _read_from_file(self, csvfile) -> list[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs
from enum import Enum
class DatasourceType(Enum):
FILE = "upload_file"
NOTION = "notion_import"
from pydantic import BaseModel
from models.dataset import Document
from models.model import UploadFile
class NotionInfo(BaseModel):
"""
Notion import info.
"""
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str
document: Document = None
class Config:
arbitrary_types_allowed = True
def __init__(self, **data) -> None:
super().__init__(**data)
class ExtractSetting(BaseModel):
"""
Model class for provider response.
"""
datasource_type: str
upload_file: UploadFile = None
notion_info: NotionInfo = None
document_model: str = None
class Config:
arbitrary_types_allowed = True
def __init__(self, **data) -> None:
super().__init__(**data)
import logging """Abstract interface for document loader implementations."""
from typing import Optional
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from openpyxl.reader.excel import load_workbook from openpyxl.reader.excel import load_workbook
logger = logging.getLogger(__name__) from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class ExcelLoader(BaseLoader): class ExcelExtractor(BaseExtractor):
"""Load xlxs files. """Load Excel files.
Args: Args:
...@@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader): ...@@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader):
""" """
def __init__( def __init__(
self, self,
file_path: str file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def load(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load from file path."""
data = [] data = []
keys = [] keys = []
wb = load_workbook(filename=self._file_path, read_only=True) wb = load_workbook(filename=self._file_path, read_only=True)
......
import tempfile
from pathlib import Path
from typing import Union
import requests
from flask import current_app
from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.extractor.markdown_extractor import MarkdownExtractor
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.extractor.pdf_extractor import PdfExtractor
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class ExtractProcessor:
@classmethod
def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
-> Union[list[Document], str]:
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=upload_file,
document_model='text_model'
)
if return_text:
delimiter = '\n'
return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
else:
return cls.extract(extract_setting, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
extract_setting = ExtractSetting(
datasource_type="upload_file",
document_model='text_model'
)
if return_text:
delimiter = '\n'
return delimiter.join([document.page_content for document in cls.extract(
extract_setting=extract_setting, file_path=file_path)])
else:
return cls.extract(extract_setting=extract_setting, file_path=file_path)
@classmethod
def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
file_path: str = None) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE.value:
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
upload_file: UploadFile = extract_setting.upload_file
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
extractor = ExcelExtractor(file_path)
elif file_extension == '.pdf':
extractor = PdfExtractor(file_path)
elif file_extension in ['.md', '.markdown']:
extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
else MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
elif file_extension == '.eml':
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
elif file_extension == '.ppt':
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url)
elif file_extension == '.pptx':
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
elif file_extension == '.xml':
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
else:
# txt
extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
else TextExtractor(file_path, autodetect_encoding=True)
else:
if file_extension == '.xlsx':
extractor = ExcelExtractor(file_path)
elif file_extension == '.pdf':
extractor = PdfExtractor(file_path)
elif file_extension in ['.md', '.markdown']:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
extractor = WordExtractor(file_path)
elif file_extension == '.csv':
extractor = CSVExtractor(file_path, autodetect_encoding=True)
else:
# txt
extractor = TextExtractor(file_path, autodetect_encoding=True)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document
)
return extractor.extract()
else:
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
class BaseExtractor(ABC):
"""Interface for extract files.
"""
@abstractmethod
def extract(self):
raise NotImplementedError
"""Document loader helpers."""
import concurrent.futures
from typing import NamedTuple, Optional, cast
class FileEncoding(NamedTuple):
"""A file encoding as the NamedTuple."""
encoding: Optional[str]
"""The encoding of the file."""
confidence: float
"""The confidence of the encoding."""
language: Optional[str]
"""The language of the file."""
def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
"""Try to detect the file encoding.
Returns a list of `FileEncoding` tuples with the detected encodings ordered
by confidence.
Args:
file_path: The path to the file to detect the encoding for.
timeout: The timeout in seconds for the encoding detection.
"""
import chardet
def read_and_detect(file_path: str) -> list[dict]:
with open(file_path, "rb") as f:
rawdata = f.read()
return cast(list[dict], chardet.detect_all(rawdata))
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(read_and_detect, file_path)
try:
encodings = future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
raise TimeoutError(
f"Timeout reached while detecting encoding for {file_path}"
)
if all(encoding["encoding"] is None for encoding in encodings):
raise RuntimeError(f"Could not detect encoding for {file_path}")
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
import csv """Abstract interface for document loader implementations."""
import logging
from typing import Optional from typing import Optional
from langchain.document_loaders import CSVLoader as LCCSVLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.document_loaders.helpers import detect_file_encodings from core.rag.extractor.helpers import detect_file_encodings
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class HtmlExtractor(BaseExtractor):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
class CSVLoader(LCCSVLoader):
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None, source_column: Optional[str] = None,
csv_args: Optional[dict] = None, csv_args: Optional[dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
): ):
self.file_path = file_path """Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
self.source_column = source_column self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {} self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
def load(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load data into document objects.""" """Load data into document objects."""
try: try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile: with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile) docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
if self.autodetect_encoding: if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path) detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings: for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try: try:
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile) docs = self._read_from_file(csvfile)
break break
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue
else: else:
raise RuntimeError(f"Error loading {self.file_path}") from e raise RuntimeError(f"Error loading {self._file_path}") from e
return docs return docs
def _read_from_file(self, csvfile): def _read_from_file(self, csvfile) -> list[Document]:
docs = [] docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader): for i, row in enumerate(csv_reader):
......
import logging """Abstract interface for document loader implementations."""
import re import re
from typing import Optional, cast from typing import Optional, cast
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.document_loaders.helpers import detect_file_encodings from core.rag.extractor.helpers import detect_file_encodings
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class MarkdownExtractor(BaseExtractor):
class MarkdownLoader(BaseLoader): """Load Markdown files.
"""Load md files.
Args: Args:
file_path: Path to the file to load. file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
remove_images: Whether to remove images from the text.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
""" """
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
remove_hyperlinks: bool = True, remove_hyperlinks: bool = True,
remove_images: bool = True, remove_images: bool = True,
encoding: Optional[str] = None, encoding: Optional[str] = None,
autodetect_encoding: bool = True, autodetect_encoding: bool = True,
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
...@@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader): ...@@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader):
self._encoding = encoding self._encoding = encoding
self._autodetect_encoding = autodetect_encoding self._autodetect_encoding = autodetect_encoding
def load(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load from file path."""
tups = self.parse_tups(self._file_path) tups = self.parse_tups(self._file_path)
documents = [] documents = []
for header, value in tups: for header, value in tups:
...@@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader): ...@@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader):
if self._autodetect_encoding: if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath) detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings: for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try: try:
with open(filepath, encoding=encoding.encoding) as f: with open(filepath, encoding=encoding.encoding) as f:
content = f.read() content = f.read()
......
...@@ -4,9 +4,10 @@ from typing import Any, Optional ...@@ -4,9 +4,10 @@ from typing import Any, Optional
import requests import requests
from flask import current_app from flask import current_app
from langchain.document_loaders.base import BaseLoader from flask_login import current_user
from langchain.schema import Document from langchain.schema import Document
from core.rag.extractor.extractor_base import BaseExtractor
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document as DocumentModel from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding from models.source import DataSourceBinding
...@@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" ...@@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
class NotionLoader(BaseLoader): class NotionExtractor(BaseExtractor):
def __init__( def __init__(
self, self,
notion_access_token: str,
notion_workspace_id: str, notion_workspace_id: str,
notion_obj_id: str, notion_obj_id: str,
notion_page_type: str, notion_page_type: str,
document_model: Optional[DocumentModel] = None document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None
): ):
self._notion_access_token = None
self._document_model = document_model self._document_model = document_model
self._notion_workspace_id = notion_workspace_id self._notion_workspace_id = notion_workspace_id
self._notion_obj_id = notion_obj_id self._notion_obj_id = notion_obj_id
self._notion_page_type = notion_page_type self._notion_page_type = notion_page_type
self._notion_access_token = notion_access_token if notion_access_token:
self._notion_access_token = notion_access_token
if not self._notion_access_token: else:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
if integration_token is None: self._notion_workspace_id)
raise ValueError( if not self._notion_access_token:
"Must specify `integration_token` or set environment " integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
"variable `NOTION_INTEGRATION_TOKEN`." if integration_token is None:
) raise ValueError(
"Must specify `integration_token` or set environment "
self._notion_access_token = integration_token "variable `NOTION_INTEGRATION_TOKEN`."
)
@classmethod
def from_document(cls, document_model: DocumentModel): self._notion_access_token = integration_token
data_source_info = document_model.data_source_info_dict
if not data_source_info or 'notion_page_id' not in data_source_info \ def extract(self) -> list[Document]:
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
notion_workspace_id = data_source_info['notion_workspace_id']
notion_obj_id = data_source_info['notion_page_id']
notion_page_type = data_source_info['type']
notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
return cls(
notion_access_token=notion_access_token,
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_obj_id,
notion_page_type=notion_page_type,
document_model=document_model
)
def load(self) -> list[Document]:
self.update_last_edited_time( self.update_last_edited_time(
self._document_model self._document_model
) )
......
"""Abstract interface for document loader implementations."""
from collections.abc import Iterator
from typing import Optional
from core.rag.extractor.blod.blod import Blob
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
class PdfExtractor(BaseExtractor):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
file_cache_key: Optional[str] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._file_cache_key:
try:
text = storage.load(self._file_cache_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = list(self.load())
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents
def load(
self,
) -> Iterator[Document]:
"""Lazy load given path as pages."""
blob = Blob.from_path(self._file_path)
yield from self.parse(blob)
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import pypdfium2
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
try:
for page_number, page in enumerate(pdf_reader):
text_page = page.get_textpage()
content = text_page.get_text_range()
text_page.close()
page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata)
finally:
pdf_reader.close()
"""Abstract interface for document loader implementations."""
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
from core.rag.models.document import Document
class TextExtractor(BaseExtractor):
"""Load text files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def extract(self) -> list[Document]:
"""Load from file path."""
text = ""
try:
with open(self._file_path, encoding=self._encoding) as f:
text = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, encoding=encoding.encoding) as f:
text = f.read()
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
except Exception as e:
raise RuntimeError(f"Error loading {self._file_path}") from e
metadata = {"source": self._file_path}
return [Document(page_content=text, metadata=metadata)]
import logging
import os
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredWordExtractor(BaseExtractor):
"""Loader that uses unstructured to load word documents.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
def extract(self) -> list[Document]:
from unstructured.__version__ import __version__ as __unstructured_version__
from unstructured.file_utils.filetype import FileType, detect_filetype
unstructured_version = tuple(
[int(x) for x in __unstructured_version__.split(".")]
)
# check the file extension
try:
import magic # noqa: F401
is_doc = detect_filetype(self._file_path) == FileType.DOC
except ImportError:
_, extension = os.path.splitext(str(self._file_path))
is_doc = extension == ".doc"
if is_doc and unstructured_version < (0, 4, 11):
raise ValueError(
f"You are on unstructured version {__unstructured_version__}. "
"Partitioning .doc files is only supported in unstructured>=0.4.11. "
"Please upgrade the unstructured package and try again."
)
if is_doc:
from unstructured.partition.doc import partition_doc
elements = partition_doc(filename=self._file_path)
else:
from unstructured.partition.docx import partition_docx
elements = partition_docx(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents
...@@ -2,13 +2,14 @@ import base64 ...@@ -2,13 +2,14 @@ import base64
import logging import logging
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredEmailLoader(BaseLoader): class UnstructuredEmailExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
Args: Args:
file_path: Path to the file to load. file_path: Path to the file to load.
...@@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader): ...@@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.email import partition_email from unstructured.partition.email import partition_email
elements = partition_email(filename=self._file_path, api_url=self._api_url) elements = partition_email(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredMarkdownLoader(BaseLoader): class UnstructuredMarkdownExtractor(BaseExtractor):
"""Load md files. """Load md files.
...@@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader): ...@@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.md import partition_md from unstructured.partition.md import partition_md
elements = partition_md(filename=self._file_path, api_url=self._api_url) elements = partition_md(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredMsgLoader(BaseLoader): class UnstructuredMsgExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader): ...@@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.msg import partition_msg from unstructured.partition.msg import partition_msg
elements = partition_msg(filename=self._file_path, api_url=self._api_url) elements = partition_msg(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredPPTLoader(BaseLoader):
class UnstructuredPPTExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader): ...@@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader):
""" """
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
api_url: str api_url: str
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.ppt import partition_ppt from unstructured.partition.ppt import partition_ppt
elements = partition_ppt(filename=self._file_path, api_url=self._api_url) elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredPPTXLoader(BaseLoader):
class UnstructuredPPTXExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader): ...@@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader):
""" """
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
api_url: str api_url: str
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.pptx import partition_pptx from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path, api_url=self._api_url) elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredTextLoader(BaseLoader): class UnstructuredTextExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader): ...@@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.text import partition_text from unstructured.partition.text import partition_text
elements = partition_text(filename=self._file_path, api_url=self._api_url) elements = partition_text(filename=self._file_path, api_url=self._api_url)
......
import logging import logging
from langchain.document_loaders.base import BaseLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.schema import Document from core.rag.models.document import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstructuredXmlLoader(BaseLoader): class UnstructuredXmlExtractor(BaseExtractor):
"""Load msg files. """Load msg files.
...@@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader): ...@@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader):
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url
def load(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.partition.xml import partition_xml from unstructured.partition.xml import partition_xml
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
......
"""Abstract interface for document loader implementations."""
import os
import tempfile
from urllib.parse import urlparse
import requests
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class WordExtractor(BaseExtractor):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(self, file_path: str):
"""Initialize with file path."""
self.file_path = file_path
if "~" in self.file_path:
self.file_path = os.path.expanduser(self.file_path)
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
r = requests.get(self.file_path)
if r.status_code != 200:
raise ValueError(
"Check the url of your file; returned status code %s"
% r.status_code
)
self.web_path = self.file_path
self.temp_file = tempfile.NamedTemporaryFile()
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):
raise ValueError("File path %s is not a valid file or url" % self.file_path)
def __del__(self) -> None:
if hasattr(self, "temp_file"):
self.temp_file.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""
import docx2txt
return [
Document(
page_content=docx2txt.process(self.file_path),
metadata={"source": self.file_path},
)
]
@staticmethod
def _is_valid_url(url: str) -> bool:
"""Check if the url is valid."""
parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment