Commit 52e6f458 authored by jyong's avatar jyong

add rag test

parent 703aefbd
...@@ -3,6 +3,8 @@ import datetime ...@@ -3,6 +3,8 @@ import datetime
import uuid import uuid
from typing import Optional from typing import Optional
import pytest
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
...@@ -16,9 +18,9 @@ from models.dataset import Dataset ...@@ -16,9 +18,9 @@ from models.dataset import Dataset
from models.model import UploadFile from models.model import UploadFile
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self) -> list[Document]: @pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
def extract() -> list[Document]:
file_detail = UploadFile( file_detail = UploadFile(
tenant_id='test', tenant_id='test',
storage_type='local', storage_type='local',
...@@ -40,11 +42,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): ...@@ -40,11 +42,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
) )
text_docs = ExtractProcessor.extract(extract_setting=extract_setting, text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=False) is_automatic=True)
assert isinstance(text_docs, list)
return text_docs return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes. # Split the text documents into nodes.
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=kwargs.get('embedding_model_instance')) embedding_model_instance=kwargs.get('embedding_model_instance'))
...@@ -74,7 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): ...@@ -74,7 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents) all_documents.extend(split_documents)
return all_documents return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset) vector = Vector(dataset)
vector.create(documents) vector.create(documents)
...@@ -82,7 +84,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): ...@@ -82,7 +84,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword = Keyword(dataset) keyword = Keyword(dataset)
keyword.create(documents) keyword.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset) vector = Vector(dataset)
if node_ids: if node_ids:
...@@ -96,7 +98,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): ...@@ -96,7 +98,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
else: else:
keyword.delete() keyword.delete()
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]: score_threshold: float, reranking_model: dict) -> list[Document]:
# Set search parameters. # Set search parameters.
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment