Unverified Commit a3c7c07e authored by Jyong's avatar Jyong Committed by GitHub

use redis to cache embeddings (#2085)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent dc8a8af1
import base64
import json
import logging
from typing import List, Optional
......@@ -5,6 +7,8 @@ import numpy as np
from core.model_manager import ModelInstance
from extensions.ext_database import db
from langchain.embeddings.base import Embeddings
from extensions.ext_redis import redis_client
from libs import helper
from models.dataset import Embedding
from sqlalchemy.exc import IntegrityError
......@@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
embedding = redis_client.get(embedding_cache_key)
if embedding:
text_embeddings[i] = embedding.get_embedding()
redis_client.expire(embedding_cache_key, 3600)
text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
else:
embedding_queue_indices.append(i)
......@@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(texts[indice])
try:
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
text_embeddings[indice] = normalized_embedding
embedding.set_embedding(normalized_embedding)
db.session.add(embedding)
db.session.commit()
# encode embedding to base64
embedding_vector = np.array(normalized_embedding)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 3600, encoded_str)
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
logging.exception('Failed to add embedding to redis')
continue
return text_embeddings
......@@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings):
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
embedding = redis_client.get(embedding_cache_key)
if embedding:
return embedding.get_embedding()
redis_client.expire(embedding_cache_key, 3600)
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
try:
embedding_result = self._model_instance.invoke_text_embedding(
......@@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings):
raise ex
try:
embedding = Embedding(model_name=self._model_instance.model, hash=hash)
embedding.set_embedding(embedding_results)
db.session.add(embedding)
db.session.commit()
# encode embedding to base64
embedding_vector = np.array(embedding_results)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 3600, encoded_str)
except IntegrityError:
db.session.rollback()
except:
logging.exception('Failed to add embedding to db')
logging.exception('Failed to add embedding to redis')
return embedding_results
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment