Commit 7f50f41f authored by John Wang's avatar John Wang

feat: pinecone support

parent fe688b50
...@@ -65,7 +65,7 @@ SESSION_REDIS_PORT=6379 ...@@ -65,7 +65,7 @@ SESSION_REDIS_PORT=6379
SESSION_REDIS_PASSWORD=difyai123456 SESSION_REDIS_PASSWORD=difyai123456
SESSION_REDIS_DB=2 SESSION_REDIS_DB=2
# Vector database configuration, support: weaviate, qdrant # Vector database configuration, support: weaviate, qdrant, pinecone
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Weaviate configuration # Weaviate configuration
...@@ -77,6 +77,10 @@ WEAVIATE_GRPC_ENABLED=false ...@@ -77,6 +77,10 @@ WEAVIATE_GRPC_ENABLED=false
QDRANT_URL=path:storage/qdrant QDRANT_URL=path:storage/qdrant
QDRANT_API_KEY=your-qdrant-api-key QDRANT_API_KEY=your-qdrant-api-key
# Pinecone configuration
PINECONE_API_KEY=
PINECONE_ENVIRONMENT=us-east4-gcp
# Sentry configuration # Sentry configuration
SENTRY_DSN= SENTRY_DSN=
......
...@@ -143,6 +143,10 @@ class Config: ...@@ -143,6 +143,10 @@ class Config:
self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# pinecone settings
self.PINECONE_API_KEY = get_env('PINECONE_API_KEY')
self.PINECONE_ENVIRONMENT = get_env('PINECONE_ENVIRONMENT')
# cors settings # cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL) 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL)
......
...@@ -187,6 +187,8 @@ And answer according to the language of the user's question. ...@@ -187,6 +187,8 @@ And answer according to the language of the user's question.
if chain_output: if chain_output:
human_inputs['context'] = chain_output human_inputs['context'] = chain_output
# Inspired by @Yorki
human_message_prompt += """Use the following CONTEXT as your learned knowledge. human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT] [CONTEXT]
{context} {context}
......
...@@ -21,8 +21,7 @@ class VectorIndex: ...@@ -21,8 +21,7 @@ class VectorIndex:
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False): def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
if not self._dataset.index_struct_dict: if not self._dataset.index_struct_dict:
index_id = "Vector_index_" + self._dataset.id.replace("-", "_") self._dataset.index_struct = json.dumps(vector_store.to_index_struct(self._dataset.id))
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
db.session.commit() db.session.commit()
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
......
...@@ -12,7 +12,7 @@ class BaseVectorStoreClient(ABC): ...@@ -12,7 +12,7 @@ class BaseVectorStoreClient(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def to_index_config(self, index_id: str) -> dict: def to_index_config(self, dataset_id: str) -> dict:
raise NotImplementedError raise NotImplementedError
......
import logging
import time
from typing import List, cast
import pinecone
from llama_index import GPTPineconeIndex, ServiceContext, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import PineconeIndexDict
from llama_index.data_structs.node_v2 import DocumentRelationship, Node
from llama_index.vector_stores import PineconeVectorStore
from llama_index.vector_stores.pinecone import generate_sparse_vectors, get_node_info_from_metadata
from llama_index.vector_stores.types import VectorStoreQueryResult, VectorStoreQuery, VectorStoreQueryMode
from pinecone import NotFoundException
from core.embedding.openai_embedding import OpenAIEmbedding
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
class PineconeVectorStoreClient(BaseVectorStoreClient):
def __init__(self, api_key: str, environment: str):
self._client = self.init_from_config(api_key, environment)
@classmethod
def init_from_config(cls, api_key: str, environment: str):
pinecone.init(
api_key=api_key,
environment=environment
)
return pinecone
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = PineconeIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
index_name = config.get('index_name')
if not index_name:
raise Exception("index_name cannot be None.")
try:
self._client.describe_index(index_name)
except NotFoundException:
# pinecone index not found
self._client.create_index(index_name, dimension=1536, metric="cosine", pod_type="p2")
waiting_iterations = 120
while waiting_iterations > 0:
try:
index_info = self._client.describe_index(index_name)
except:
logging.exception("Failed to query index status.")
break
if index_info.status['ready']:
# index is ready
break
time.sleep(1)
waiting_iterations -= 1
return GPTPineconeEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=PineconeEnhanceVectorStore(
pinecone_index=self._client.Index(index_name),
tokenizer=OpenAIEmbedding().get_text_embedding
)
)
def to_index_config(self, dataset_id: str) -> dict:
index_id = "vector-" + dataset_id
return {"index_name": index_id}
class GPTPineconeEnhanceIndex(GPTPineconeIndex, BaseGPTVectorStoreIndex):
pass
class PineconeEnhanceVectorStore(PineconeVectorStore, EnhanceVectorStore):
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
"""
sparse_vector = None
if query.mode in (VectorStoreQueryMode.SPARSE, VectorStoreQueryMode.HYBRID):
if query.query_str is None:
raise ValueError(
"query_str must be specified if mode is SPARSE or HYBRID."
)
sparse_vector = generate_sparse_vectors([query.query_str], self._tokenizer)[
0
]
if query.alpha is not None:
sparse_vector = {
"indices": sparse_vector["indices"],
"values": [v * (1 - query.alpha) for v in sparse_vector["values"]],
}
query_embedding = None
if query.mode in (VectorStoreQueryMode.DEFAULT, VectorStoreQueryMode.HYBRID):
query_embedding = cast(List[float], query.query_embedding)
if query.alpha is not None:
query_embedding = [v * query.alpha for v in query_embedding]
response = self._pinecone_index.query(
vector=query_embedding,
sparse_vector=sparse_vector,
top_k=query.similarity_top_k,
include_values=True,
include_metadata=True,
namespace=self._namespace,
filter=self._metadata_filters,
**self._pinecone_kwargs,
)
top_k_nodes = []
top_k_ids = []
top_k_scores = []
for match in response.matches:
text = match.metadata["text"]
extra_info = get_node_info_from_metadata(match.metadata, "extra_info")
node_info = get_node_info_from_metadata(match.metadata, "node_info")
doc_id = match.metadata["doc_id"]
id = match.metadata["id"]
embedding = match.values
node = Node(
text=text,
extra_info=extra_info,
node_info=node_info,
embedding=embedding,
doc_id=id,
relationships={DocumentRelationship.SOURCE: doc_id},
)
top_k_ids.append(match.id)
top_k_nodes.append(node)
top_k_scores.append(match.score)
return VectorStoreQueryResult(
nodes=top_k_nodes, similarities=top_k_scores, ids=top_k_ids
)
def delete_node(self, node_id: str):
self._pinecone_index.delete([node_id])
def exists_by_node_id(self, node_id: str) -> bool:
query_response = self._pinecone_index.query(
id=node_id
)
return len(query_response.matches) > 0
...@@ -56,7 +56,8 @@ class QdrantVectorStoreClient(BaseVectorStoreClient): ...@@ -56,7 +56,8 @@ class QdrantVectorStoreClient(BaseVectorStoreClient):
) )
) )
def to_index_config(self, index_id: str) -> dict: def to_index_config(self, dataset_id: str) -> dict:
index_id = "Vector_index_" + dataset_id.replace("-", "_")
return {"collection_name": index_id} return {"collection_name": index_id}
......
...@@ -3,10 +3,11 @@ from llama_index import ServiceContext, GPTVectorStoreIndex ...@@ -3,10 +3,11 @@ from llama_index import ServiceContext, GPTVectorStoreIndex
from requests import ReadTimeout from requests import ReadTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt from tenacity import retry, retry_if_exception_type, stop_after_attempt
from core.vector_store.pinecone_vector_store_client import PineconeVectorStoreClient
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant'] SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant', 'pinecone']
class VectorStore: class VectorStore:
...@@ -35,6 +36,11 @@ class VectorStore: ...@@ -35,6 +36,11 @@ class VectorStore:
api_key=app.config['QDRANT_API_KEY'], api_key=app.config['QDRANT_API_KEY'],
root_path=app.root_path root_path=app.root_path
) )
elif self._vector_store == 'pinecone':
self._client = PineconeVectorStoreClient(
api_key=app.config['PINECONE_API_KEY'],
environment=app.config['PINECONE_ENVIRONMENT']
)
app.extensions['vector_store'] = self app.extensions['vector_store'] = self
...@@ -48,10 +54,10 @@ class VectorStore: ...@@ -48,10 +54,10 @@ class VectorStore:
return index return index
def to_index_struct(self, index_id: str) -> dict: def to_index_struct(self, dataset_id: str) -> dict:
return { return {
"type": self._vector_store, "type": self._vector_store,
"vector_store": self.get_client().to_index_config(index_id) "vector_store": self.get_client().to_index_config(dataset_id)
} }
def get_client(self): def get_client(self):
......
import json import json
import weaviate import weaviate
from dataclasses import field from dataclasses import field
from typing import List, Any, Dict, Optional from typing import List, Any, Dict, Optional, cast
from weaviate import UnexpectedStatusCodeException
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
...@@ -12,7 +14,7 @@ from llama_index.vector_stores import WeaviateVectorStore ...@@ -12,7 +14,7 @@ from llama_index.vector_stores import WeaviateVectorStore
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
from llama_index.readers.weaviate.utils import ( from llama_index.readers.weaviate.utils import (
parse_get_response, parse_get_response,
validate_client, validate_client, get_default_class_prefix,
) )
...@@ -53,11 +55,68 @@ class WeaviateVectorStoreClient(BaseVectorStoreClient): ...@@ -53,11 +55,68 @@ class WeaviateVectorStoreClient(BaseVectorStoreClient):
) )
) )
def to_index_config(self, index_id: str) -> dict: def to_index_config(self, dataset_id: str) -> dict:
index_id = "Vector_index_" + dataset_id.replace("-", "_")
return {"class_prefix": index_id} return {"class_prefix": index_id}
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore): class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
def __init__(
self,
weaviate_client: Optional[Any] = None,
class_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
import_err_msg = (
"`weaviate` package not found, please run `pip install weaviate-client`"
)
try:
import weaviate # noqa: F401
from weaviate import Client # noqa: F401
except ImportError:
raise ImportError(import_err_msg)
if weaviate_client is None:
raise ValueError("Missing Weaviate client!")
self._client = cast(Client, weaviate_client)
# validate class prefix starts with a capital letter
if class_prefix is not None and not class_prefix[0].isupper():
raise ValueError(
"Class prefix must start with a capital letter, e.g. 'Gpt'"
)
self._class_prefix = class_prefix or get_default_class_prefix()
# try to create schema
self.create_schema(self._client, self._class_prefix)
def create_schema(self, client: Any, class_prefix: str) -> None:
"""Create schema."""
validate_client(client)
# first check if schema exists
class_name = _class_name(class_prefix)
try:
exist_class = client.schema.get(class_name)
if exist_class:
return
except UnexpectedStatusCodeException as e:
if e.status_code != 404:
raise e
except Exception as e:
raise e
properties = NODE_SCHEMA
class_obj = {
"class": class_name, # <= note the capital "A".
"description": f"Class for {class_name}",
"properties": properties,
"vectorIndexConfig": {
"efConstruction": 160,
"maxConnections": 32
},
}
client.schema.create_class(class_obj)
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.""" """Query index for top k most similar nodes."""
nodes = self.weaviate_query( nodes = self.weaviate_query(
......
...@@ -23,6 +23,7 @@ tenacity==8.2.2 ...@@ -23,6 +23,7 @@ tenacity==8.2.2
cachetools~=5.3.0 cachetools~=5.3.0
weaviate-client~=3.16.2 weaviate-client~=3.16.2
qdrant_client~=1.1.6 qdrant_client~=1.1.6
pinecone-client~=2.2.1
mailchimp-transactional~=1.0.50 mailchimp-transactional~=1.0.50
scikit-learn==1.2.2 scikit-learn==1.2.2
sentry-sdk[flask]~=1.21.1 sentry-sdk[flask]~=1.21.1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment