Unverified Commit 74b2260b authored by Jyong's avatar Jyong Committed by GitHub

fix score_threshold_enabled name (#1626)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 603e55f2
...@@ -40,7 +40,7 @@ default_retrieval_model = { ...@@ -40,7 +40,7 @@ default_retrieval_model = {
'reranking_model_name': '' 'reranking_model_name': ''
}, },
'top_k': 2, 'top_k': 2,
'score_threshold_enable': False 'score_threshold_enabled': False
} }
class OrchestratorRuleParser: class OrchestratorRuleParser:
...@@ -220,8 +220,8 @@ class OrchestratorRuleParser: ...@@ -220,8 +220,8 @@ class OrchestratorRuleParser:
# top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None score_threshold = None
score_threshold_enable = retrieval_model_config.get("score_threshold_enable") score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enable: if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold") score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset( tool = DatasetRetrieverTool.from_dataset(
...@@ -239,7 +239,7 @@ class OrchestratorRuleParser: ...@@ -239,7 +239,7 @@ class OrchestratorRuleParser:
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
tenant_id=kwargs['tenant_id'], tenant_id=kwargs['tenant_id'],
top_k=dataset_configs.get('top_k', 2), top_k=dataset_configs.get('top_k', 2),
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None, score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enabled', False) else None,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)], callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
return_resource=return_resource, return_resource=return_resource,
......
...@@ -24,7 +24,7 @@ default_retrieval_model = { ...@@ -24,7 +24,7 @@ default_retrieval_model = {
'reranking_model_name': '' 'reranking_model_name': ''
}, },
'top_k': 2, 'top_k': 2,
'score_threshold_enable': False 'score_threshold_enabled': False
} }
...@@ -216,7 +216,7 @@ class DatasetMultiRetrieverTool(BaseTool): ...@@ -216,7 +216,7 @@ class DatasetMultiRetrieverTool(BaseTool):
'embeddings': embeddings, 'embeddings': embeddings,
'score_threshold': retrieval_model[ 'score_threshold': retrieval_model[
'score_threshold'] if retrieval_model[ 'score_threshold'] if retrieval_model[
'score_threshold_enable'] else None, 'score_threshold_enabled'] else None,
'top_k': self.top_k, 'top_k': self.top_k,
'reranking_model': retrieval_model[ 'reranking_model': retrieval_model[
'reranking_model'] if retrieval_model[ 'reranking_model'] if retrieval_model[
......
...@@ -25,7 +25,7 @@ default_retrieval_model = { ...@@ -25,7 +25,7 @@ default_retrieval_model = {
'reranking_model_name': '' 'reranking_model_name': ''
}, },
'top_k': 2, 'top_k': 2,
'score_threshold_enable': False 'score_threshold_enabled': False
} }
...@@ -110,7 +110,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -110,7 +110,7 @@ class DatasetRetrieverTool(BaseTool):
'query': query, 'query': query,
'top_k': self.top_k, 'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None, 'score_threshold_enabled'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None, 'reranking_enable'] else None,
'all_documents': documents, 'all_documents': documents,
...@@ -129,7 +129,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -129,7 +129,7 @@ class DatasetRetrieverTool(BaseTool):
'search_method': retrieval_model['search_method'], 'search_method': retrieval_model['search_method'],
'embeddings': embeddings, 'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enable'] else None, 'score_threshold_enabled'] else None,
'top_k': self.top_k, 'top_k': self.top_k,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None, 'reranking_enable'] else None,
...@@ -148,7 +148,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -148,7 +148,7 @@ class DatasetRetrieverTool(BaseTool):
model_name=retrieval_model['reranking_model']['reranking_model_name'] model_name=retrieval_model['reranking_model']['reranking_model_name']
) )
documents = hybrid_rerank.rerank(query, documents, documents = hybrid_rerank.rerank(query, documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
self.top_k) self.top_k)
else: else:
documents = [] documents = []
......
...@@ -22,7 +22,7 @@ dataset_retrieval_model_fields = { ...@@ -22,7 +22,7 @@ dataset_retrieval_model_fields = {
'reranking_enable': fields.Boolean, 'reranking_enable': fields.Boolean,
'reranking_model': fields.Nested(reranking_model_fields), 'reranking_model': fields.Nested(reranking_model_fields),
'top_k': fields.Integer, 'top_k': fields.Integer,
'score_threshold_enable': fields.Boolean, 'score_threshold_enabled': fields.Boolean,
'score_threshold': fields.Float 'score_threshold': fields.Float
} }
......
...@@ -104,7 +104,7 @@ class Dataset(db.Model): ...@@ -104,7 +104,7 @@ class Dataset(db.Model):
'reranking_model_name': '' 'reranking_model_name': ''
}, },
'top_k': 2, 'top_k': 2,
'score_threshold_enable': False 'score_threshold_enabled': False
} }
return self.retrieval_model if self.retrieval_model else default_retrieval_model return self.retrieval_model if self.retrieval_model else default_retrieval_model
......
...@@ -485,7 +485,7 @@ class DocumentService: ...@@ -485,7 +485,7 @@ class DocumentService:
'reranking_model_name': '' 'reranking_model_name': ''
}, },
'top_k': 2, 'top_k': 2,
'score_threshold_enable': False 'score_threshold_enabled': False
} }
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
...@@ -769,7 +769,7 @@ class DocumentService: ...@@ -769,7 +769,7 @@ class DocumentService:
'reranking_model_name': '' 'reranking_model_name': ''
}, },
'top_k': 2, 'top_k': 2,
'score_threshold_enable': False 'score_threshold_enabled': False
} }
retrieval_model = default_retrieval_model retrieval_model = default_retrieval_model
# save dataset # save dataset
......
...@@ -25,7 +25,7 @@ default_retrieval_model = { ...@@ -25,7 +25,7 @@ default_retrieval_model = {
'reranking_model_name': '' 'reranking_model_name': ''
}, },
'top_k': 2, 'top_k': 2,
'score_threshold_enable': False 'score_threshold_enabled': False
} }
class HitTestingService: class HitTestingService:
...@@ -64,7 +64,7 @@ class HitTestingService: ...@@ -64,7 +64,7 @@ class HitTestingService:
'dataset_id': str(dataset.id), 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': retrieval_model['top_k'], 'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents, 'all_documents': all_documents,
'search_method': retrieval_model['search_method'], 'search_method': retrieval_model['search_method'],
...@@ -81,7 +81,7 @@ class HitTestingService: ...@@ -81,7 +81,7 @@ class HitTestingService:
'query': query, 'query': query,
'search_method': retrieval_model['search_method'], 'search_method': retrieval_model['search_method'],
'embeddings': embeddings, 'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
'top_k': retrieval_model['top_k'], 'top_k': retrieval_model['top_k'],
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
'all_documents': all_documents 'all_documents': all_documents
...@@ -99,7 +99,7 @@ class HitTestingService: ...@@ -99,7 +99,7 @@ class HitTestingService:
model_name=retrieval_model['reranking_model']['reranking_model_name'] model_name=retrieval_model['reranking_model']['reranking_model_name']
) )
all_documents = hybrid_rerank.rerank(query, all_documents, all_documents = hybrid_rerank.rerank(query, all_documents,
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
retrieval_model['top_k']) retrieval_model['top_k'])
end = time.perf_counter() end = time.perf_counter()
......
...@@ -15,7 +15,7 @@ default_retrieval_model = { ...@@ -15,7 +15,7 @@ default_retrieval_model = {
'reranking_model_name': '' 'reranking_model_name': ''
}, },
'top_k': 2, 'top_k': 2,
'score_threshold_enable': False 'score_threshold_enabled': False
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment