Commit 4c596272 authored by jyong's avatar jyong

fix clean dataset task

parent 90b22d8c
import datetime
import re
from os import environ
from uuid import uuid4
import openai
from langchain.document_loaders import WebBaseLoader, UnstructuredFileLoader, TextLoader
from langchain.embeddings import OpenAIEmbeddings, MiniMaxEmbeddings
from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter
from pymilvus import connections, Collection
from pymilvus.orm import utility
from core.data_loader.loader.excel import ExcelLoader
from core.generator.llm_generator import LLMGenerator
from core.index.vector_index.milvus import Milvus
OPENAI_API_KEY = "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" # example: "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
## Set up environment variables
environ["OPENAI_API_KEY"] = OPENAI_API_KEY
environ["MINIMAX_GROUP_ID"] = "1686736670459291"
environ[
"MINIMAX_API_KEY"] = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJOYW1lIjoiIiwiU3ViamVjdElEIjoiMTY4NjczNjY3MDQ0NzEyNSIsIlBob25lIjoiTVRVd01UZzBNREU1TlRFPSIsIkdyb3VwSUQiOiIiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiJwYW5wYW5AZGlmeS5haSIsIkNyZWF0ZVRpbWUiOiIiLCJpc3MiOiJtaW5pbWF4In0.i9gRKYmOW3zM8vEcT7lD-Ym-0eE6UUU3vb-gVxpWfSMkdc6ObbRnkP5nYumZJbV9L-yRA00GW6nMWYcWkY3IbDWWFAi-hRmzAtl-orpkz5DxPzjRJbwAPy9snYlqBWYQ4hOQ-53zmA5wgsm0ga5pMpBTN9SCkm7EnBQDEsPEY1m121tuwXe6LhAMjdX0Kic-UI-KTYbDdWGAl6nu8h8lrSHVuEEYA6Lz3VDyJTcYfME-B435vw-x1UXSb5-V-YhMEhIixEO8ezUQXaERq0mErtIQEoZN4r7OeNNGjocsfwiHRiw_EdxbfYUWjpvAytmmekIuL3tfvfhbif-EZc4E5w"
CONVERSATION_PROMPT = (
"你是出题人.\n"
"用户会发送一段长文本.\n请一步一步思考"
'Step1:了解并总结这段文本的主要内容\n'
'Step2:这段文本提到了哪些关键信息或概念\n'
'Step3:可分解或结合多个信息与概念\n'
'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.\n'
"按格式回答: Q1:\nA1:\nQ2:\nA2:...\n"
)
def test_milvus():
def format_split_text(text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
result = [] # 存储最终的结果
for match in matches:
q = match[0]
a = match[1]
if q and a:
# 如果Q和A都存在,就将其添加到结果中
result.append({
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
})
return result
# 84b2202c-c359-46b7-a810-bce50feaa4d1
# Use the WebBaseLoader to load specified web pages into documents
# loader = WebBaseLoader([
# "https://milvus.io/docs/overview.md",
# ])
loader = ExcelLoader('/Users/jiangyong/Downloads/xiaoming.xlsx')
# loader = TextLoader('/Users/jiangyong/Downloads/all.txt', autodetect_encoding=True)
# loader = UnstructuredFileLoader('/Users/jiangyong/Downloads/douban.xlsx')
docs = loader.load()
#
# # Split the documents into smaller chunks
text_splitter = CharacterTextSplitter(chunk_size=1024, chunk_overlap=0)
#
docs = text_splitter.split_documents(docs)
new_docs = []
for doc in docs:
openai.api_key="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"
response = openai.ChatCompletion.create(
model='gpt-3.5-turbo',
messages=[
{
'role': 'system',
'content': CONVERSATION_PROMPT
},
{
'role': 'user',
'content': doc.page_content
}
],
temperature=0,
stream=False, # this time, we set stream=True
n=1,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
#response = LLMGenerator.generate_qa_document('84b2202c-c359-46b7-a810-bce50feaa4d1', doc.page_content)
results = format_split_text(response['choices'][0]['message']['content'])
print(results)
# for result in results:
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
# new_docs.append(document)
# new_docs.append(doc)
# embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
embeddings = MiniMaxEmbeddings()
# cont = connections.connect(
# alias="default",
# user='username',
# password='password',
# host='localhost',
# port='19530'
# )
# chunk_size = 100
# for i in range(0, len(new_docs), chunk_size):
# # check document is paused
# chunk_documents = new_docs[i:i + chunk_size]
# vector_store = Milvus.from_documents(
# chunk_documents,
# collection_name='jytest5',
# embedding=embeddings,
# connection_args={"uri": 'https://in01-706333b4f51fa0b.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'user': 'db_admin', 'password': 'dify123456!'}
# )
# collection = Collection("jytest4") # Get an existing collection.
# collection.release()
# print(datetime.datetime.utcnow())
# alias = uuid4().hex
# # #connection_args = {"host": 'localhost', "port": '19530'}
# connection_args = {"uri": 'https://in01-91c80c04f4aed06.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'user': 'db_admin', 'password': 'dify123456!'}
# connections.connect(alias=alias, **connection_args)
# connection = Collection(
# 'jytest10',
# using=alias,
# )
# print(datetime.datetime.utcnow())
# # connection.release()
# query = '阿甘正传'
# search_params = {"metric_type": "IP", "params": {"level": 2}}
# docs = Milvus(embedding_function=embeddings, collection_name='jytest4').similarity_search(query)
# docs = Milvus(embedding_function=embeddings, collection_name='jytest', connection_args={"uri": 'https://in01-706333b4f51fa0b.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'user': 'db_admin', 'password': 'dify123456!'}).similarity_search(query)
# docs = Milvus(embedding_function=embeddings, collection_name='jytest10', connection_args={"uri": 'https://in01-91c80c04f4aed06.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'token': '01a3da355f5645fe949b1c6e97339c90b1931b6726094fcac3dd0594ab6312eb4ea314095ca989d7dfc8abfac1092dd1a6d46017', 'db_name':'dify'}).similarity_search(query)
# print(datetime.datetime.utcnow())
# docs = vector_store.similarity_search(query)
# cont = connections.connect(
# alias="default",
# user='username',
# password='password',
# host='localhost',
# port='19530'
# )
# docs = cont.search(query='What is milvus?', search_type='similarity',
# connection_args={"host": 'localhost', "port": '19530'})
# docs = vector_store.similarity_search(query)
# print(docs)
# connections.connect("default",
# uri='https://in01-617651a0cb211be.aws-us-west-2.vectordb.zillizcloud.com:19533',
# user='db_admin',
# password='dify123456!')
#
# # Check if the collection exists
# collection_name = "jytest"
# check_collection = utility.has_collection(collection_name)
# if check_collection:
# drop_result = utility.drop_collection(collection_name)
# print("Success!")
# collection = Collection(name=collection_name)
# collection.
# search_params = {"metric_type": "L2", "params": {"level": 2}}
# results = collection.search('电影排名50',
# anns_field='page_content',
# param=search_params,
# limit=1,
# guarantee_timestamp=1)
# connections.disconnect("default")
import numpy as np
from numpy import average
from sentence_transformers import SentenceTransformer
def test_embdding():
sentences = ["My name is john"]
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embeddings = model.encode(sentences)
for embedding in embeddings:
print(embedding)
embedding = (embedding / np.linalg.norm(embedding)).tolist()
print(embedding)
embedding = (embedding / np.linalg.norm(embedding)).tolist()
print(embedding)
print(embeddings)
import base64
import binascii
import hashlib
import secrets
from os import environ
import numpy as np
from langchain.embeddings import MiniMaxEmbeddings
from numpy import average
from sentence_transformers import SentenceTransformer
from core.index.vector_index.milvus import Milvus
environ["MINIMAX_GROUP_ID"] = "1686736670459291"
environ["MINIMAX_API_KEY"] = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJOYW1lIjoiIiwiU3ViamVjdElEIjoiMTY4NjczNjY3MDQ0NzEyNSIsIlBob25lIjoiTVRVd01UZzBNREU1TlRFPSIsIkdyb3VwSUQiOiIiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiJwYW5wYW5AZGlmeS5haSIsIkNyZWF0ZVRpbWUiOiIiLCJpc3MiOiJtaW5pbWF4In0.i9gRKYmOW3zM8vEcT7lD-Ym-0eE6UUU3vb-gVxpWfSMkdc6ObbRnkP5nYumZJbV9L-yRA00GW6nMWYcWkY3IbDWWFAi-hRmzAtl-orpkz5DxPzjRJbwAPy9snYlqBWYQ4hOQ-53zmA5wgsm0ga5pMpBTN9SCkm7EnBQDEsPEY1m121tuwXe6LhAMjdX0Kic-UI-KTYbDdWGAl6nu8h8lrSHVuEEYA6Lz3VDyJTcYfME-B435vw-x1UXSb5-V-YhMEhIixEO8ezUQXaERq0mErtIQEoZN4r7OeNNGjocsfwiHRiw_EdxbfYUWjpvAytmmekIuL3tfvfhbif-EZc4E5w"
def test_query():
# embeddings = MiniMaxEmbeddings()
# query = '你对这部电影有什么感悟'
# # search_params = {"metric_type": "IP", "params": {"level": 2}}
# # docs = Milvus(embedding_function=embeddings, collection_name='jytest4').similarity_search(query)
# docs = Milvus(embedding_function=embeddings, collection_name='jytest5',
# connection_args={"uri": 'https://in01-706333b4f51fa0b.aws-us-west-2.vectordb.zillizcloud.com:19530',
# 'user': 'db_admin', 'password': 'dify123456!'}).similarity_search(query)
# print(docs)
# generate password salt
def hash_password(password_str, salt_byte):
dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000)
return binascii.hexlify(dk)
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password('dify123456!', salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
print(base64_password_hashed)
print('*******************')
print(base64_salt)
......@@ -178,7 +178,7 @@ class LLMGenerator:
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=1000
max_tokens=100
)
if isinstance(llm, BaseChatModel):
......
"""Wrapper around the Milvus vector database."""
from __future__ import annotations
import logging
from typing import Any, Iterable, List, Optional, Tuple, Union
from uuid import uuid4
import numpy as np
from numpy import average
from sentence_transformers import SentenceTransformer
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
from sklearn import preprocessing
logger = logging.getLogger(__name__)
DEFAULT_MILVUS_CONNECTION = {
"host": "localhost",
"port": "19530",
"user": "",
"password": "",
"secure": False,
}
class Milvus(VectorStore):
"""Wrapper around the Milvus vector database."""
def __init__(
self,
embedding_function: Embeddings,
collection_name: str = "LangChainCollection",
connection_args: Optional[dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: Optional[bool] = False,
):
"""Initialize wrapper around the milvus vector database.
In order to use this you need to have `pymilvus` installed and a
running Milvus/Zilliz Cloud instance.
See the following documentation for how to run a Milvus instance:
https://milvus.io/docs/install_standalone-docker.md
If looking for a hosted Milvus, take a looka this documentation:
https://zilliz.com/cloud
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
The connection args used for this class comes in the form of a dict,
here are a few of the options:
address (str): The actual address of Milvus
instance. Example address: "localhost:19530"
uri (str): The uri of Milvus instance. Example uri:
"http://randomwebsite:19530",
"tcp:foobarsite:19530",
"https://ok.s3.south.com:19530".
host (str): The host of Milvus instance. Default at "localhost",
PyMilvus will fill in the default host if only port is provided.
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
will fill in the default port if only host is provided.
user (str): Use which user to connect to Milvus instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
secure (bool): Default is false. If set to true, tls will be enabled.
client_key_path (str): If use tls two-way authentication, need to
write the client.key path.
client_pem_path (str): If use tls two-way authentication, need to
write the client.pem path.
ca_pem_path (str): If use tls two-way authentication, need to write
the ca.pem path.
server_pem_path (str): If use tls one-way authentication, need to
write the server.pem path.
server_name (str): If use tls, need to write the common name.
Args:
embedding_function (Embeddings): Function used to embed the text.
collection_name (str): Which Milvus collection to use. Defaults to
"LangChainCollection".
connection_args (Optional[dict[str, any]]): The arguments for connection to
Milvus/Zilliz instance. Defaults to DEFAULT_MILVUS_CONNECTION.
consistency_level (str): The consistency level to use for a collection.
Defaults to "Session".
index_params (Optional[dict]): Which index params to use. Defaults to
HNSW/AUTOINDEX depending on service.
search_params (Optional[dict]): Which search params to use. Defaults to
default of index.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
"""
try:
from pymilvus import Collection, utility
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
# Default search params when one is not provided.
self.default_search_params = {
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"HNSW": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
"AUTOINDEX": {"metric_type": "L2", "params": {}},
}
self.embedding_func = embedding_function
self.collection_name = collection_name
self.index_params = index_params
self.search_params = search_params
self.consistency_level = consistency_level
# In order for a collection to be compatible, pk needs to be auto'id and int
self._primary_field = "pk"
# In order for compatiblility, the text field will need to be called "text"
self._text_field = "text"
# In order for compatbility, the vector field needs to be called "vector"
self._vector_field = "vector"
self.fields: list[str] = []
# Create the connection to the server
if connection_args is None:
connection_args = DEFAULT_MILVUS_CONNECTION
self.alias = self._create_connection_alias(connection_args)
self.col: Optional[Collection] = None
# Grab the existing colection if it exists
if utility.has_collection(self.collection_name, using=self.alias):
self.col = Collection(
self.collection_name,
using=self.alias,
)
# If need to drop old, drop it
if drop_old and isinstance(self.col, Collection):
self.col.drop()
self.col = None
# Initialize the vector store
self._init()
def _create_connection_alias(self, connection_args: dict) -> str:
"""Create the connection to the Milvus server."""
from pymilvus import MilvusException, connections
# Grab the connection arguments that are used for checking existing connection
host: str = connection_args.get("host", None)
port: Union[str, int] = connection_args.get("port", None)
address: str = connection_args.get("address", None)
uri: str = connection_args.get("uri", None)
user = connection_args.get("user", None)
# Order of use is host/port, uri, address
if host is not None and port is not None:
given_address = str(host) + ":" + str(port)
elif uri is not None:
given_address = uri.split("https://")[1]
elif address is not None:
given_address = address
else:
given_address = None
logger.debug("Missing standard address type for reuse atttempt")
# User defaults to empty string when getting connection info
if user is not None:
tmp_user = user
else:
tmp_user = ""
# If a valid address was given, then check if a connection exists
if given_address is not None:
for con in connections.list_connections():
addr = connections.get_connection_addr(con[0])
if (
con[1]
and ("address" in addr)
and (addr["address"] == given_address)
and ("user" in addr)
and (addr["user"] == tmp_user)
):
logger.debug("Using previous connection: %s", con[0])
return con[0]
# Generate a new connection if one doesnt exist
alias = uuid4().hex
try:
connections.connect(alias=alias, **connection_args)
logger.debug("Created new connection using: %s", alias)
return alias
except MilvusException as e:
logger.error("Failed to create new connection using: %s", alias)
raise e
def _init(
self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
) -> None:
if embeddings is not None:
self._create_collection(embeddings, metadatas)
self._extract_fields()
self._create_index()
self._create_search_params()
self._load()
def _create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None
) -> None:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusException,
)
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
# Determine metadata schema
if metadatas:
# Create FieldSchema for each entry in metadata.
for key, value in metadatas[0].items():
# Infer the corresponding datatype of the metadata
dtype = infer_dtype_bydata(value)
# Datatype isnt compatible
if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
logger.error(
"Failure to create collection, unrecognized dtype for key: %s",
key,
)
raise ValueError(f"Unrecognized datatype for {key}.")
# Dataype is a string/varchar equivalent
elif dtype == DataType.VARCHAR:
fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
else:
fields.append(FieldSchema(key, dtype))
# Create the text field
fields.append(
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
fields.append(
FieldSchema(
self._primary_field, DataType.INT64, is_primary=True, auto_id=True
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(fields)
# Create the collection
try:
self.col = Collection(
name=self.collection_name,
schema=schema,
consistency_level=self.consistency_level,
using=self.alias,
)
except MilvusException as e:
logger.error(
"Failed to create collection: %s error: %s", self.collection_name, e
)
raise e
def _extract_fields(self) -> None:
"""Grab the existing fields from the Collection"""
from pymilvus import Collection
if isinstance(self.col, Collection):
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
# Since primary field is auto-id, no need to track it
self.fields.remove(self._primary_field)
def _get_index(self) -> Optional[dict[str, Any]]:
"""Return the vector index information if it exists"""
from pymilvus import Collection
if isinstance(self.col, Collection):
for x in self.col.indexes:
if x.field_name == self._vector_field:
return x.to_dict()
return None
def _create_index(self) -> None:
"""Create a index on the collection"""
from pymilvus import Collection, MilvusException
if isinstance(self.col, Collection) and self._get_index() is None:
try:
# If no index params, use a default HNSW based one
if self.index_params is None:
self.index_params = {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
try:
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
# If default did not work, most likely on Zilliz Cloud
except MilvusException:
# Use AUTOINDEX based index
self.index_params = {
"metric_type": "L2",
"index_type": "AUTOINDEX",
"params": {},
}
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
logger.debug(
"Successfully created an index on collection: %s",
self.collection_name,
)
except MilvusException as e:
logger.error(
"Failed to create an index on collection: %s", self.collection_name
)
raise e
def _create_search_params(self) -> None:
"""Generate search params based on the current index type"""
from pymilvus import Collection
if isinstance(self.col, Collection) and self.search_params is None:
index = self._get_index()
if index is not None:
index_type: str = index["index_param"]["index_type"]
metric_type: str = index["index_param"]["metric_type"]
self.search_params = self.default_search_params[index_type]
self.search_params["metric_type"] = metric_type
def _load(self) -> None:
"""Load the collection if available."""
from pymilvus import Collection
if isinstance(self.col, Collection) and self._get_index() is not None:
self.col.load()
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
timeout: Optional[int] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""Insert text data into Milvus.
Inserting data when the collection has not be made yet will result
in creating a new Collection. The data of the first entity decides
the schema of the new collection, the dim is extracted from the first
embedding and the columns are decided by the first metadata dict.
Metada keys will need to be present for all inserted values. At
the moment there is no None equivalent in Milvus.
Args:
texts (Iterable[str]): The texts to embed, it is assumed
that they all fit in memory.
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
the texts. Defaults to None.
timeout (Optional[int]): Timeout for each batch insert. Defaults
to None.
batch_size (int, optional): Batch size to use for insertion.
Defaults to 1000.
Raises:
MilvusException: Failure to add texts
Returns:
List[str]: The resulting keys for each inserted element.
"""
from pymilvus import Collection, MilvusException
texts = list(texts)
try:
embeddings = self.embedding_test(texts)
#embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
# If the collection hasnt been initialized yet, perform all steps to do so
if not isinstance(self.col, Collection):
self._init(embeddings, metadatas)
# Dict to hold all insert columns
insert_dict: dict[str, list] = {
self._text_field: texts,
self._vector_field: embeddings,
}
# Collect the metadata into the insert dict.
if metadatas is not None:
for d in metadatas:
for key, value in d.items():
if key in self.fields:
insert_dict.setdefault(key, []).append(value)
# Total insert count
vectors: list = insert_dict[self._vector_field]
total_count = len(vectors)
pks: list[str] = []
assert isinstance(self.col, Collection)
for i in range(0, total_count, batch_size):
# Grab end index
end = min(i + batch_size, total_count)
# Convert dict to list of lists batch for insertion
insert_list = [insert_dict[x][i:end] for x in self.fields]
# Insert into the collection.
try:
res: Collection
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
pks.extend(res.primary_keys)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
def embedding_test(self, texts: List[str]) -> List[List[float]]:
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embeddings = model.encode(texts)
new_embeddings =[]
for i in range(len(texts)):
average = embeddings[i]
new_embeddings.append((average / np.linalg.norm(average)).tolist())
return new_embeddings
def similarity_search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string.
Args:
query (str): The text to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score(
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string.
Args:
embedding (List[float]): The embedding vector to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
query (str): The text being searched.
k (int, optional): The amount of results ot return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[float], List[Tuple[Document, any, any]]:
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
# Embed the query text.
embeddings = self.embedding_test([query])
embeddings = self.embedding_func.embed_query(query)
res = self.similarity_search_with_score_by_vector(
embedding=embeddings[0], k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return res
def normalize_embedding(self, embedding):
return preprocessing.normalize(embedding, norm='l2')
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The amount of results ot return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ret = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
pair = (doc, result.score)
ret.append(pair)
return ret
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR.
Args:
query (str): The text being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
embedding = self.embedding_func.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding=embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
param=param,
expr=expr,
timeout=timeout,
**kwargs,
)
def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR.
Args:
embedding (str): The embedding vector being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=fetch_k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ids = []
documents = []
scores = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
documents.append(doc)
scores.append(result.score)
ids.append(result.id)
vectors = self.col.query(
expr=f"{self._primary_field} in {ids}",
output_fields=[self._primary_field, self._vector_field],
timeout=timeout,
)
# Reorganize the results from query to match search order.
vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
ordered_result_embeddings = [vectors[x] for x in ids]
# Get the new order of results.
new_ordering = maximal_marginal_relevance(
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
)
# Reorder the values and return.
ret = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
else:
ret.append(documents[x])
return ret
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: bool = False,
**kwargs: Any,
) -> Milvus:
"""Create a Milvus collection, indexes it with HNSW, and insert data.
Args:
texts (List[str]): Text data.
embedding (Embeddings): Embedding function.
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
Defaults to None.
collection_name (str, optional): Collection name to use. Defaults to
"LangChainCollection".
connection_args (dict[str, Any], optional): Connection args to use. Defaults
to DEFAULT_MILVUS_CONNECTION.
consistency_level (str, optional): Which consistency level to use. Defaults
to "Session".
index_params (Optional[dict], optional): Which index_params to use. Defaults
to None.
search_params (Optional[dict], optional): Which search params to use.
Defaults to None.
drop_old (Optional[bool], optional): Whether to drop the collection with
that name if it exists. Defaults to False.
Returns:
Milvus: Milvus Vector Store
"""
vector_db = cls(
embedding_function=embedding,
collection_name=collection_name,
connection_args=connection_args,
consistency_level=consistency_level,
index_params=index_params,
search_params=search_params,
drop_old=drop_old,
**kwargs,
)
vector_db.add_texts(texts=texts, metadatas=metadatas)
return vector_db
from typing import Optional, cast
import requests
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class MilvusConfig(BaseModel):
uri: str
username: Optional[str]
password: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['uri']:
raise ValueError("config Milvus uri is required")
return values
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
def _init_client(self, config: MilvusConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
import numpy as np
import sklearn.decomposition
import pickle
import time
# Apply 'Algorithm 1' to the ada-002 embeddings to make them isotropic, taken from the paper:
# ALL-BUT-THE-TOP: SIMPLE AND EFFECTIVE POST- PROCESSING FOR WORD REPRESENTATIONS
# Jiaqi Mu, Pramod Viswanath
# This uses Principal Component Analysis (PCA) to 'evenly distribute' the embedding vectors (make them isotropic)
# For more information on PCA, see https://jamesmccaffrey.wordpress.com/2021/07/16/computing-pca-using-numpy-without-scikit/
# get the file pointer of the pickle containing the embeddings
fp = open('/path/to/your/data/Embedding-Latest.pkl', 'rb')
# the embedding data here is a dict consisting of key / value pairs
# the key is the hash of the message (SHA3-256), the value is the embedding from ada-002 (array of dimension 1536)
# the hash can be used to lookup the orignal text in a database
E = pickle.load(fp) # load the data into memory
# seperate the keys (hashes) and values (embeddings) into seperate vectors
K = list(E.keys()) # vector of all the hash values
X = np.array(list(E.values())) # vector of all the embeddings, converted to numpy arrays
# list the total number of embeddings
# this can be truncated if there are too many embeddings to do PCA on
print(f"Total number of embeddings: {len(X)}")
# get dimension of embeddings, used later
Dim = len(X[0])
# flash out the first few embeddings
print("First two embeddings are: ")
print(X[0])
print(f"First embedding length: {len(X[0])}")
print(X[1])
print(f"Second embedding length: {len(X[1])}")
# compute the mean of all the embeddings, and flash the result
mu = np.mean(X, axis=0) # same as mu in paper
print(f"Mean embedding vector: {mu}")
print(f"Mean embedding vector length: {len(mu)}")
# subtract the mean vector from each embedding vector ... vectorized in numpy
X_tilde = X - mu # same as v_tilde(w) in paper
# do the heavy lifting of extracting the principal components
# note that this is a function of the embeddings you currently have here, and this set may grow over time
# therefore the PCA basis vectors may change over time, and your final isotropic embeddings may drift over time
# but the drift should stabilize after you have extracted enough embedding data to characterize the nature of the embedding engine
print(f"Performing PCA on the normalized embeddings ...")
pca = sklearn.decomposition.PCA() # new object
TICK = time.time() # start timer
pca.fit(X_tilde) # do the heavy lifting!
TOCK = time.time() # end timer
DELTA = TOCK - TICK
print(f"PCA finished in {DELTA} seconds ...")
# dimensional reduction stage (the only hyperparameter)
# pick max dimension of PCA components to express embddings
# in general this is some integer less than or equal to the dimension of your embeddings
# it could be set as a high percentile, say 95th percentile of pca.explained_variance_ratio_
# but just hardcoding a constant here
D = 15 # hyperparameter on dimension (out of 1536 for ada-002), paper recommeds D = Dim/100
# form the set of v_prime(w), which is the final embedding
# this could be vectorized in numpy to speed it up, but coding it directly here in a double for-loop to avoid errors and to be transparent
E_prime = dict() # output dict of the new embeddings
N = len(X_tilde)
N10 = round(N/10)
U = pca.components_ # set of PCA basis vectors, sorted by most significant to least significant
print(f"Shape of full set of PCA componenents {U.shape}")
U = U[0:D,:] # take the top D dimensions (or take them all if D is the size of the embedding vector)
print(f"Shape of downselected PCA componenents {U.shape}")
for ii in range(N):
v_tilde = X_tilde[ii]
v = X[ii]
v_projection = np.zeros(Dim) # start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for jj in range(D):
u_jj = U[jj,:] # vector
v_jj = np.dot(u_jj,v) # scaler
v_projection += v_jj*u_jj # vector
v_prime = v_tilde - v_projection # final embedding vector
v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
E_prime[K[ii]] = v_prime
if (ii%N10 == 0) or (ii == N-1):
print(f"Finished with {ii+1} embeddings out of {N} ({round(100*ii/N)}% done)")
# save as new pickle
print("Saving new pickle ...")
embeddingName = '/path/to/your/data/Embedding-Latest-Isotropic.pkl'
with open(embeddingName, 'wb') as f: # Python 3: open(..., 'wb')
pickle.dump([E_prime,mu,U], f)
print(embeddingName)
print("Done!")
# When working with live data with a new embedding from ada-002, be sure to tranform it first with this function before comparing it
#
def projectEmbedding(v,mu,U):
v = np.array(v)
v_tilde = v - mu
v_projection = np.zeros(len(v)) # start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for u in U:
v_jj = np.dot(u,v) # scaler
v_projection += v_jj*u # vector
v_prime = v_tilde - v_projection # final embedding vector
v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector
return v_prime
\ No newline at end of file
......@@ -71,18 +71,18 @@ class IndexingRunner:
dataset_document=dataset_document,
processing_rule=processing_rule
)
new_documents = []
for document in documents:
response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
document_qa_list = self.format_split_text(response)
for result in document_qa_list:
document = Document(page_content=result['question'], metadata={'source': result['answer']})
new_documents.append(document)
# new_documents = []
# for document in documents:
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
# document_qa_list = self.format_split_text(response)
# for result in document_qa_list:
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
# new_documents.append(document)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
documents=new_documents
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
......@@ -251,7 +251,8 @@ class IndexingRunner:
documents = self._split_to_documents(
text_docs=text_docs,
splitter=splitter,
processing_rule=processing_rule
processing_rule=processing_rule,
tenant_id='84b2202c-c359-46b7-a810-bce50feaa4d1'
)
total_segments += len(documents)
for document in documents:
......@@ -311,7 +312,8 @@ class IndexingRunner:
documents = self._split_to_documents(
text_docs=documents,
splitter=splitter,
processing_rule=processing_rule
processing_rule=processing_rule,
tenant_id='84b2202c-c359-46b7-a810-bce50feaa4d1'
)
total_segments += len(documents)
for document in documents:
......@@ -414,7 +416,8 @@ class IndexingRunner:
documents = self._split_to_documents(
text_docs=text_docs,
splitter=splitter,
processing_rule=processing_rule
processing_rule=processing_rule,
tenant_id=dataset.tenant_id
)
# save node to document segment
......@@ -469,18 +472,18 @@ class IndexingRunner:
if document.page_content is None or not document.page_content.strip():
continue
response = LLMGenerator.generate_qa_document(processing_rule.tenant_id, document.page_content)
response = LLMGenerator.generate_qa_document(tenant_id, document.page_content)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
document = Document(page_content=result['question'], metadata={'source': result['answer']})
new_documents.append(document)
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
split_documents.append(document)
document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash
qa_documents.append(document)
split_documents.extend(qa_documents)
all_documents.extend(split_documents)
......
......@@ -51,6 +51,7 @@ GENERATOR_QA_PROMPT = (
'Step3:可分解或结合多个信息与概念\n'
'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.\n'
"按格式回答: Q1:\nA1:\nQ2:\nA2:...\n"
"只输出Step4中的内容"
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment