Unverified Commit 4588831b authored by Jyong's avatar Jyong Committed by GitHub

Feat/add retriever rerank (#1560)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent a4f37220
......@@ -8,6 +8,8 @@ import time
import uuid
import click
import qdrant_client
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
......@@ -484,6 +486,38 @@ def normalization_collections():
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
def add_qdrant_full_text_index():
click.echo(click.style('Start add full text index.', fg='green'))
binds = db.session.query(DatasetCollectionBinding).all()
if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
qdrant_url = current_app.config['QDRANT_URL']
qdrant_api_key = current_app.config['QDRANT_API_KEY']
client = qdrant_client.QdrantClient(
qdrant_url,
api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance
)
for bind in binds:
try:
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
client.create_payload_index(bind.collection_name, 'page_content',
field_schema=text_index_params)
except Exception as e:
click.echo(
click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(
'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
fg='green'))
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context():
try:
......@@ -647,10 +681,10 @@ def update_app_model_configs(batch_size):
pbar.update(len(data_batch))
@click.command('migrate_default_input_to_dataset_query_variable')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def migrate_default_input_to_dataset_query_variable(batch_size):
click.secho("Starting...", fg='green')
total_records = db.session.query(AppModelConfig) \
......@@ -731,3 +765,4 @@ def register_commands(app):
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index)
......@@ -170,6 +170,7 @@ class DatasetApi(Resource):
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
args = parser.parse_args()
# The role of the current user in the ta table must be admin or owner
......@@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource):
resource_type = 'dataset'
@setup_required
@login_required
@account_initialization_required
......@@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource):
}
class DatasetRetrievalSettingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']
if vector_type == 'milvus':
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")
class DatasetRetrievalSettingMockApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, vector_type):
if vector_type == 'milvus':
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
......@@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
......@@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']:
......@@ -263,6 +265,8 @@ class DatasetInitApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
if args['indexing_technique'] == 'high_quality':
try:
......
......@@ -42,19 +42,18 @@ class HitTestingApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('query', type=str, location='json')
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
args = parser.parse_args()
query = args['query']
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')
HitTestingService.hit_testing_args_check(args)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=query,
query=args['query'],
account=current_user,
limit=10,
retrieval_model=args['retrieval_model'],
limit=10
)
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
......
......@@ -19,7 +19,7 @@ class DefaultModelApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
......@@ -71,18 +71,17 @@ class DefaultModelApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
model_settings = args['model_settings']
for model_setting in model_settings:
provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id,
model_type=args['model_type'],
provider_name=args['provider_name'],
model_name=args['model_name']
model_type=model_setting['model_type'],
provider_name=model_setting['provider_name'],
model_name=model_setting['model_name']
)
return {'result': 'success'}
......
......@@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource):
location='json')
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
......@@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
......
......@@ -14,7 +14,6 @@ from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
......@@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
......
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
......@@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
......
......@@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
......@@ -78,7 +79,7 @@ class AgentExecutor:
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
tools=self.configuration.tools,
......@@ -86,7 +87,7 @@ class AgentExecutor:
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
tools=self.configuration.tools,
......
......@@ -10,8 +10,7 @@ from models.dataset import DocumentSegment
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
self.dataset_id = dataset_id
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
self.conversation_message_task = conversation_message_task
def on_tool_end(self, documents: List[Document]) -> None:
......@@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler:
# add hit count to document segment
db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.index_node_id == doc_id
).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
......
......@@ -127,6 +127,7 @@ class Completion:
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback,
tenant_id=app.tenant_id,
retriever_from=retriever_from
)
......
......@@ -3,7 +3,7 @@ from pathlib import Path
from typing import List, Union, Optional
import requests
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
......@@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
return cls.load_from_file(file_path, return_text, upload_file)
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
......@@ -44,10 +44,20 @@ class FileExtractor:
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
upload_file: Optional[UploadFile] = None,
is_automatic: bool = False) -> Union[List[Document] | str]:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
if is_automatic:
loader = UnstructuredFileLoader(
file_path, strategy="hi_res", mode="elements"
)
# loader = UnstructuredAPIFileLoader(
# file_path=filenames[0],
# api_key="FAKE_API_KEY",
# )
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
......
......@@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex):
def _get_vector_store_class(self) -> type:
raise NotImplementedError
@abstractmethod
def search_by_full_text_index(
self, query: str,
**kwargs: Any
) -> List[Document]:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
......
from typing import Optional, cast
from typing import cast, Any, List
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore, milvus
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
from models.dataset import Dataset
class MilvusConfig(BaseModel):
......@@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex):
),
],
))
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
# milvus/zilliz doesn't support bm25 search
return []
......@@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
return vector_store.similarity_search_by_bm25(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
],
), kwargs.get('top_k', 2))
from typing import Optional, cast
from typing import Optional, cast, Any, List
import requests
import weaviate
......@@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel):
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
......@@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex):
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
......@@ -49,14 +49,14 @@ class IndexingRunner:
if not dataset:
raise ValueError("no dataset found")
# load file
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# load file
text_docs = self._load_data(dataset_document)
# get splitter
splitter = self._get_splitter(processing_rule)
......@@ -380,7 +380,7 @@ class IndexingRunner:
"preview": preview_texts
}
def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
# load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return []
......@@ -396,7 +396,7 @@ class IndexingRunner:
one_or_none()
if file_detail:
text_docs = FileExtractor.load(file_detail)
text_docs = FileExtractor.load(file_detail, is_automatic=False)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
......
......@@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
......@@ -140,6 +141,44 @@ class ModelFactory:
name=model_name
)
@classmethod
def get_reranking_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseReranking]:
"""
get reranking model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Reranking Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init reranking model
model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_speech2text_model(cls,
tenant_id: str,
......
......@@ -72,6 +72,9 @@ class ModelProviderFactory:
elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider
elif provider_name == 'cohere':
from core.model_providers.providers.cohere_provider import CohereProvider
return CohereProvider
else:
raise NotImplementedError
......
......@@ -17,7 +17,7 @@ class ModelType(enum.Enum):
IMAGE = 'image'
VIDEO = 'video'
MODERATION = 'moderation'
RERANKING = 'reranking'
@staticmethod
def value_of(value):
for member in ModelType:
......
from abc import abstractmethod
from typing import Any, Optional, List
from langchain.schema import Document
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
import logging
logger = logging.getLogger(__name__)
class BaseReranking(BaseProviderModel):
name: str
type: ModelType = ModelType.RERANKING
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
@property
def base_model_name(self) -> str:
"""
get base model name
:return: str
"""
return self.name
@abstractmethod
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError
import logging
from typing import Optional, List
import cohere
import openai
from langchain.schema import Document
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.providers.base import BaseModelProvider
class CohereReranking(BaseReranking):
def __init__(self, model_provider: BaseModelProvider, name: str):
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = cohere.Client(self.credentials.get('api_key'))
super().__init__(model_provider, client, name)
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
docs = []
doc_id = []
for document in documents:
if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id'])
docs.append(document.page_content)
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
rerank_documents = []
for idx, result in enumerate(results):
# format document
rerank_document = Document(
page_content=result.document['text'],
metadata={
"doc_id": documents[result.index].metadata['doc_id'],
"doc_hash": documents[result.index].metadata['doc_hash'],
"document_id": documents[result.index].metadata['document_id'],
"dataset_id": documents[result.index].metadata['dataset_id'],
'score': result.relevance_score
}
)
# score threshold check
if score_threshold is not None:
if result.relevance_score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return rerank_documents
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.reranking.cohere_reranking import CohereReranking
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
class CohereProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'cohere'
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.RERANKING:
return [
{
'id': 'rerank-english-v2.0',
'name': 'rerank-english-v2.0'
},
{
'id': 'rerank-multilingual-v2.0',
'name': 'rerank-multilingual-v2.0'
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.RERANKING:
model_class = CohereReranking
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Cohere api_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
}
# todo validate
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)
......@@ -13,5 +13,6 @@
"huggingface_hub",
"xinference",
"openllm",
"localai"
"localai",
"cohere"
]
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
}
\ No newline at end of file
from typing import Optional
import json
import threading
from typing import Optional, List
from flask import Flask
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from pydantic import BaseModel, Field
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
......@@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.tool.current_datetime_tool import DatetimeTool
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
......@@ -25,6 +32,16 @@ from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class OrchestratorRuleParser:
"""Parse the orchestrator rule to entities."""
......@@ -34,7 +51,7 @@ class OrchestratorRuleParser:
self.app_model_config = app_model_config
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str,
retriever_from: str = 'dev') -> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict:
return None
......@@ -101,7 +118,8 @@ class OrchestratorRuleParser:
rest_tokens=rest_tokens,
return_resource=return_resource,
retriever_from=retriever_from,
dataset_configs=dataset_configs
dataset_configs=dataset_configs,
tenant_id=tenant_id
)
if len(tools) == 0:
......@@ -132,6 +150,7 @@ class OrchestratorRuleParser:
:return:
"""
tools = []
dataset_tools = []
for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0]
......@@ -140,7 +159,7 @@ class OrchestratorRuleParser:
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
dataset_tools.append(tool_config)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
elif tool_type == "google_search":
......@@ -156,28 +175,35 @@ class OrchestratorRuleParser:
else:
tool.callbacks = callbacks
tools.append(tool)
# format dataset tool
if len(dataset_tools) > 0:
dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs)
if dataset_retriever_tools:
tools.extend(dataset_retriever_tools)
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
dataset_configs: dict, rest_tokens: int,
def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask,
return_resource: bool = False, retriever_from: str = 'dev',
**kwargs) \
-> Optional[BaseTool]:
-> Optional[List[BaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param dataset_configs:
:param tool_configs:
:param conversation_message_task:
:param return_resource:
:param retriever_from:
:return:
"""
dataset_configs = kwargs['dataset_configs']
retrieval_model = dataset_configs.get('retrieval_model', 'single')
tools = []
dataset_ids = []
tenant_id = None
for tool_config in tool_configs:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
Dataset.id == tool_config.get('dataset').get("id")
).first()
if not dataset:
......@@ -185,16 +211,18 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
top_k = dataset_configs.get("top_k", 2)
dataset_ids.append(dataset.id)
if retrieval_model == 'single':
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
top_k = retrieval_model['top_k']
# dynamically adjust top_k when the remaining token number is not enough to support top_k
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
# top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None
score_threshold_config = dataset_configs.get("score_threshold")
if score_threshold_config and score_threshold_config.get("enable"):
score_threshold = score_threshold_config.get("value")
score_threshold_enable = retrieval_model.get("score_threshold_enable")
if score_threshold_enable:
score_threshold = retrieval_model.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
......@@ -205,8 +233,23 @@ class OrchestratorRuleParser:
return_resource=return_resource,
retriever_from=retriever_from
)
tools.append(tool)
if retrieval_model == 'multiple':
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=dataset_ids,
tenant_id=kwargs['tenant_id'],
top_k=dataset_configs.get('top_k', 2),
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
retriever_from=retriever_from,
reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'),
reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name')
)
tools.append(tool)
return tool
return tools
def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
"""
......
import json
import threading
from typing import Type, Optional, List
from flask import current_app, Flask
from langchain.tools import BaseTool
from pydantic import Field, BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.conversation_message_task import ConversationMessageTask
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, Document
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(BaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset-"
args_schema: Type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
tenant_id: str
dataset_ids: List[str]
top_k: int = 2
score_threshold: Optional[float] = None
reranking_provider_name: str
reranking_model_name: str
conversation_message_task: ConversationMessageTask
return_resource: bool
retriever_from: str
@classmethod
def from_dataset(cls, dataset_ids: List[str], tenant_id: str, **kwargs):
return cls(
name=f'dataset-{tenant_id}',
tenant_id=tenant_id,
dataset_ids=dataset_ids,
**kwargs
)
def _run(self, query: str) -> str:
threads = []
all_documents = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'all_documents': all_documents
})
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
rerank = ModelFactory.get_reranking_model(
tenant_id=self.tenant_id,
model_provider_name=self.reranking_provider_name,
model_name=self.reranking_model_name
)
all_documents = rerank.rerank(query, all_documents, self.score_threshold, self.top_k)
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
hit_callback.on_tool_end(all_documents)
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(
id=segment.dataset_id
).first()
document = Document.query.filter(Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': self.retriever_from
}
if self.retriever_from == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
source['segment_position'] = segment.position
source['index_node_hash'] = segment.index_node_hash
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return []
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
if documents:
all_documents.extend(documents)
else:
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
return []
except ProviderTokenNotInitError:
return []
embeddings = CacheEmbedding(embedding_model)
documents = []
threads = []
if self.top_k > 0:
# retrieval_model source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[
'search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'top_k': self.top_k,
'score_threshold': self.score_threshold,
'reranking_model': None,
'all_documents': documents,
'search_method': 'hybrid_search',
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval_model source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[
'search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'search_method': 'hybrid_search',
'embeddings': embeddings,
'score_threshold': retrieval_model[
'score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model[
'reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
all_documents.extend(documents)
import json
from typing import Type, Optional
import threading
from typing import Type, Optional, List
from flask import current_app
from langchain.tools import BaseTool
......@@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, Document
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class DatasetRetrieverToolInput(BaseModel):
......@@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool):
).first()
if not dataset:
return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
return ''
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
......@@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool):
return ''
embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = []
threads = []
if self.top_k > 0:
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': self.top_k,
'score_threshold': self.score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
# retrieval source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval_model source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
# hybrid search: rerank after all documents have been searched
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
)
documents = hybrid_rerank.rerank(query, documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
self.top_k)
else:
documents = []
hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
......
from core.index.vector_index.milvus import Milvus
from core.vector_store.vector.milvus import Milvus
class MilvusVectorStore(Milvus):
......
......@@ -4,7 +4,7 @@ from langchain.schema import Document
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
from qdrant_client.local.qdrant_local import QdrantLocal
from core.index.vector_index.qdrant import Qdrant
from core.vector_store.vector.qdrant import Qdrant
class QdrantVectorStore(Qdrant):
......@@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant):
if isinstance(self.client, QdrantLocal):
self.client = cast(QdrantLocal, self.client)
self.client._load()
......@@ -28,7 +28,7 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType
from qdrant_client.http.models import PayloadSchemaType, FilterSelector, TextIndexParams, TokenizerType, TextIndexType
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
......@@ -189,14 +189,25 @@ class Qdrant(VectorStore):
texts, metadatas, ids, batch_size
):
self.client.upsert(
collection_name=self.collection_name, points=points, **kwargs
collection_name=self.collection_name, points=points
)
added_ids.extend(batch_ids)
# if is new collection, create payload index on group_id
if self.is_new_collection:
# create payload index
self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
self.client.create_payload_index(self.collection_name, self.content_payload_key,
field_schema=text_index_params)
return added_ids
@sync_call_fallback
......@@ -600,7 +611,7 @@ class Qdrant(VectorStore):
limit=k,
offset=offset,
with_payload=True,
with_vectors=True, # Langchain does not expect vectors to be returned
with_vectors=True,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
......@@ -615,6 +626,39 @@ class Qdrant(VectorStore):
for result in results
]
def similarity_search_by_bm25(
self,
filter: Optional[MetadataFilter] = None,
k: int = 4
) -> List[Document]:
"""Return docs most similar by bm25.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
Returns:
List of documents most similar to the query text and distance for each.
"""
response = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=filter,
limit=k,
with_payload=True,
with_vectors=True
)
results = response[0]
documents = []
for result in results:
if result:
documents.append(self._document_from_scored_point(
result, self.content_payload_key, self.metadata_payload_key
))
return documents
@sync_call_fallback
async def asimilarity_search_with_score_by_vector(
self,
......
"""Wrapper around weaviate vector database."""
from __future__ import annotations
import datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
from uuid import uuid4
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
def _default_schema(index_name: str) -> Dict:
return {
"class": index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
def _create_weaviate_client(**kwargs: Any) -> Any:
client = kwargs.get("client")
if client is not None:
return client
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
try:
# the weaviate api key param should not be mandatory
weaviate_api_key = get_from_dict_or_env(
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
)
except ValueError:
weaviate_api_key = None
try:
import weaviate
except ImportError:
raise ValueError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`"
)
auth = (
weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
if weaviate_api_key is not None
else None
)
client = weaviate.Client(weaviate_url, auth_client_secret=auth)
return client
def _default_score_normalizer(val: float) -> float:
return 1 - 1 / (1 + np.exp(val))
def _json_serializable(value: Any) -> Any:
if isinstance(value, datetime.datetime):
return value.isoformat()
return value
class Weaviate(VectorStore):
"""Wrapper around Weaviate vector database.
To use, you should have the ``weaviate-client`` python package installed.
Example:
.. code-block:: python
import weaviate
from langchain.vectorstores import Weaviate
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
weaviate = Weaviate(client, index_name, text_key)
"""
def __init__(
self,
client: Any,
index_name: str,
text_key: str,
embedding: Optional[Embeddings] = None,
attributes: Optional[List[str]] = None,
relevance_score_fn: Optional[
Callable[[float], float]
] = _default_score_normalizer,
by_text: bool = True,
):
"""Initialize with Weaviate client."""
try:
import weaviate
except ImportError:
raise ValueError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
if not isinstance(client, weaviate.Client):
raise ValueError(
f"client should be an instance of weaviate.Client, got {type(client)}"
)
self._client = client
self._index_name = index_name
self._embedding = embedding
self._text_key = text_key
self._query_attrs = [self._text_key]
self.relevance_score_fn = relevance_score_fn
self._by_text = by_text
if attributes is not None:
self._query_attrs.extend(attributes)
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return (
self.relevance_score_fn
if self.relevance_score_fn
else _default_score_normalizer
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Upload texts with metadata (properties) to Weaviate."""
from weaviate.util import get_valid_uuid
ids = []
embeddings: Optional[List[List[float]]] = None
if self._embedding:
if not isinstance(texts, list):
texts = list(texts)
embeddings = self._embedding.embed_documents(texts)
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {self._text_key: text}
if metadatas is not None:
for key, val in metadatas[i].items():
data_properties[key] = _json_serializable(val)
# Allow for ids (consistent w/ other methods)
# # Or uuids (backwards compatble w/ existing arg)
# If the UUID of one of the objects already exists
# then the existing object will be replaced by the new object.
_id = get_valid_uuid(uuid4())
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
elif "ids" in kwargs:
_id = kwargs["ids"][i]
batch.add_data_object(
data_object=data_properties,
class_name=self._index_name,
uuid=_id,
vector=embeddings[i] if embeddings else None,
)
ids.append(_id)
return ids
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
if self._by_text:
return self.similarity_search_by_text(query, k, **kwargs)
else:
if self._embedding is None:
raise ValueError(
"_embedding cannot be None for similarity_search when "
"_by_text=False"
)
embedding = self._embedding.embed_query(query)
return self.similarity_search_by_vector(embedding, k, **kwargs)
def similarity_search_by_text(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
content: Dict[str, Any] = {"concepts": [query]}
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
result = query_obj.with_near_text(content).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
docs.append(Document(page_content=text, metadata=res))
return docs
def similarity_search_by_bm25(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
content: Dict[str, Any] = {"concepts": [query]}
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
result = query_obj.with_bm25(query=content).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
docs.append(Document(page_content=text, metadata=res))
return docs
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
vector = {"vector": embedding}
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
result = query_obj.with_near_vector(vector).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
docs.append(Document(page_content=text, metadata=res))
return docs
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._embedding is not None:
embedding = self._embedding.embed_query(query)
else:
raise ValueError(
"max_marginal_relevance_search requires a suitable Embeddings object"
)
return self.max_marginal_relevance_search_by_vector(
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
vector = {"vector": embedding}
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
results = (
query_obj.with_additional("vector")
.with_near_vector(vector)
.with_limit(fetch_k)
.do()
)
payload = results["data"]["Get"][self._index_name]
embeddings = [result["_additional"]["vector"] for result in payload]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
docs = []
for idx in mmr_selected:
text = payload[idx].pop(self._text_key)
payload[idx].pop("_additional")
meta = payload[idx]
docs.append(Document(page_content=text, metadata=meta))
return docs
def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""
Return list of documents most similar to the query
text and cosine distance in float for each.
Lower score represents more similarity.
"""
if self._embedding is None:
raise ValueError(
"_embedding cannot be None for similarity_search_with_score"
)
content: Dict[str, Any] = {"concepts": [query]}
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(self._index_name, self._query_attrs)
embedded_query = self._embedding.embed_query(query)
if not self._by_text:
vector = {"vector": embedded_query}
result = (
query_obj.with_near_vector(vector)
.with_limit(k)
.with_additional("vector")
.do()
)
else:
result = (
query_obj.with_near_text(content)
.with_limit(k)
.with_additional("vector")
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
score = np.dot(res["_additional"]["vector"], embedded_query)
docs_and_scores.append((Document(page_content=text, metadata=res), score))
return docs_and_scores
@classmethod
def from_texts(
cls: Type[Weaviate],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> Weaviate:
"""Construct Weaviate wrapper from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
2. Creates a new index for the embeddings in the Weaviate instance.
3. Adds the documents to the newly created Weaviate index.
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain.vectorstores.weaviate import Weaviate
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
weaviate = Weaviate.from_texts(
texts,
embeddings,
weaviate_url="http://localhost:8080"
)
"""
client = _create_weaviate_client(**kwargs)
from weaviate.util import get_valid_uuid
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
embeddings = embedding.embed_documents(texts) if embedding else None
text_key = "text"
schema = _default_schema(index_name)
attributes = list(metadatas[0].keys()) if metadatas else None
# check whether the index already exists
if not client.schema.contains(schema):
client.schema.create_class(schema)
with client.batch as batch:
for i, text in enumerate(texts):
data_properties = {
text_key: text,
}
if metadatas is not None:
for key in metadatas[i].keys():
data_properties[key] = metadatas[i][key]
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
# if an embedding strategy is not provided, we let
# weaviate create the embedding. Note that this will only
# work if weaviate has been installed with a vectorizer module
# like text2vec-contextionary for example
params = {
"uuid": _id,
"data_object": data_properties,
"class_name": index_name,
}
if embeddings is not None:
params["vector"] = embeddings[i]
batch.add_data_object(**params)
batch.flush()
relevance_score_fn = kwargs.get("relevance_score_fn")
by_text: bool = kwargs.get("by_text", False)
return cls(
client,
index_name,
text_key,
embedding=embedding,
attributes=attributes,
relevance_score_fn=relevance_score_fn,
by_text=by_text,
)
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
"""Delete by vector IDs.
Args:
ids: List of ids to delete.
"""
if ids is None:
raise ValueError("No ids provided to delete.")
# TODO: Check if this can be done in bulk
for id in ids:
self._client.data_object.delete(uuid=id)
......@@ -12,6 +12,21 @@ dataset_fields = {
'created_at': TimestampField,
}
reranking_model_fields = {
'reranking_provider_name': fields.String,
'reranking_model_name': fields.String
}
dataset_retrieval_model_fields = {
'search_method': fields.String,
'reranking_enable': fields.Boolean,
'reranking_model': fields.Nested(reranking_model_fields),
'top_k': fields.Integer,
'score_threshold_enable': fields.Boolean,
'score_threshold': fields.Float
}
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
......@@ -29,7 +44,8 @@ dataset_detail_fields = {
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
'embedding_available': fields.Boolean,
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields)
}
dataset_query_detail_fields = {
......@@ -41,3 +57,5 @@ dataset_query_detail_fields = {
"created_by": fields.String,
"created_at": TimestampField
}
"""add-dataset-retrival-model
Revision ID: fca025d3b60f
Revises: b3a09c049e8e
Create Date: 2023-11-03 13:08:23.246396
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'fca025d3b60f'
down_revision = '8fe468ba0ca5'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('sessions')
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('retrieval_model', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
batch_op.create_index('retrieval_model_idx', ['retrieval_model'], unique=False, postgresql_using='gin')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_index('retrieval_model_idx', postgresql_using='gin')
batch_op.drop_column('retrieval_model')
op.create_table('sessions',
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('session_id', sa.VARCHAR(length=255), autoincrement=False, nullable=True),
sa.Column('data', postgresql.BYTEA(), autoincrement=False, nullable=True),
sa.Column('expiry', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
sa.PrimaryKeyConstraint('id', name='sessions_pkey'),
sa.UniqueConstraint('session_id', name='sessions_session_id_key')
)
# ### end Alembic commands ###
......@@ -3,7 +3,7 @@ import pickle
from json import JSONDecodeError
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import UUID, JSONB
from extensions.ext_database import db
from models.account import Account
......@@ -15,6 +15,7 @@ class Dataset(db.Model):
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_pkey'),
db.Index('dataset_tenant_idx', 'tenant_id'),
db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
)
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy']
......@@ -39,7 +40,7 @@ class Dataset(db.Model):
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
@property
def dataset_keyword_table(self):
......@@ -93,6 +94,20 @@ class Dataset(db.Model):
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
.filter(Document.dataset_id == self.id).scalar()
@property
def retrieval_model_dict(self):
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
return self.retrieval_model if self.retrieval_model else default_retrieval_model
class DatasetProcessRule(db.Model):
__tablename__ = 'dataset_process_rules'
......@@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model):
],
'segmentation': {
'delimiter': '\n',
'max_tokens': 1000
'max_tokens': 512
}
}
......@@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model):
model_name = db.Column(db.String(40), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
......@@ -160,7 +160,13 @@ class AppModelConfig(db.Model):
@property
def dataset_configs_dict(self) -> dict:
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
if self.dataset_configs:
dataset_configs = json.loads(self.dataset_configs)
if 'retrieval_model' not in dataset_configs:
return {'retrieval_model': 'single'}
else:
return dataset_configs
return {'retrieval_model': 'single'}
@property
def file_upload_dict(self) -> dict:
......
......@@ -23,7 +23,6 @@ boto3==1.28.17
tenacity==8.2.2
cachetools~=5.3.0
weaviate-client~=3.21.0
qdrant_client~=1.1.6
mailchimp-transactional~=1.0.50
scikit-learn==1.2.2
sentry-sdk[flask]~=1.21.1
......@@ -54,3 +53,5 @@ safetensors==0.3.2
zhipuai==1.0.7
werkzeug==2.3.7
pymilvus==2.3.0
qdrant-client==1.6.4
cohere~=4.32
\ No newline at end of file
......@@ -470,7 +470,16 @@ class AppModelConfigService:
# dataset_configs
if 'dataset_configs' not in config or not config["dataset_configs"]:
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
config["dataset_configs"] = {'retrieval_model': 'single'}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if config["dataset_configs"]['retrieval_model'] == 'multiple':
if not config["dataset_configs"]['reranking_model']:
raise ValueError("reranking_model has not been set")
if not isinstance(config["dataset_configs"]['reranking_model'], dict):
raise ValueError("reranking_model must be of object type")
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
......
......@@ -173,6 +173,9 @@ class DatasetService:
filtered_data['updated_by'] = user.id
filtered_data['updated_at'] = datetime.datetime.now()
# update Retrieval model
filtered_data['retrieval_model'] = data['retrieval_model']
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
......@@ -473,7 +476,19 @@ class DocumentService:
embedding_model.name
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
......@@ -733,6 +748,7 @@ class DocumentService:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None
dataset_collection_binding_id = None
retrieval_model = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
......@@ -742,6 +758,20 @@ class DocumentService:
embedding_model.name
)
dataset_collection_binding_id = dataset_collection_binding.id
if 'retrieval_model' in document_data and document_data['retrieval_model']:
retrieval_model = document_data['retrieval_model']
else:
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
retrieval_model = default_retrieval_model
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
......@@ -751,7 +781,8 @@ class DocumentService:
created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
collection_binding_id=dataset_collection_binding_id
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model
)
db.session.add(dataset)
......
import json
import logging
import threading
import time
from typing import List
......@@ -9,16 +11,26 @@ from langchain.schema import Document
from sklearn.manifold import TSNE
from core.embedding.cached_embedding import CacheEmbedding
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return {
"query": {
......@@ -28,31 +40,68 @@ class HitTestingService:
"records": []
}
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
all_documents = []
threads = []
# retrieval_model source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
'top_k': retrieval_model['top_k'],
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
)
all_documents = hybrid_rerank.rerank(query, all_documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
retrieval_model['top_k'])
start = time.perf_counter()
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 10,
'filter': {
'group_id': [dataset.id]
}
}
)
end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
......@@ -67,7 +116,7 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(dataset, embeddings, query, documents)
return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
......@@ -99,7 +148,7 @@ class HitTestingService:
record = {
"segment": segment,
"score": document.metadata['score'],
"score": document.metadata.get('score', None),
"tsne_position": tsne_position_data[i]
}
......@@ -136,3 +185,11 @@ class HitTestingService:
tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})
return tsne_position_data
@classmethod
def hit_testing_args_check(cls, args):
query = args['query']
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')
from typing import Optional
from flask import current_app, Flask
from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from models.dataset import Dataset
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class RetrievalService:
@classmethod
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': top_k,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
if reranking_model and search_method == 'semantic_search':
rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=reranking_model['reranking_provider_name'],
model_name=reranking_model['reranking_model_name']
)
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
else:
all_documents.extend(documents)
@classmethod
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search_by_full_text_index(
query,
search_type='similarity_score_threshold',
top_k=top_k
)
if documents:
if reranking_model and search_method == 'full_text_search':
rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=reranking_model['reranking_provider_name'],
model_name=reranking_model['reranking_model_name']
)
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
else:
all_documents.extend(documents)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment