Unverified Commit 22bc9ddc authored by WangBooth's avatar WangBooth Committed by GitHub

Hotfix/fix documents index mismatch error in rerank (#1662)

Co-authored-by: 's avatarbaomi.wbm <baomi.wbm@dtwave-inc.com>
parent 04237756
import logging import logging
from typing import Optional, List from typing import List, Optional
import cohere import cohere
import openai import openai
from langchain.schema import Document from core.model_providers.error import (LLMAPIConnectionError,
LLMAPIUnavailableError,
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ LLMAuthorizationError,
LLMRateLimitError, LLMAuthorizationError LLMBadRequestError, LLMRateLimitError)
from core.model_providers.models.reranking.base import BaseReranking from core.model_providers.models.reranking.base import BaseReranking
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
from langchain.schema import Document
class CohereReranking(BaseReranking): class CohereReranking(BaseReranking):
...@@ -26,10 +27,14 @@ class CohereReranking(BaseReranking): ...@@ -26,10 +27,14 @@ class CohereReranking(BaseReranking):
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
docs = [] docs = []
doc_id = [] doc_id = []
unique_documents = []
for document in documents: for document in documents:
if document.metadata['doc_id'] not in doc_id: if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id']) doc_id.append(document.metadata['doc_id'])
docs.append(document.page_content) docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k) results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
rerank_documents = [] rerank_documents = []
......
...@@ -23,11 +23,14 @@ class XinferenceReranking(BaseReranking): ...@@ -23,11 +23,14 @@ class XinferenceReranking(BaseReranking):
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
docs = [] docs = []
doc_id = [] doc_id = []
unique_documents = []
for document in documents: for document in documents:
if document.metadata['doc_id'] not in doc_id: if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id']) doc_id.append(document.metadata['doc_id'])
docs.append(document.page_content) docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
model = self.client.get_model(self.credentials['model_uid']) model = self.client.get_model(self.credentials['model_uid'])
response = model.rerank(query=query, documents=docs, top_n=top_k) response = model.rerank(query=query, documents=docs, top_n=top_k)
rerank_documents = [] rerank_documents = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment