Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
a5b80c9d
Unverified
Commit
a5b80c9d
authored
Nov 22, 2023
by
Jyong
Committed by
GitHub
Nov 22, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix/multi thread parameter (#1604)
parent
f704094a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
8 deletions
+15
-8
dataset_multi_retriever_tool.py
api/core/tool/dataset_multi_retriever_tool.py
+2
-2
dataset_retriever_tool.py
api/core/tool/dataset_retriever_tool.py
+2
-2
hit_testing_service.py
api/services/hit_testing_service.py
+2
-2
retrieval_service.py
api/services/retrieval_service.py
+9
-2
No files found.
api/core/tool/dataset_multi_retriever_tool.py
View file @
a5b80c9d
...
@@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool):
...
@@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool):
'search_method'
]
==
'hybrid_search'
:
'search_method'
]
==
'hybrid_search'
:
embedding_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
embedding_search
,
kwargs
=
{
embedding_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
embedding_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset
'
:
dataset
,
'dataset
_id'
:
str
(
dataset
.
id
)
,
'query'
:
query
,
'query'
:
query
,
'top_k'
:
self
.
top_k
,
'top_k'
:
self
.
top_k
,
'score_threshold'
:
self
.
score_threshold
,
'score_threshold'
:
self
.
score_threshold
,
...
@@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool):
...
@@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool):
full_text_index_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
full_text_index_search
,
full_text_index_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
full_text_index_search
,
kwargs
=
{
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset
'
:
dataset
,
'dataset
_id'
:
str
(
dataset
.
id
)
,
'query'
:
query
,
'query'
:
query
,
'search_method'
:
'hybrid_search'
,
'search_method'
:
'hybrid_search'
,
'embeddings'
:
embeddings
,
'embeddings'
:
embeddings
,
...
...
api/core/tool/dataset_retriever_tool.py
View file @
a5b80c9d
...
@@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool):
...
@@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool):
if
retrieval_model
[
'search_method'
]
==
'semantic_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
if
retrieval_model
[
'search_method'
]
==
'semantic_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
embedding_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
embedding_search
,
kwargs
=
{
embedding_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
embedding_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset
'
:
dataset
,
'dataset
_id'
:
str
(
dataset
.
id
)
,
'query'
:
query
,
'query'
:
query
,
'top_k'
:
self
.
top_k
,
'top_k'
:
self
.
top_k
,
'score_threshold'
:
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
'score_threshold'
:
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
...
@@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool):
...
@@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool):
if
retrieval_model
[
'search_method'
]
==
'full_text_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
if
retrieval_model
[
'search_method'
]
==
'full_text_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
full_text_index_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
full_text_index_search
,
kwargs
=
{
full_text_index_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
full_text_index_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset
'
:
dataset
,
'dataset
_id'
:
str
(
dataset
.
id
)
,
'query'
:
query
,
'query'
:
query
,
'search_method'
:
retrieval_model
[
'search_method'
],
'search_method'
:
retrieval_model
[
'search_method'
],
'embeddings'
:
embeddings
,
'embeddings'
:
embeddings
,
...
...
api/services/hit_testing_service.py
View file @
a5b80c9d
...
@@ -61,7 +61,7 @@ class HitTestingService:
...
@@ -61,7 +61,7 @@ class HitTestingService:
if
retrieval_model
[
'search_method'
]
==
'semantic_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
if
retrieval_model
[
'search_method'
]
==
'semantic_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
embedding_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
embedding_search
,
kwargs
=
{
embedding_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
embedding_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset
'
:
dataset
,
'dataset
_id'
:
str
(
dataset
.
id
)
,
'query'
:
query
,
'query'
:
query
,
'top_k'
:
retrieval_model
[
'top_k'
],
'top_k'
:
retrieval_model
[
'top_k'
],
'score_threshold'
:
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
'score_threshold_enable'
]
else
None
,
'score_threshold'
:
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
'score_threshold_enable'
]
else
None
,
...
@@ -77,7 +77,7 @@ class HitTestingService:
...
@@ -77,7 +77,7 @@ class HitTestingService:
if
retrieval_model
[
'search_method'
]
==
'full_text_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
if
retrieval_model
[
'search_method'
]
==
'full_text_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
full_text_index_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
full_text_index_search
,
kwargs
=
{
full_text_index_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
full_text_index_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset
'
:
dataset
,
'dataset
_id'
:
str
(
dataset
.
id
)
,
'query'
:
query
,
'query'
:
query
,
'search_method'
:
retrieval_model
[
'search_method'
],
'search_method'
:
retrieval_model
[
'search_method'
],
'embeddings'
:
embeddings
,
'embeddings'
:
embeddings
,
...
...
api/services/retrieval_service.py
View file @
a5b80c9d
...
@@ -4,6 +4,7 @@ from flask import current_app, Flask
...
@@ -4,6 +4,7 @@ from flask import current_app, Flask
from
langchain.embeddings.base
import
Embeddings
from
langchain.embeddings.base
import
Embeddings
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.model_factory
import
ModelFactory
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
from
models.dataset
import
Dataset
default_retrieval_model
=
{
default_retrieval_model
=
{
...
@@ -21,10 +22,13 @@ default_retrieval_model = {
...
@@ -21,10 +22,13 @@ default_retrieval_model = {
class
RetrievalService
:
class
RetrievalService
:
@
classmethod
@
classmethod
def
embedding_search
(
cls
,
flask_app
:
Flask
,
dataset
:
Dataset
,
query
:
str
,
def
embedding_search
(
cls
,
flask_app
:
Flask
,
dataset
_id
:
str
,
query
:
str
,
top_k
:
int
,
score_threshold
:
Optional
[
float
],
reranking_model
:
Optional
[
dict
],
top_k
:
int
,
score_threshold
:
Optional
[
float
],
reranking_model
:
Optional
[
dict
],
all_documents
:
list
,
search_method
:
str
,
embeddings
:
Embeddings
):
all_documents
:
list
,
search_method
:
str
,
embeddings
:
Embeddings
):
with
flask_app
.
app_context
():
with
flask_app
.
app_context
():
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
id
==
dataset_id
)
.
first
()
vector_index
=
VectorIndex
(
vector_index
=
VectorIndex
(
dataset
=
dataset
,
dataset
=
dataset
,
...
@@ -56,10 +60,13 @@ class RetrievalService:
...
@@ -56,10 +60,13 @@ class RetrievalService:
all_documents
.
extend
(
documents
)
all_documents
.
extend
(
documents
)
@
classmethod
@
classmethod
def
full_text_index_search
(
cls
,
flask_app
:
Flask
,
dataset
:
Dataset
,
query
:
str
,
def
full_text_index_search
(
cls
,
flask_app
:
Flask
,
dataset
_id
:
str
,
query
:
str
,
top_k
:
int
,
score_threshold
:
Optional
[
float
],
reranking_model
:
Optional
[
dict
],
top_k
:
int
,
score_threshold
:
Optional
[
float
],
reranking_model
:
Optional
[
dict
],
all_documents
:
list
,
search_method
:
str
,
embeddings
:
Embeddings
):
all_documents
:
list
,
search_method
:
str
,
embeddings
:
Embeddings
):
with
flask_app
.
app_context
():
with
flask_app
.
app_context
():
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
id
==
dataset_id
)
.
first
()
vector_index
=
VectorIndex
(
vector_index
=
VectorIndex
(
dataset
=
dataset
,
dataset
=
dataset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment