Unverified Commit 1fc57d73 authored by Jyong's avatar Jyong Committed by GitHub

normalize embedding (#974)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 916d8be0
import logging import logging
from typing import List from typing import List
import numpy as np
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
...@@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings): ...@@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
except Exception as ex: except Exception as ex:
raise self._embeddings.handle_exceptions(ex) raise self._embeddings.handle_exceptions(ex)
i = 0 i = 0
normalized_embedding_results = []
for text in embedding_queue_texts: for text in embedding_queue_texts:
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
try: try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash) embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i]) vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
normalized_embedding_results.append(normalized_embedding)
embedding.set_embedding(normalized_embedding)
db.session.add(embedding) db.session.add(embedding)
db.session.commit() db.session.commit()
except IntegrityError: except IntegrityError:
...@@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings): ...@@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
finally: finally:
i += 1 i += 1
text_embeddings.extend(embedding_results) text_embeddings.extend(normalized_embedding_results)
return text_embeddings return text_embeddings
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
...@@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings): ...@@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):
try: try:
embedding_results = self._embeddings.client.embed_query(text) embedding_results = self._embeddings.client.embed_query(text)
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex: except Exception as ex:
raise self._embeddings.handle_exceptions(ex) raise self._embeddings.handle_exceptions(ex)
...@@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings): ...@@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):
return embedding_results return embedding_results
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment