Commit 52e6f458 authored by jyong's avatar jyong

add rag test

parent 703aefbd
......@@ -3,6 +3,8 @@ import datetime
import uuid
from typing import Optional
import pytest
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
......@@ -16,98 +18,98 @@ from models.dataset import Dataset
from models.model import UploadFile
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self) -> list[Document]:
file_detail = UploadFile(
tenant_id='test',
storage_type='local',
key='test.txt',
name='test.txt',
size=1024,
extension='txt',
mime_type='text/plain',
created_by='test',
created_at=datetime.datetime.utcnow(),
used=True,
used_by='d48632d7-c972-484a-8ed9-262490919c79',
used_at=datetime.datetime.utcnow()
)
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model='text_model'
)
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=False)
@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
def extract() -> list[Document]:
file_detail = UploadFile(
tenant_id='test',
storage_type='local',
key='test.txt',
name='test.txt',
size=1024,
extension='txt',
mime_type='text/plain',
created_by='test',
created_at=datetime.datetime.utcnow(),
used=True,
used_by='d48632d7-c972-484a-8ed9-262490919c79',
used_at=datetime.datetime.utcnow()
)
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model='text_model'
)
return text_docs
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=True)
assert isinstance(text_docs, list)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=kwargs.get('embedding_model_instance'))
all_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=kwargs.get('embedding_model_instance'))
all_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith("。"):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
return all_documents
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith("。"):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
keyword = Keyword(dataset)
keyword.create(documents)
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
keyword = Keyword(dataset)
keyword.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
else:
vector.delete()
if with_keywords:
keyword = Keyword(dataset)
if node_ids:
keyword.delete_by_ids(node_ids)
else:
keyword.delete()
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
else:
vector.delete()
if with_keywords:
keyword = Keyword(dataset)
if node_ids:
keyword.delete_by_ids(node_ids)
else:
keyword.delete()
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment