Commit 7f50f41f authored by John Wang's avatar John Wang

feat: pinecone support

parent fe688b50
......@@ -65,7 +65,7 @@ SESSION_REDIS_PORT=6379
SESSION_REDIS_PASSWORD=difyai123456
SESSION_REDIS_DB=2
# Vector database configuration, support: weaviate, qdrant
# Vector database configuration, support: weaviate, qdrant, pinecone
VECTOR_STORE=weaviate
# Weaviate configuration
......@@ -77,6 +77,10 @@ WEAVIATE_GRPC_ENABLED=false
QDRANT_URL=path:storage/qdrant
QDRANT_API_KEY=your-qdrant-api-key
# Pinecone configuration
PINECONE_API_KEY=
PINECONE_ENVIRONMENT=us-east4-gcp
# Sentry configuration
SENTRY_DSN=
......
......@@ -143,6 +143,10 @@ class Config:
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# pinecone settings
self.PINECONE_API_KEY = get_env('PINECONE_API_KEY')
self.PINECONE_ENVIRONMENT = get_env('PINECONE_ENVIRONMENT')
# cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL)
......
......@@ -187,6 +187,8 @@ And answer according to the language of the user's question.
if chain_output:
human_inputs['context'] = chain_output
# Inspired by @Yorki
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT]
{context}
......
......@@ -21,8 +21,7 @@ class VectorIndex:
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
if not self._dataset.index_struct_dict:
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(self._dataset.id))
db.session.commit()
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
......
......@@ -12,7 +12,7 @@ class BaseVectorStoreClient(ABC):
raise NotImplementedError
@abstractmethod
def to_index_config(self, index_id: str) -> dict:
def to_index_config(self, dataset_id: str) -> dict:
raise NotImplementedError
......
import logging
import time
from typing import List, cast
import pinecone
from llama_index import GPTPineconeIndex, ServiceContext, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import PineconeIndexDict
from llama_index.data_structs.node_v2 import DocumentRelationship, Node
from llama_index.vector_stores import PineconeVectorStore
from llama_index.vector_stores.pinecone import generate_sparse_vectors, get_node_info_from_metadata
from llama_index.vector_stores.types import VectorStoreQueryResult, VectorStoreQuery, VectorStoreQueryMode
from pinecone import NotFoundException
from core.embedding.openai_embedding import OpenAIEmbedding
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
class PineconeVectorStoreClient(BaseVectorStoreClient):
def __init__(self, api_key: str, environment: str):
self._client = self.init_from_config(api_key, environment)
@classmethod
def init_from_config(cls, api_key: str, environment: str):
pinecone.init(
api_key=api_key,
environment=environment
)
return pinecone
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = PineconeIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
index_name = config.get('index_name')
if not index_name:
raise Exception("index_name cannot be None.")
try:
self._client.describe_index(index_name)
except NotFoundException:
# pinecone index not found
self._client.create_index(index_name, dimension=1536, metric="cosine", pod_type="p2")
waiting_iterations = 120
while waiting_iterations > 0:
try:
index_info = self._client.describe_index(index_name)
except:
logging.exception("Failed to query index status.")
break
if index_info.status['ready']:
# index is ready
break
time.sleep(1)
waiting_iterations -= 1
return GPTPineconeEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=PineconeEnhanceVectorStore(
pinecone_index=self._client.Index(index_name),
tokenizer=OpenAIEmbedding().get_text_embedding
)
)
def to_index_config(self, dataset_id: str) -> dict:
index_id = "vector-" + dataset_id
return {"index_name": index_id}
class GPTPineconeEnhanceIndex(GPTPineconeIndex, BaseGPTVectorStoreIndex):
pass
class PineconeEnhanceVectorStore(PineconeVectorStore, EnhanceVectorStore):
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
"""
sparse_vector = None
if query.mode in (VectorStoreQueryMode.SPARSE, VectorStoreQueryMode.HYBRID):
if query.query_str is None:
raise ValueError(
"query_str must be specified if mode is SPARSE or HYBRID."
)
sparse_vector = generate_sparse_vectors([query.query_str], self._tokenizer)[
0
]
if query.alpha is not None:
sparse_vector = {
"indices": sparse_vector["indices"],
"values": [v * (1 - query.alpha) for v in sparse_vector["values"]],
}
query_embedding = None
if query.mode in (VectorStoreQueryMode.DEFAULT, VectorStoreQueryMode.HYBRID):
query_embedding = cast(List[float], query.query_embedding)
if query.alpha is not None:
query_embedding = [v * query.alpha for v in query_embedding]
response = self._pinecone_index.query(
vector=query_embedding,
sparse_vector=sparse_vector,
top_k=query.similarity_top_k,
include_values=True,
include_metadata=True,
namespace=self._namespace,
filter=self._metadata_filters,
**self._pinecone_kwargs,
)
top_k_nodes = []
top_k_ids = []
top_k_scores = []
for match in response.matches:
text = match.metadata["text"]
extra_info = get_node_info_from_metadata(match.metadata, "extra_info")
node_info = get_node_info_from_metadata(match.metadata, "node_info")
doc_id = match.metadata["doc_id"]
id = match.metadata["id"]
embedding = match.values
node = Node(
text=text,
extra_info=extra_info,
node_info=node_info,
embedding=embedding,
doc_id=id,
relationships={DocumentRelationship.SOURCE: doc_id},
)
top_k_ids.append(match.id)
top_k_nodes.append(node)
top_k_scores.append(match.score)
return VectorStoreQueryResult(
nodes=top_k_nodes, similarities=top_k_scores, ids=top_k_ids
)
def delete_node(self, node_id: str):
self._pinecone_index.delete([node_id])
def exists_by_node_id(self, node_id: str) -> bool:
query_response = self._pinecone_index.query(
id=node_id
)
return len(query_response.matches) > 0
......@@ -56,7 +56,8 @@ class QdrantVectorStoreClient(BaseVectorStoreClient):
)
)
def to_index_config(self, index_id: str) -> dict:
def to_index_config(self, dataset_id: str) -> dict:
index_id = "Vector_index_" + dataset_id.replace("-", "_")
return {"collection_name": index_id}
......
......@@ -3,10 +3,11 @@ from llama_index import ServiceContext, GPTVectorStoreIndex
from requests import ReadTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from core.vector_store.pinecone_vector_store_client import PineconeVectorStoreClient
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant', 'pinecone']
class VectorStore:
......@@ -35,6 +36,11 @@ class VectorStore:
api_key=app.config['QDRANT_API_KEY'],
root_path=app.root_path
)
elif self._vector_store == 'pinecone':
self._client = PineconeVectorStoreClient(
api_key=app.config['PINECONE_API_KEY'],
environment=app.config['PINECONE_ENVIRONMENT']
)
app.extensions['vector_store'] = self
......@@ -48,10 +54,10 @@ class VectorStore:
return index
def to_index_struct(self, index_id: str) -> dict:
def to_index_struct(self, dataset_id: str) -> dict:
return {
"type": self._vector_store,
"vector_store": self.get_client().to_index_config(index_id)
"vector_store": self.get_client().to_index_config(dataset_id)
}
def get_client(self):
......
import json
import weaviate
from dataclasses import field
from typing import List, Any, Dict, Optional
from typing import List, Any, Dict, Optional, cast
from weaviate import UnexpectedStatusCodeException
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
......@@ -12,7 +14,7 @@ from llama_index.vector_stores import WeaviateVectorStore
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
from llama_index.readers.weaviate.utils import (
parse_get_response,
validate_client,
validate_client, get_default_class_prefix,
)
......@@ -53,11 +55,68 @@ class WeaviateVectorStoreClient(BaseVectorStoreClient):
)
)
def to_index_config(self, index_id: str) -> dict:
def to_index_config(self, dataset_id: str) -> dict:
index_id = "Vector_index_" + dataset_id.replace("-", "_")
return {"class_prefix": index_id}
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
def __init__(
self,
weaviate_client: Optional[Any] = None,
class_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
import_err_msg = (
"`weaviate` package not found, please run `pip install weaviate-client`"
)
try:
import weaviate # noqa: F401
from weaviate import Client # noqa: F401
except ImportError:
raise ImportError(import_err_msg)
if weaviate_client is None:
raise ValueError("Missing Weaviate client!")
self._client = cast(Client, weaviate_client)
# validate class prefix starts with a capital letter
if class_prefix is not None and not class_prefix[0].isupper():
raise ValueError(
"Class prefix must start with a capital letter, e.g. 'Gpt'"
)
self._class_prefix = class_prefix or get_default_class_prefix()
# try to create schema
self.create_schema(self._client, self._class_prefix)
def create_schema(self, client: Any, class_prefix: str) -> None:
"""Create schema."""
validate_client(client)
# first check if schema exists
class_name = _class_name(class_prefix)
try:
exist_class = client.schema.get(class_name)
if exist_class:
return
except UnexpectedStatusCodeException as e:
if e.status_code != 404:
raise e
except Exception as e:
raise e
properties = NODE_SCHEMA
class_obj = {
"class": class_name, # <= note the capital "A".
"description": f"Class for {class_name}",
"properties": properties,
"vectorIndexConfig": {
"efConstruction": 160,
"maxConnections": 32
},
}
client.schema.create_class(class_obj)
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
nodes = self.weaviate_query(
......
......@@ -23,6 +23,7 @@ tenacity==8.2.2
cachetools~=5.3.0
weaviate-client~=3.16.2
qdrant_client~=1.1.6
pinecone-client~=2.2.1
mailchimp-transactional~=1.0.50
scikit-learn==1.2.2
sentry-sdk[flask]~=1.21.1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment