Unverified Commit 4588831b authored by Jyong's avatar Jyong Committed by GitHub

Feat/add retriever rerank (#1560)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent a4f37220
...@@ -8,6 +8,8 @@ import time ...@@ -8,6 +8,8 @@ import time
import uuid import uuid
import click import click
import qdrant_client
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm from tqdm import tqdm
from flask import current_app, Flask from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
...@@ -484,6 +486,38 @@ def normalization_collections(): ...@@ -484,6 +486,38 @@ def normalization_collections():
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green')) click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
def add_qdrant_full_text_index():
click.echo(click.style('Start add full text index.', fg='green'))
binds = db.session.query(DatasetCollectionBinding).all()
if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
qdrant_url = current_app.config['QDRANT_URL']
qdrant_api_key = current_app.config['QDRANT_API_KEY']
client = qdrant_client.QdrantClient(
qdrant_url,
api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance
)
for bind in binds:
try:
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
client.create_payload_index(bind.collection_name, 'page_content',
field_schema=text_index_params)
except Exception as e:
click.echo(
click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(
'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
fg='green'))
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list): def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context(): with flask_app.app_context():
try: try:
...@@ -647,10 +681,10 @@ def update_app_model_configs(batch_size): ...@@ -647,10 +681,10 @@ def update_app_model_configs(batch_size):
pbar.update(len(data_batch)) pbar.update(len(data_batch))
@click.command('migrate_default_input_to_dataset_query_variable') @click.command('migrate_default_input_to_dataset_query_variable')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def migrate_default_input_to_dataset_query_variable(batch_size): def migrate_default_input_to_dataset_query_variable(batch_size):
click.secho("Starting...", fg='green') click.secho("Starting...", fg='green')
total_records = db.session.query(AppModelConfig) \ total_records = db.session.query(AppModelConfig) \
...@@ -731,3 +765,4 @@ def register_commands(app): ...@@ -731,3 +765,4 @@ def register_commands(app):
app.cli.add_command(update_app_model_configs) app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections) app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable) app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index)
...@@ -170,6 +170,7 @@ class DatasetApi(Resource): ...@@ -170,6 +170,7 @@ class DatasetApi(Resource):
help='Invalid indexing technique.') help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=( parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.') 'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
...@@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource): ...@@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource): class DatasetApiDeleteApi(Resource):
resource_type = 'dataset' resource_type = 'dataset'
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
...@@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource): ...@@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource):
} }
class DatasetRetrievalSettingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']
if vector_type == 'milvus':
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")
class DatasetRetrievalSettingMockApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, vector_type):
if vector_type == 'milvus':
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")
api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
...@@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing ...@@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
...@@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource): ...@@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json') location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']: if not dataset.indexing_technique and not args['indexing_technique']:
...@@ -263,6 +265,8 @@ class DatasetInitApi(Resource): ...@@ -263,6 +265,8 @@ class DatasetInitApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json') location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
if args['indexing_technique'] == 'high_quality': if args['indexing_technique'] == 'high_quality':
try: try:
......
...@@ -42,19 +42,18 @@ class HitTestingApi(Resource): ...@@ -42,19 +42,18 @@ class HitTestingApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('query', type=str, location='json') parser.add_argument('query', type=str, location='json')
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
args = parser.parse_args() args = parser.parse_args()
query = args['query'] HitTestingService.hit_testing_args_check(args)
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')
try: try:
response = HitTestingService.retrieve( response = HitTestingService.retrieve(
dataset=dataset, dataset=dataset,
query=query, query=args['query'],
account=current_user, account=current_user,
limit=10, retrieval_model=args['retrieval_model'],
limit=10
) )
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
......
...@@ -19,7 +19,7 @@ class DefaultModelApi(Resource): ...@@ -19,7 +19,7 @@ class DefaultModelApi(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False, parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args') choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
args = parser.parse_args() args = parser.parse_args()
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
...@@ -71,18 +71,17 @@ class DefaultModelApi(Resource): ...@@ -71,18 +71,17 @@ class DefaultModelApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
provider_service = ProviderService() provider_service = ProviderService()
model_settings = args['model_settings']
for model_setting in model_settings:
provider_service.update_default_model_of_model_type( provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
model_type=args['model_type'], model_type=model_setting['model_type'],
provider_name=args['provider_name'], provider_name=model_setting['provider_name'],
model_name=args['model_name'] model_name=model_setting['model_name']
) )
return {'result': 'success'} return {'result': 'success'}
......
...@@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource): ...@@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource):
location='json') location='json')
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json') location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
...@@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource): ...@@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json') location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
......
...@@ -14,7 +14,6 @@ from pydantic import root_validator ...@@ -14,7 +14,6 @@ from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent): class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
...@@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": ''}, log='') return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1: elif len(self.tools) == 1:
tool = next(iter(self.tools)) tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']}) rst = tool.run(tool_input={'query': kwargs['input']})
# output = '' # output = ''
# rst_json = json.loads(rst) # rst_json = json.loads(rst)
......
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
...@@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): ...@@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return AgentFinish(return_values={"output": ''}, log='') return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1: elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools)) tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']}) rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst) return AgentFinish(return_values={"output": rst}, log=rst)
......
...@@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor ...@@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor
from core.helper import moderation from core.helper import moderation
from core.model_providers.error import LLMError from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
...@@ -78,7 +79,7 @@ class AgentExecutor: ...@@ -78,7 +79,7 @@ class AgentExecutor:
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.ROUTER: elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools( agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
tools=self.configuration.tools, tools=self.configuration.tools,
...@@ -86,7 +87,7 @@ class AgentExecutor: ...@@ -86,7 +87,7 @@ class AgentExecutor:
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance, model_instance=self.configuration.model_instance,
tools=self.configuration.tools, tools=self.configuration.tools,
......
...@@ -10,8 +10,7 @@ from models.dataset import DocumentSegment ...@@ -10,8 +10,7 @@ from models.dataset import DocumentSegment
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool.""" """Callback handler for dataset tool."""
def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
self.dataset_id = dataset_id
self.conversation_message_task = conversation_message_task self.conversation_message_task = conversation_message_task
def on_tool_end(self, documents: List[Document]) -> None: def on_tool_end(self, documents: List[Document]) -> None:
...@@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler: ...@@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler:
# add hit count to document segment # add hit count to document segment
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.index_node_id == doc_id DocumentSegment.index_node_id == doc_id
).update( ).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
......
...@@ -127,6 +127,7 @@ class Completion: ...@@ -127,6 +127,7 @@ class Completion:
memory=memory, memory=memory,
rest_tokens=rest_tokens_for_context_and_memory, rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback, chain_callback=chain_callback,
tenant_id=app.tenant_id,
retriever_from=retriever_from retriever_from=retriever_from
) )
......
...@@ -3,7 +3,7 @@ from pathlib import Path ...@@ -3,7 +3,7 @@ from pathlib import Path
from typing import List, Union, Optional from typing import List, Union, Optional
import requests import requests
from langchain.document_loaders import TextLoader, Docx2txtLoader from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
from langchain.schema import Document from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader from core.data_loader.loader.csv_loader import CSVLoader
...@@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM ...@@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
class FileExtractor: class FileExtractor:
@classmethod @classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]: def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]:
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path) storage.download(upload_file.key, file_path)
return cls.load_from_file(file_path, return_text, upload_file) return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
@classmethod @classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]: def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
...@@ -44,10 +44,20 @@ class FileExtractor: ...@@ -44,10 +44,20 @@ class FileExtractor:
@classmethod @classmethod
def load_from_file(cls, file_path: str, return_text: bool = False, def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]: upload_file: Optional[UploadFile] = None,
is_automatic: bool = False) -> Union[List[Document] | str]:
input_file = Path(file_path) input_file = Path(file_path)
delimiter = '\n' delimiter = '\n'
file_extension = input_file.suffix.lower() file_extension = input_file.suffix.lower()
if is_automatic:
loader = UnstructuredFileLoader(
file_path, strategy="hi_res", mode="elements"
)
# loader = UnstructuredAPIFileLoader(
# file_path=filenames[0],
# api_key="FAKE_API_KEY",
# )
else:
if file_extension == '.xlsx': if file_extension == '.xlsx':
loader = ExcelLoader(file_path) loader = ExcelLoader(file_path)
elif file_extension == '.pdf': elif file_extension == '.pdf':
......
...@@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex): ...@@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex):
def _get_vector_store_class(self) -> type: def _get_vector_store_class(self) -> type:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def search_by_full_text_index(
self, query: str,
**kwargs: Any
) -> List[Document]:
raise NotImplementedError
def search( def search(
self, query: str, self, query: str,
**kwargs: Any **kwargs: Any
......
from typing import Optional, cast from typing import cast, Any, List
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever from langchain.schema import Document
from langchain.vectorstores import VectorStore, milvus from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore from models.dataset import Dataset
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class MilvusConfig(BaseModel): class MilvusConfig(BaseModel):
...@@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex): ...@@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex):
), ),
], ],
)) ))
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
# milvus/zilliz doesn't support bm25 search
return []
...@@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex): ...@@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return True return True
return False return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
return vector_store.similarity_search_by_bm25(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
],
), kwargs.get('top_k', 2))
from typing import Optional, cast from typing import Optional, cast, Any, List
import requests import requests
import weaviate import weaviate
...@@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel): ...@@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel):
class WeaviateVectorIndex(BaseVectorIndex): class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings) super().__init__(dataset, embeddings)
self._client = self._init_client(config) self._client = self._init_client(config)
...@@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex): ...@@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex):
return True return True
return False return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
...@@ -49,14 +49,14 @@ class IndexingRunner: ...@@ -49,14 +49,14 @@ class IndexingRunner:
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# load file
text_docs = self._load_data(dataset_document)
# get the process rule # get the process rule
processing_rule = db.session.query(DatasetProcessRule). \ processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first() first()
# load file
text_docs = self._load_data(dataset_document)
# get splitter # get splitter
splitter = self._get_splitter(processing_rule) splitter = self._get_splitter(processing_rule)
...@@ -380,7 +380,7 @@ class IndexingRunner: ...@@ -380,7 +380,7 @@ class IndexingRunner:
"preview": preview_texts "preview": preview_texts
} }
def _load_data(self, dataset_document: DatasetDocument) -> List[Document]: def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
# load file # load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]: if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return [] return []
...@@ -396,7 +396,7 @@ class IndexingRunner: ...@@ -396,7 +396,7 @@ class IndexingRunner:
one_or_none() one_or_none()
if file_detail: if file_detail:
text_docs = FileExtractor.load(file_detail) text_docs = FileExtractor.load(file_detail, is_automatic=False)
elif dataset_document.data_source_type == 'notion_import': elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document) loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load() text_docs = loader.load()
......
...@@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding ...@@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation.base import BaseModeration from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.models.speech2text.base import BaseSpeech2Text from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db from extensions.ext_database import db
from models.provider import TenantDefaultModel from models.provider import TenantDefaultModel
...@@ -140,6 +141,44 @@ class ModelFactory: ...@@ -140,6 +141,44 @@ class ModelFactory:
name=model_name name=model_name
) )
@classmethod
def get_reranking_model(cls,
tenant_id: str,
model_provider_name: Optional[str] = None,
model_name: Optional[str] = None) -> Optional[BaseReranking]:
"""
get reranking model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if model_provider_name is None and model_name is None:
default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
if not default_model:
raise LLMBadRequestError(f"Default model is not available. "
f"Please configure a Default Reranking Model "
f"in the Settings -> Model Provider.")
model_provider_name = default_model.provider_name
model_name = default_model.model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
if not model_provider:
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
# init reranking model
model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
return model_class(
model_provider=model_provider,
name=model_name
)
@classmethod @classmethod
def get_speech2text_model(cls, def get_speech2text_model(cls,
tenant_id: str, tenant_id: str,
......
...@@ -72,6 +72,9 @@ class ModelProviderFactory: ...@@ -72,6 +72,9 @@ class ModelProviderFactory:
elif provider_name == 'localai': elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider return LocalAIProvider
elif provider_name == 'cohere':
from core.model_providers.providers.cohere_provider import CohereProvider
return CohereProvider
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -17,7 +17,7 @@ class ModelType(enum.Enum): ...@@ -17,7 +17,7 @@ class ModelType(enum.Enum):
IMAGE = 'image' IMAGE = 'image'
VIDEO = 'video' VIDEO = 'video'
MODERATION = 'moderation' MODERATION = 'moderation'
RERANKING = 'reranking'
@staticmethod @staticmethod
def value_of(value): def value_of(value):
for member in ModelType: for member in ModelType:
......
from abc import abstractmethod
from typing import Any, Optional, List
from langchain.schema import Document
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
import logging
logger = logging.getLogger(__name__)
class BaseReranking(BaseProviderModel):
name: str
type: ModelType = ModelType.RERANKING
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
@property
def base_model_name(self) -> str:
"""
get base model name
:return: str
"""
return self.name
@abstractmethod
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError
import logging
from typing import Optional, List
import cohere
import openai
from langchain.schema import Document
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.providers.base import BaseModelProvider
class CohereReranking(BaseReranking):
def __init__(self, model_provider: BaseModelProvider, name: str):
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = cohere.Client(self.credentials.get('api_key'))
super().__init__(model_provider, client, name)
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
docs = []
doc_id = []
for document in documents:
if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id'])
docs.append(document.page_content)
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
rerank_documents = []
for idx, result in enumerate(results):
# format document
rerank_document = Document(
page_content=result.document['text'],
metadata={
"doc_id": documents[result.index].metadata['doc_id'],
"doc_hash": documents[result.index].metadata['doc_hash'],
"document_id": documents[result.index].metadata['document_id'],
"dataset_id": documents[result.index].metadata['dataset_id'],
'score': result.relevance_score
}
)
# score threshold check
if score_threshold is not None:
if result.relevance_score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return rerank_documents
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.reranking.cohere_reranking import CohereReranking
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
class CohereProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'cohere'
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.RERANKING:
return [
{
'id': 'rerank-english-v2.0',
'name': 'rerank-english-v2.0'
},
{
'id': 'rerank-multilingual-v2.0',
'name': 'rerank-multilingual-v2.0'
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.RERANKING:
model_class = CohereReranking
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Cohere api_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
}
# todo validate
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)
...@@ -13,5 +13,6 @@ ...@@ -13,5 +13,6 @@
"huggingface_hub", "huggingface_hub",
"xinference", "xinference",
"openllm", "openllm",
"localai" "localai",
"cohere"
] ]
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
}
\ No newline at end of file
from typing import Optional import json
import threading
from typing import Optional, List
from flask import Flask
from langchain import WikipediaAPIWrapper from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
...@@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory ...@@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.tool.current_datetime_tool import DatetimeTool from core.tool.current_datetime_tool import DatetimeTool
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
...@@ -25,6 +32,16 @@ from extensions.ext_database import db ...@@ -25,6 +32,16 @@ from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig from models.model import AppModelConfig
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class OrchestratorRuleParser: class OrchestratorRuleParser:
"""Parse the orchestrator rule to entities.""" """Parse the orchestrator rule to entities."""
...@@ -34,7 +51,7 @@ class OrchestratorRuleParser: ...@@ -34,7 +51,7 @@ class OrchestratorRuleParser:
self.app_model_config = app_model_config self.app_model_config = app_model_config
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str,
retriever_from: str = 'dev') -> Optional[AgentExecutor]: retriever_from: str = 'dev') -> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict: if not self.app_model_config.agent_mode_dict:
return None return None
...@@ -101,7 +118,8 @@ class OrchestratorRuleParser: ...@@ -101,7 +118,8 @@ class OrchestratorRuleParser:
rest_tokens=rest_tokens, rest_tokens=rest_tokens,
return_resource=return_resource, return_resource=return_resource,
retriever_from=retriever_from, retriever_from=retriever_from,
dataset_configs=dataset_configs dataset_configs=dataset_configs,
tenant_id=tenant_id
) )
if len(tools) == 0: if len(tools) == 0:
...@@ -132,6 +150,7 @@ class OrchestratorRuleParser: ...@@ -132,6 +150,7 @@ class OrchestratorRuleParser:
:return: :return:
""" """
tools = [] tools = []
dataset_tools = []
for tool_config in tool_configs: for tool_config in tool_configs:
tool_type = list(tool_config.keys())[0] tool_type = list(tool_config.keys())[0]
tool_val = list(tool_config.values())[0] tool_val = list(tool_config.values())[0]
...@@ -140,7 +159,7 @@ class OrchestratorRuleParser: ...@@ -140,7 +159,7 @@ class OrchestratorRuleParser:
tool = None tool = None
if tool_type == "dataset": if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs) dataset_tools.append(tool_config)
elif tool_type == "web_reader": elif tool_type == "web_reader":
tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs) tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
elif tool_type == "google_search": elif tool_type == "google_search":
...@@ -156,28 +175,35 @@ class OrchestratorRuleParser: ...@@ -156,28 +175,35 @@ class OrchestratorRuleParser:
else: else:
tool.callbacks = callbacks tool.callbacks = callbacks
tools.append(tool) tools.append(tool)
# format dataset tool
if len(dataset_tools) > 0:
dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs)
if dataset_retriever_tools:
tools.extend(dataset_retriever_tools)
return tools return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask,
dataset_configs: dict, rest_tokens: int,
return_resource: bool = False, retriever_from: str = 'dev', return_resource: bool = False, retriever_from: str = 'dev',
**kwargs) \ **kwargs) \
-> Optional[BaseTool]: -> Optional[List[BaseTool]]:
""" """
A dataset tool is a tool that can be used to retrieve information from a dataset A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens: :param tool_configs:
:param tool_config:
:param dataset_configs:
:param conversation_message_task: :param conversation_message_task:
:param return_resource: :param return_resource:
:param retriever_from: :param retriever_from:
:return: :return:
""" """
dataset_configs = kwargs['dataset_configs']
retrieval_model = dataset_configs.get('retrieval_model', 'single')
tools = []
dataset_ids = []
tenant_id = None
for tool_config in tool_configs:
# get dataset from dataset id # get dataset from dataset id
dataset = db.session.query(Dataset).filter( dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id, Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id") Dataset.id == tool_config.get('dataset').get("id")
).first() ).first()
if not dataset: if not dataset:
...@@ -185,16 +211,18 @@ class OrchestratorRuleParser: ...@@ -185,16 +211,18 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None return None
dataset_ids.append(dataset.id)
top_k = dataset_configs.get("top_k", 2) if retrieval_model == 'single':
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
top_k = retrieval_model['top_k']
# dynamically adjust top_k when the remaining token number is not enough to support top_k # dynamically adjust top_k when the remaining token number is not enough to support top_k
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None score_threshold = None
score_threshold_config = dataset_configs.get("score_threshold") score_threshold_enable = retrieval_model.get("score_threshold_enable")
if score_threshold_config and score_threshold_config.get("enable"): if score_threshold_enable:
score_threshold = score_threshold_config.get("value") score_threshold = retrieval_model.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset( tool = DatasetRetrieverTool.from_dataset(
dataset=dataset, dataset=dataset,
...@@ -205,8 +233,23 @@ class OrchestratorRuleParser: ...@@ -205,8 +233,23 @@ class OrchestratorRuleParser:
return_resource=return_resource, return_resource=return_resource,
retriever_from=retriever_from retriever_from=retriever_from
) )
tools.append(tool)
if retrieval_model == 'multiple':
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=dataset_ids,
tenant_id=kwargs['tenant_id'],
top_k=dataset_configs.get('top_k', 2),
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
retriever_from=retriever_from,
reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'),
reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name')
)
tools.append(tool)
return tool return tools
def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]: def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
""" """
......
This diff is collapsed.
import json import json
from typing import Type, Optional import threading
from typing import Type, Optional, List
from flask import current_app from flask import current_app
from langchain.tools import BaseTool from langchain.tools import BaseTool
...@@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE ...@@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, Document from models.dataset import Dataset, DocumentSegment, Document
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class DatasetRetrieverToolInput(BaseModel): class DatasetRetrieverToolInput(BaseModel):
...@@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool): ...@@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool):
).first() ).first()
if not dataset: if not dataset:
return f'[{self.name} failed to find dataset with id {self.dataset_id}.]' return ''
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
...@@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool): ...@@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool):
return '' return ''
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = []
threads = []
if self.top_k > 0: if self.top_k > 0:
documents = vector_index.search( # retrieval source with semantic
query, if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
search_type='similarity_score_threshold', embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
search_kwargs={ 'flask_app': current_app._get_current_object(),
'k': self.top_k, 'dataset': dataset,
'score_threshold': self.score_threshold, 'query': query,
'filter': { 'top_k': self.top_k,
'group_id': [dataset.id] 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
} 'score_threshold_enable'] else None,
} 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval_model source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
# hybrid search: rerank after all documents have been searched
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
) )
documents = hybrid_rerank.rerank(query, documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
self.top_k)
else: else:
documents = [] documents = []
hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task) hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
hit_callback.on_tool_end(documents) hit_callback.on_tool_end(documents)
document_score_list = {} document_score_list = {}
if dataset.indexing_technique != "economy": if dataset.indexing_technique != "economy":
......
from core.index.vector_index.milvus import Milvus from core.vector_store.vector.milvus import Milvus
class MilvusVectorStore(Milvus): class MilvusVectorStore(Milvus):
......
...@@ -4,7 +4,7 @@ from langchain.schema import Document ...@@ -4,7 +4,7 @@ from langchain.schema import Document
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.local.qdrant_local import QdrantLocal
from core.index.vector_index.qdrant import Qdrant from core.vector_store.vector.qdrant import Qdrant
class QdrantVectorStore(Qdrant): class QdrantVectorStore(Qdrant):
...@@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant): ...@@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant):
if isinstance(self.client, QdrantLocal): if isinstance(self.client, QdrantLocal):
self.client = cast(QdrantLocal, self.client) self.client = cast(QdrantLocal, self.client)
self.client._load() self.client._load()
...@@ -28,7 +28,7 @@ from langchain.docstore.document import Document ...@@ -28,7 +28,7 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType from qdrant_client.http.models import PayloadSchemaType, FilterSelector, TextIndexParams, TokenizerType, TextIndexType
if TYPE_CHECKING: if TYPE_CHECKING:
from qdrant_client import grpc # noqa from qdrant_client import grpc # noqa
...@@ -189,14 +189,25 @@ class Qdrant(VectorStore): ...@@ -189,14 +189,25 @@ class Qdrant(VectorStore):
texts, metadatas, ids, batch_size texts, metadatas, ids, batch_size
): ):
self.client.upsert( self.client.upsert(
collection_name=self.collection_name, points=points, **kwargs collection_name=self.collection_name, points=points
) )
added_ids.extend(batch_ids) added_ids.extend(batch_ids)
# if is new collection, create payload index on group_id # if is new collection, create payload index on group_id
if self.is_new_collection: if self.is_new_collection:
# create payload index
self.client.create_payload_index(self.collection_name, self.group_payload_key, self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD, field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD) field_type=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
self.client.create_payload_index(self.collection_name, self.content_payload_key,
field_schema=text_index_params)
return added_ids return added_ids
@sync_call_fallback @sync_call_fallback
...@@ -600,7 +611,7 @@ class Qdrant(VectorStore): ...@@ -600,7 +611,7 @@ class Qdrant(VectorStore):
limit=k, limit=k,
offset=offset, offset=offset,
with_payload=True, with_payload=True,
with_vectors=True, # Langchain does not expect vectors to be returned with_vectors=True,
score_threshold=score_threshold, score_threshold=score_threshold,
consistency=consistency, consistency=consistency,
**kwargs, **kwargs,
...@@ -615,6 +626,39 @@ class Qdrant(VectorStore): ...@@ -615,6 +626,39 @@ class Qdrant(VectorStore):
for result in results for result in results
] ]
def similarity_search_by_bm25(
self,
filter: Optional[MetadataFilter] = None,
k: int = 4
) -> List[Document]:
"""Return docs most similar by bm25.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
Returns:
List of documents most similar to the query text and distance for each.
"""
response = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=filter,
limit=k,
with_payload=True,
with_vectors=True
)
results = response[0]
documents = []
for result in results:
if result:
documents.append(self._document_from_scored_point(
result, self.content_payload_key, self.metadata_payload_key
))
return documents
@sync_call_fallback @sync_call_fallback
async def asimilarity_search_with_score_by_vector( async def asimilarity_search_with_score_by_vector(
self, self,
......
This diff is collapsed.
...@@ -12,6 +12,21 @@ dataset_fields = { ...@@ -12,6 +12,21 @@ dataset_fields = {
'created_at': TimestampField, 'created_at': TimestampField,
} }
reranking_model_fields = {
'reranking_provider_name': fields.String,
'reranking_model_name': fields.String
}
dataset_retrieval_model_fields = {
'search_method': fields.String,
'reranking_enable': fields.Boolean,
'reranking_model': fields.Nested(reranking_model_fields),
'top_k': fields.Integer,
'score_threshold_enable': fields.Boolean,
'score_threshold': fields.Float
}
dataset_detail_fields = { dataset_detail_fields = {
'id': fields.String, 'id': fields.String,
'name': fields.String, 'name': fields.String,
...@@ -29,7 +44,8 @@ dataset_detail_fields = { ...@@ -29,7 +44,8 @@ dataset_detail_fields = {
'updated_at': TimestampField, 'updated_at': TimestampField,
'embedding_model': fields.String, 'embedding_model': fields.String,
'embedding_model_provider': fields.String, 'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean 'embedding_available': fields.Boolean,
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields)
} }
dataset_query_detail_fields = { dataset_query_detail_fields = {
...@@ -41,3 +57,5 @@ dataset_query_detail_fields = { ...@@ -41,3 +57,5 @@ dataset_query_detail_fields = {
"created_by": fields.String, "created_by": fields.String,
"created_at": TimestampField "created_at": TimestampField
} }
"""add-dataset-retrival-model
Revision ID: fca025d3b60f
Revises: b3a09c049e8e
Create Date: 2023-11-03 13:08:23.246396
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'fca025d3b60f'
down_revision = '8fe468ba0ca5'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('sessions')
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('retrieval_model', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
batch_op.create_index('retrieval_model_idx', ['retrieval_model'], unique=False, postgresql_using='gin')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_index('retrieval_model_idx', postgresql_using='gin')
batch_op.drop_column('retrieval_model')
op.create_table('sessions',
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column('session_id', sa.VARCHAR(length=255), autoincrement=False, nullable=True),
sa.Column('data', postgresql.BYTEA(), autoincrement=False, nullable=True),
sa.Column('expiry', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
sa.PrimaryKeyConstraint('id', name='sessions_pkey'),
sa.UniqueConstraint('session_id', name='sessions_session_id_key')
)
# ### end Alembic commands ###
...@@ -3,7 +3,7 @@ import pickle ...@@ -3,7 +3,7 @@ import pickle
from json import JSONDecodeError from json import JSONDecodeError
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID, JSONB
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
...@@ -15,6 +15,7 @@ class Dataset(db.Model): ...@@ -15,6 +15,7 @@ class Dataset(db.Model):
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_pkey'), db.PrimaryKeyConstraint('id', name='dataset_pkey'),
db.Index('dataset_tenant_idx', 'tenant_id'), db.Index('dataset_tenant_idx', 'tenant_id'),
db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
) )
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy'] INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy']
...@@ -39,7 +40,7 @@ class Dataset(db.Model): ...@@ -39,7 +40,7 @@ class Dataset(db.Model):
embedding_model = db.Column(db.String(255), nullable=True) embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True) collection_binding_id = db.Column(UUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
@property @property
def dataset_keyword_table(self): def dataset_keyword_table(self):
...@@ -93,6 +94,20 @@ class Dataset(db.Model): ...@@ -93,6 +94,20 @@ class Dataset(db.Model):
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
.filter(Document.dataset_id == self.id).scalar() .filter(Document.dataset_id == self.id).scalar()
@property
def retrieval_model_dict(self):
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
return self.retrieval_model if self.retrieval_model else default_retrieval_model
class DatasetProcessRule(db.Model): class DatasetProcessRule(db.Model):
__tablename__ = 'dataset_process_rules' __tablename__ = 'dataset_process_rules'
...@@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model): ...@@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model):
], ],
'segmentation': { 'segmentation': {
'delimiter': '\n', 'delimiter': '\n',
'max_tokens': 1000 'max_tokens': 512
} }
} }
...@@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model): ...@@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model):
model_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(40), nullable=False)
collection_name = db.Column(db.String(64), nullable=False) collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
...@@ -160,7 +160,13 @@ class AppModelConfig(db.Model): ...@@ -160,7 +160,13 @@ class AppModelConfig(db.Model):
@property @property
def dataset_configs_dict(self) -> dict: def dataset_configs_dict(self) -> dict:
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}} if self.dataset_configs:
dataset_configs = json.loads(self.dataset_configs)
if 'retrieval_model' not in dataset_configs:
return {'retrieval_model': 'single'}
else:
return dataset_configs
return {'retrieval_model': 'single'}
@property @property
def file_upload_dict(self) -> dict: def file_upload_dict(self) -> dict:
......
...@@ -23,7 +23,6 @@ boto3==1.28.17 ...@@ -23,7 +23,6 @@ boto3==1.28.17
tenacity==8.2.2 tenacity==8.2.2
cachetools~=5.3.0 cachetools~=5.3.0
weaviate-client~=3.21.0 weaviate-client~=3.21.0
qdrant_client~=1.1.6
mailchimp-transactional~=1.0.50 mailchimp-transactional~=1.0.50
scikit-learn==1.2.2 scikit-learn==1.2.2
sentry-sdk[flask]~=1.21.1 sentry-sdk[flask]~=1.21.1
...@@ -54,3 +53,5 @@ safetensors==0.3.2 ...@@ -54,3 +53,5 @@ safetensors==0.3.2
zhipuai==1.0.7 zhipuai==1.0.7
werkzeug==2.3.7 werkzeug==2.3.7
pymilvus==2.3.0 pymilvus==2.3.0
qdrant-client==1.6.4
cohere~=4.32
\ No newline at end of file
...@@ -470,7 +470,16 @@ class AppModelConfigService: ...@@ -470,7 +470,16 @@ class AppModelConfigService:
# dataset_configs # dataset_configs
if 'dataset_configs' not in config or not config["dataset_configs"]: if 'dataset_configs' not in config or not config["dataset_configs"]:
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}} config["dataset_configs"] = {'retrieval_model': 'single'}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if config["dataset_configs"]['retrieval_model'] == 'multiple':
if not config["dataset_configs"]['reranking_model']:
raise ValueError("reranking_model has not been set")
if not isinstance(config["dataset_configs"]['reranking_model'], dict):
raise ValueError("reranking_model must be of object type")
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
......
...@@ -173,6 +173,9 @@ class DatasetService: ...@@ -173,6 +173,9 @@ class DatasetService:
filtered_data['updated_by'] = user.id filtered_data['updated_by'] = user.id
filtered_data['updated_at'] = datetime.datetime.now() filtered_data['updated_at'] = datetime.datetime.now()
# update Retrieval model
filtered_data['retrieval_model'] = data['retrieval_model']
dataset.query.filter_by(id=dataset_id).update(filtered_data) dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.commit() db.session.commit()
...@@ -473,7 +476,19 @@ class DocumentService: ...@@ -473,7 +476,19 @@ class DocumentService:
embedding_model.name embedding_model.name
) )
dataset.collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
documents = [] documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
...@@ -733,6 +748,7 @@ class DocumentService: ...@@ -733,6 +748,7 @@ class DocumentService:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.") raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None embedding_model = None
dataset_collection_binding_id = None dataset_collection_binding_id = None
retrieval_model = None
if document_data['indexing_technique'] == 'high_quality': if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id tenant_id=tenant_id
...@@ -742,6 +758,20 @@ class DocumentService: ...@@ -742,6 +758,20 @@ class DocumentService:
embedding_model.name embedding_model.name
) )
dataset_collection_binding_id = dataset_collection_binding.id dataset_collection_binding_id = dataset_collection_binding.id
if 'retrieval_model' in document_data and document_data['retrieval_model']:
retrieval_model = document_data['retrieval_model']
else:
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
retrieval_model = default_retrieval_model
# save dataset # save dataset
dataset = Dataset( dataset = Dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
...@@ -751,7 +781,8 @@ class DocumentService: ...@@ -751,7 +781,8 @@ class DocumentService:
created_by=account.id, created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None, embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None, embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
collection_binding_id=dataset_collection_binding_id collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model
) )
db.session.add(dataset) db.session.add(dataset)
......
import json
import logging import logging
import threading
import time import time
from typing import List from typing import List
...@@ -9,16 +11,26 @@ from langchain.schema import Document ...@@ -9,16 +11,26 @@ from langchain.schema import Document
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class HitTestingService: class HitTestingService:
@classmethod @classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0: if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return { return {
"query": { "query": {
...@@ -28,31 +40,68 @@ class HitTestingService: ...@@ -28,31 +40,68 @@ class HitTestingService:
"records": [] "records": []
} }
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get embedding model
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider, model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model model_name=dataset.embedding_model
) )
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)
vector_index = VectorIndex( all_documents = []
dataset=dataset, threads = []
config=current_app.config,
embeddings=embeddings # retrieval_model source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
'top_k': retrieval_model['top_k'],
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrieval_model['search_method'] == 'hybrid_search':
hybrid_rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
model_name=retrieval_model['reranking_model']['reranking_model_name']
) )
all_documents = hybrid_rerank.rerank(query, all_documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
retrieval_model['top_k'])
start = time.perf_counter()
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 10,
'filter': {
'group_id': [dataset.id]
}
}
)
end = time.perf_counter() end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
...@@ -67,7 +116,7 @@ class HitTestingService: ...@@ -67,7 +116,7 @@ class HitTestingService:
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()
return cls.compact_retrieve_response(dataset, embeddings, query, documents) return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
@classmethod @classmethod
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]): def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
...@@ -99,7 +148,7 @@ class HitTestingService: ...@@ -99,7 +148,7 @@ class HitTestingService:
record = { record = {
"segment": segment, "segment": segment,
"score": document.metadata['score'], "score": document.metadata.get('score', None),
"tsne_position": tsne_position_data[i] "tsne_position": tsne_position_data[i]
} }
...@@ -136,3 +185,11 @@ class HitTestingService: ...@@ -136,3 +185,11 @@ class HitTestingService:
tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])}) tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})
return tsne_position_data return tsne_position_data
@classmethod
def hit_testing_args_check(cls, args):
query = args['query']
if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')
from typing import Optional
from flask import current_app, Flask
from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory
from models.dataset import Dataset
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
class RetrievalService:
@classmethod
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': top_k,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
}
)
if documents:
if reranking_model and search_method == 'semantic_search':
rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=reranking_model['reranking_provider_name'],
model_name=reranking_model['reranking_model_name']
)
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
else:
all_documents.extend(documents)
@classmethod
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context():
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search_by_full_text_index(
query,
search_type='similarity_score_threshold',
top_k=top_k
)
if documents:
if reranking_model and search_method == 'full_text_search':
rerank = ModelFactory.get_reranking_model(
tenant_id=dataset.tenant_id,
model_provider_name=reranking_model['reranking_provider_name'],
model_name=reranking_model['reranking_model_name']
)
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
else:
all_documents.extend(documents)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment