Commit 5eddcaae authored by John Wang's avatar John Wang

feat: add milvus vector support

breaking change: MUST upgrade python to 3.11.x
parent 85663d99
......@@ -14,7 +14,7 @@ You need to install and configure the following dependencies on your machine to
- [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x
- [Python](https://www.python.org/) version 3.11.x
## Local development
......
......@@ -12,7 +12,7 @@
- [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) 版本 8.x.x 或 [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) 版本 3.10.x
- [Python](https://www.python.org/) 版本 3.11.x
## 本地开发
......
......@@ -14,7 +14,7 @@
- [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) バージョン 8.x.x もしくは [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) バージョン 3.10.x
- [Python](https://www.python.org/) バージョン 3.11.x
## ローカル開発
......
......@@ -65,7 +65,7 @@ SESSION_REDIS_PORT=6379
SESSION_REDIS_PASSWORD=difyai123456
SESSION_REDIS_DB=2
# Vector database configuration, support: weaviate, qdrant, pinecone
# Vector database configuration, support: weaviate, qdrant, pinecone, milvus
VECTOR_STORE=weaviate
# Weaviate configuration
......@@ -81,6 +81,13 @@ QDRANT_API_KEY=your-qdrant-api-key
PINECONE_API_KEY=
PINECONE_ENVIRONMENT=us-east4-gcp
# Milvus configuration
MILVUS_HOST=localhost
MILVUS_PORT=19530
MILVUS_USER=
MILVUS_PASSWORD=
MILVUS_USE_SECURE=
# Sentry configuration
SENTRY_DSN=
......
......@@ -43,6 +43,9 @@ DEFAULTS = {
'SENTRY_TRACES_SAMPLE_RATE': 1.0,
'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
'WEAVIATE_GRPC_ENABLED': 'True',
'MILVUS_USER': '',
'MILVUS_PASSWORD': '',
'MILVUS_USE_SECURE': 'False',
'CELERY_BACKEND': 'database',
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO',
......@@ -147,6 +150,13 @@ class Config:
self.PINECONE_API_KEY = get_env('PINECONE_API_KEY')
self.PINECONE_ENVIRONMENT = get_env('PINECONE_ENVIRONMENT')
# milvus settings
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = int(get_env('MILVUS_PORT'))
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_USE_SECURE = get_bool_env('MILVUS_USE_SECURE')
# cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL)
......
......@@ -71,3 +71,9 @@ class InvalidMetadataError(BaseHTTPException):
error_code = 'invalid_metadata'
description = "The metadata content is incorrect. Please check and verify."
code = 400
class CurrentVectorStoreNotSupportHitTestingError(BaseHTTPException):
error_code = 'current_vector_store_not_support_hit_testing'
description = "The current vector store does not support hit testing."
code = 400
......@@ -8,12 +8,14 @@ import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError, \
CurrentVectorStoreNotSupportHitTestingError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import TimestampField
from services.dataset_service import DatasetService
from services.errors.dataset import VectorStoreNotSupportHitTestingError
from services.hit_testing_service import HitTestingService
document_fields = {
......@@ -101,6 +103,8 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except VectorStoreNotSupportHitTestingError:
raise CurrentVectorStoreNotSupportHitTestingError()
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
......
import logging
from typing import List, Optional
from pymilvus import MilvusException
from llama_index import GPTMilvusIndex, ServiceContext, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import MilvusIndexDict
from llama_index.data_structs.node_v2 import DocumentRelationship, Node
from llama_index.vector_stores import MilvusVectorStore
from llama_index.vector_stores.types import VectorStoreQueryResult, VectorStoreQuery, VectorStoreQueryMode
from core.embedding.openai_embedding import OpenAIEmbedding
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
class MilvusVectorStoreClient(BaseVectorStoreClient):
def __init__(self, host: str, port: int,
user: str = "", password: str = "", use_secure: bool = False):
self._host = host
self._port = port
self._user = user
self._password = password
self._use_secure = use_secure
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = MilvusIndexDict()
collection_name = config.get('collection_name')
if not collection_name:
raise Exception("collection_name cannot be None.")
return GPTMilvusEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=MilvusEnhanceVectorStore(
collection_name=collection_name,
host=self._host,
port=self._port,
user=self._user,
password=self._password,
use_secure=self._use_secure,
overwrite=False,
tokenizer=OpenAIEmbedding().get_text_embedding
)
)
def to_index_config(self, dataset_id: str) -> dict:
index_id = "vector_" + dataset_id.replace("-", "_")
return {"collection_name": index_id}
class GPTMilvusEnhanceIndex(GPTMilvusIndex, BaseGPTVectorStoreIndex):
pass
class MilvusEnhanceVectorStore(MilvusVectorStore, EnhanceVectorStore):
# Vector field is not supported in current release.
def delete_node(self, node_id: str):
try:
# Begin by querying for the primary keys to delete
self.collection.delete(f"id in [\"{node_id}\"]")
logging.debug(f"Successfully deleted embedding with node_id: {node_id}")
except MilvusException as e:
logging.debug(f"Unsuccessfully deleted embedding with node_id: {node_id}")
raise e
def exists_by_node_id(self, node_id: str) -> bool:
try:
rst = self.collection.query(
expr=f"id in [\"{node_id}\"]",
offset=0,
limit=1,
consistency_level="Strong"
)
if len(rst) > 0:
return True
except MilvusException as e:
raise e
return False
......@@ -3,11 +3,12 @@ from llama_index import ServiceContext, GPTVectorStoreIndex
from requests import ReadTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from core.vector_store.milvus_vector_store_client import MilvusVectorStoreClient
from core.vector_store.pinecone_vector_store_client import PineconeVectorStoreClient
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant', 'pinecone']
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant', 'pinecone', 'milvus']
class VectorStore:
......@@ -41,6 +42,14 @@ class VectorStore:
api_key=app.config['PINECONE_API_KEY'],
environment=app.config['PINECONE_ENVIRONMENT']
)
elif self._vector_store == 'milvus':
self._client = MilvusVectorStoreClient(
host=app.config['MILVUS_HOST'],
port=app.config['MILVUS_PORT'],
user=app.config['MILVUS_USER'],
password=app.config['MILVUS_PASSWORD'],
use_secure=app.config['MILVUS_USE_SECURE'],
)
app.extensions['vector_store'] = self
......@@ -65,3 +74,10 @@ class VectorStore:
raise Exception("Vector store client is not initialized.")
return self._client
def support_hit_testing(self):
if isinstance(self._client, MilvusVectorStoreClient):
# search API not return vector data
return False
return True
\ No newline at end of file
......@@ -30,4 +30,5 @@ sentry-sdk[flask]~=1.21.1
jieba==0.42.1
celery==5.2.7
redis~=4.5.4
pypdf==3.8.1
\ No newline at end of file
pypdf==3.8.1
pymilvus==2.2.9
\ No newline at end of file
......@@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError
class DatasetNameDuplicateError(BaseServiceError):
pass
class VectorStoreNotSupportHitTestingError(BaseServiceError):
pass
......@@ -11,14 +11,19 @@ from sklearn.manifold import TSNE
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.vector_index import VectorIndex
from extensions.ext_database import db
from extensions.ext_vector_store import vector_store
from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.errors.dataset import VectorStoreNotSupportHitTestingError
from services.errors.index import IndexNotInitializedError
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
if not vector_store.support_hit_testing():
raise VectorStoreNotSupportHitTestingError()
index = VectorIndex(dataset=dataset).query_index
if not index:
......@@ -67,12 +72,22 @@ class HitTestingService:
]
for node in nodes:
embeddings.append(node.node.embedding)
if node.node.embedding:
embeddings.append(node.node.embedding)
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
query_position = tsne_position_data.pop(0)
if not tsne_position_data:
return {
"query": {
"content": query_bundle.query_str,
"tsne_position": query_position,
},
"records": []
}
i = 0
records = []
for node in nodes:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment