Unverified Commit 07aab5e8 authored by Jyong's avatar Jyong Committed by GitHub

Feat/add milvus vector db (#1302)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 875dfbbf
...@@ -63,6 +63,13 @@ WEAVIATE_BATCH_SIZE=100 ...@@ -63,6 +63,13 @@ WEAVIATE_BATCH_SIZE=100
QDRANT_URL=http://localhost:6333 QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456 QDRANT_API_KEY=difyai123456
# Milvus configuration
MILVUS_HOST=127.0.0.1
MILVUS_PORT=19530
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# Mail configuration, support: resend # Mail configuration, support: resend
MAIL_TYPE= MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai> MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
......
...@@ -135,6 +135,14 @@ class Config: ...@@ -135,6 +135,14 @@ class Config:
self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# milvus setting
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = get_env('MILVUS_PORT')
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
# cors settings # cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
......
This diff is collapsed.
...@@ -9,30 +9,46 @@ from core.index.base import BaseIndex ...@@ -9,30 +9,46 @@ from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class MilvusConfig(BaseModel): class MilvusConfig(BaseModel):
endpoint: str host: str
port: int
user: str user: str
password: str password: str
secure: bool
batch_size: int = 100 batch_size: int = 100
@root_validator() @root_validator()
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values['endpoint']: if not values['host']:
raise ValueError("config MILVUS_ENDPOINT is required") raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['secure']:
raise ValueError("config MILVUS_SECURE is required")
if not values['user']: if not values['user']:
raise ValueError("config MILVUS_USER is required") raise ValueError("config MILVUS_USER is required")
if not values['password']: if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required") raise ValueError("config MILVUS_PASSWORD is required")
return values return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVectorIndex(BaseVectorIndex): class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings): def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings) super().__init__(dataset, embeddings)
self._client = self._init_client(config) self._client_config = config
def get_type(self) -> str: def get_type(self) -> str:
return 'milvus' return 'milvus'
...@@ -49,7 +65,6 @@ class MilvusVectorIndex(BaseVectorIndex): ...@@ -49,7 +65,6 @@ class MilvusVectorIndex(BaseVectorIndex):
dataset_id = dataset.id dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict: def to_index_struct(self) -> dict:
return { return {
"type": self.get_type(), "type": self.get_type(),
...@@ -58,26 +73,29 @@ class MilvusVectorIndex(BaseVectorIndex): ...@@ -58,26 +73,29 @@ class MilvusVectorIndex(BaseVectorIndex):
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts) uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents( index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
self._vector_store = MilvusVectorStore.from_documents(
texts, texts,
self._embeddings, self._embeddings,
client=self._client, collection_name=self.get_index_name(self.dataset),
index_name=self.get_index_name(self.dataset), connection_args=self._client_config.to_milvus_params(),
uuids=uuids, index_params=index_params
by_text=False
) )
return self return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts) uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents( self._vector_store = MilvusVectorStore.from_documents(
texts, texts,
self._embeddings, self._embeddings,
client=self._client, collection_name=collection_name,
index_name=collection_name, ids=uuids,
uuids=uuids, content_payload_key='page_content'
by_text=False
) )
return self return self
...@@ -86,42 +104,53 @@ class MilvusVectorIndex(BaseVectorIndex): ...@@ -86,42 +104,53 @@ class MilvusVectorIndex(BaseVectorIndex):
"""Only for created index.""" """Only for created index."""
if self._vector_store: if self._vector_store:
return self._vector_store return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id'] attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id'] return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),
return WeaviateVectorStore( embedding_function=self._embeddings,
client=self._client, connection_args=self._client_config.to_milvus_params()
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
) )
def _get_vector_store_class(self) -> type: def _get_vector_store_class(self) -> type:
return MilvusVectorStore return MilvusVectorStore
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_document_id(document_id)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_doc_ids(doc_ids)
vector_store.del_texts({ vector_store.del_texts({
"operator": "Equal", 'filter': f' id in {ids}'
"path": ["document_id"],
"valueText": document_id
}) })
def _is_origin(self): def delete_by_group_id(self, group_id: str) -> None:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] vector_store = self._get_vector_store()
if not class_prefix.endswith('_Node'): vector_store = cast(self._get_vector_store_class(), vector_store)
# original class_prefix
return True vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return False from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
...@@ -47,6 +47,20 @@ class VectorIndex: ...@@ -47,6 +47,20 @@ class VectorIndex:
), ),
embeddings=embeddings embeddings=embeddings
) )
elif vector_type == "milvus":
from core.index.vector_index.milvus_vector_index import MilvusVectorIndex, MilvusConfig
return MilvusVectorIndex(
dataset=dataset,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
),
embeddings=embeddings
)
else: else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
......
from langchain.vectorstores import Milvus from core.index.vector_index.milvus import Milvus
class MilvusVectorStore(Milvus): class MilvusVectorStore(Milvus):
...@@ -6,33 +6,41 @@ class MilvusVectorStore(Milvus): ...@@ -6,33 +6,41 @@ class MilvusVectorStore(Milvus):
if not where_filter: if not where_filter:
raise ValueError('where_filter must not be empty') raise ValueError('where_filter must not be empty')
self._client.batch.delete_objects( self.col.delete(where_filter.get('filter'))
class_name=self._index_name,
where=where_filter,
output='minimal'
)
def del_text(self, uuid: str) -> None: def del_text(self, uuid: str) -> None:
self._client.data_object.delete( expr = f"id == {uuid}"
uuid, self.col.delete(expr)
class_name=self._index_name
)
def text_exists(self, uuid: str) -> bool: def text_exists(self, uuid: str) -> bool:
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ result = self.col.query(
"path": ["doc_id"], expr=f'metadata["doc_id"] == "{uuid}"',
"operator": "Equal", output_fields=["id"]
"valueText": uuid, )
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][self._index_name] return len(result) > 0
if len(entries) == 0:
return False
return True def get_ids_by_document_id(self, document_id: str):
result = self.col.query(
expr=f'metadata["document_id"] == "{document_id}"',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def get_ids_by_doc_ids(self, doc_ids: list):
result = self.col.query(
expr=f'metadata["doc_id"] in {doc_ids}',
output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def delete(self): def delete(self):
self._client.schema.delete_class(self._index_name) from pymilvus import utility
utility.drop_collection(self.collection_name, None, self.alias)
...@@ -52,4 +52,5 @@ pandas==1.5.3 ...@@ -52,4 +52,5 @@ pandas==1.5.3
xinference==0.5.2 xinference==0.5.2
safetensors==0.3.2 safetensors==0.3.2
zhipuai==1.0.7 zhipuai==1.0.7
werkzeug==2.3.7 werkzeug==2.3.7
\ No newline at end of file pymilvus==2.3.0
\ No newline at end of file
version: '3.5'
services:
etcd:
container_name: milvus-etcd
image: quay.io/coreos/etcd:v3.5.5
environment:
- ETCD_AUTO_COMPACTION_MODE=revision
- ETCD_AUTO_COMPACTION_RETENTION=1000
- ETCD_QUOTA_BACKEND_BYTES=4294967296
- ETCD_SNAPSHOT_COUNT=50000
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
healthcheck:
test: ["CMD", "etcdctl", "endpoint", "health"]
interval: 30s
timeout: 20s
retries: 3
minio:
container_name: milvus-minio
image: minio/minio:RELEASE.2023-03-20T20-16-18Z
environment:
MINIO_ACCESS_KEY: minioadmin
MINIO_SECRET_KEY: minioadmin
ports:
- "9001:9001"
- "9000:9000"
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
command: minio server /minio_data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 30s
timeout: 20s
retries: 3
standalone:
container_name: milvus-standalone
image: milvusdb/milvus:v2.3.1
command: ["milvus", "run", "standalone"]
environment:
ETCD_ENDPOINTS: etcd:2379
MINIO_ADDRESS: minio:9000
common.security.authorizationEnabled: true
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
interval: 30s
start_period: 90s
timeout: 20s
retries: 3
ports:
- "19530:19530"
- "9091:9091"
depends_on:
- "etcd"
- "minio"
networks:
default:
name: milvus
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment