Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
4c596272
Commit
4c596272
authored
Jul 12, 2023
by
jyong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix clean dataset task
parent
90b22d8c
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1329 additions
and
20 deletions
+1329
-20
test.py
api/controllers/console/datasets/test.py
+176
-0
test_embedding.py
api/controllers/console/datasets/test_embedding.py
+18
-0
test_query.py
api/controllers/console/datasets/test_query.py
+39
-0
llm_generator.py
api/core/generator/llm_generator.py
+1
-1
milvus.py
api/core/index/vector_index/milvus.py
+812
-0
milvus_vector_index.py
api/core/index/vector_index/milvus_vector_index.py
+137
-0
test-embedding.py
api/core/index/vector_index/test-embedding.py
+123
-0
indexing_runner.py
api/core/indexing_runner.py
+22
-19
prompts.py
api/core/prompt/prompts.py
+1
-0
No files found.
api/controllers/console/datasets/test.py
0 → 100644
View file @
4c596272
import
datetime
import
re
from
os
import
environ
from
uuid
import
uuid4
import
openai
from
langchain.document_loaders
import
WebBaseLoader
,
UnstructuredFileLoader
,
TextLoader
from
langchain.embeddings
import
OpenAIEmbeddings
,
MiniMaxEmbeddings
from
langchain.schema
import
Document
from
langchain.text_splitter
import
CharacterTextSplitter
from
pymilvus
import
connections
,
Collection
from
pymilvus.orm
import
utility
from
core.data_loader.loader.excel
import
ExcelLoader
from
core.generator.llm_generator
import
LLMGenerator
from
core.index.vector_index.milvus
import
Milvus
OPENAI_API_KEY
=
"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"
# example: "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
## Set up environment variables
environ
[
"OPENAI_API_KEY"
]
=
OPENAI_API_KEY
environ
[
"MINIMAX_GROUP_ID"
]
=
"1686736670459291"
environ
[
"MINIMAX_API_KEY"
]
=
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJOYW1lIjoiIiwiU3ViamVjdElEIjoiMTY4NjczNjY3MDQ0NzEyNSIsIlBob25lIjoiTVRVd01UZzBNREU1TlRFPSIsIkdyb3VwSUQiOiIiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiJwYW5wYW5AZGlmeS5haSIsIkNyZWF0ZVRpbWUiOiIiLCJpc3MiOiJtaW5pbWF4In0.i9gRKYmOW3zM8vEcT7lD-Ym-0eE6UUU3vb-gVxpWfSMkdc6ObbRnkP5nYumZJbV9L-yRA00GW6nMWYcWkY3IbDWWFAi-hRmzAtl-orpkz5DxPzjRJbwAPy9snYlqBWYQ4hOQ-53zmA5wgsm0ga5pMpBTN9SCkm7EnBQDEsPEY1m121tuwXe6LhAMjdX0Kic-UI-KTYbDdWGAl6nu8h8lrSHVuEEYA6Lz3VDyJTcYfME-B435vw-x1UXSb5-V-YhMEhIixEO8ezUQXaERq0mErtIQEoZN4r7OeNNGjocsfwiHRiw_EdxbfYUWjpvAytmmekIuL3tfvfhbif-EZc4E5w"
CONVERSATION_PROMPT
=
(
"你是出题人.
\n
"
"用户会发送一段长文本.
\n
请一步一步思考"
'Step1:了解并总结这段文本的主要内容
\n
'
'Step2:这段文本提到了哪些关键信息或概念
\n
'
'Step3:可分解或结合多个信息与概念
\n
'
'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.
\n
'
"按格式回答: Q1:
\n
A1:
\n
Q2:
\n
A2:...
\n
"
)
def
test_milvus
():
def
format_split_text
(
text
):
regex
=
r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
# 匹配Q和A的正则表达式
matches
=
re
.
findall
(
regex
,
text
,
re
.
MULTILINE
)
# 获取所有匹配到的结果
result
=
[]
# 存储最终的结果
for
match
in
matches
:
q
=
match
[
0
]
a
=
match
[
1
]
if
q
and
a
:
# 如果Q和A都存在,就将其添加到结果中
result
.
append
({
"question"
:
q
,
"answer"
:
re
.
sub
(
r"\n\s*"
,
"
\n
"
,
a
.
strip
())
})
return
result
# 84b2202c-c359-46b7-a810-bce50feaa4d1
# Use the WebBaseLoader to load specified web pages into documents
# loader = WebBaseLoader([
# "https://milvus.io/docs/overview.md",
# ])
loader
=
ExcelLoader
(
'/Users/jiangyong/Downloads/xiaoming.xlsx'
)
# loader = TextLoader('/Users/jiangyong/Downloads/all.txt', autodetect_encoding=True)
# loader = UnstructuredFileLoader('/Users/jiangyong/Downloads/douban.xlsx')
docs
=
loader
.
load
()
#
# # Split the documents into smaller chunks
text_splitter
=
CharacterTextSplitter
(
chunk_size
=
1024
,
chunk_overlap
=
0
)
#
docs
=
text_splitter
.
split_documents
(
docs
)
new_docs
=
[]
for
doc
in
docs
:
openai
.
api_key
=
"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"
response
=
openai
.
ChatCompletion
.
create
(
model
=
'gpt-3.5-turbo'
,
messages
=
[
{
'role'
:
'system'
,
'content'
:
CONVERSATION_PROMPT
},
{
'role'
:
'user'
,
'content'
:
doc
.
page_content
}
],
temperature
=
0
,
stream
=
False
,
# this time, we set stream=True
n
=
1
,
top_p
=
1
,
frequency_penalty
=
0
,
presence_penalty
=
0
)
#response = LLMGenerator.generate_qa_document('84b2202c-c359-46b7-a810-bce50feaa4d1', doc.page_content)
results
=
format_split_text
(
response
[
'choices'
][
0
][
'message'
][
'content'
])
print
(
results
)
# for result in results:
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
# new_docs.append(document)
# new_docs.append(doc)
# embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
embeddings
=
MiniMaxEmbeddings
()
# cont = connections.connect(
# alias="default",
# user='username',
# password='password',
# host='localhost',
# port='19530'
# )
# chunk_size = 100
# for i in range(0, len(new_docs), chunk_size):
# # check document is paused
# chunk_documents = new_docs[i:i + chunk_size]
# vector_store = Milvus.from_documents(
# chunk_documents,
# collection_name='jytest5',
# embedding=embeddings,
# connection_args={"uri": 'https://in01-706333b4f51fa0b.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'user': 'db_admin', 'password': 'dify123456!'}
# )
# collection = Collection("jytest4") # Get an existing collection.
# collection.release()
# print(datetime.datetime.utcnow())
# alias = uuid4().hex
# # #connection_args = {"host": 'localhost', "port": '19530'}
# connection_args = {"uri": 'https://in01-91c80c04f4aed06.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'user': 'db_admin', 'password': 'dify123456!'}
# connections.connect(alias=alias, **connection_args)
# connection = Collection(
# 'jytest10',
# using=alias,
# )
# print(datetime.datetime.utcnow())
# # connection.release()
# query = '阿甘正传'
# search_params = {"metric_type": "IP", "params": {"level": 2}}
# docs = Milvus(embedding_function=embeddings, collection_name='jytest4').similarity_search(query)
# docs = Milvus(embedding_function=embeddings, collection_name='jytest', connection_args={"uri": 'https://in01-706333b4f51fa0b.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'user': 'db_admin', 'password': 'dify123456!'}).similarity_search(query)
# docs = Milvus(embedding_function=embeddings, collection_name='jytest10', connection_args={"uri": 'https://in01-91c80c04f4aed06.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'token': '01a3da355f5645fe949b1c6e97339c90b1931b6726094fcac3dd0594ab6312eb4ea314095ca989d7dfc8abfac1092dd1a6d46017', 'db_name':'dify'}).similarity_search(query)
# print(datetime.datetime.utcnow())
# docs = vector_store.similarity_search(query)
# cont = connections.connect(
# alias="default",
# user='username',
# password='password',
# host='localhost',
# port='19530'
# )
# docs = cont.search(query='What is milvus?', search_type='similarity',
# connection_args={"host": 'localhost', "port": '19530'})
# docs = vector_store.similarity_search(query)
# print(docs)
# connections.connect("default",
# uri='https://in01-617651a0cb211be.aws-us-west-2.vectordb.zillizcloud.com:19533',
# user='db_admin',
# password='dify123456!')
#
# # Check if the collection exists
# collection_name = "jytest"
# check_collection = utility.has_collection(collection_name)
# if check_collection:
# drop_result = utility.drop_collection(collection_name)
# print("Success!")
# collection = Collection(name=collection_name)
# collection.
# search_params = {"metric_type": "L2", "params": {"level": 2}}
# results = collection.search('电影排名50',
# anns_field='page_content',
# param=search_params,
# limit=1,
# guarantee_timestamp=1)
# connections.disconnect("default")
api/controllers/console/datasets/test_embedding.py
0 → 100644
View file @
4c596272
import
numpy
as
np
from
numpy
import
average
from
sentence_transformers
import
SentenceTransformer
def
test_embdding
():
sentences
=
[
"My name is john"
]
model
=
SentenceTransformer
(
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
)
embeddings
=
model
.
encode
(
sentences
)
for
embedding
in
embeddings
:
print
(
embedding
)
embedding
=
(
embedding
/
np
.
linalg
.
norm
(
embedding
))
.
tolist
()
print
(
embedding
)
embedding
=
(
embedding
/
np
.
linalg
.
norm
(
embedding
))
.
tolist
()
print
(
embedding
)
print
(
embeddings
)
api/controllers/console/datasets/test_query.py
0 → 100644
View file @
4c596272
import
base64
import
binascii
import
hashlib
import
secrets
from
os
import
environ
import
numpy
as
np
from
langchain.embeddings
import
MiniMaxEmbeddings
from
numpy
import
average
from
sentence_transformers
import
SentenceTransformer
from
core.index.vector_index.milvus
import
Milvus
environ
[
"MINIMAX_GROUP_ID"
]
=
"1686736670459291"
environ
[
"MINIMAX_API_KEY"
]
=
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJOYW1lIjoiIiwiU3ViamVjdElEIjoiMTY4NjczNjY3MDQ0NzEyNSIsIlBob25lIjoiTVRVd01UZzBNREU1TlRFPSIsIkdyb3VwSUQiOiIiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiJwYW5wYW5AZGlmeS5haSIsIkNyZWF0ZVRpbWUiOiIiLCJpc3MiOiJtaW5pbWF4In0.i9gRKYmOW3zM8vEcT7lD-Ym-0eE6UUU3vb-gVxpWfSMkdc6ObbRnkP5nYumZJbV9L-yRA00GW6nMWYcWkY3IbDWWFAi-hRmzAtl-orpkz5DxPzjRJbwAPy9snYlqBWYQ4hOQ-53zmA5wgsm0ga5pMpBTN9SCkm7EnBQDEsPEY1m121tuwXe6LhAMjdX0Kic-UI-KTYbDdWGAl6nu8h8lrSHVuEEYA6Lz3VDyJTcYfME-B435vw-x1UXSb5-V-YhMEhIixEO8ezUQXaERq0mErtIQEoZN4r7OeNNGjocsfwiHRiw_EdxbfYUWjpvAytmmekIuL3tfvfhbif-EZc4E5w"
def
test_query
():
# embeddings = MiniMaxEmbeddings()
# query = '你对这部电影有什么感悟'
# # search_params = {"metric_type": "IP", "params": {"level": 2}}
# # docs = Milvus(embedding_function=embeddings, collection_name='jytest4').similarity_search(query)
# docs = Milvus(embedding_function=embeddings, collection_name='jytest5',
# connection_args={"uri": 'https://in01-706333b4f51fa0b.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'user': 'db_admin', 'password': 'dify123456!'}).similarity_search(query)
# print(docs)
# generate password salt
def
hash_password
(
password_str
,
salt_byte
):
dk
=
hashlib
.
pbkdf2_hmac
(
'sha256'
,
password_str
.
encode
(
'utf-8'
),
salt_byte
,
10000
)
return
binascii
.
hexlify
(
dk
)
salt
=
secrets
.
token_bytes
(
16
)
base64_salt
=
base64
.
b64encode
(
salt
)
.
decode
()
# encrypt password with salt
password_hashed
=
hash_password
(
'dify123456!'
,
salt
)
base64_password_hashed
=
base64
.
b64encode
(
password_hashed
)
.
decode
()
print
(
base64_password_hashed
)
print
(
'*******************'
)
print
(
base64_salt
)
api/core/generator/llm_generator.py
View file @
4c596272
...
@@ -178,7 +178,7 @@ class LLMGenerator:
...
@@ -178,7 +178,7 @@ class LLMGenerator:
llm
:
StreamableOpenAI
=
LLMBuilder
.
to_llm
(
llm
:
StreamableOpenAI
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
model_name
=
'gpt-3.5-turbo'
,
max_tokens
=
100
0
max_tokens
=
100
)
)
if
isinstance
(
llm
,
BaseChatModel
):
if
isinstance
(
llm
,
BaseChatModel
):
...
...
api/core/index/vector_index/milvus.py
0 → 100644
View file @
4c596272
"""Wrapper around the Milvus vector database."""
from
__future__
import
annotations
import
logging
from
typing
import
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
uuid
import
uuid4
import
numpy
as
np
from
numpy
import
average
from
sentence_transformers
import
SentenceTransformer
from
langchain.docstore.document
import
Document
from
langchain.embeddings.base
import
Embeddings
from
langchain.vectorstores.base
import
VectorStore
from
langchain.vectorstores.utils
import
maximal_marginal_relevance
from
sklearn
import
preprocessing
logger
=
logging
.
getLogger
(
__name__
)
DEFAULT_MILVUS_CONNECTION
=
{
"host"
:
"localhost"
,
"port"
:
"19530"
,
"user"
:
""
,
"password"
:
""
,
"secure"
:
False
,
}
class
Milvus
(
VectorStore
):
"""Wrapper around the Milvus vector database."""
def
__init__
(
self
,
embedding_function
:
Embeddings
,
collection_name
:
str
=
"LangChainCollection"
,
connection_args
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
consistency_level
:
str
=
"Session"
,
index_params
:
Optional
[
dict
]
=
None
,
search_params
:
Optional
[
dict
]
=
None
,
drop_old
:
Optional
[
bool
]
=
False
,
):
"""Initialize wrapper around the milvus vector database.
In order to use this you need to have `pymilvus` installed and a
running Milvus/Zilliz Cloud instance.
See the following documentation for how to run a Milvus instance:
https://milvus.io/docs/install_standalone-docker.md
If looking for a hosted Milvus, take a looka this documentation:
https://zilliz.com/cloud
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
The connection args used for this class comes in the form of a dict,
here are a few of the options:
address (str): The actual address of Milvus
instance. Example address: "localhost:19530"
uri (str): The uri of Milvus instance. Example uri:
"http://randomwebsite:19530",
"tcp:foobarsite:19530",
"https://ok.s3.south.com:19530".
host (str): The host of Milvus instance. Default at "localhost",
PyMilvus will fill in the default host if only port is provided.
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
will fill in the default port if only host is provided.
user (str): Use which user to connect to Milvus instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
secure (bool): Default is false. If set to true, tls will be enabled.
client_key_path (str): If use tls two-way authentication, need to
write the client.key path.
client_pem_path (str): If use tls two-way authentication, need to
write the client.pem path.
ca_pem_path (str): If use tls two-way authentication, need to write
the ca.pem path.
server_pem_path (str): If use tls one-way authentication, need to
write the server.pem path.
server_name (str): If use tls, need to write the common name.
Args:
embedding_function (Embeddings): Function used to embed the text.
collection_name (str): Which Milvus collection to use. Defaults to
"LangChainCollection".
connection_args (Optional[dict[str, any]]): The arguments for connection to
Milvus/Zilliz instance. Defaults to DEFAULT_MILVUS_CONNECTION.
consistency_level (str): The consistency level to use for a collection.
Defaults to "Session".
index_params (Optional[dict]): Which index params to use. Defaults to
HNSW/AUTOINDEX depending on service.
search_params (Optional[dict]): Which search params to use. Defaults to
default of index.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
"""
try
:
from
pymilvus
import
Collection
,
utility
except
ImportError
:
raise
ValueError
(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
# Default search params when one is not provided.
self
.
default_search_params
=
{
"IVF_FLAT"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"nprobe"
:
10
}},
"IVF_SQ8"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"nprobe"
:
10
}},
"IVF_PQ"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"nprobe"
:
10
}},
"HNSW"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"ef"
:
10
}},
"RHNSW_FLAT"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"ef"
:
10
}},
"RHNSW_SQ"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"ef"
:
10
}},
"RHNSW_PQ"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"ef"
:
10
}},
"IVF_HNSW"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"nprobe"
:
10
,
"ef"
:
10
}},
"ANNOY"
:
{
"metric_type"
:
"L2"
,
"params"
:
{
"search_k"
:
10
}},
"AUTOINDEX"
:
{
"metric_type"
:
"L2"
,
"params"
:
{}},
}
self
.
embedding_func
=
embedding_function
self
.
collection_name
=
collection_name
self
.
index_params
=
index_params
self
.
search_params
=
search_params
self
.
consistency_level
=
consistency_level
# In order for a collection to be compatible, pk needs to be auto'id and int
self
.
_primary_field
=
"pk"
# In order for compatiblility, the text field will need to be called "text"
self
.
_text_field
=
"text"
# In order for compatbility, the vector field needs to be called "vector"
self
.
_vector_field
=
"vector"
self
.
fields
:
list
[
str
]
=
[]
# Create the connection to the server
if
connection_args
is
None
:
connection_args
=
DEFAULT_MILVUS_CONNECTION
self
.
alias
=
self
.
_create_connection_alias
(
connection_args
)
self
.
col
:
Optional
[
Collection
]
=
None
# Grab the existing colection if it exists
if
utility
.
has_collection
(
self
.
collection_name
,
using
=
self
.
alias
):
self
.
col
=
Collection
(
self
.
collection_name
,
using
=
self
.
alias
,
)
# If need to drop old, drop it
if
drop_old
and
isinstance
(
self
.
col
,
Collection
):
self
.
col
.
drop
()
self
.
col
=
None
# Initialize the vector store
self
.
_init
()
def
_create_connection_alias
(
self
,
connection_args
:
dict
)
->
str
:
"""Create the connection to the Milvus server."""
from
pymilvus
import
MilvusException
,
connections
# Grab the connection arguments that are used for checking existing connection
host
:
str
=
connection_args
.
get
(
"host"
,
None
)
port
:
Union
[
str
,
int
]
=
connection_args
.
get
(
"port"
,
None
)
address
:
str
=
connection_args
.
get
(
"address"
,
None
)
uri
:
str
=
connection_args
.
get
(
"uri"
,
None
)
user
=
connection_args
.
get
(
"user"
,
None
)
# Order of use is host/port, uri, address
if
host
is
not
None
and
port
is
not
None
:
given_address
=
str
(
host
)
+
":"
+
str
(
port
)
elif
uri
is
not
None
:
given_address
=
uri
.
split
(
"https://"
)[
1
]
elif
address
is
not
None
:
given_address
=
address
else
:
given_address
=
None
logger
.
debug
(
"Missing standard address type for reuse atttempt"
)
# User defaults to empty string when getting connection info
if
user
is
not
None
:
tmp_user
=
user
else
:
tmp_user
=
""
# If a valid address was given, then check if a connection exists
if
given_address
is
not
None
:
for
con
in
connections
.
list_connections
():
addr
=
connections
.
get_connection_addr
(
con
[
0
])
if
(
con
[
1
]
and
(
"address"
in
addr
)
and
(
addr
[
"address"
]
==
given_address
)
and
(
"user"
in
addr
)
and
(
addr
[
"user"
]
==
tmp_user
)
):
logger
.
debug
(
"Using previous connection:
%
s"
,
con
[
0
])
return
con
[
0
]
# Generate a new connection if one doesnt exist
alias
=
uuid4
()
.
hex
try
:
connections
.
connect
(
alias
=
alias
,
**
connection_args
)
logger
.
debug
(
"Created new connection using:
%
s"
,
alias
)
return
alias
except
MilvusException
as
e
:
logger
.
error
(
"Failed to create new connection using:
%
s"
,
alias
)
raise
e
def
_init
(
self
,
embeddings
:
Optional
[
list
]
=
None
,
metadatas
:
Optional
[
list
[
dict
]]
=
None
)
->
None
:
if
embeddings
is
not
None
:
self
.
_create_collection
(
embeddings
,
metadatas
)
self
.
_extract_fields
()
self
.
_create_index
()
self
.
_create_search_params
()
self
.
_load
()
def
_create_collection
(
self
,
embeddings
:
list
,
metadatas
:
Optional
[
list
[
dict
]]
=
None
)
->
None
:
from
pymilvus
import
(
Collection
,
CollectionSchema
,
DataType
,
FieldSchema
,
MilvusException
,
)
from
pymilvus.orm.types
import
infer_dtype_bydata
# Determine embedding dim
dim
=
len
(
embeddings
[
0
])
fields
=
[]
# Determine metadata schema
if
metadatas
:
# Create FieldSchema for each entry in metadata.
for
key
,
value
in
metadatas
[
0
]
.
items
():
# Infer the corresponding datatype of the metadata
dtype
=
infer_dtype_bydata
(
value
)
# Datatype isnt compatible
if
dtype
==
DataType
.
UNKNOWN
or
dtype
==
DataType
.
NONE
:
logger
.
error
(
"Failure to create collection, unrecognized dtype for key:
%
s"
,
key
,
)
raise
ValueError
(
f
"Unrecognized datatype for {key}."
)
# Dataype is a string/varchar equivalent
elif
dtype
==
DataType
.
VARCHAR
:
fields
.
append
(
FieldSchema
(
key
,
DataType
.
VARCHAR
,
max_length
=
65_535
))
else
:
fields
.
append
(
FieldSchema
(
key
,
dtype
))
# Create the text field
fields
.
append
(
FieldSchema
(
self
.
_text_field
,
DataType
.
VARCHAR
,
max_length
=
65_535
)
)
# Create the primary key field
fields
.
append
(
FieldSchema
(
self
.
_primary_field
,
DataType
.
INT64
,
is_primary
=
True
,
auto_id
=
True
)
)
# Create the vector field, supports binary or float vectors
fields
.
append
(
FieldSchema
(
self
.
_vector_field
,
infer_dtype_bydata
(
embeddings
[
0
]),
dim
=
dim
)
)
# Create the schema for the collection
schema
=
CollectionSchema
(
fields
)
# Create the collection
try
:
self
.
col
=
Collection
(
name
=
self
.
collection_name
,
schema
=
schema
,
consistency_level
=
self
.
consistency_level
,
using
=
self
.
alias
,
)
except
MilvusException
as
e
:
logger
.
error
(
"Failed to create collection:
%
s error:
%
s"
,
self
.
collection_name
,
e
)
raise
e
def
_extract_fields
(
self
)
->
None
:
"""Grab the existing fields from the Collection"""
from
pymilvus
import
Collection
if
isinstance
(
self
.
col
,
Collection
):
schema
=
self
.
col
.
schema
for
x
in
schema
.
fields
:
self
.
fields
.
append
(
x
.
name
)
# Since primary field is auto-id, no need to track it
self
.
fields
.
remove
(
self
.
_primary_field
)
def
_get_index
(
self
)
->
Optional
[
dict
[
str
,
Any
]]:
"""Return the vector index information if it exists"""
from
pymilvus
import
Collection
if
isinstance
(
self
.
col
,
Collection
):
for
x
in
self
.
col
.
indexes
:
if
x
.
field_name
==
self
.
_vector_field
:
return
x
.
to_dict
()
return
None
def
_create_index
(
self
)
->
None
:
"""Create a index on the collection"""
from
pymilvus
import
Collection
,
MilvusException
if
isinstance
(
self
.
col
,
Collection
)
and
self
.
_get_index
()
is
None
:
try
:
# If no index params, use a default HNSW based one
if
self
.
index_params
is
None
:
self
.
index_params
=
{
"metric_type"
:
"L2"
,
"index_type"
:
"HNSW"
,
"params"
:
{
"M"
:
8
,
"efConstruction"
:
64
},
}
try
:
self
.
col
.
create_index
(
self
.
_vector_field
,
index_params
=
self
.
index_params
,
using
=
self
.
alias
,
)
# If default did not work, most likely on Zilliz Cloud
except
MilvusException
:
# Use AUTOINDEX based index
self
.
index_params
=
{
"metric_type"
:
"L2"
,
"index_type"
:
"AUTOINDEX"
,
"params"
:
{},
}
self
.
col
.
create_index
(
self
.
_vector_field
,
index_params
=
self
.
index_params
,
using
=
self
.
alias
,
)
logger
.
debug
(
"Successfully created an index on collection:
%
s"
,
self
.
collection_name
,
)
except
MilvusException
as
e
:
logger
.
error
(
"Failed to create an index on collection:
%
s"
,
self
.
collection_name
)
raise
e
def
_create_search_params
(
self
)
->
None
:
"""Generate search params based on the current index type"""
from
pymilvus
import
Collection
if
isinstance
(
self
.
col
,
Collection
)
and
self
.
search_params
is
None
:
index
=
self
.
_get_index
()
if
index
is
not
None
:
index_type
:
str
=
index
[
"index_param"
][
"index_type"
]
metric_type
:
str
=
index
[
"index_param"
][
"metric_type"
]
self
.
search_params
=
self
.
default_search_params
[
index_type
]
self
.
search_params
[
"metric_type"
]
=
metric_type
def
_load
(
self
)
->
None
:
"""Load the collection if available."""
from
pymilvus
import
Collection
if
isinstance
(
self
.
col
,
Collection
)
and
self
.
_get_index
()
is
not
None
:
self
.
col
.
load
()
def
add_texts
(
self
,
texts
:
Iterable
[
str
],
metadatas
:
Optional
[
List
[
dict
]]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
batch_size
:
int
=
1000
,
**
kwargs
:
Any
,
)
->
List
[
str
]:
"""Insert text data into Milvus.
Inserting data when the collection has not be made yet will result
in creating a new Collection. The data of the first entity decides
the schema of the new collection, the dim is extracted from the first
embedding and the columns are decided by the first metadata dict.
Metada keys will need to be present for all inserted values. At
the moment there is no None equivalent in Milvus.
Args:
texts (Iterable[str]): The texts to embed, it is assumed
that they all fit in memory.
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
the texts. Defaults to None.
timeout (Optional[int]): Timeout for each batch insert. Defaults
to None.
batch_size (int, optional): Batch size to use for insertion.
Defaults to 1000.
Raises:
MilvusException: Failure to add texts
Returns:
List[str]: The resulting keys for each inserted element.
"""
from
pymilvus
import
Collection
,
MilvusException
texts
=
list
(
texts
)
try
:
embeddings
=
self
.
embedding_test
(
texts
)
#embeddings = self.embedding_func.embed_documents(texts)
except
NotImplementedError
:
embeddings
=
[
self
.
embedding_func
.
embed_query
(
x
)
for
x
in
texts
]
if
len
(
embeddings
)
==
0
:
logger
.
debug
(
"Nothing to insert, skipping."
)
return
[]
# If the collection hasnt been initialized yet, perform all steps to do so
if
not
isinstance
(
self
.
col
,
Collection
):
self
.
_init
(
embeddings
,
metadatas
)
# Dict to hold all insert columns
insert_dict
:
dict
[
str
,
list
]
=
{
self
.
_text_field
:
texts
,
self
.
_vector_field
:
embeddings
,
}
# Collect the metadata into the insert dict.
if
metadatas
is
not
None
:
for
d
in
metadatas
:
for
key
,
value
in
d
.
items
():
if
key
in
self
.
fields
:
insert_dict
.
setdefault
(
key
,
[])
.
append
(
value
)
# Total insert count
vectors
:
list
=
insert_dict
[
self
.
_vector_field
]
total_count
=
len
(
vectors
)
pks
:
list
[
str
]
=
[]
assert
isinstance
(
self
.
col
,
Collection
)
for
i
in
range
(
0
,
total_count
,
batch_size
):
# Grab end index
end
=
min
(
i
+
batch_size
,
total_count
)
# Convert dict to list of lists batch for insertion
insert_list
=
[
insert_dict
[
x
][
i
:
end
]
for
x
in
self
.
fields
]
# Insert into the collection.
try
:
res
:
Collection
res
=
self
.
col
.
insert
(
insert_list
,
timeout
=
timeout
,
**
kwargs
)
pks
.
extend
(
res
.
primary_keys
)
except
MilvusException
as
e
:
logger
.
error
(
"Failed to insert batch starting at entity:
%
s/
%
s"
,
i
,
total_count
)
raise
e
return
pks
def
embedding_test
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
model
=
SentenceTransformer
(
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
)
embeddings
=
model
.
encode
(
texts
)
new_embeddings
=
[]
for
i
in
range
(
len
(
texts
)):
average
=
embeddings
[
i
]
new_embeddings
.
append
((
average
/
np
.
linalg
.
norm
(
average
))
.
tolist
())
return
new_embeddings
def
similarity_search
(
self
,
query
:
str
,
k
:
int
=
4
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Perform a similarity search against the query string.
Args:
query (str): The text to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
res
=
self
.
similarity_search_with_score
(
query
=
query
,
k
=
k
,
param
=
param
,
expr
=
expr
,
timeout
=
timeout
,
**
kwargs
)
return
[
doc
for
doc
,
_
in
res
]
def
similarity_search_by_vector
(
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Perform a similarity search against the query string.
Args:
embedding (List[float]): The embedding vector to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
res
=
self
.
similarity_search_with_score_by_vector
(
embedding
=
embedding
,
k
=
k
,
param
=
param
,
expr
=
expr
,
timeout
=
timeout
,
**
kwargs
)
return
[
doc
for
doc
,
_
in
res
]
def
similarity_search_with_score
(
self
,
query
:
str
,
k
:
int
=
4
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
Document
,
float
]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
query (str): The text being searched.
k (int, optional): The amount of results ot return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[float], List[Tuple[Document, any, any]]:
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
# Embed the query text.
embeddings
=
self
.
embedding_test
([
query
])
embeddings
=
self
.
embedding_func
.
embed_query
(
query
)
res
=
self
.
similarity_search_with_score_by_vector
(
embedding
=
embeddings
[
0
],
k
=
k
,
param
=
param
,
expr
=
expr
,
timeout
=
timeout
,
**
kwargs
)
return
res
def
normalize_embedding
(
self
,
embedding
):
return
preprocessing
.
normalize
(
embedding
,
norm
=
'l2'
)
def
similarity_search_with_score_by_vector
(
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
Document
,
float
]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The amount of results ot return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
if
param
is
None
:
param
=
self
.
search_params
# Determine result metadata fields.
output_fields
=
self
.
fields
[:]
output_fields
.
remove
(
self
.
_vector_field
)
# Perform the search.
res
=
self
.
col
.
search
(
data
=
[
embedding
],
anns_field
=
self
.
_vector_field
,
param
=
param
,
limit
=
k
,
expr
=
expr
,
output_fields
=
output_fields
,
timeout
=
timeout
,
**
kwargs
,
)
# Organize results.
ret
=
[]
for
result
in
res
[
0
]:
meta
=
{
x
:
result
.
entity
.
get
(
x
)
for
x
in
output_fields
}
doc
=
Document
(
page_content
=
meta
.
pop
(
self
.
_text_field
),
metadata
=
meta
)
pair
=
(
doc
,
result
.
score
)
ret
.
append
(
pair
)
return
ret
def
max_marginal_relevance_search
(
self
,
query
:
str
,
k
:
int
=
4
,
fetch_k
:
int
=
20
,
lambda_mult
:
float
=
0.5
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Perform a search and return results that are reordered by MMR.
Args:
query (str): The text being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
embedding
=
self
.
embedding_func
.
embed_query
(
query
)
return
self
.
max_marginal_relevance_search_by_vector
(
embedding
=
embedding
,
k
=
k
,
fetch_k
=
fetch_k
,
lambda_mult
=
lambda_mult
,
param
=
param
,
expr
=
expr
,
timeout
=
timeout
,
**
kwargs
,
)
def
max_marginal_relevance_search_by_vector
(
self
,
embedding
:
list
[
float
],
k
:
int
=
4
,
fetch_k
:
int
=
20
,
lambda_mult
:
float
=
0.5
,
param
:
Optional
[
dict
]
=
None
,
expr
:
Optional
[
str
]
=
None
,
timeout
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Perform a search and return results that are reordered by MMR.
Args:
embedding (str): The embedding vector being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if
self
.
col
is
None
:
logger
.
debug
(
"No existing collection to search."
)
return
[]
if
param
is
None
:
param
=
self
.
search_params
# Determine result metadata fields.
output_fields
=
self
.
fields
[:]
output_fields
.
remove
(
self
.
_vector_field
)
# Perform the search.
res
=
self
.
col
.
search
(
data
=
[
embedding
],
anns_field
=
self
.
_vector_field
,
param
=
param
,
limit
=
fetch_k
,
expr
=
expr
,
output_fields
=
output_fields
,
timeout
=
timeout
,
**
kwargs
,
)
# Organize results.
ids
=
[]
documents
=
[]
scores
=
[]
for
result
in
res
[
0
]:
meta
=
{
x
:
result
.
entity
.
get
(
x
)
for
x
in
output_fields
}
doc
=
Document
(
page_content
=
meta
.
pop
(
self
.
_text_field
),
metadata
=
meta
)
documents
.
append
(
doc
)
scores
.
append
(
result
.
score
)
ids
.
append
(
result
.
id
)
vectors
=
self
.
col
.
query
(
expr
=
f
"{self._primary_field} in {ids}"
,
output_fields
=
[
self
.
_primary_field
,
self
.
_vector_field
],
timeout
=
timeout
,
)
# Reorganize the results from query to match search order.
vectors
=
{
x
[
self
.
_primary_field
]:
x
[
self
.
_vector_field
]
for
x
in
vectors
}
ordered_result_embeddings
=
[
vectors
[
x
]
for
x
in
ids
]
# Get the new order of results.
new_ordering
=
maximal_marginal_relevance
(
np
.
array
(
embedding
),
ordered_result_embeddings
,
k
=
k
,
lambda_mult
=
lambda_mult
)
# Reorder the values and return.
ret
=
[]
for
x
in
new_ordering
:
# Function can return -1 index
if
x
==
-
1
:
break
else
:
ret
.
append
(
documents
[
x
])
return
ret
@
classmethod
def
from_texts
(
cls
,
texts
:
List
[
str
],
embedding
:
Embeddings
,
metadatas
:
Optional
[
List
[
dict
]]
=
None
,
collection_name
:
str
=
"LangChainCollection"
,
connection_args
:
dict
[
str
,
Any
]
=
DEFAULT_MILVUS_CONNECTION
,
consistency_level
:
str
=
"Session"
,
index_params
:
Optional
[
dict
]
=
None
,
search_params
:
Optional
[
dict
]
=
None
,
drop_old
:
bool
=
False
,
**
kwargs
:
Any
,
)
->
Milvus
:
"""Create a Milvus collection, indexes it with HNSW, and insert data.
Args:
texts (List[str]): Text data.
embedding (Embeddings): Embedding function.
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
Defaults to None.
collection_name (str, optional): Collection name to use. Defaults to
"LangChainCollection".
connection_args (dict[str, Any], optional): Connection args to use. Defaults
to DEFAULT_MILVUS_CONNECTION.
consistency_level (str, optional): Which consistency level to use. Defaults
to "Session".
index_params (Optional[dict], optional): Which index_params to use. Defaults
to None.
search_params (Optional[dict], optional): Which search params to use.
Defaults to None.
drop_old (Optional[bool], optional): Whether to drop the collection with
that name if it exists. Defaults to False.
Returns:
Milvus: Milvus Vector Store
"""
vector_db
=
cls
(
embedding_function
=
embedding
,
collection_name
=
collection_name
,
connection_args
=
connection_args
,
consistency_level
=
consistency_level
,
index_params
=
index_params
,
search_params
=
search_params
,
drop_old
=
drop_old
,
**
kwargs
,
)
vector_db
.
add_texts
(
texts
=
texts
,
metadatas
=
metadatas
)
return
vector_db
api/core/index/vector_index/milvus_vector_index.py
0 → 100644
View file @
4c596272
from
typing
import
Optional
,
cast
import
requests
import
weaviate
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
,
root_validator
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.weaviate_vector_store
import
WeaviateVectorStore
from
models.dataset
import
Dataset
class
MilvusConfig
(
BaseModel
):
uri
:
str
username
:
Optional
[
str
]
password
:
Optional
[
str
]
batch_size
:
int
=
100
@
root_validator
()
def
validate_config
(
cls
,
values
:
dict
)
->
dict
:
if
not
values
[
'uri'
]:
raise
ValueError
(
"config Milvus uri is required"
)
return
values
class
MilvusVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
MilvusConfig
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
,
embeddings
)
self
.
_client
=
self
.
_init_client
(
config
)
def
_init_client
(
self
,
config
:
MilvusConfig
)
->
weaviate
.
Client
:
auth_config
=
weaviate
.
auth
.
AuthApiKey
(
api_key
=
config
.
api_key
)
weaviate
.
connect
.
connection
.
has_grpc
=
False
try
:
client
=
weaviate
.
Client
(
url
=
config
.
endpoint
,
auth_client_secret
=
auth_config
,
timeout_config
=
(
5
,
60
),
startup_period
=
None
)
except
requests
.
exceptions
.
ConnectionError
:
raise
ConnectionError
(
"Vector database connection error"
)
client
.
batch
.
configure
(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size
=
config
.
batch_size
,
# dynamically update the `batch_size` based on import speed
dynamic
=
True
,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries
=
3
,
)
return
client
def
get_type
(
self
)
->
str
:
return
'weaviate'
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
class_prefix
+=
'_Node'
return
class_prefix
dataset_id
=
dataset
.
id
return
"Vector_index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
+
'_Node'
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
dataset
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
dataset
),
uuids
=
uuids
,
by_text
=
False
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
return
self
.
_vector_store
attributes
=
[
'doc_id'
,
'dataset_id'
,
'document_id'
]
if
self
.
_is_origin
():
attributes
=
[
'doc_id'
]
return
WeaviateVectorStore
(
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
dataset
),
text_key
=
'text'
,
embedding
=
self
.
_embeddings
,
attributes
=
attributes
,
by_text
=
False
)
def
_get_vector_store_class
(
self
)
->
type
:
return
WeaviateVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
.
del_texts
({
"operator"
:
"Equal"
,
"path"
:
[
"document_id"
],
"valueText"
:
document_id
})
def
_is_origin
(
self
):
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
return
True
return
False
api/core/index/vector_index/test-embedding.py
0 → 100644
View file @
4c596272
import
numpy
as
np
import
sklearn.decomposition
import
pickle
import
time
# Apply 'Algorithm 1' to the ada-002 embeddings to make them isotropic, taken from the paper:
# ALL-BUT-THE-TOP: SIMPLE AND EFFECTIVE POST- PROCESSING FOR WORD REPRESENTATIONS
# Jiaqi Mu, Pramod Viswanath
# This uses Principal Component Analysis (PCA) to 'evenly distribute' the embedding vectors (make them isotropic)
# For more information on PCA, see https://jamesmccaffrey.wordpress.com/2021/07/16/computing-pca-using-numpy-without-scikit/
# get the file pointer of the pickle containing the embeddings
fp
=
open
(
'/path/to/your/data/Embedding-Latest.pkl'
,
'rb'
)
# the embedding data here is a dict consisting of key / value pairs
# the key is the hash of the message (SHA3-256), the value is the embedding from ada-002 (array of dimension 1536)
# the hash can be used to lookup the orignal text in a database
E
=
pickle
.
load
(
fp
)
# load the data into memory
# seperate the keys (hashes) and values (embeddings) into seperate vectors
K
=
list
(
E
.
keys
())
# vector of all the hash values
X
=
np
.
array
(
list
(
E
.
values
()))
# vector of all the embeddings, converted to numpy arrays
# list the total number of embeddings
# this can be truncated if there are too many embeddings to do PCA on
print
(
f
"Total number of embeddings: {len(X)}"
)
# get dimension of embeddings, used later
Dim
=
len
(
X
[
0
])
# flash out the first few embeddings
print
(
"First two embeddings are: "
)
print
(
X
[
0
])
print
(
f
"First embedding length: {len(X[0])}"
)
print
(
X
[
1
])
print
(
f
"Second embedding length: {len(X[1])}"
)
# compute the mean of all the embeddings, and flash the result
mu
=
np
.
mean
(
X
,
axis
=
0
)
# same as mu in paper
print
(
f
"Mean embedding vector: {mu}"
)
print
(
f
"Mean embedding vector length: {len(mu)}"
)
# subtract the mean vector from each embedding vector ... vectorized in numpy
X_tilde
=
X
-
mu
# same as v_tilde(w) in paper
# do the heavy lifting of extracting the principal components
# note that this is a function of the embeddings you currently have here, and this set may grow over time
# therefore the PCA basis vectors may change over time, and your final isotropic embeddings may drift over time
# but the drift should stabilize after you have extracted enough embedding data to characterize the nature of the embedding engine
print
(
f
"Performing PCA on the normalized embeddings ..."
)
pca
=
sklearn
.
decomposition
.
PCA
()
# new object
TICK
=
time
.
time
()
# start timer
pca
.
fit
(
X_tilde
)
# do the heavy lifting!
TOCK
=
time
.
time
()
# end timer
DELTA
=
TOCK
-
TICK
print
(
f
"PCA finished in {DELTA} seconds ..."
)
# dimensional reduction stage (the only hyperparameter)
# pick max dimension of PCA components to express embddings
# in general this is some integer less than or equal to the dimension of your embeddings
# it could be set as a high percentile, say 95th percentile of pca.explained_variance_ratio_
# but just hardcoding a constant here
D
=
15
# hyperparameter on dimension (out of 1536 for ada-002), paper recommeds D = Dim/100
# form the set of v_prime(w), which is the final embedding
# this could be vectorized in numpy to speed it up, but coding it directly here in a double for-loop to avoid errors and to be transparent
E_prime
=
dict
()
# output dict of the new embeddings
N
=
len
(
X_tilde
)
N10
=
round
(
N
/
10
)
U
=
pca
.
components_
# set of PCA basis vectors, sorted by most significant to least significant
print
(
f
"Shape of full set of PCA componenents {U.shape}"
)
U
=
U
[
0
:
D
,:]
# take the top D dimensions (or take them all if D is the size of the embedding vector)
print
(
f
"Shape of downselected PCA componenents {U.shape}"
)
for
ii
in
range
(
N
):
v_tilde
=
X_tilde
[
ii
]
v
=
X
[
ii
]
v_projection
=
np
.
zeros
(
Dim
)
# start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for
jj
in
range
(
D
):
u_jj
=
U
[
jj
,:]
# vector
v_jj
=
np
.
dot
(
u_jj
,
v
)
# scaler
v_projection
+=
v_jj
*
u_jj
# vector
v_prime
=
v_tilde
-
v_projection
# final embedding vector
v_prime
=
v_prime
/
np
.
linalg
.
norm
(
v_prime
)
# create unit vector
E_prime
[
K
[
ii
]]
=
v_prime
if
(
ii
%
N10
==
0
)
or
(
ii
==
N
-
1
):
print
(
f
"Finished with {ii+1} embeddings out of {N} ({round(100*ii/N)}
%
done)"
)
# save as new pickle
print
(
"Saving new pickle ..."
)
embeddingName
=
'/path/to/your/data/Embedding-Latest-Isotropic.pkl'
with
open
(
embeddingName
,
'wb'
)
as
f
:
# Python 3: open(..., 'wb')
pickle
.
dump
([
E_prime
,
mu
,
U
],
f
)
print
(
embeddingName
)
print
(
"Done!"
)
# When working with live data with a new embedding from ada-002, be sure to tranform it first with this function before comparing it
#
def
projectEmbedding
(
v
,
mu
,
U
):
v
=
np
.
array
(
v
)
v_tilde
=
v
-
mu
v_projection
=
np
.
zeros
(
len
(
v
))
# start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for
u
in
U
:
v_jj
=
np
.
dot
(
u
,
v
)
# scaler
v_projection
+=
v_jj
*
u
# vector
v_prime
=
v_tilde
-
v_projection
# final embedding vector
v_prime
=
v_prime
/
np
.
linalg
.
norm
(
v_prime
)
# create unit vector
return
v_prime
\ No newline at end of file
api/core/indexing_runner.py
View file @
4c596272
...
@@ -71,18 +71,18 @@ class IndexingRunner:
...
@@ -71,18 +71,18 @@ class IndexingRunner:
dataset_document
=
dataset_document
,
dataset_document
=
dataset_document
,
processing_rule
=
processing_rule
processing_rule
=
processing_rule
)
)
new_documents
=
[]
#
new_documents = []
for
document
in
documents
:
#
for document in documents:
response
=
LLMGenerator
.
generate_qa_document
(
dataset
.
tenant_id
,
document
.
page_content
)
#
response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
document_qa_list
=
self
.
format_split_text
(
response
)
#
document_qa_list = self.format_split_text(response)
for
result
in
document_qa_list
:
#
for result in document_qa_list:
document
=
Document
(
page_content
=
result
[
'question'
],
metadata
=
{
'source'
:
result
[
'answer'
]})
#
document = Document(page_content=result['question'], metadata={'source': result['answer']})
new_documents
.
append
(
document
)
#
new_documents.append(document)
# build index
# build index
self
.
_build_index
(
self
.
_build_index
(
dataset
=
dataset
,
dataset
=
dataset
,
dataset_document
=
dataset_document
,
dataset_document
=
dataset_document
,
documents
=
new_
documents
documents
=
documents
)
)
except
DocumentIsPausedException
:
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
...
@@ -251,7 +251,8 @@ class IndexingRunner:
...
@@ -251,7 +251,8 @@ class IndexingRunner:
documents
=
self
.
_split_to_documents
(
documents
=
self
.
_split_to_documents
(
text_docs
=
text_docs
,
text_docs
=
text_docs
,
splitter
=
splitter
,
splitter
=
splitter
,
processing_rule
=
processing_rule
processing_rule
=
processing_rule
,
tenant_id
=
'84b2202c-c359-46b7-a810-bce50feaa4d1'
)
)
total_segments
+=
len
(
documents
)
total_segments
+=
len
(
documents
)
for
document
in
documents
:
for
document
in
documents
:
...
@@ -311,7 +312,8 @@ class IndexingRunner:
...
@@ -311,7 +312,8 @@ class IndexingRunner:
documents
=
self
.
_split_to_documents
(
documents
=
self
.
_split_to_documents
(
text_docs
=
documents
,
text_docs
=
documents
,
splitter
=
splitter
,
splitter
=
splitter
,
processing_rule
=
processing_rule
processing_rule
=
processing_rule
,
tenant_id
=
'84b2202c-c359-46b7-a810-bce50feaa4d1'
)
)
total_segments
+=
len
(
documents
)
total_segments
+=
len
(
documents
)
for
document
in
documents
:
for
document
in
documents
:
...
@@ -414,7 +416,8 @@ class IndexingRunner:
...
@@ -414,7 +416,8 @@ class IndexingRunner:
documents
=
self
.
_split_to_documents
(
documents
=
self
.
_split_to_documents
(
text_docs
=
text_docs
,
text_docs
=
text_docs
,
splitter
=
splitter
,
splitter
=
splitter
,
processing_rule
=
processing_rule
processing_rule
=
processing_rule
,
tenant_id
=
dataset
.
tenant_id
)
)
# save node to document segment
# save node to document segment
...
@@ -469,18 +472,18 @@ class IndexingRunner:
...
@@ -469,18 +472,18 @@ class IndexingRunner:
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
continue
continue
response
=
LLMGenerator
.
generate_qa_document
(
processing_rule
.
tenant_id
,
document
.
page_content
)
response
=
LLMGenerator
.
generate_qa_document
(
tenant_id
,
document
.
page_content
)
document_qa_list
=
self
.
format_split_text
(
response
)
document_qa_list
=
self
.
format_split_text
(
response
)
qa_documents
=
[]
for
result
in
document_qa_list
:
for
result
in
document_qa_list
:
document
=
Document
(
page_content
=
result
[
'question'
],
metadata
=
{
'source'
:
result
[
'answer'
]})
document
=
Document
(
page_content
=
result
[
'question'
],
metadata
=
{
'source'
:
result
[
'answer'
]})
new_documents
.
append
(
document
)
doc_id
=
str
(
uuid
.
uuid4
())
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
document
.
metadata
[
'doc_id'
]
=
doc_id
document
.
metadata
[
'doc_id'
]
=
doc_id
document
.
metadata
[
'doc_hash'
]
=
hash
document
.
metadata
[
'doc_hash'
]
=
hash
qa_documents
.
append
(
document
)
split_documents
.
append
(
document
)
split_documents
.
extend
(
qa_documents
)
all_documents
.
extend
(
split_documents
)
all_documents
.
extend
(
split_documents
)
...
...
api/core/prompt/prompts.py
View file @
4c596272
...
@@ -51,6 +51,7 @@ GENERATOR_QA_PROMPT = (
...
@@ -51,6 +51,7 @@ GENERATOR_QA_PROMPT = (
'Step3:可分解或结合多个信息与概念
\n
'
'Step3:可分解或结合多个信息与概念
\n
'
'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.
\n
'
'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.
\n
'
"按格式回答: Q1:
\n
A1:
\n
Q2:
\n
A2:...
\n
"
"按格式回答: Q1:
\n
A1:
\n
Q2:
\n
A2:...
\n
"
"只输出Step4中的内容"
)
)
RULE_CONFIG_GENERATE_TEMPLATE
=
"""Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select
\
RULE_CONFIG_GENERATE_TEMPLATE
=
"""Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment