Commit 703aefbd authored by jyong's avatar jyong

add rag test

parent cc84d077
from ctypes import Union from ctypes import Union
from typing import List, Optional, Tuple from typing import List
from qdrant_client.conversions import common_types as types
class MockMilvusClass(object): class MockMilvusClass(object):
@staticmethod @staticmethod
def get_collections() -> types.CollectionsResponse: def insert() -> List[Union[str, int]]:
collections_response = types.CollectionsResponse( result = [447829498067199697]
collections=["test"] return result
)
return collections_response
@staticmethod
def recreate_collection() -> bool:
return True
@staticmethod
def create_payload_index() -> types.UpdateResult:
update_result = types.UpdateResult(
updated=1
)
return update_result
@staticmethod @staticmethod
def upsert() -> types.UpdateResult: def delete() -> List[Union[str, int]]:
update_result = types.UpdateResult( result = [447829498067199697]
updated=1 return result
)
return update_result
@staticmethod @staticmethod
def insert() -> List[Union[str, int]]: def search() -> List[dict]:
result = ['d48632d7-c972-484a-8ed9-262490919c79'] result = [
{
'id': 447829498067199697,
'distance': 0.8776655793190002,
'entity': {
'page_content': 'Dify is a company that provides a platform for the development of AI models.',
'metadata':
{
'doc_id': '327d1cb8-15ce-4934-bede-936a13c19ace',
'doc_hash': '7ee3cf010e606bb768c3bca7b1397ff651fd008ef10e56a646c537d2c8afb319',
'document_id': '6c4619dd-2169-4879-b05a-b8937c98c80c',
'dataset_id': 'a2f4f4eb-75eb-4432-8c5f-788100533454'
}
}
}
]
return result return result
@staticmethod @staticmethod
def delete() -> List[Union[str, int]]: def query() -> List[dict]:
result = ['d48632d7-c972-484a-8ed9-262490919c79'] result = [
{
'id': 447829498067199697,
'distance': 0.8776655793190002,
'entity': {
'page_content': 'Dify is a company that provides a platform for the development of AI models.',
'metadata':
{
'doc_id': '327d1cb8-15ce-4934-bede-936a13c19ace',
'doc_hash': '7ee3cf010e606bb768c3bca7b1397ff651fd008ef10e56a646c537d2c8afb319',
'document_id': '6c4619dd-2169-4879-b05a-b8937c98c80c',
'dataset_id': 'a2f4f4eb-75eb-4432-8c5f-788100533454'
}
}
}
]
return result return result
@staticmethod @staticmethod
def scroll() -> Tuple[List[types.Record], Optional[types.PointId]]: def create_collection_with_schema():
pass
record = types.Record(
id='d48632d7-c972-484a-8ed9-262490919c79',
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
vector=[0.23333 for _ in range(233)]
)
return [record], 'd48632d7-c972-484a-8ed9-262490919c79'
@staticmethod @staticmethod
def search() -> List[types.ScoredPoint]: def has_collection() -> bool:
result = types.ScoredPoint( return True
id='d48632d7-c972-484a-8ed9-262490919c79',
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
vision=999,
vector=[0.23333 for _ in range(233)],
score=0.99
)
return [result]
...@@ -27,18 +27,18 @@ def mock_milvus(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections ...@@ -27,18 +27,18 @@ def mock_milvus(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections
if "connect" in methods: if "connect" in methods:
monkeypatch.setattr(Connections, "connect", MockMilvusClass.delete()) monkeypatch.setattr(Connections, "connect", MockMilvusClass.delete())
if "get_collections" in methods: if "has_collection" in methods:
monkeypatch.setattr(utility, "has_collection", MockMilvusClass.get_collections()) monkeypatch.setattr(utility, "has_collection", MockMilvusClass.has_collection())
if "insert" in methods: if "insert" in methods:
monkeypatch.setattr(MilvusClient, "insert", MockMilvusClass.insert()) monkeypatch.setattr(MilvusClient, "insert", MockMilvusClass.insert())
if "create_payload_index" in methods: if "query" in methods:
monkeypatch.setattr(QdrantClient, "create_payload_index", MockMilvusClass.create_payload_index()) monkeypatch.setattr(MilvusClient, "query", MockMilvusClass.query())
if "upsert" in methods: if "delete" in methods:
monkeypatch.setattr(QdrantClient, "upsert", MockMilvusClass.upsert()) monkeypatch.setattr(MilvusClient, "delete", MockMilvusClass.delete())
if "scroll" in methods:
monkeypatch.setattr(QdrantClient, "scroll", MockMilvusClass.scroll())
if "search" in methods: if "search" in methods:
monkeypatch.setattr(QdrantClient, "search", MockMilvusClass.search()) monkeypatch.setattr(MilvusClient, "search", MockMilvusClass.search())
if "create_collection_with_schema" in methods:
monkeypatch.setattr(MilvusClient, "create_collection_with_schema", MockMilvusClass.create_collection_with_schema())
return unpatch return unpatch
......
"""test paragraph index processor."""
import datetime
import uuid
from typing import Optional
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from libs import helper
from models.dataset import Dataset
from models.model import UploadFile
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self) -> list[Document]:
file_detail = UploadFile(
tenant_id='test',
storage_type='local',
key='test.txt',
name='test.txt',
size=1024,
extension='txt',
mime_type='text/plain',
created_by='test',
created_at=datetime.datetime.utcnow(),
used=True,
used_by='d48632d7-c972-484a-8ed9-262490919c79',
used_at=datetime.datetime.utcnow()
)
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model='text_model'
)
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=False)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=kwargs.get('embedding_model_instance'))
all_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith("。"):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
keyword = Keyword(dataset)
keyword.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
else:
vector.delete()
if with_keywords:
keyword = Keyword(dataset)
if node_ids:
keyword.delete_by_ids(node_ids)
else:
keyword.delete()
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment