Unverified Commit 4588831b authored by Jyong's avatar Jyong Committed by GitHub

Feat/add retriever rerank (#1560)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent a4f37220
......@@ -8,6 +8,8 @@ import time
import uuid
import click
import qdrant_client
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
......@@ -484,6 +486,38 @@ def normalization_collections():
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
def add_qdrant_full_text_index():
click.echo(click.style('Start add full text index.', fg='green'))
binds = db.session.query(DatasetCollectionBinding).all()
if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
qdrant_url = current_app.config['QDRANT_URL']
qdrant_api_key = current_app.config['QDRANT_API_KEY']
client = qdrant_client.QdrantClient(
qdrant_url,
api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance
)
for bind in binds:
try:
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
client.create_payload_index(bind.collection_name, 'page_content',
field_schema=text_index_params)
except Exception as e:
click.echo(
click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(
'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
fg='green'))
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context():
try:
......@@ -647,10 +681,10 @@ def update_app_model_configs(batch_size):
pbar.update(len(data_batch))
@click.command('migrate_default_input_to_dataset_query_variable')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def migrate_default_input_to_dataset_query_variable(batch_size):
click.secho("Starting...", fg='green')
total_records = db.session.query(AppModelConfig) \
......@@ -658,13 +692,13 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
num_batches = (total_records + batch_size - 1) // batch_size
with tqdm(total=total_records, desc="Migrating Data") as pbar:
for i in range(num_batches):
offset = i * batch_size
......@@ -697,14 +731,14 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
for form in user_input_form:
paragraph = form.get('paragraph')
if paragraph \
and paragraph.get('variable') == 'query':
data.dataset_query_variable = 'query'
break
and paragraph.get('variable') == 'query':
data.dataset_query_variable = 'query'
break
if paragraph \
and paragraph.get('variable') == 'default_input':
data.dataset_query_variable = 'default_input'
break
and paragraph.get('variable') == 'default_input':
data.dataset_query_variable = 'default_input'
break
db.session.commit()
......@@ -712,7 +746,7 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
......@@ -731,3 +765,4 @@ def register_commands(app):
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index)
......@@ -170,6 +170,7 @@ class DatasetApi(Resource):
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
args = parser.parse_args()
# The role of the current user in the ta table must be admin or owner
......@@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource):
resource_type = 'dataset'
@setup_required
@login_required
@account_initialization_required
......@@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource):
}
class DatasetRetrievalSettingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']
if vector_type == 'milvus':
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")
class DatasetRetrievalSettingMockApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, vector_type):
if vector_type == 'milvus':
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
......@@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
......@@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']:
......@@ -263,6 +265,8 @@ class DatasetInitApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
if args['indexing_technique'] == 'high_quality':
try:
......
......@@ -42,19 +42,18 @@ class HitTestingApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('query', type=str, location='json')
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
args = parser.parse_args()
query = args['query']
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')
HitTestingService.hit_testing_args_check(args)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=query,
query=args['query'],
account=current_user,
limit=10,
retrieval_model=args['retrieval_model'],
limit=10
)
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
......
......@@ -19,7 +19,7 @@ class DefaultModelApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
......@@ -71,19 +71,18 @@ class DefaultModelApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id,
model_type=args['model_type'],
provider_name=args['provider_name'],
model_name=args['model_name']
)
model_settings = args['model_settings']
for model_setting in model_settings:
provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id,
model_type=model_setting['model_type'],
provider_name=model_setting['provider_name'],
model_name=model_setting['model_name']
)
return {'result': 'success'}
......
......@@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource):
location='json')
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
......@@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
......
......@@ -14,7 +14,6 @@ from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
......@@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
......
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
......@@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
......
......@@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
......@@ -78,7 +79,7 @@ class AgentExecutor:
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
tools=self.configuration.tools,
......@@ -86,7 +87,7 @@ class AgentExecutor:
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
tools=self.configuration.tools,
......
......@@ -10,8 +10,7 @@ from models.dataset import DocumentSegment
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
self.dataset_id = dataset_id
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
self.conversation_message_task = conversation_message_task
def on_tool_end(self, documents: List[Document]) -> None:
......@@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler:
# add hit count to document segment
db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.index_node_id == doc_id
).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
......
......@@ -127,6 +127,7 @@ class Completion:
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback,
tenant_id=app.tenant_id,
retriever_from=retriever_from
)
......
......@@ -3,7 +3,7 @@ from pathlib import Path
from typing import List, Union, Optional
import requests
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
......@@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
return cls.load_from_file(file_path, return_text, upload_file)
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
......@@ -44,24 +44,34 @@ class FileExtractor:
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
upload_file: Optional[UploadFile] = None,
is_automatic: bool = False) -> Union[List[Document] | str]:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension == '.docx':
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
if is_automatic:
loader = UnstructuredFileLoader(
file_path, strategy="hi_res", mode="elements"
)
# loader = UnstructuredAPIFileLoader(
# file_path=filenames[0],
# api_key="FAKE_API_KEY",
# )
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension == '.docx':
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
......@@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex):
def _get_vector_store_class(self) -> type:
raise NotImplementedError
@abstractmethod
def search_by_full_text_index(
self, query: str,
**kwargs: Any
) -> List[Document]:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
......
from typing import Optional, cast
from typing import cast, Any, List
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore, milvus
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
from models.dataset import Dataset
class MilvusConfig(BaseModel):
......@@ -74,7 +72,7 @@ class MilvusVectorIndex(BaseVectorIndex):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
'params': {"M": 8, "efConstruction": 64}
}
self._vector_store = MilvusVectorStore.from_documents(
texts,
......@@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex):
),
],
))
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
# milvus/zilliz doesn't support bm25 search
return []
......@@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
return vector_store.similarity_search_by_bm25(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
],
), kwargs.get('top_k', 2))
from typing import Optional, cast
from typing import Optional, cast, Any, List
import requests
import weaviate
......@@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel):
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
......@@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex):
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
......@@ -49,14 +49,14 @@ class IndexingRunner:
if not dataset:
raise ValueError("no dataset found")
# load file
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# load file
text_docs = self._load_data(dataset_document)
# get splitter
splitter = self._get_splitter(processing_rule)
......@@ -380,7 +380,7 @@ class IndexingRunner:
"preview": preview_texts
}
def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
# load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return []
......@@ -396,7 +396,7 @@ class IndexingRunner:
one_or_none()
if file_detail:
text_docs = FileExtractor.load(file_detail)
text_docs = FileExtractor.load(file_detail, is_automatic=False)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
......
......@@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
......@@ -140,6 +141,44 @@ class ModelFactory:
name=model_name
)
@classmethod
def get_reranking_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseReranking]:
"""
get reranking model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Reranking Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init reranking model
model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod
def get_speech2text_model(cls,
tenant_id: str,
......
......@@ -72,6 +72,9 @@ class ModelProviderFactory:
elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider
elif provider_name == 'cohere':
from core.model_providers.providers.cohere_provider import CohereProvider
return CohereProvider
else:
raise NotImplementedError
......
......@@ -17,7 +17,7 @@ class ModelType(enum.Enum):
IMAGE = 'image'
VIDEO = 'video'
MODERATION = 'moderation'
RERANKING = 'reranking'
@staticmethod
def value_of(value):
for member in ModelType:
......
from abc import abstractmethod
from typing import Any, Optional, List
from langchain.schema import Document
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
import logging
logger = logging.getLogger(__name__)
class BaseReranking(BaseProviderModel):
name: str
type: ModelType = ModelType.RERANKING
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
@property
def base_model_name(self) -> str:
"""
get base model name
:return: str
"""
return self.name
@abstractmethod
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError
import logging
from typing import Optional, List
import cohere
import openai
from langchain.schema import Document
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.providers.base import BaseModelProvider
class CohereReranking(BaseReranking):
def __init__(self, model_provider: BaseModelProvider, name: str):
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = cohere.Client(self.credentials.get('api_key'))
super().__init__(model_provider, client, name)
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
docs = []
doc_id = []
for document in documents:
if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id'])
docs.append(document.page_content)
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
rerank_documents = []
for idx, result in enumerate(results):
# format document
rerank_document = Document(
page_content=result.document['text'],
metadata={
"doc_id": documents[result.index].metadata['doc_id'],
"doc_hash": documents[result.index].metadata['doc_hash'],
"document_id": documents[result.index].metadata['document_id'],
"dataset_id": documents[result.index].metadata['dataset_id'],
'score': result.relevance_score
}
)
# score threshold check
if score_threshold is not None:
if result.relevance_score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return rerank_documents
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.reranking.cohere_reranking import CohereReranking
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
class CohereProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'cohere'
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.RERANKING:
return [
{
'id': 'rerank-english-v2.0',
'name': 'rerank-english-v2.0'
},
{
'id': 'rerank-multilingual-v2.0',
'name': 'rerank-multilingual-v2.0'
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.RERANKING:
model_class = CohereReranking
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Cohere api_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
}
# todo validate
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)
......@@ -13,5 +13,6 @@
"huggingface_hub",
"xinference",
"openllm",
"localai"
"localai",
"cohere"
]
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
}
\ No newline at end of file
from typing import Optional
import json
import threading
from typing import Optional, List
from flask import Flask
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from pydantic import BaseModel, Field
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
......@@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.tool.current_datetime_tool import DatetimeTool
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
......@@ -25,6 +32,16 @@ from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class OrchestratorRuleParser:
"""Parse the orchestrator rule to entities."""
......@@ -34,7 +51,7 @@ class OrchestratorRuleParser:
self.app_model_config = app_model_config
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str,
retriever_from: str = 'dev') -> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict:
return None
......@@ -101,7 +118,8 @@ class OrchestratorRuleParser:
rest_tokens=rest_tokens,
return_resource=return_resource,
retriever_from=retriever_from,
dataset_configs=dataset_configs
dataset_configs=dataset_configs,
tenant_id=tenant_id
)
if len(tools) == 0:
......@@ -123,7 +141,7 @@ class OrchestratorRuleParser:
return chain
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
......@@ -132,6 +150,7 @@ class OrchestratorRuleParser:
:return:
"""
tools = []
dataset_tools = []
for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0]
......@@ -140,7 +159,7 @@ class OrchestratorRuleParser:
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
dataset_tools.append(tool_config)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
elif tool_type == "google_search":
......@@ -156,57 +175,81 @@ class OrchestratorRuleParser:
else:
tool.callbacks = callbacks
tools.append(tool)
# format dataset tool
if len(dataset_tools) > 0:
dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs)
if dataset_retriever_tools:
tools.extend(dataset_retriever_tools)
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
dataset_configs: dict, rest_tokens: int,
def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask,
return_resource: bool = False, retriever_from: str = 'dev',
**kwargs) \
-> Optional[BaseTool]:
-> Optional[List[BaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param dataset_configs:
:param tool_configs:
:param conversation_message_task:
:param return_resource:
:param retriever_from:
:return:
"""
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
if not dataset:
return None
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
top_k = dataset_configs.get("top_k", 2)
# dynamically adjust top_k when the remaining token number is not enough to support top_k
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
dataset_configs = kwargs['dataset_configs']
retrieval_model = dataset_configs.get('retrieval_model', 'single')
tools = []
dataset_ids = []
tenant_id = None
for tool_config in tool_configs:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get('dataset').get("id")
).first()
score_threshold = None
score_threshold_config = dataset_configs.get("score_threshold")
if score_threshold_config and score_threshold_config.get("enable"):
score_threshold = score_threshold_config.get("value")
if not dataset:
return None
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
retriever_from=retriever_from
)
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
dataset_ids.append(dataset.id)
if retrieval_model == 'single':
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
top_k = retrieval_model['top_k']
# dynamically adjust top_k when the remaining token number is not enough to support top_k
# top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None
score_threshold_enable = retrieval_model.get("score_threshold_enable")
if score_threshold_enable:
score_threshold = retrieval_model.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
retriever_from=retriever_from
)
tools.append(tool)
if retrieval_model == 'multiple':
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=dataset_ids,
tenant_id=kwargs['tenant_id'],
top_k=dataset_configs.get('top_k', 2),
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
retriever_from=retriever_from,
reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'),
reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name')
)
tools.append(tool)
return tool
return tools
def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
"""
......
This diff is collapsed.
import json
from typing import Type, Optional
import threading
from typing import Type, Optional, List
from flask import current_app
from langchain.tools import BaseTool
......@@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, Document
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class DatasetRetrieverToolInput(BaseModel):
......@@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool):
).first()
if not dataset:
return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
return ''
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
......@@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool):
return ''
embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = []
threads = []
if self.top_k > 0:
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': self.top_k,
'score_threshold': self.score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
# retrieval source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval_model source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
# hybrid search: rerank after all documents have been searched
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
)
documents = hybrid_rerank.rerank(query, documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
self.top_k)
else:
documents = []
hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
......
from core.index.vector_index.milvus import Milvus
from core.vector_store.vector.milvus import Milvus
class MilvusVectorStore(Milvus):
......
......@@ -4,7 +4,7 @@ from langchain.schema import Document
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
from qdrant_client.local.qdrant_local import QdrantLocal
from core.index.vector_index.qdrant import Qdrant
from core.vector_store.vector.qdrant import Qdrant
class QdrantVectorStore(Qdrant):
......@@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant):
if isinstance(self.client, QdrantLocal):
self.client = cast(QdrantLocal, self.client)
self.client._load()
......@@ -28,7 +28,7 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType
from qdrant_client.http.models import PayloadSchemaType, FilterSelector, TextIndexParams, TokenizerType, TextIndexType
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
......@@ -189,14 +189,25 @@ class Qdrant(VectorStore):
texts, metadatas, ids, batch_size
):
self.client.upsert(
collection_name=self.collection_name, points=points, **kwargs
collection_name=self.collection_name, points=points
)
added_ids.extend(batch_ids)
# if is new collection, create payload index on group_id
if self.is_new_collection:
# create payload index
self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
self.client.create_payload_index(self.collection_name, self.content_payload_key,
field_schema=text_index_params)
return added_ids
@sync_call_fallback
......@@ -600,7 +611,7 @@ class Qdrant(VectorStore):
limit=k,
offset=offset,
with_payload=True,
with_vectors=True, # Langchain does not expect vectors to be returned
with_vectors=True,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
......@@ -615,6 +626,39 @@ class Qdrant(VectorStore):
for result in results
]
def similarity_search_by_bm25(
self,
filter: Optional[MetadataFilter] = None,
k: int = 4
) -> List[Document]:
"""Return docs most similar by bm25.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
Returns:
List of documents most similar to the query text and distance for each.
"""
response = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=filter,
limit=k,
with_payload=True,
with_vectors=True
)
results = response[0]
documents = []
for result in results:
if result:
documents.append(self._document_from_scored_point(
result, self.content_payload_key, self.metadata_payload_key
))
return documents
@sync_call_fallback
async def asimilarity_search_with_score_by_vector(
self,
......
This diff is collapsed.
......@@ -12,6 +12,21 @@ dataset_fields = {
'created_at': TimestampField,
}
reranking_model_fields = {
'reranking_provider_name': fields.String,
'reranking_model_name': fields.String
}
dataset_retrieval_model_fields = {
'search_method': fields.String,
'reranking_enable': fields.Boolean,
'reranking_model': fields.Nested(reranking_model_fields),
'top_k': fields.Integer,
'score_threshold_enable': fields.Boolean,
'score_threshold': fields.Float
}
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
......@@ -29,7 +44,8 @@ dataset_detail_fields = {
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
'embedding_available': fields.Boolean,
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields)
}
dataset_query_detail_fields = {
......@@ -41,3 +57,5 @@ dataset_query_detail_fields = {
"created_by": fields.String,
"created_at": TimestampField
}
"""add-dataset-retrival-model
Revision ID: fca025d3b60f
Revises: b3a09c049e8e
Create Date: 2023-11-03 13:08:23.246396
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'fca025d3b60f'
down_revision = '8fe468ba0ca5'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('sessions')
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('retrieval_model', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
batch_op.create_index('retrieval_model_idx', ['retrieval_model'], unique=False, postgresql_using='gin')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_index('retrieval_model_idx', postgresql_using='gin')
batch_op.drop_column('retrieval_model')
op.create_table('sessions',
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('session_id', sa.VARCHAR(length=255), autoincrement=False, nullable=True),
sa.Column('data', postgresql.BYTEA(), autoincrement=False, nullable=True),
sa.Column('expiry', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
sa.PrimaryKeyConstraint('id', name='sessions_pkey'),
sa.UniqueConstraint('session_id', name='sessions_session_id_key')
)
# ### end Alembic commands ###
......@@ -3,7 +3,7 @@ import pickle
from json import JSONDecodeError
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import UUID, JSONB
from extensions.ext_database import db
from models.account import Account
......@@ -15,6 +15,7 @@ class Dataset(db.Model):
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_pkey'),
db.Index('dataset_tenant_idx', 'tenant_id'),
db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
)
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy']
......@@ -39,7 +40,7 @@ class Dataset(db.Model):
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
@property
def dataset_keyword_table(self):
......@@ -93,6 +94,20 @@ class Dataset(db.Model):
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
.filter(Document.dataset_id == self.id).scalar()
@property
def retrieval_model_dict(self):
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
return self.retrieval_model if self.retrieval_model else default_retrieval_model
class DatasetProcessRule(db.Model):
__tablename__ = 'dataset_process_rules'
......@@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model):
],
'segmentation': {
'delimiter': '\n',
'max_tokens': 1000
'max_tokens': 512
}
}
......@@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model):
model_name = db.Column(db.String(40), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
......@@ -160,7 +160,13 @@ class AppModelConfig(db.Model):
@property
def dataset_configs_dict(self) -> dict:
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
if self.dataset_configs:
dataset_configs = json.loads(self.dataset_configs)
if 'retrieval_model' not in dataset_configs:
return {'retrieval_model': 'single'}
else:
return dataset_configs
return {'retrieval_model': 'single'}
@property
def file_upload_dict(self) -> dict:
......
......@@ -23,7 +23,6 @@ boto3==1.28.17
tenacity==8.2.2
cachetools~=5.3.0
weaviate-client~=3.21.0
qdrant_client~=1.1.6
mailchimp-transactional~=1.0.50
scikit-learn==1.2.2
sentry-sdk[flask]~=1.21.1
......@@ -53,4 +52,6 @@ xinference-client~=0.5.4
safetensors==0.3.2
zhipuai==1.0.7
werkzeug==2.3.7
pymilvus==2.3.0
\ No newline at end of file
pymilvus==2.3.0
qdrant-client==1.6.4
cohere~=4.32
\ No newline at end of file
......@@ -470,7 +470,16 @@ class AppModelConfigService:
# dataset_configs
if 'dataset_configs' not in config or not config["dataset_configs"]:
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
config["dataset_configs"] = {'retrieval_model': 'single'}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if config["dataset_configs"]['retrieval_model'] == 'multiple':
if not config["dataset_configs"]['reranking_model']:
raise ValueError("reranking_model has not been set")
if not isinstance(config["dataset_configs"]['reranking_model'], dict):
raise ValueError("reranking_model must be of object type")
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
......
......@@ -173,6 +173,9 @@ class DatasetService:
filtered_data['updated_by'] = user.id
filtered_data['updated_at'] = datetime.datetime.now()
# update Retrieval model
filtered_data['retrieval_model'] = data['retrieval_model']
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
......@@ -473,7 +476,19 @@ class DocumentService:
embedding_model.name
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
......@@ -733,6 +748,7 @@ class DocumentService:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None
dataset_collection_binding_id = None
retrieval_model = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
......@@ -742,6 +758,20 @@ class DocumentService:
embedding_model.name
)
dataset_collection_binding_id = dataset_collection_binding.id
if 'retrieval_model' in document_data and document_data['retrieval_model']:
retrieval_model = document_data['retrieval_model']
else:
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
retrieval_model = default_retrieval_model
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
......@@ -751,7 +781,8 @@ class DocumentService:
created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
collection_binding_id=dataset_collection_binding_id
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model
)
db.session.add(dataset)
......@@ -768,7 +799,7 @@ class DocumentService:
return dataset, documents, batch
@classmethod
def document_create_args_validate(cls, args: dict):
def document_create_args_validate(cls, args: dict):
if 'original_document_id' not in args or not args['original_document_id']:
DocumentService.data_source_args_validate(args)
DocumentService.process_rule_args_validate(args)
......
import json
import logging
import threading
import time
from typing import List
......@@ -9,16 +11,26 @@ from langchain.schema import Document
from sklearn.manifold import TSNE
from core.embedding.cached_embedding import CacheEmbedding
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return {
"query": {
......@@ -28,31 +40,68 @@ class HitTestingService:
"records": []
}
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get embedding model
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
all_documents = []
threads = []
# retrieval_model source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
'top_k': retrieval_model['top_k'],
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
)
all_documents = hybrid_rerank.rerank(query, all_documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
retrieval_model['top_k'])
start = time.perf_counter()
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 10,
'filter': {
'group_id': [dataset.id]
}
}
)
end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
......@@ -67,7 +116,7 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(dataset, embeddings, query, documents)
return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
......@@ -99,7 +148,7 @@ class HitTestingService:
record = {
"segment": segment,
"score": document.metadata['score'],
"score": document.metadata.get('score', None),
"tsne_position": tsne_position_data[i]
}
......@@ -136,3 +185,11 @@ class HitTestingService:
tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})
return tsne_position_data
@classmethod
def hit_testing_args_check(cls, args):
query = args['query']
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')
from typing import Optional
from flask import current_app, Flask
from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from models.dataset import Dataset
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class RetrievalService:
@classmethod
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': top_k,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
if reranking_model and search_method == 'semantic_search':
rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=reranking_model['reranking_provider_name'],
model_name=reranking_model['reranking_model_name']
)
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
else:
all_documents.extend(documents)
@classmethod
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search_by_full_text_index(
query,
search_type='similarity_score_threshold',
top_k=top_k
)
if documents:
if reranking_model and search_method == 'full_text_search':
rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=reranking_model['reranking_provider_name'],
model_name=reranking_model['reranking_model_name']
)
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
else:
all_documents.extend(documents)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment