Commit f6af0f99 authored by John Wang's avatar John Wang

feat: use embeddings table query instead of not implement of milvus vector

parent 5eddcaae
...@@ -73,7 +73,3 @@ class InvalidMetadataError(BaseHTTPException): ...@@ -73,7 +73,3 @@ class InvalidMetadataError(BaseHTTPException):
code = 400 code = 400
class CurrentVectorStoreNotSupportHitTestingError(BaseHTTPException):
error_code = 'current_vector_store_not_support_hit_testing'
description = "The current vector store does not support hit testing."
code = 400
...@@ -8,14 +8,12 @@ import services ...@@ -8,14 +8,12 @@ import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \ from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError ProviderModelCurrentlyNotSupportError
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError, \ from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
CurrentVectorStoreNotSupportHitTestingError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import TimestampField from libs.helper import TimestampField
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.errors.dataset import VectorStoreNotSupportHitTestingError
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
document_fields = { document_fields = {
...@@ -103,8 +101,6 @@ class HitTestingApi(Resource): ...@@ -103,8 +101,6 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except VectorStoreNotSupportHitTestingError:
raise CurrentVectorStoreNotSupportHitTestingError()
except Exception as e: except Exception as e:
logging.exception("Hit testing failed.") logging.exception("Hit testing failed.")
raise InternalServerError(str(e)) raise InternalServerError(str(e))
......
...@@ -74,10 +74,3 @@ class VectorStore: ...@@ -74,10 +74,3 @@ class VectorStore:
raise Exception("Vector store client is not initialized.") raise Exception("Vector store client is not initialized.")
return self._client return self._client
def support_hit_testing(self):
if isinstance(self._client, MilvusVectorStoreClient):
# search API not return vector data
return False
return True
\ No newline at end of file
...@@ -3,7 +3,3 @@ from services.errors.base import BaseServiceError ...@@ -3,7 +3,3 @@ from services.errors.base import BaseServiceError
class DatasetNameDuplicateError(BaseServiceError): class DatasetNameDuplicateError(BaseServiceError):
pass pass
class VectorStoreNotSupportHitTestingError(BaseServiceError):
pass
...@@ -11,19 +11,14 @@ from sklearn.manifold import TSNE ...@@ -11,19 +11,14 @@ from sklearn.manifold import TSNE
from core.docstore.empty_docstore import EmptyDocumentStore from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.vector_index import VectorIndex from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_vector_store import vector_store
from models.account import Account from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery from models.dataset import Dataset, DocumentSegment, DatasetQuery, Embedding
from services.errors.dataset import VectorStoreNotSupportHitTestingError
from services.errors.index import IndexNotInitializedError from services.errors.index import IndexNotInitializedError
class HitTestingService: class HitTestingService:
@classmethod @classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
if not vector_store.support_hit_testing():
raise VectorStoreNotSupportHitTestingError()
index = VectorIndex(dataset=dataset).query_index index = VectorIndex(dataset=dataset).query_index
if not index: if not index:
...@@ -74,6 +69,11 @@ class HitTestingService: ...@@ -74,6 +69,11 @@ class HitTestingService:
for node in nodes: for node in nodes:
if node.node.embedding: if node.node.embedding:
embeddings.append(node.node.embedding) embeddings.append(node.node.embedding)
else:
embedding = db.session.query(Embedding).filter_by(hash=node.node.doc_hash).first()
if embedding:
node.node.embedding = embedding.get_embedding()
embeddings.append(node.node.embedding)
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings) tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment