Unverified Commit a5b80c9d authored by Jyong's avatar Jyong Committed by GitHub

Fix/multi thread parameter (#1604)

parent f704094a
...@@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool): ...@@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool):
'search_method'] == 'hybrid_search': 'search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': self.top_k, 'top_k': self.top_k,
'score_threshold': self.score_threshold, 'score_threshold': self.score_threshold,
...@@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool): ...@@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool):
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
kwargs={ kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'search_method': 'hybrid_search', 'search_method': 'hybrid_search',
'embeddings': embeddings, 'embeddings': embeddings,
......
...@@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool):
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': self.top_k, 'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
...@@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool):
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'search_method': retrieval_model['search_method'], 'search_method': retrieval_model['search_method'],
'embeddings': embeddings, 'embeddings': embeddings,
......
...@@ -61,7 +61,7 @@ class HitTestingService: ...@@ -61,7 +61,7 @@ class HitTestingService:
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': retrieval_model['top_k'], 'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
...@@ -77,7 +77,7 @@ class HitTestingService: ...@@ -77,7 +77,7 @@ class HitTestingService:
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'search_method': retrieval_model['search_method'], 'search_method': retrieval_model['search_method'],
'embeddings': embeddings, 'embeddings': embeddings,
......
...@@ -4,6 +4,7 @@ from flask import current_app, Flask ...@@ -4,6 +4,7 @@ from flask import current_app, Flask
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
default_retrieval_model = { default_retrieval_model = {
...@@ -21,10 +22,13 @@ default_retrieval_model = { ...@@ -21,10 +22,13 @@ default_retrieval_model = {
class RetrievalService: class RetrievalService:
@classmethod @classmethod
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str, def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
vector_index = VectorIndex( vector_index = VectorIndex(
dataset=dataset, dataset=dataset,
...@@ -56,10 +60,13 @@ class RetrievalService: ...@@ -56,10 +60,13 @@ class RetrievalService:
all_documents.extend(documents) all_documents.extend(documents)
@classmethod @classmethod
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str, def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
vector_index = VectorIndex( vector_index = VectorIndex(
dataset=dataset, dataset=dataset,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment