Unverified Commit ee9c7e20 authored by Jyong's avatar Jyong Committed by GitHub

delete document cache embedding (#2101)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 483dcb63
import base64 import base64
import json import json
import logging import logging
from typing import List, Optional from typing import List, Optional, cast
import numpy as np import numpy as np
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from extensions.ext_database import db from extensions.ext_database import db
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
...@@ -22,56 +24,33 @@ class CacheEmbedding(Embeddings): ...@@ -22,56 +24,33 @@ class CacheEmbedding(Embeddings):
self._user = user self._user = user
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.""" """Embed search docs in batches of 10."""
# use doc embedding cache or store if not exists text_embeddings = []
text_embeddings = [None for _ in range(len(texts))]
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 3600)
text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
else:
embedding_queue_indices.append(i)
if embedding_queue_indices:
try: try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
for i in range(0, len(texts), max_chunks):
batch_texts = texts[i:i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding( embedding_result = self._model_instance.invoke_text_embedding(
texts=[texts[i] for i in embedding_queue_indices], texts=batch_texts,
user=self._user user=self._user
) )
embedding_results = embedding_result.embeddings for vector in embedding_result.embeddings:
except Exception as ex:
logger.error('Failed to embed documents: ', ex)
raise ex
for i, indice in enumerate(embedding_queue_indices):
hash = helper.generate_text_hash(texts[indice])
try: try:
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
text_embeddings[indice] = normalized_embedding text_embeddings.append(normalized_embedding)
# encode embedding to base64
embedding_vector = np.array(normalized_embedding)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 3600, encoded_str)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
continue except Exception as e:
except:
logging.exception('Failed to add embedding to redis') logging.exception('Failed to add embedding to redis')
continue
except Exception as ex:
logger.error('Failed to embed documents: ', ex)
raise ex
return text_embeddings return text_embeddings
...@@ -82,7 +61,7 @@ class CacheEmbedding(Embeddings): ...@@ -82,7 +61,7 @@ class CacheEmbedding(Embeddings):
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
embedding = redis_client.get(embedding_cache_key) embedding = redis_client.get(embedding_cache_key)
if embedding: if embedding:
redis_client.expire(embedding_cache_key, 3600) redis_client.expire(embedding_cache_key, 600)
return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
...@@ -105,7 +84,7 @@ class CacheEmbedding(Embeddings): ...@@ -105,7 +84,7 @@ class CacheEmbedding(Embeddings):
encoded_vector = base64.b64encode(vector_bytes) encoded_vector = base64.b64encode(vector_bytes)
# Transform to string # Transform to string
encoded_str = encoded_vector.decode("utf-8") encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 3600, encoded_str) redis_client.setex(embedding_cache_key, 600, encoded_str)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment