Commit 5eddcaae authored by John Wang's avatar John Wang

feat: add milvus vector support

breaking change: MUST upgrade python to 3.11.x
parent 85663d99
...@@ -14,7 +14,7 @@ You need to install and configure the following dependencies on your machine to ...@@ -14,7 +14,7 @@ You need to install and configure the following dependencies on your machine to
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x - [Python](https://www.python.org/) version 3.11.x
## Local development ## Local development
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) 版本 8.x.x 或 [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) 版本 8.x.x 或 [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) 版本 3.10.x - [Python](https://www.python.org/) 版本 3.11.x
## 本地开发 ## 本地开发
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) バージョン 8.x.x もしくは [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) バージョン 8.x.x もしくは [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) バージョン 3.10.x - [Python](https://www.python.org/) バージョン 3.11.x
## ローカル開発 ## ローカル開発
......
...@@ -65,7 +65,7 @@ SESSION_REDIS_PORT=6379 ...@@ -65,7 +65,7 @@ SESSION_REDIS_PORT=6379
SESSION_REDIS_PASSWORD=difyai123456 SESSION_REDIS_PASSWORD=difyai123456
SESSION_REDIS_DB=2 SESSION_REDIS_DB=2
# Vector database configuration, support: weaviate, qdrant, pinecone # Vector database configuration, support: weaviate, qdrant, pinecone, milvus
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Weaviate configuration # Weaviate configuration
...@@ -81,6 +81,13 @@ QDRANT_API_KEY=your-qdrant-api-key ...@@ -81,6 +81,13 @@ QDRANT_API_KEY=your-qdrant-api-key
PINECONE_API_KEY= PINECONE_API_KEY=
PINECONE_ENVIRONMENT=us-east4-gcp PINECONE_ENVIRONMENT=us-east4-gcp
# Milvus configuration
MILVUS_HOST=localhost
MILVUS_PORT=19530
MILVUS_USER=
MILVUS_PASSWORD=
MILVUS_USE_SECURE=
# Sentry configuration # Sentry configuration
SENTRY_DSN= SENTRY_DSN=
......
...@@ -43,6 +43,9 @@ DEFAULTS = { ...@@ -43,6 +43,9 @@ DEFAULTS = {
'SENTRY_TRACES_SAMPLE_RATE': 1.0, 'SENTRY_TRACES_SAMPLE_RATE': 1.0,
'SENTRY_PROFILES_SAMPLE_RATE': 1.0, 'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
'WEAVIATE_GRPC_ENABLED': 'True', 'WEAVIATE_GRPC_ENABLED': 'True',
'MILVUS_USER': '',
'MILVUS_PASSWORD': '',
'MILVUS_USE_SECURE': 'False',
'CELERY_BACKEND': 'database', 'CELERY_BACKEND': 'database',
'PDF_PREVIEW': 'True', 'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
...@@ -147,6 +150,13 @@ class Config: ...@@ -147,6 +150,13 @@ class Config:
self.PINECONE_API_KEY = get_env('PINECONE_API_KEY') self.PINECONE_API_KEY = get_env('PINECONE_API_KEY')
self.PINECONE_ENVIRONMENT = get_env('PINECONE_ENVIRONMENT') self.PINECONE_ENVIRONMENT = get_env('PINECONE_ENVIRONMENT')
# milvus settings
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = int(get_env('MILVUS_PORT'))
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_USE_SECURE = get_bool_env('MILVUS_USE_SECURE')
# cors settings # cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL) 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL)
......
...@@ -71,3 +71,9 @@ class InvalidMetadataError(BaseHTTPException): ...@@ -71,3 +71,9 @@ class InvalidMetadataError(BaseHTTPException):
error_code = 'invalid_metadata' error_code = 'invalid_metadata'
description = "The metadata content is incorrect. Please check and verify." description = "The metadata content is incorrect. Please check and verify."
code = 400 code = 400
class CurrentVectorStoreNotSupportHitTestingError(BaseHTTPException):
error_code = 'current_vector_store_not_support_hit_testing'
description = "The current vector store does not support hit testing."
code = 400
...@@ -8,12 +8,14 @@ import services ...@@ -8,12 +8,14 @@ import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \ from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError ProviderModelCurrentlyNotSupportError
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError, \
CurrentVectorStoreNotSupportHitTestingError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import TimestampField from libs.helper import TimestampField
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.errors.dataset import VectorStoreNotSupportHitTestingError
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
document_fields = { document_fields = {
...@@ -101,6 +103,8 @@ class HitTestingApi(Resource): ...@@ -101,6 +103,8 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except VectorStoreNotSupportHitTestingError:
raise CurrentVectorStoreNotSupportHitTestingError()
except Exception as e: except Exception as e:
logging.exception("Hit testing failed.") logging.exception("Hit testing failed.")
raise InternalServerError(str(e)) raise InternalServerError(str(e))
......
import logging
from typing import List, Optional
from pymilvus import MilvusException
from llama_index import GPTMilvusIndex, ServiceContext, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import MilvusIndexDict
from llama_index.data_structs.node_v2 import DocumentRelationship, Node
from llama_index.vector_stores import MilvusVectorStore
from llama_index.vector_stores.types import VectorStoreQueryResult, VectorStoreQuery, VectorStoreQueryMode
from core.embedding.openai_embedding import OpenAIEmbedding
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
class MilvusVectorStoreClient(BaseVectorStoreClient):
def __init__(self, host: str, port: int,
user: str = "", password: str = "", use_secure: bool = False):
self._host = host
self._port = port
self._user = user
self._password = password
self._use_secure = use_secure
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = MilvusIndexDict()
collection_name = config.get('collection_name')
if not collection_name:
raise Exception("collection_name cannot be None.")
return GPTMilvusEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=MilvusEnhanceVectorStore(
collection_name=collection_name,
host=self._host,
port=self._port,
user=self._user,
password=self._password,
use_secure=self._use_secure,
overwrite=False,
tokenizer=OpenAIEmbedding().get_text_embedding
)
)
def to_index_config(self, dataset_id: str) -> dict:
index_id = "vector_" + dataset_id.replace("-", "_")
return {"collection_name": index_id}
class GPTMilvusEnhanceIndex(GPTMilvusIndex, BaseGPTVectorStoreIndex):
pass
class MilvusEnhanceVectorStore(MilvusVectorStore, EnhanceVectorStore):
# Vector field is not supported in current release.
def delete_node(self, node_id: str):
try:
# Begin by querying for the primary keys to delete
self.collection.delete(f"id in [\"{node_id}\"]")
logging.debug(f"Successfully deleted embedding with node_id: {node_id}")
except MilvusException as e:
logging.debug(f"Unsuccessfully deleted embedding with node_id: {node_id}")
raise e
def exists_by_node_id(self, node_id: str) -> bool:
try:
rst = self.collection.query(
expr=f"id in [\"{node_id}\"]",
offset=0,
limit=1,
consistency_level="Strong"
)
if len(rst) > 0:
return True
except MilvusException as e:
raise e
return False
...@@ -3,11 +3,12 @@ from llama_index import ServiceContext, GPTVectorStoreIndex ...@@ -3,11 +3,12 @@ from llama_index import ServiceContext, GPTVectorStoreIndex
from requests import ReadTimeout from requests import ReadTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt from tenacity import retry, retry_if_exception_type, stop_after_attempt
from core.vector_store.milvus_vector_store_client import MilvusVectorStoreClient
from core.vector_store.pinecone_vector_store_client import PineconeVectorStoreClient from core.vector_store.pinecone_vector_store_client import PineconeVectorStoreClient
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant', 'pinecone'] SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant', 'pinecone', 'milvus']
class VectorStore: class VectorStore:
...@@ -41,6 +42,14 @@ class VectorStore: ...@@ -41,6 +42,14 @@ class VectorStore:
api_key=app.config['PINECONE_API_KEY'], api_key=app.config['PINECONE_API_KEY'],
environment=app.config['PINECONE_ENVIRONMENT'] environment=app.config['PINECONE_ENVIRONMENT']
) )
elif self._vector_store == 'milvus':
self._client = MilvusVectorStoreClient(
host=app.config['MILVUS_HOST'],
port=app.config['MILVUS_PORT'],
user=app.config['MILVUS_USER'],
password=app.config['MILVUS_PASSWORD'],
use_secure=app.config['MILVUS_USE_SECURE'],
)
app.extensions['vector_store'] = self app.extensions['vector_store'] = self
...@@ -65,3 +74,10 @@ class VectorStore: ...@@ -65,3 +74,10 @@ class VectorStore:
raise Exception("Vector store client is not initialized.") raise Exception("Vector store client is not initialized.")
return self._client return self._client
def support_hit_testing(self):
if isinstance(self._client, MilvusVectorStoreClient):
# search API not return vector data
return False
return True
\ No newline at end of file
...@@ -30,4 +30,5 @@ sentry-sdk[flask]~=1.21.1 ...@@ -30,4 +30,5 @@ sentry-sdk[flask]~=1.21.1
jieba==0.42.1 jieba==0.42.1
celery==5.2.7 celery==5.2.7
redis~=4.5.4 redis~=4.5.4
pypdf==3.8.1 pypdf==3.8.1
\ No newline at end of file pymilvus==2.2.9
\ No newline at end of file
...@@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError ...@@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError
class DatasetNameDuplicateError(BaseServiceError): class DatasetNameDuplicateError(BaseServiceError):
pass pass
class VectorStoreNotSupportHitTestingError(BaseServiceError):
pass
...@@ -11,14 +11,19 @@ from sklearn.manifold import TSNE ...@@ -11,14 +11,19 @@ from sklearn.manifold import TSNE
from core.docstore.empty_docstore import EmptyDocumentStore from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.vector_index import VectorIndex from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_vector_store import vector_store
from models.account import Account from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.errors.dataset import VectorStoreNotSupportHitTestingError
from services.errors.index import IndexNotInitializedError from services.errors.index import IndexNotInitializedError
class HitTestingService: class HitTestingService:
@classmethod @classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
if not vector_store.support_hit_testing():
raise VectorStoreNotSupportHitTestingError()
index = VectorIndex(dataset=dataset).query_index index = VectorIndex(dataset=dataset).query_index
if not index: if not index:
...@@ -67,12 +72,22 @@ class HitTestingService: ...@@ -67,12 +72,22 @@ class HitTestingService:
] ]
for node in nodes: for node in nodes:
embeddings.append(node.node.embedding) if node.node.embedding:
embeddings.append(node.node.embedding)
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings) tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
query_position = tsne_position_data.pop(0) query_position = tsne_position_data.pop(0)
if not tsne_position_data:
return {
"query": {
"content": query_bundle.query_str,
"tsne_position": query_position,
},
"records": []
}
i = 0 i = 0
records = [] records = []
for node in nodes: for node in nodes:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment