Commit 3622691f authored by jyong's avatar jyong

add qdrant test

parent 52e6f458
...@@ -2,25 +2,26 @@ ...@@ -2,25 +2,26 @@
import datetime import datetime
import uuid import uuid
from typing import Optional from typing import Optional
import pytest import pytest
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import Document
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset
from models.model import UploadFile from models.model import UploadFile
@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True) @pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
def extract() -> list[Document]: def extract():
index_processor = IndexProcessorFactory('text_model').init_index_processor()
# extract
file_detail = UploadFile( file_detail = UploadFile(
tenant_id='test', tenant_id='test',
storage_type='local', storage_type='local',
...@@ -44,45 +45,30 @@ def extract() -> list[Document]: ...@@ -44,45 +45,30 @@ def extract() -> list[Document]:
text_docs = ExtractProcessor.extract(extract_setting=extract_setting, text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=True) is_automatic=True)
assert isinstance(text_docs, list) assert isinstance(text_docs, list)
return text_docs for text_doc in text_docs:
assert isinstance(text_doc, Document)
def transform(self, documents: list[Document], **kwargs) -> list[Document]: # transform
# Split the text documents into nodes. process_rule = {
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), 'pre_processing_rules': [
embedding_model_instance=kwargs.get('embedding_model_instance')) {'id': 'remove_extra_spaces', 'enabled': True},
all_documents = [] {'id': 'remove_urls_emails', 'enabled': False}
],
'segmentation': {
'delimiter': '\n',
'max_tokens': 500,
'chunk_overlap': 50
}
}
documents = index_processor.transform(text_docs, embedding_model_instance=None,
process_rule=process_rule)
for document in documents: for document in documents:
# document clean assert isinstance(document, Document)
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip(): # load
doc_id = str(uuid.uuid4()) vector = Vector(dataset)
hash = helper.generate_text_hash(document_node.page_content) vector.create(documents)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
# delete Spliter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith("。"):
page_content = page_content[1:]
else:
page_content = page_content
document_node.page_content = page_content
split_documents.append(document_node)
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
keyword = Keyword(dataset)
keyword.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
...@@ -98,6 +84,7 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: ...@@ -98,6 +84,7 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords:
else: else:
keyword.delete() keyword.delete()
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]: score_threshold: float, reranking_model: dict) -> list[Document]:
# Set search parameters. # Set search parameters.
......
import os
from typing import Generator
import pytest import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent,
SystemPromptMessage, TextPromptMessageContent,
UserPromptMessage)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
def test_validate_credentials(setup_google_mock): from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVector, QdrantConfig
model = GoogleLargeLanguageModel() from core.rag.models.document import Document
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials( @pytest.mark.parametrize('setup_qdrant_mock',
model='gemini-pro', [['get_collections', 'recreate_collection',
credentials={ 'create_payload_index', 'upsert', 'scroll',
'google_api_key': 'invalid_key' 'search']],
} indirect=True)
def test_qdrant(setup_qdrant_mock):
document = Document(page_content="test", metadata={"test": "test"})
qdrant_vector = QdrantVector(
collection_name="test",
group_id='test',
config=QdrantConfig(
endpoint="http://localhost:6333",
api_key="test",
root_path="test",
timeout=10
) )
model.validate_credentials(
model='gemini-pro',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
}
)
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_invoke_model(setup_google_mock):
model = GoogleLargeLanguageModel()
response = model.invoke(
model='gemini-pro',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Give me your worst dad joke or i will unplug you'
),
AssistantPromptMessage(
content='Why did the scarecrow win an award? Because he was outstanding in his field!'
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="ok something snarkier pls"
),
TextPromptMessageContent(
data="i may still unplug you"
)]
)
],
model_parameters={
'temperature': 0.5,
'top_p': 1.0,
'max_tokens_to_sample': 2048
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_invoke_stream_model(setup_google_mock):
model = GoogleLargeLanguageModel()
response = model.invoke(
model='gemini-pro',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Give me your worst dad joke or i will unplug you'
),
AssistantPromptMessage(
content='Why did the scarecrow win an award? Because he was outstanding in his field!'
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="ok something snarkier pls"
),
TextPromptMessageContent(
data="i may still unplug you"
)]
)
],
model_parameters={
'temperature': 0.2,
'top_k': 5,
'max_tokens_to_sample': 2048
},
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_invoke_chat_model_with_vision(setup_google_mock):
model = GoogleLargeLanguageModel()
result = model.invoke(
model='gemini-pro-vision',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="what do you see?"
),
ImagePromptMessageContent(
data=''
)
]
)
],
model_parameters={
'temperature': 0.3,
'top_p': 0.2,
'top_k': 3,
'max_tokens': 100
},
stream=False,
user="abc-123"
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
model = GoogleLargeLanguageModel()
result = model.invoke(
model='gemini-pro-vision',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.'
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="what do you see?"
),
ImagePromptMessageContent(
data=''
)
]
),
AssistantPromptMessage(
content="I see a blue letter 'D' with a gradient from light blue to dark blue."
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="what about now?"
),
ImagePromptMessageContent(
data=''
)
]
)
],
model_parameters={
'temperature': 0.3,
'top_p': 0.2,
'top_k': 3,
'max_tokens': 100
},
stream=False,
user="abc-123"
)
print(f"resultz: {result.message.content}")
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
def test_get_num_tokens():
model = GoogleLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='gemini-pro',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
) )
# create
qdrant_vector.create(texts=[document], embeddings=[[0.23333 for _ in range(233)]])
# search
result = qdrant_vector.search_by_vector(query_vector=[0.23333 for _ in range(233)])
for item in result:
assert isinstance(item, Document)
# delete
qdrant_vector.delete()
assert num_tokens > 0 # The exact number of tokens may vary based on the model's tokenization
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment