Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
4f1b4b73
Commit
4f1b4b73
authored
Jul 21, 2023
by
jyong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
delete test file
parent
ba441908
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
949 deletions
+0
-949
milvus.py
api/core/index/vector_index/milvus.py
+0
-812
milvus_vector_index.py
api/core/index/vector_index/milvus_vector_index.py
+0
-137
No files found.
api/core/index/vector_index/milvus.py
deleted
100644 → 0
View file @
ba441908
"""Wrapper around the Milvus vector database."""
from
__future__
import
annotations
import
logging
from
typing
import
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
uuid
import
uuid4
import
numpy
as
np
from
numpy
import
average
from
sentence_transformers
import
SentenceTransformer
from
langchain.docstore.document
import
Document
from
langchain.embeddings.base
import
Embeddings
from
langchain.vectorstores.base
import
VectorStore
from
langchain.vectorstores.utils
import
maximal_marginal_relevance
from
sklearn
import
preprocessing
logger
=
logging
.
getLogger
(
__name__
)
DEFAULT_MILVUS_CONNECTION
=
{
"host"
:
"localhost"
,
"port"
:
"19530"
,
"user"
:
""
,
"password"
:
""
,
"secure"
:
False
,
}
class
Milvus
(
VectorStore
):
"""Wrapper around the Milvus vector database."""
def
__init__
(
self
,
embedding_function
:
Embeddings
,
collection_name
:
str
=
"LangChainCollection"
,
connection_args
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
consistency_level
:
str
=
"Session"
,
index_params
:
Optional
[
dict
]
=
None
,
search_params
:
Optional
[
dict
]
=
None
,
drop_old
:
Optional
[
bool
]
=
False
,
):
"""Initialize wrapper around the milvus vector database.
In order to use this you need to have `pymilvus` installed and a
running Milvus/Zilliz Cloud instance.
See the following documentation for how to run a Milvus instance:
https://milvus.io/docs/install_standalone-docker.md
If looking for a hosted Milvus, take a looka this documentation:
https://zilliz.com/cloud
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
The connection args used for this class comes in the form of a dict,
here are a few of the options:
address (str): The actual address of Milvus
instance. Example address: "localhost:19530"
uri (str): The uri of Milvus instance. Example uri:
"http://randomwebsite:19530",
"tcp:foobarsite:19530",
"https://ok.s3.south.com:19530".
host (str): The host of Milvus instance. Default at "localhost",
PyMilvus will fill in the default host if only port is provided.
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
will fill in the default port if only host is provided.
user (str): Use which user to connect to Milvus instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
secure (bool): Default is false. If set to true, tls will be enabled.
client_key_path (str): If use tls two-way authentication, need to
write the client.key path.
client_pem_path (str): If use tls two-way authentication, need to
write the client.pem path.
ca_pem_path (str): If use tls two-way authentication, need to write
the ca.pem path.
server_pem_path (str): If use tls one-way authentication, need to
write the server.pem path.
server_name (str): If use tls, need to write the common name.
Args:
embedding_function (Embeddings): Function used to embed the text.
collection_name (str): Which Milvus collection to use. Defaults to
"LangChainCollection".
connection_args (Optional[dict[str, any]]): The arguments for connection to
Milvus/Zilliz instance. Defaults to DEFAULT_MILVUS_CONNECTION.
consistency_level (str): The consistency level to use for a collection.
Defaults to "Session".
index_params (Optional[dict]): Which index params to use. Defaults to
HNSW/AUTOINDEX depending on service.
search_params (Optional[dict]): Which search params to use. Defaults to
default of index.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
"""
try
:
from
pymilvus
import
Collection
,
utility
except
ImportError
:
raise
ValueError
(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
# Default search params when one is not provided.
self
.
default_search_params
=
{
"IVF_FLAT"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"nprobe"
:
10
}},
"IVF_SQ8"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"nprobe"
:
10
}},
"IVF_PQ"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"nprobe"
:
10
}},
"HNSW"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"ef"
:
10
}},
"RHNSW_FLAT"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"ef"
:
10
}},
"RHNSW_SQ"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"ef"
:
10
}},
"RHNSW_PQ"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"ef"
:
10
}},
"IVF_HNSW"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"nprobe"
:
10
,
"ef"
:
10
}},
"ANNOY"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"search_k"
:
10
}},
"AUTOINDEX"
:
{
"metric_type"
:
"L2"
,
"params"
:
{}},
}
self
.
embedding_func
=
embedding_function
self
.
collection_name
=
collection_name
self
.
index_params
=
index_params
self
.
search_params
=
search_params
self
.
consistency_level
=
consistency_level
# In order for a collection to be compatible, pk needs to be auto'id and int
self
.
_primary_field
=
"pk"
# In order for compatiblility, the text field will need to be called "text"
self
.
_text_field
=
"text"
# In order for compatbility, the vector field needs to be called "vector"
self
.
_vector_field
=
"vector"
self
.
fields
:
list
[
str
]
=
[]
# Create the connection to the server
if
connection_args
is
None
:
connection_args
=
DEFAULT_MILVUS_CONNECTION
self
.
alias
=
self
.
_create_connection_alias
(
connection_args
)
self
.
col
:
Optional
[
Collection
]
=
None
# Grab the existing colection if it exists
if
utility
.
has_collection
(
self
.
collection_name
,
using
=
self
.
alias
):
self
.
col
=
Collection
(
self
.
collection_name
,
using
=
self
.
alias
,
)
# If need to drop old, drop it
if
drop_old
and
isinstance
(
self
.
col
,
Collection
):
self
.
col
.
drop
()
self
.
col
=
None
# Initialize the vector store
self
.
_init
()
def
_create_connection_alias
(
self
,
connection_args
:
dict
)
->
str
:
"""Create the connection to the Milvus server."""
from
pymilvus
import
MilvusException
,
connections
# Grab the connection arguments that are used for checking existing connection
host
:
str
=
connection_args
.
get
(
"host"
,
None
)
port
:
Union
[
str
,
int
]
=
connection_args
.
get
(
"port"
,
None
)
address
:
str
=
connection_args
.
get
(
"address"
,
None
)
uri
:
str
=
connection_args
.
get
(
"uri"
,
None
)
user
=
connection_args
.
get
(
"user"
,
None
)
# Order of use is host/port, uri, address
if
host
is
not
None
and
port
is
not
None
:
given_address
=
str
(
host
)
+
":"
+
str
(
port
)
elif
uri
is
not
None
:
given_address
=
uri
.
split
(
"https://"
)[
1
]
elif
address
is
not
None
:
given_address
=
address
else
:
given_address
=
None
logger
.
debug
(
"Missing standard address type for reuse atttempt"
)
# User defaults to empty string when getting connection info
if
user
is
not
None
:
tmp_user
=
user
else
:
tmp_user
=
""
# If a valid address was given, then check if a connection exists
if
given_address
is
not
None
:
for
con
in
connections
.
list_connections
():
addr
=
connections
.
get_connection_addr
(
con
[
0
])
if
(
con
[
1
]
and
(
"address"
in
addr
)
and
(
addr
[
"address"
]
==
given_address
)
and
(
"user"
in
addr
)
and
(
addr
[
"user"
]
==
tmp_user
)
):
logger
.
debug
(
"Using previous connection:
%
s"
,
con
[
0
])
return
con
[
0
]
# Generate a new connection if one doesnt exist
alias
=
uuid4
()
.
hex
try
:
connections
.
connect
(
alias
=
alias
,
**
connection_args
)
logger
.
debug
(
"Created new connection using:
%
s"
,
alias
)
return
alias
except
MilvusException
as
e
:
logger
.
error
(
"Failed to create new connection using:
%
s"
,
alias
)
raise
e
def
_init
(
self
,
embeddings
:
Optional
[
list
]
=
None
,
metadatas
:
Optional
[
list
[
dict
]]
=
None
)
->
None
:
if
embeddings
is
not
None
:
self
.
_create_collection
(
embeddings
,
metadatas
)
self
.
_extract_fields
()
self
.
_create_index
()
self
.
_create_search_params
()
self
.
_load
()
def
_create_collection
(
self
,
embeddings
:
list
,
metadatas
:
Optional
[
list
[
dict
]]
=
None
)
->
None
:
from
pymilvus
import
(
Collection
,
CollectionSchema
,
DataType
,
FieldSchema
,
MilvusException
,
)
from
pymilvus.orm.types
import
infer_dtype_bydata
# Determine embedding dim
dim
=
len
(
embeddings
[
0
])
fields
=
[]
# Determine metadata schema
if
metadatas
:
# Create FieldSchema for each entry in metadata.
for
key
,
value
in
metadatas
[
0
]
.
items
():
# Infer the corresponding datatype of the metadata
dtype
=
infer_dtype_bydata
(
value
)
# Datatype isnt compatible
if
dtype
==
DataType
.
UNKNOWN
or
dtype
==
DataType
.
NONE
:
logger
.
error
(
"Failure to create collection, unrecognized dtype for key:
%
s"
,
key
,
)
raise
ValueError
(
f
"Unrecognized datatype for {key}."
)
# Dataype is a string/varchar equivalent
elif
dtype
==
DataType
.
VARCHAR
:
fields
.
append
(
FieldSchema
(
key
,
DataType
.
VARCHAR
,
max_length
=
65_535
))
else
:
fields
.
append
(
FieldSchema
(
key
,
dtype
))
# Create the text field
fields
.
append
(
FieldSchema
(
self
.
_text_field
,
DataType
.
VARCHAR
,
max_length
=
65_535
)
)
# Create the primary key field
fields
.
append
(
FieldSchema
(
self
.
_primary_field
,
DataType
.
INT64
,
is_primary
=
True
,
auto_id
=
True
)
)
# Create the vector field, supports binary or float vectors
fields
.
append
(
FieldSchema
(
self
.
_vector_field
,
infer_dtype_bydata
(
embeddings
[
0
]),
dim
=
dim
)
)
# Create the schema for the collection
schema
=
CollectionSchema
(
fields
)
# Create the collection
try
:
self
.
col
=
Collection
(
name
=
self
.
collection_name
,
schema
=
schema
,
consistency_level
=
self
.
consistency_level
,
using
=
self
.
alias
,
)
except
MilvusException
as
e
:
logger
.
error
(
"Failed to create collection:
%
s error:
%
s"
,
self
.
collection_name
,
e
)
raise
e
def
_extract_fields
(
self
)
->
None
:
"""Grab the existing fields from the Collection"""
from
pymilvus
import
Collection
if
isinstance
(
self
.
col
,
Collection
):
schema
=
self
.
col
.
schema
for
x
in
schema
.
fields
:
self
.
fields
.
append
(
x
.
name
)
# Since primary field is auto-id, no need to track it
self
.
fields
.
remove
(
self
.
_primary_field
)
def
_get_index
(
self
)
->
Optional
[
dict
[
str
,
Any
]]:
"""Return the vector index information if it exists"""
from
pymilvus
import
Collection
if
isinstance
(
self
.
col
,
Collection
):
for
x
in
self
.
col
.
indexes
:
if
x
.
field_name
==
self
.
_vector_field
:
return
x
.
to_dict
()
return
None
def
_create_index
(
self
)
->
None
:
"""Create a index on the collection"""
from
pymilvus
import
Collection
,
MilvusException
if
isinstance
(
self
.
col
,
Collection
)
and
self
.
_get_index
()
is
None
:
try
:
# If no index params, use a default HNSW based one
if
self
.
index_params
is
None
:
self
.
index_params
=
{
"metric_type"
:
"L2"
,
"index_type"
:
"HNSW"
,
"params"
:
{
"M"
:
8
,
"efConstruction"
:
64
},
}
try
:
self
.
col
.
create_index
(
self
.
_vector_field
,
index_params
=
self
.
index_params
,
using
=
self
.
alias
,
)
# If default did not work, most likely on Zilliz Cloud
except
MilvusException
:
# Use AUTOINDEX based index
self
.
index_params
=
{
"metric_type"
:
"L2"
,
"index_type"
:
"AUTOINDEX"
,
"params"
:
{},
}
self
.
col
.
create_index
(
self
.
_vector_field
,
index_params
=
self
.
index_params
,
using
=
self
.
alias
,
)
logger
.
debug
(
"Successfully created an index on collection:
%
s"
,
self
.
collection_name
,
)
except
MilvusException
as
e
:
logger
.
error
(
"Failed to create an index on collection:
%
s"
,
self
.
collection_name
)
raise
e
def
_create_search_params
(
self
)
->
None
:
"""Generate search params based on the current index type"""
from
pymilvus
import
Collection
if
isinstance
(
self
.
col
,
Collection
)
and
self
.
search_params
is
None
:
index
=
self
.
_get_index
()
if
index
is
not
None
:
index_type
:
str
=
index
[
"index_param"
][
"index_type"
]
metric_type
:
str
=
index
[
"index_param"
][
"metric_type"
]
self
.
search_params
=
self
.
default_search_params
[
index_type
]
self
.
search_params
[
"metric_type"
]
=
metric_type
def
_load
(
self
)
->
None
:
"""Load the collection if available."""
from
pymilvus
import
Collection
if
isinstance
(
self
.
col
,
Collection
)
and
self
.
_get_index
()
is
not
None
:
self
.
col
.
load
()
def
add_texts
(
self
,
texts
:
Iterable
[
str
],
metadatas
:
Optional
[
List
[
dict
]]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
batch_size
:
int
=
1000
,
**
kwargs
:
Any
,
)
->
List
[
str
]:
"""Insert text data into Milvus.
Inserting data when the collection has not be made yet will result
in creating a new Collection. The data of the first entity decides
the schema of the new collection, the dim is extracted from the first
embedding and the columns are decided by the first metadata dict.
Metada keys will need to be present for all inserted values. At
the moment there is no None equivalent in Milvus.
Args:
texts (Iterable[str]): The texts to embed, it is assumed
that they all fit in memory.
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
the texts. Defaults to None.
timeout (Optional[int]): Timeout for each batch insert. Defaults
to None.
batch_size (int, optional): Batch size to use for insertion.
Defaults to 1000.
Raises:
MilvusException: Failure to add texts
Returns:
List[str]: The resulting keys for each inserted element.
"""
from
pymilvus
import
Collection
,
MilvusException
texts
=
list
(
texts
)
try
:
embeddings
=
self
.
embedding_test
(
texts
)
#embeddings = self.embedding_func.embed_documents(texts)
except
NotImplementedError
:
embeddings
=
[
self
.
embedding_func
.
embed_query
(
x
)
for
x
in
texts
]
if
len
(
embeddings
)
==
0
:
logger
.
debug
(
"Nothing to insert, skipping."
)
return
[]
# If the collection hasnt been initialized yet, perform all steps to do so
if
not
isinstance
(
self
.
col
,
Collection
):
self
.
_init
(
embeddings
,
metadatas
)
# Dict to hold all insert columns
insert_dict
:
dict
[
str
,
list
]
=
{
self
.
_text_field
:
texts
,
self
.
_vector_field
:
embeddings
,
}
# Collect the metadata into the insert dict.
if
metadatas
is
not
None
:
for
d
in
metadatas
:
for
key
,
value
in
d
.
items
():
if
key
in
self
.
fields
:
insert_dict
.
setdefault
(
key
,
[])
.
append
(
value
)
# Total insert count
vectors
:
list
=
insert_dict
[
self
.
_vector_field
]
total_count
=
len
(
vectors
)
pks
:
list
[
str
]
=
[]
assert
isinstance
(
self
.
col
,
Collection
)
for
i
in
range
(
0
,
total_count
,
batch_size
):
# Grab end index
end
=
min
(
i
+
batch_size
,
total_count
)
# Convert dict to list of lists batch for insertion
insert_list
=
[
insert_dict
[
x
][
i
:
end
]
for
x
in
self
.
fields
]
# Insert into the collection.
try
:
res
:
Collection
res
=
self
.
col
.
insert
(
insert_list
,
timeout
=
timeout
,
**
kwargs
)
pks
.
extend
(
res
.
primary_keys
)
except
MilvusException
as
e
:
logger
.
error
(
"Failed to insert batch starting at entity:
%
s/
%
s"
,
i
,
total_count
)
raise
e
return
pks
def
embedding_test
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
model
=
SentenceTransformer
(
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
)
embeddings
=
model
.
encode
(
texts
)
new_embeddings
=
[]
for
i
in
range
(
len
(
texts
)):
average
=
embeddings
[
i
]
new_embeddings
.
append
((
average
/
np
.
linalg
.
norm
(
average
))
.
tolist
())
return
new_embeddings
def
similarity_search
(
self
,
query
:
str
,
k
:
int
=
4
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Perform a similarity search against the query string.
Args:
query (str): The text to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
res
=
self
.
similarity_search_with_score
(
query
=
query
,
k
=
k
,
param
=
param
,
expr
=
expr
,
timeout
=
timeout
,
**
kwargs
)
return
[
doc
for
doc
,
_
in
res
]
def
similarity_search_by_vector
(
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Perform a similarity search against the query string.
Args:
embedding (List[float]): The embedding vector to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
res
=
self
.
similarity_search_with_score_by_vector
(
embedding
=
embedding
,
k
=
k
,
param
=
param
,
expr
=
expr
,
timeout
=
timeout
,
**
kwargs
)
return
[
doc
for
doc
,
_
in
res
]
def
similarity_search_with_score
(
self
,
query
:
str
,
k
:
int
=
4
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
Document
,
float
]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
query (str): The text being searched.
k (int, optional): The amount of results ot return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[float], List[Tuple[Document, any, any]]:
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
# Embed the query text.
embeddings
=
self
.
embedding_test
([
query
])
embeddings
=
self
.
embedding_func
.
embed_query
(
query
)
res
=
self
.
similarity_search_with_score_by_vector
(
embedding
=
embeddings
[
0
],
k
=
k
,
param
=
param
,
expr
=
expr
,
timeout
=
timeout
,
**
kwargs
)
return
res
def
normalize_embedding
(
self
,
embedding
):
return
preprocessing
.
normalize
(
embedding
,
norm
=
'l2'
)
def
similarity_search_with_score_by_vector
(
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
Document
,
float
]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The amount of results ot return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
if
param
is
None
:
param
=
self
.
search_params
# Determine result metadata fields.
output_fields
=
self
.
fields
[:]
output_fields
.
remove
(
self
.
_vector_field
)
# Perform the search.
res
=
self
.
col
.
search
(
data
=
[
embedding
],
anns_field
=
self
.
_vector_field
,
param
=
param
,
limit
=
k
,
expr
=
expr
,
output_fields
=
output_fields
,
timeout
=
timeout
,
**
kwargs
,
)
# Organize results.
ret
=
[]
for
result
in
res
[
0
]:
meta
=
{
x
:
result
.
entity
.
get
(
x
)
for
x
in
output_fields
}
doc
=
Document
(
page_content
=
meta
.
pop
(
self
.
_text_field
),
metadata
=
meta
)
pair
=
(
doc
,
result
.
score
)
ret
.
append
(
pair
)
return
ret
def
max_marginal_relevance_search
(
self
,
query
:
str
,
k
:
int
=
4
,
fetch_k
:
int
=
20
,
lambda_mult
:
float
=
0.5
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Perform a search and return results that are reordered by MMR.
Args:
query (str): The text being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
embedding
=
self
.
embedding_func
.
embed_query
(
query
)
return
self
.
max_marginal_relevance_search_by_vector
(
embedding
=
embedding
,
k
=
k
,
fetch_k
=
fetch_k
,
lambda_mult
=
lambda_mult
,
param
=
param
,
expr
=
expr
,
timeout
=
timeout
,
**
kwargs
,
)
def
max_marginal_relevance_search_by_vector
(
self
,
embedding
:
list
[
float
],
k
:
int
=
4
,
fetch_k
:
int
=
20
,
lambda_mult
:
float
=
0.5
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Perform a search and return results that are reordered by MMR.
Args:
embedding (str): The embedding vector being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
if
param
is
None
:
param
=
self
.
search_params
# Determine result metadata fields.
output_fields
=
self
.
fields
[:]
output_fields
.
remove
(
self
.
_vector_field
)
# Perform the search.
res
=
self
.
col
.
search
(
data
=
[
embedding
],
anns_field
=
self
.
_vector_field
,
param
=
param
,
limit
=
fetch_k
,
expr
=
expr
,
output_fields
=
output_fields
,
timeout
=
timeout
,
**
kwargs
,
)
# Organize results.
ids
=
[]
documents
=
[]
scores
=
[]
for
result
in
res
[
0
]:
meta
=
{
x
:
result
.
entity
.
get
(
x
)
for
x
in
output_fields
}
doc
=
Document
(
page_content
=
meta
.
pop
(
self
.
_text_field
),
metadata
=
meta
)
documents
.
append
(
doc
)
scores
.
append
(
result
.
score
)
ids
.
append
(
result
.
id
)
vectors
=
self
.
col
.
query
(
expr
=
f
"{self._primary_field} in {ids}"
,
output_fields
=
[
self
.
_primary_field
,
self
.
_vector_field
],
timeout
=
timeout
,
)
# Reorganize the results from query to match search order.
vectors
=
{
x
[
self
.
_primary_field
]:
x
[
self
.
_vector_field
]
for
x
in
vectors
}
ordered_result_embeddings
=
[
vectors
[
x
]
for
x
in
ids
]
# Get the new order of results.
new_ordering
=
maximal_marginal_relevance
(
np
.
array
(
embedding
),
ordered_result_embeddings
,
k
=
k
,
lambda_mult
=
lambda_mult
)
# Reorder the values and return.
ret
=
[]
for
x
in
new_ordering
:
# Function can return -1 index
if
x
==
-
1
:
break
else
:
ret
.
append
(
documents
[
x
])
return
ret
@
classmethod
def
from_texts
(
cls
,
texts
:
List
[
str
],
embedding
:
Embeddings
,
metadatas
:
Optional
[
List
[
dict
]]
=
None
,
collection_name
:
str
=
"LangChainCollection"
,
connection_args
:
dict
[
str
,
Any
]
=
DEFAULT_MILVUS_CONNECTION
,
consistency_level
:
str
=
"Session"
,
index_params
:
Optional
[
dict
]
=
None
,
search_params
:
Optional
[
dict
]
=
None
,
drop_old
:
bool
=
False
,
**
kwargs
:
Any
,
)
->
Milvus
:
"""Create a Milvus collection, indexes it with HNSW, and insert data.
Args:
texts (List[str]): Text data.
embedding (Embeddings): Embedding function.
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
Defaults to None.
collection_name (str, optional): Collection name to use. Defaults to
"LangChainCollection".
connection_args (dict[str, Any], optional): Connection args to use. Defaults
to DEFAULT_MILVUS_CONNECTION.
consistency_level (str, optional): Which consistency level to use. Defaults
to "Session".
index_params (Optional[dict], optional): Which index_params to use. Defaults
to None.
search_params (Optional[dict], optional): Which search params to use.
Defaults to None.
drop_old (Optional[bool], optional): Whether to drop the collection with
that name if it exists. Defaults to False.
Returns:
Milvus: Milvus Vector Store
"""
vector_db
=
cls
(
embedding_function
=
embedding
,
collection_name
=
collection_name
,
connection_args
=
connection_args
,
consistency_level
=
consistency_level
,
index_params
=
index_params
,
search_params
=
search_params
,
drop_old
=
drop_old
,
**
kwargs
,
)
vector_db
.
add_texts
(
texts
=
texts
,
metadatas
=
metadatas
)
return
vector_db
api/core/index/vector_index/milvus_vector_index.py
deleted
100644 → 0
View file @
ba441908
from
typing
import
Optional
,
cast
import
requests
import
weaviate
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
,
root_validator
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.weaviate_vector_store
import
WeaviateVectorStore
from
models.dataset
import
Dataset
class
MilvusConfig
(
BaseModel
):
uri
:
str
username
:
Optional
[
str
]
password
:
Optional
[
str
]
batch_size
:
int
=
100
@
root_validator
()
def
validate_config
(
cls
,
values
:
dict
)
->
dict
:
if
not
values
[
'uri'
]:
raise
ValueError
(
"config Milvus uri is required"
)
return
values
class
MilvusVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
MilvusConfig
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
,
embeddings
)
self
.
_client
=
self
.
_init_client
(
config
)
def
_init_client
(
self
,
config
:
MilvusConfig
)
->
weaviate
.
Client
:
auth_config
=
weaviate
.
auth
.
AuthApiKey
(
api_key
=
config
.
api_key
)
weaviate
.
connect
.
connection
.
has_grpc
=
False
try
:
client
=
weaviate
.
Client
(
url
=
config
.
endpoint
,
auth_client_secret
=
auth_config
,
timeout_config
=
(
5
,
60
),
startup_period
=
None
)
except
requests
.
exceptions
.
ConnectionError
:
raise
ConnectionError
(
"Vector database connection error"
)
client
.
batch
.
configure
(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size
=
config
.
batch_size
,
# dynamically update the `batch_size` based on import speed
dynamic
=
True
,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries
=
3
,
)
return
client
def
get_type
(
self
)
->
str
:
return
'weaviate'
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
class_prefix
+=
'_Node'
return
class_prefix
dataset_id
=
dataset
.
id
return
"Vector_index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
+
'_Node'
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
dataset
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
dataset
),
uuids
=
uuids
,
by_text
=
False
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
return
self
.
_vector_store
attributes
=
[
'doc_id'
,
'dataset_id'
,
'document_id'
]
if
self
.
_is_origin
():
attributes
=
[
'doc_id'
]
return
WeaviateVectorStore
(
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
dataset
),
text_key
=
'text'
,
embedding
=
self
.
_embeddings
,
attributes
=
attributes
,
by_text
=
False
)
def
_get_vector_store_class
(
self
)
->
type
:
return
WeaviateVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
.
del_texts
({
"operator"
:
"Equal"
,
"path"
:
[
"document_id"
],
"valueText"
:
document_id
})
def
_is_origin
(
self
):
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
return
True
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment