Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
4588831b
Unverified
Commit
4588831b
authored
Nov 17, 2023
by
Jyong
Committed by
GitHub
Nov 17, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Feat/add retriever rerank (#1560)
Co-authored-by:
jyong
<
jyong@dify.ai
>
parent
a4f37220
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
44 changed files
with
1903 additions
and
168 deletions
+1903
-168
commands.py
api/commands.py
+46
-11
datasets.py
api/controllers/console/datasets/datasets.py
+48
-0
datasets_document.py
api/controllers/console/datasets/datasets_document.py
+4
-0
hit_testing.py
api/controllers/console/datasets/hit_testing.py
+5
-6
models.py
api/controllers/console/workspace/models.py
+10
-11
document.py
api/controllers/service_api/dataset/document.py
+4
-0
multi_dataset_router_agent.py
api/core/agent/agent/multi_dataset_router_agent.py
+0
-2
retirver_dataset_agent.py
api/core/agent/agent/output_parser/retirver_dataset_agent.py
+158
-0
structed_multi_dataset_router_agent.py
api/core/agent/agent/structed_multi_dataset_router_agent.py
+0
-1
agent_executor.py
api/core/agent/agent_executor.py
+3
-2
index_tool_callback_handler.py
api/core/callback_handler/index_tool_callback_handler.py
+1
-3
completion.py
api/core/completion.py
+1
-0
file_extractor.py
api/core/data_loader/file_extractor.py
+28
-18
base.py
api/core/index/vector_index/base.py
+7
-0
milvus_vector_index.py
api/core/index/vector_index/milvus_vector_index.py
+9
-7
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+18
-0
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+8
-1
indexing_runner.py
api/core/indexing_runner.py
+5
-5
model_factory.py
api/core/model_providers/model_factory.py
+39
-0
model_provider_factory.py
api/core/model_providers/model_provider_factory.py
+3
-0
model_params.py
api/core/model_providers/models/entity/model_params.py
+1
-1
__init__.py
api/core/model_providers/models/reranking/__init__.py
+0
-0
base.py
api/core/model_providers/models/reranking/base.py
+36
-0
cohere_reranking.py
...core/model_providers/models/reranking/cohere_reranking.py
+73
-0
cohere_provider.py
api/core/model_providers/providers/cohere_provider.py
+152
-0
_providers.json
api/core/model_providers/rules/_providers.json
+2
-1
cohere.json
api/core/model_providers/rules/cohere.json
+7
-0
orchestrator_rule_parser.py
api/core/orchestrator_rule_parser.py
+85
-42
dataset_multi_retriever_tool.py
api/core/tool/dataset_multi_retriever_tool.py
+227
-0
dataset_retriever_tool.py
api/core/tool/dataset_retriever_tool.py
+68
-19
milvus_vector_store.py
api/core/vector_store/milvus_vector_store.py
+1
-1
qdrant_vector_store.py
api/core/vector_store/qdrant_vector_store.py
+2
-1
milvus.py
api/core/vector_store/vector/milvus.py
+0
-0
qdrant.py
api/core/vector_store/vector/qdrant.py
+47
-3
weaviate.py
api/core/vector_store/vector/weaviate.py
+505
-0
dataset_fields.py
api/fields/dataset_fields.py
+19
-1
fca025d3b60f_add_dataset_retrival_model.py
...tions/versions/fca025d3b60f_add_dataset_retrival_model.py
+43
-0
dataset.py
api/models/dataset.py
+18
-4
model.py
api/models/model.py
+7
-1
requirements.txt
api/requirements.txt
+3
-2
app_model_config_service.py
api/services/app_model_config_service.py
+10
-1
dataset_service.py
api/services/dataset_service.py
+33
-2
hit_testing_service.py
api/services/hit_testing_service.py
+79
-22
retrieval_service.py
api/services/retrieval_service.py
+88
-0
No files found.
api/commands.py
View file @
4588831b
...
@@ -8,6 +8,8 @@ import time
...
@@ -8,6 +8,8 @@ import time
import
uuid
import
uuid
import
click
import
click
import
qdrant_client
from
qdrant_client.http.models
import
TextIndexParams
,
TextIndexType
,
TokenizerType
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
flask
import
current_app
,
Flask
from
flask
import
current_app
,
Flask
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.embeddings
import
OpenAIEmbeddings
...
@@ -484,6 +486,38 @@ def normalization_collections():
...
@@ -484,6 +486,38 @@ def normalization_collections():
click
.
echo
(
click
.
style
(
'Congratulations! restore {} dataset indexes.'
.
format
(
len
(
normalization_count
)),
fg
=
'green'
))
click
.
echo
(
click
.
style
(
'Congratulations! restore {} dataset indexes.'
.
format
(
len
(
normalization_count
)),
fg
=
'green'
))
@
click
.
command
(
'add-qdrant-full-text-index'
,
help
=
'add qdrant full text index'
)
def
add_qdrant_full_text_index
():
click
.
echo
(
click
.
style
(
'Start add full text index.'
,
fg
=
'green'
))
binds
=
db
.
session
.
query
(
DatasetCollectionBinding
)
.
all
()
if
binds
and
current_app
.
config
[
'VECTOR_STORE'
]
==
'qdrant'
:
qdrant_url
=
current_app
.
config
[
'QDRANT_URL'
]
qdrant_api_key
=
current_app
.
config
[
'QDRANT_API_KEY'
]
client
=
qdrant_client
.
QdrantClient
(
qdrant_url
,
api_key
=
qdrant_api_key
,
# For Qdrant Cloud, None for local instance
)
for
bind
in
binds
:
try
:
text_index_params
=
TextIndexParams
(
type
=
TextIndexType
.
TEXT
,
tokenizer
=
TokenizerType
.
MULTILINGUAL
,
min_token_len
=
2
,
max_token_len
=
20
,
lowercase
=
True
)
client
.
create_payload_index
(
bind
.
collection_name
,
'page_content'
,
field_schema
=
text_index_params
)
except
Exception
as
e
:
click
.
echo
(
click
.
style
(
'Create full text index error: {} {}'
.
format
(
e
.
__class__
.
__name__
,
str
(
e
)),
fg
=
'red'
))
click
.
echo
(
click
.
style
(
'Congratulations! add collection {} full text index successful.'
.
format
(
bind
.
collection_name
),
fg
=
'green'
))
def
deal_dataset_vector
(
flask_app
:
Flask
,
dataset
:
Dataset
,
normalization_count
:
list
):
def
deal_dataset_vector
(
flask_app
:
Flask
,
dataset
:
Dataset
,
normalization_count
:
list
):
with
flask_app
.
app_context
():
with
flask_app
.
app_context
():
try
:
try
:
...
@@ -647,10 +681,10 @@ def update_app_model_configs(batch_size):
...
@@ -647,10 +681,10 @@ def update_app_model_configs(batch_size):
pbar
.
update
(
len
(
data_batch
))
pbar
.
update
(
len
(
data_batch
))
@
click
.
command
(
'migrate_default_input_to_dataset_query_variable'
)
@
click
.
command
(
'migrate_default_input_to_dataset_query_variable'
)
@
click
.
option
(
"--batch-size"
,
default
=
500
,
help
=
"Number of records to migrate in each batch."
)
@
click
.
option
(
"--batch-size"
,
default
=
500
,
help
=
"Number of records to migrate in each batch."
)
def
migrate_default_input_to_dataset_query_variable
(
batch_size
):
def
migrate_default_input_to_dataset_query_variable
(
batch_size
):
click
.
secho
(
"Starting..."
,
fg
=
'green'
)
click
.
secho
(
"Starting..."
,
fg
=
'green'
)
total_records
=
db
.
session
.
query
(
AppModelConfig
)
\
total_records
=
db
.
session
.
query
(
AppModelConfig
)
\
...
@@ -658,13 +692,13 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
...
@@ -658,13 +692,13 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
.
filter
(
App
.
mode
==
'completion'
)
\
.
filter
(
App
.
mode
==
'completion'
)
\
.
filter
(
AppModelConfig
.
dataset_query_variable
==
None
)
\
.
filter
(
AppModelConfig
.
dataset_query_variable
==
None
)
\
.
count
()
.
count
()
if
total_records
==
0
:
if
total_records
==
0
:
click
.
secho
(
"No data to migrate."
,
fg
=
'green'
)
click
.
secho
(
"No data to migrate."
,
fg
=
'green'
)
return
return
num_batches
=
(
total_records
+
batch_size
-
1
)
//
batch_size
num_batches
=
(
total_records
+
batch_size
-
1
)
//
batch_size
with
tqdm
(
total
=
total_records
,
desc
=
"Migrating Data"
)
as
pbar
:
with
tqdm
(
total
=
total_records
,
desc
=
"Migrating Data"
)
as
pbar
:
for
i
in
range
(
num_batches
):
for
i
in
range
(
num_batches
):
offset
=
i
*
batch_size
offset
=
i
*
batch_size
...
@@ -697,14 +731,14 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
...
@@ -697,14 +731,14 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
for
form
in
user_input_form
:
for
form
in
user_input_form
:
paragraph
=
form
.
get
(
'paragraph'
)
paragraph
=
form
.
get
(
'paragraph'
)
if
paragraph
\
if
paragraph
\
and
paragraph
.
get
(
'variable'
)
==
'query'
:
and
paragraph
.
get
(
'variable'
)
==
'query'
:
data
.
dataset_query_variable
=
'query'
data
.
dataset_query_variable
=
'query'
break
break
if
paragraph
\
if
paragraph
\
and
paragraph
.
get
(
'variable'
)
==
'default_input'
:
and
paragraph
.
get
(
'variable'
)
==
'default_input'
:
data
.
dataset_query_variable
=
'default_input'
data
.
dataset_query_variable
=
'default_input'
break
break
db
.
session
.
commit
()
db
.
session
.
commit
()
...
@@ -712,7 +746,7 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
...
@@ -712,7 +746,7 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
click
.
secho
(
f
"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}"
,
click
.
secho
(
f
"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}"
,
fg
=
'red'
)
fg
=
'red'
)
continue
continue
click
.
secho
(
f
"Successfully migrated batch {i + 1}/{num_batches}."
,
fg
=
'green'
)
click
.
secho
(
f
"Successfully migrated batch {i + 1}/{num_batches}."
,
fg
=
'green'
)
pbar
.
update
(
len
(
data_batch
))
pbar
.
update
(
len
(
data_batch
))
...
@@ -731,3 +765,4 @@ def register_commands(app):
...
@@ -731,3 +765,4 @@ def register_commands(app):
app
.
cli
.
add_command
(
update_app_model_configs
)
app
.
cli
.
add_command
(
update_app_model_configs
)
app
.
cli
.
add_command
(
normalization_collections
)
app
.
cli
.
add_command
(
normalization_collections
)
app
.
cli
.
add_command
(
migrate_default_input_to_dataset_query_variable
)
app
.
cli
.
add_command
(
migrate_default_input_to_dataset_query_variable
)
app
.
cli
.
add_command
(
add_qdrant_full_text_index
)
api/controllers/console/datasets/datasets.py
View file @
4588831b
...
@@ -170,6 +170,7 @@ class DatasetApi(Resource):
...
@@ -170,6 +170,7 @@ class DatasetApi(Resource):
help
=
'Invalid indexing technique.'
)
help
=
'Invalid indexing technique.'
)
parser
.
add_argument
(
'permission'
,
type
=
str
,
location
=
'json'
,
choices
=
(
parser
.
add_argument
(
'permission'
,
type
=
str
,
location
=
'json'
,
choices
=
(
'only_me'
,
'all_team_members'
),
help
=
'Invalid permission.'
)
'only_me'
,
'all_team_members'
),
help
=
'Invalid permission.'
)
parser
.
add_argument
(
'retrieval_model'
,
type
=
dict
,
location
=
'json'
,
help
=
'Invalid retrieval model.'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# The role of the current user in the ta table must be admin or owner
# The role of the current user in the ta table must be admin or owner
...
@@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource):
...
@@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource):
class
DatasetApiDeleteApi
(
Resource
):
class
DatasetApiDeleteApi
(
Resource
):
resource_type
=
'dataset'
resource_type
=
'dataset'
@
setup_required
@
setup_required
@
login_required
@
login_required
@
account_initialization_required
@
account_initialization_required
...
@@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource):
...
@@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource):
}
}
class
DatasetRetrievalSettingApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
get
(
self
):
vector_type
=
current_app
.
config
[
'VECTOR_STORE'
]
if
vector_type
==
'milvus'
:
return
{
'retrieval_method'
:
[
'semantic_search'
]
}
elif
vector_type
==
'qdrant'
or
vector_type
==
'weaviate'
:
return
{
'retrieval_method'
:
[
'semantic_search'
,
'full_text_search'
,
'hybrid_search'
]
}
else
:
raise
ValueError
(
"Unsupported vector db type."
)
class
DatasetRetrievalSettingMockApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
get
(
self
,
vector_type
):
if
vector_type
==
'milvus'
:
return
{
'retrieval_method'
:
[
'semantic_search'
]
}
elif
vector_type
==
'qdrant'
or
vector_type
==
'weaviate'
:
return
{
'retrieval_method'
:
[
'semantic_search'
,
'full_text_search'
,
'hybrid_search'
]
}
else
:
raise
ValueError
(
"Unsupported vector db type."
)
api
.
add_resource
(
DatasetListApi
,
'/datasets'
)
api
.
add_resource
(
DatasetListApi
,
'/datasets'
)
api
.
add_resource
(
DatasetApi
,
'/datasets/<uuid:dataset_id>'
)
api
.
add_resource
(
DatasetApi
,
'/datasets/<uuid:dataset_id>'
)
api
.
add_resource
(
DatasetQueryApi
,
'/datasets/<uuid:dataset_id>/queries'
)
api
.
add_resource
(
DatasetQueryApi
,
'/datasets/<uuid:dataset_id>/queries'
)
...
@@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing
...
@@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing
api
.
add_resource
(
DatasetApiKeyApi
,
'/datasets/api-keys'
)
api
.
add_resource
(
DatasetApiKeyApi
,
'/datasets/api-keys'
)
api
.
add_resource
(
DatasetApiDeleteApi
,
'/datasets/api-keys/<uuid:api_key_id>'
)
api
.
add_resource
(
DatasetApiDeleteApi
,
'/datasets/api-keys/<uuid:api_key_id>'
)
api
.
add_resource
(
DatasetApiBaseUrlApi
,
'/datasets/api-base-info'
)
api
.
add_resource
(
DatasetApiBaseUrlApi
,
'/datasets/api-base-info'
)
api
.
add_resource
(
DatasetRetrievalSettingApi
,
'/datasets/retrieval-setting'
)
api
.
add_resource
(
DatasetRetrievalSettingMockApi
,
'/datasets/retrieval-setting/<string:vector_type>'
)
api/controllers/console/datasets/datasets_document.py
View file @
4588831b
...
@@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource):
...
@@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource):
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
location
=
'json'
)
parser
.
add_argument
(
'retrieval_model'
,
type
=
dict
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
not
dataset
.
indexing_technique
and
not
args
[
'indexing_technique'
]:
if
not
dataset
.
indexing_technique
and
not
args
[
'indexing_technique'
]:
...
@@ -263,6 +265,8 @@ class DatasetInitApi(Resource):
...
@@ -263,6 +265,8 @@ class DatasetInitApi(Resource):
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
location
=
'json'
)
parser
.
add_argument
(
'retrieval_model'
,
type
=
dict
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
[
'indexing_technique'
]
==
'high_quality'
:
if
args
[
'indexing_technique'
]
==
'high_quality'
:
try
:
try
:
...
...
api/controllers/console/datasets/hit_testing.py
View file @
4588831b
...
@@ -42,19 +42,18 @@ class HitTestingApi(Resource):
...
@@ -42,19 +42,18 @@ class HitTestingApi(Resource):
parser
=
reqparse
.
RequestParser
()
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'query'
,
type
=
str
,
location
=
'json'
)
parser
.
add_argument
(
'query'
,
type
=
str
,
location
=
'json'
)
parser
.
add_argument
(
'retrieval_model'
,
type
=
dict
,
required
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
query
=
args
[
'query'
]
HitTestingService
.
hit_testing_args_check
(
args
)
if
not
query
or
len
(
query
)
>
250
:
raise
ValueError
(
'Query is required and cannot exceed 250 characters'
)
try
:
try
:
response
=
HitTestingService
.
retrieve
(
response
=
HitTestingService
.
retrieve
(
dataset
=
dataset
,
dataset
=
dataset
,
query
=
query
,
query
=
args
[
'query'
]
,
account
=
current_user
,
account
=
current_user
,
limit
=
10
,
retrieval_model
=
args
[
'retrieval_model'
],
limit
=
10
)
)
return
{
"query"
:
response
[
'query'
],
'records'
:
marshal
(
response
[
'records'
],
hit_testing_record_fields
)}
return
{
"query"
:
response
[
'query'
],
'records'
:
marshal
(
response
[
'records'
],
hit_testing_record_fields
)}
...
...
api/controllers/console/workspace/models.py
View file @
4588831b
...
@@ -19,7 +19,7 @@ class DefaultModelApi(Resource):
...
@@ -19,7 +19,7 @@ class DefaultModelApi(Resource):
def
get
(
self
):
def
get
(
self
):
parser
=
reqparse
.
RequestParser
()
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'model_type'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
parser
.
add_argument
(
'model_type'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
choices
=
[
'text-generation'
,
'embeddings'
,
'speech2text'
],
location
=
'args'
)
choices
=
[
'text-generation'
,
'embeddings'
,
'speech2text'
,
'reranking'
],
location
=
'args'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
tenant_id
=
current_user
.
current_tenant_id
tenant_id
=
current_user
.
current_tenant_id
...
@@ -71,19 +71,18 @@ class DefaultModelApi(Resource):
...
@@ -71,19 +71,18 @@ class DefaultModelApi(Resource):
@
account_initialization_required
@
account_initialization_required
def
post
(
self
):
def
post
(
self
):
parser
=
reqparse
.
RequestParser
()
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'model_name'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'model_settings'
,
type
=
list
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'model_type'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
choices
=
[
'text-generation'
,
'embeddings'
,
'speech2text'
],
location
=
'json'
)
parser
.
add_argument
(
'provider_name'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
provider_service
=
ProviderService
()
provider_service
=
ProviderService
()
provider_service
.
update_default_model_of_model_type
(
model_settings
=
args
[
'model_settings'
]
tenant_id
=
current_user
.
current_tenant_id
,
for
model_setting
in
model_settings
:
model_type
=
args
[
'model_type'
],
provider_service
.
update_default_model_of_model_type
(
provider_name
=
args
[
'provider_name'
],
tenant_id
=
current_user
.
current_tenant_id
,
model_name
=
args
[
'model_name'
]
model_type
=
model_setting
[
'model_type'
],
)
provider_name
=
model_setting
[
'provider_name'
],
model_name
=
model_setting
[
'model_name'
]
)
return
{
'result'
:
'success'
}
return
{
'result'
:
'success'
}
...
...
api/controllers/service_api/dataset/document.py
View file @
4588831b
...
@@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource):
...
@@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource):
location
=
'json'
)
location
=
'json'
)
parser
.
add_argument
(
'indexing_technique'
,
type
=
str
,
choices
=
Dataset
.
INDEXING_TECHNIQUE_LIST
,
nullable
=
False
,
parser
.
add_argument
(
'indexing_technique'
,
type
=
str
,
choices
=
Dataset
.
INDEXING_TECHNIQUE_LIST
,
nullable
=
False
,
location
=
'json'
)
location
=
'json'
)
parser
.
add_argument
(
'retrieval_model'
,
type
=
dict
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
dataset_id
=
str
(
dataset_id
)
dataset_id
=
str
(
dataset_id
)
tenant_id
=
str
(
tenant_id
)
tenant_id
=
str
(
tenant_id
)
...
@@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
...
@@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
location
=
'json'
)
parser
.
add_argument
(
'retrieval_model'
,
type
=
dict
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
dataset_id
=
str
(
dataset_id
)
dataset_id
=
str
(
dataset_id
)
tenant_id
=
str
(
tenant_id
)
tenant_id
=
str
(
tenant_id
)
...
...
api/core/agent/agent/multi_dataset_router_agent.py
View file @
4588831b
...
@@ -14,7 +14,6 @@ from pydantic import root_validator
...
@@ -14,7 +14,6 @@ from pydantic import root_validator
from
core.model_providers.models.entity.message
import
to_prompt_messages
from
core.model_providers.models.entity.message
import
to_prompt_messages
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.third_party.langchain.llms.fake
import
FakeLLM
from
core.third_party.langchain.llms.fake
import
FakeLLM
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
class
MultiDatasetRouterAgent
(
OpenAIFunctionsAgent
):
class
MultiDatasetRouterAgent
(
OpenAIFunctionsAgent
):
...
@@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
...
@@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return
AgentFinish
(
return_values
=
{
"output"
:
''
},
log
=
''
)
return
AgentFinish
(
return_values
=
{
"output"
:
''
},
log
=
''
)
elif
len
(
self
.
tools
)
==
1
:
elif
len
(
self
.
tools
)
==
1
:
tool
=
next
(
iter
(
self
.
tools
))
tool
=
next
(
iter
(
self
.
tools
))
tool
=
cast
(
DatasetRetrieverTool
,
tool
)
rst
=
tool
.
run
(
tool_input
=
{
'query'
:
kwargs
[
'input'
]})
rst
=
tool
.
run
(
tool_input
=
{
'query'
:
kwargs
[
'input'
]})
# output = ''
# output = ''
# rst_json = json.loads(rst)
# rst_json = json.loads(rst)
...
...
api/core/agent/agent/output_parser/retirver_dataset_agent.py
0 → 100644
View file @
4588831b
import
json
from
typing
import
Tuple
,
List
,
Any
,
Union
,
Sequence
,
Optional
,
cast
from
langchain.agents
import
OpenAIFunctionsAgent
,
BaseSingleActionAgent
from
langchain.agents.openai_functions_agent.base
import
_format_intermediate_steps
,
_parse_ai_message
from
langchain.callbacks.base
import
BaseCallbackManager
from
langchain.callbacks.manager
import
Callbacks
from
langchain.prompts.chat
import
BaseMessagePromptTemplate
from
langchain.schema
import
AgentAction
,
AgentFinish
,
SystemMessage
,
Generation
,
LLMResult
,
AIMessage
from
langchain.schema.language_model
import
BaseLanguageModel
from
langchain.tools
import
BaseTool
from
pydantic
import
root_validator
from
core.model_providers.models.entity.message
import
to_prompt_messages
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.third_party.langchain.llms.fake
import
FakeLLM
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
class
MultiDatasetRouterAgent
(
OpenAIFunctionsAgent
):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance
:
BaseLLM
class
Config
:
"""Configuration for this pydantic object."""
arbitrary_types_allowed
=
True
@
root_validator
def
validate_llm
(
cls
,
values
:
dict
)
->
dict
:
return
values
def
should_use_agent
(
self
,
query
:
str
):
"""
return should use agent
:param query:
:return:
"""
return
True
def
plan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if
len
(
self
.
tools
)
==
0
:
return
AgentFinish
(
return_values
=
{
"output"
:
''
},
log
=
''
)
elif
len
(
self
.
tools
)
==
1
:
tool
=
next
(
iter
(
self
.
tools
))
tool
=
cast
(
DatasetRetrieverTool
,
tool
)
rst
=
tool
.
run
(
tool_input
=
{
'query'
:
kwargs
[
'input'
]})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return
AgentFinish
(
return_values
=
{
"output"
:
rst
},
log
=
rst
)
if
intermediate_steps
:
_
,
observation
=
intermediate_steps
[
-
1
]
return
AgentFinish
(
return_values
=
{
"output"
:
observation
},
log
=
observation
)
try
:
agent_decision
=
self
.
real_plan
(
intermediate_steps
,
callbacks
,
**
kwargs
)
if
isinstance
(
agent_decision
,
AgentAction
):
tool_inputs
=
agent_decision
.
tool_input
if
isinstance
(
tool_inputs
,
dict
)
and
'query'
in
tool_inputs
and
'chat_history'
not
in
kwargs
:
tool_inputs
[
'query'
]
=
kwargs
[
'input'
]
agent_decision
.
tool_input
=
tool_inputs
else
:
agent_decision
.
return_values
[
'output'
]
=
''
return
agent_decision
except
Exception
as
e
:
new_exception
=
self
.
model_instance
.
handle_exceptions
(
e
)
raise
new_exception
def
real_plan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad
=
_format_intermediate_steps
(
intermediate_steps
)
selected_inputs
=
{
k
:
kwargs
[
k
]
for
k
in
self
.
prompt
.
input_variables
if
k
!=
"agent_scratchpad"
}
full_inputs
=
dict
(
**
selected_inputs
,
agent_scratchpad
=
agent_scratchpad
)
prompt
=
self
.
prompt
.
format_prompt
(
**
full_inputs
)
messages
=
prompt
.
to_messages
()
prompt_messages
=
to_prompt_messages
(
messages
)
result
=
self
.
model_instance
.
run
(
messages
=
prompt_messages
,
functions
=
self
.
functions
,
)
ai_message
=
AIMessage
(
content
=
result
.
content
,
additional_kwargs
=
{
'function_call'
:
result
.
function_call
}
)
agent_decision
=
_parse_ai_message
(
ai_message
)
return
agent_decision
async
def
aplan
(
self
,
intermediate_steps
:
List
[
Tuple
[
AgentAction
,
str
]],
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
Union
[
AgentAction
,
AgentFinish
]:
raise
NotImplementedError
()
@
classmethod
def
from_llm_and_tools
(
cls
,
model_instance
:
BaseLLM
,
tools
:
Sequence
[
BaseTool
],
callback_manager
:
Optional
[
BaseCallbackManager
]
=
None
,
extra_prompt_messages
:
Optional
[
List
[
BaseMessagePromptTemplate
]]
=
None
,
system_message
:
Optional
[
SystemMessage
]
=
SystemMessage
(
content
=
"You are a helpful AI assistant."
),
**
kwargs
:
Any
,
)
->
BaseSingleActionAgent
:
prompt
=
cls
.
create_prompt
(
extra_prompt_messages
=
extra_prompt_messages
,
system_message
=
system_message
,
)
return
cls
(
model_instance
=
model_instance
,
llm
=
FakeLLM
(
response
=
''
),
prompt
=
prompt
,
tools
=
tools
,
callback_manager
=
callback_manager
,
**
kwargs
,
)
api/core/agent/agent/structed_multi_dataset_router_agent.py
View file @
4588831b
...
@@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
...
@@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return
AgentFinish
(
return_values
=
{
"output"
:
''
},
log
=
''
)
return
AgentFinish
(
return_values
=
{
"output"
:
''
},
log
=
''
)
elif
len
(
self
.
dataset_tools
)
==
1
:
elif
len
(
self
.
dataset_tools
)
==
1
:
tool
=
next
(
iter
(
self
.
dataset_tools
))
tool
=
next
(
iter
(
self
.
dataset_tools
))
tool
=
cast
(
DatasetRetrieverTool
,
tool
)
rst
=
tool
.
run
(
tool_input
=
{
'query'
:
kwargs
[
'input'
]})
rst
=
tool
.
run
(
tool_input
=
{
'query'
:
kwargs
[
'input'
]})
return
AgentFinish
(
return_values
=
{
"output"
:
rst
},
log
=
rst
)
return
AgentFinish
(
return_values
=
{
"output"
:
rst
},
log
=
rst
)
...
...
api/core/agent/agent_executor.py
View file @
4588831b
...
@@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor
...
@@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor
from
core.helper
import
moderation
from
core.helper
import
moderation
from
core.model_providers.error
import
LLMError
from
core.model_providers.error
import
LLMError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.tool.dataset_multi_retriever_tool
import
DatasetMultiRetrieverTool
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
...
@@ -78,7 +79,7 @@ class AgentExecutor:
...
@@ -78,7 +79,7 @@ class AgentExecutor:
verbose
=
True
verbose
=
True
)
)
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
ROUTER
:
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
ROUTER
:
self
.
configuration
.
tools
=
[
t
for
t
in
self
.
configuration
.
tools
if
isinstance
(
t
,
DatasetRetrieverTool
)]
self
.
configuration
.
tools
=
[
t
for
t
in
self
.
configuration
.
tools
if
isinstance
(
t
,
DatasetRetrieverTool
)
or
isinstance
(
t
,
DatasetMultiRetrieverTool
)
]
agent
=
MultiDatasetRouterAgent
.
from_llm_and_tools
(
agent
=
MultiDatasetRouterAgent
.
from_llm_and_tools
(
model_instance
=
self
.
configuration
.
model_instance
,
model_instance
=
self
.
configuration
.
model_instance
,
tools
=
self
.
configuration
.
tools
,
tools
=
self
.
configuration
.
tools
,
...
@@ -86,7 +87,7 @@ class AgentExecutor:
...
@@ -86,7 +87,7 @@ class AgentExecutor:
verbose
=
True
verbose
=
True
)
)
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
REACT_ROUTER
:
elif
self
.
configuration
.
strategy
==
PlanningStrategy
.
REACT_ROUTER
:
self
.
configuration
.
tools
=
[
t
for
t
in
self
.
configuration
.
tools
if
isinstance
(
t
,
DatasetRetrieverTool
)]
self
.
configuration
.
tools
=
[
t
for
t
in
self
.
configuration
.
tools
if
isinstance
(
t
,
DatasetRetrieverTool
)
or
isinstance
(
t
,
DatasetMultiRetrieverTool
)
]
agent
=
StructuredMultiDatasetRouterAgent
.
from_llm_and_tools
(
agent
=
StructuredMultiDatasetRouterAgent
.
from_llm_and_tools
(
model_instance
=
self
.
configuration
.
model_instance
,
model_instance
=
self
.
configuration
.
model_instance
,
tools
=
self
.
configuration
.
tools
,
tools
=
self
.
configuration
.
tools
,
...
...
api/core/callback_handler/index_tool_callback_handler.py
View file @
4588831b
...
@@ -10,8 +10,7 @@ from models.dataset import DocumentSegment
...
@@ -10,8 +10,7 @@ from models.dataset import DocumentSegment
class
DatasetIndexToolCallbackHandler
:
class
DatasetIndexToolCallbackHandler
:
"""Callback handler for dataset tool."""
"""Callback handler for dataset tool."""
def
__init__
(
self
,
dataset_id
:
str
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
def
__init__
(
self
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
self
.
dataset_id
=
dataset_id
self
.
conversation_message_task
=
conversation_message_task
self
.
conversation_message_task
=
conversation_message_task
def
on_tool_end
(
self
,
documents
:
List
[
Document
])
->
None
:
def
on_tool_end
(
self
,
documents
:
List
[
Document
])
->
None
:
...
@@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler:
...
@@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler:
# add hit count to document segment
# add hit count to document segment
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
dataset_id
,
DocumentSegment
.
index_node_id
==
doc_id
DocumentSegment
.
index_node_id
==
doc_id
)
.
update
(
)
.
update
(
{
DocumentSegment
.
hit_count
:
DocumentSegment
.
hit_count
+
1
},
{
DocumentSegment
.
hit_count
:
DocumentSegment
.
hit_count
+
1
},
...
...
api/core/completion.py
View file @
4588831b
...
@@ -127,6 +127,7 @@ class Completion:
...
@@ -127,6 +127,7 @@ class Completion:
memory
=
memory
,
memory
=
memory
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
chain_callback
=
chain_callback
,
chain_callback
=
chain_callback
,
tenant_id
=
app
.
tenant_id
,
retriever_from
=
retriever_from
retriever_from
=
retriever_from
)
)
...
...
api/core/data_loader/file_extractor.py
View file @
4588831b
...
@@ -3,7 +3,7 @@ from pathlib import Path
...
@@ -3,7 +3,7 @@ from pathlib import Path
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
import
requests
import
requests
from
langchain.document_loaders
import
TextLoader
,
Docx2txtLoader
from
langchain.document_loaders
import
TextLoader
,
Docx2txtLoader
,
UnstructuredFileLoader
,
UnstructuredAPIFileLoader
from
langchain.schema
import
Document
from
langchain.schema
import
Document
from
core.data_loader.loader.csv_loader
import
CSVLoader
from
core.data_loader.loader.csv_loader
import
CSVLoader
...
@@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
...
@@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
class
FileExtractor
:
class
FileExtractor
:
@
classmethod
@
classmethod
def
load
(
cls
,
upload_file
:
UploadFile
,
return_text
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
def
load
(
cls
,
upload_file
:
UploadFile
,
return_text
:
bool
=
False
,
is_automatic
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
upload_file
.
key
)
.
suffix
suffix
=
Path
(
upload_file
.
key
)
.
suffix
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage
.
download
(
upload_file
.
key
,
file_path
)
storage
.
download
(
upload_file
.
key
,
file_path
)
return
cls
.
load_from_file
(
file_path
,
return_text
,
upload_file
)
return
cls
.
load_from_file
(
file_path
,
return_text
,
upload_file
,
is_automatic
)
@
classmethod
@
classmethod
def
load_from_url
(
cls
,
url
:
str
,
return_text
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
def
load_from_url
(
cls
,
url
:
str
,
return_text
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
...
@@ -44,24 +44,34 @@ class FileExtractor:
...
@@ -44,24 +44,34 @@ class FileExtractor:
@
classmethod
@
classmethod
def
load_from_file
(
cls
,
file_path
:
str
,
return_text
:
bool
=
False
,
def
load_from_file
(
cls
,
file_path
:
str
,
return_text
:
bool
=
False
,
upload_file
:
Optional
[
UploadFile
]
=
None
)
->
Union
[
List
[
Document
]
|
str
]:
upload_file
:
Optional
[
UploadFile
]
=
None
,
is_automatic
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
input_file
=
Path
(
file_path
)
input_file
=
Path
(
file_path
)
delimiter
=
'
\n
'
delimiter
=
'
\n
'
file_extension
=
input_file
.
suffix
.
lower
()
file_extension
=
input_file
.
suffix
.
lower
()
if
file_extension
==
'.xlsx'
:
if
is_automatic
:
loader
=
ExcelLoader
(
file_path
)
loader
=
UnstructuredFileLoader
(
elif
file_extension
==
'.pdf'
:
file_path
,
strategy
=
"hi_res"
,
mode
=
"elements"
loader
=
PdfLoader
(
file_path
,
upload_file
=
upload_file
)
)
elif
file_extension
in
[
'.md'
,
'.markdown'
]:
# loader = UnstructuredAPIFileLoader(
loader
=
MarkdownLoader
(
file_path
,
autodetect_encoding
=
True
)
# file_path=filenames[0],
elif
file_extension
in
[
'.htm'
,
'.html'
]:
# api_key="FAKE_API_KEY",
loader
=
HTMLLoader
(
file_path
)
# )
elif
file_extension
==
'.docx'
:
loader
=
Docx2txtLoader
(
file_path
)
elif
file_extension
==
'.csv'
:
loader
=
CSVLoader
(
file_path
,
autodetect_encoding
=
True
)
else
:
else
:
# txt
if
file_extension
==
'.xlsx'
:
loader
=
TextLoader
(
file_path
,
autodetect_encoding
=
True
)
loader
=
ExcelLoader
(
file_path
)
elif
file_extension
==
'.pdf'
:
loader
=
PdfLoader
(
file_path
,
upload_file
=
upload_file
)
elif
file_extension
in
[
'.md'
,
'.markdown'
]:
loader
=
MarkdownLoader
(
file_path
,
autodetect_encoding
=
True
)
elif
file_extension
in
[
'.htm'
,
'.html'
]:
loader
=
HTMLLoader
(
file_path
)
elif
file_extension
==
'.docx'
:
loader
=
Docx2txtLoader
(
file_path
)
elif
file_extension
==
'.csv'
:
loader
=
CSVLoader
(
file_path
,
autodetect_encoding
=
True
)
else
:
# txt
loader
=
TextLoader
(
file_path
,
autodetect_encoding
=
True
)
return
delimiter
.
join
([
document
.
page_content
for
document
in
loader
.
load
()])
if
return_text
else
loader
.
load
()
return
delimiter
.
join
([
document
.
page_content
for
document
in
loader
.
load
()])
if
return_text
else
loader
.
load
()
api/core/index/vector_index/base.py
View file @
4588831b
...
@@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex):
...
@@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex):
def
_get_vector_store_class
(
self
)
->
type
:
def
_get_vector_store_class
(
self
)
->
type
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
search_by_full_text_index
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
raise
NotImplementedError
def
search
(
def
search
(
self
,
query
:
str
,
self
,
query
:
str
,
**
kwargs
:
Any
**
kwargs
:
Any
...
...
api/core/index/vector_index/milvus_vector_index.py
View file @
4588831b
from
typing
import
Optional
,
ca
st
from
typing
import
cast
,
Any
,
Li
st
from
langchain.embeddings.base
import
Embeddings
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.schema
import
Document
from
langchain.vectorstores
import
VectorStore
,
milvus
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
,
root_validator
from
pydantic
import
BaseModel
,
root_validator
from
core.index.base
import
BaseIndex
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.milvus_vector_store
import
MilvusVectorStore
from
core.vector_store.milvus_vector_store
import
MilvusVectorStore
from
core.vector_store.weaviate_vector_store
import
WeaviateVectorStore
from
models.dataset
import
Dataset
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DatasetCollectionBinding
class
MilvusConfig
(
BaseModel
):
class
MilvusConfig
(
BaseModel
):
...
@@ -74,7 +72,7 @@ class MilvusVectorIndex(BaseVectorIndex):
...
@@ -74,7 +72,7 @@ class MilvusVectorIndex(BaseVectorIndex):
index_params
=
{
index_params
=
{
'metric_type'
:
'IP'
,
'metric_type'
:
'IP'
,
'index_type'
:
"HNSW"
,
'index_type'
:
"HNSW"
,
'params'
:
{
"M"
:
8
,
"efConstruction"
:
64
}
'params'
:
{
"M"
:
8
,
"efConstruction"
:
64
}
}
}
self
.
_vector_store
=
MilvusVectorStore
.
from_documents
(
self
.
_vector_store
=
MilvusVectorStore
.
from_documents
(
texts
,
texts
,
...
@@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex):
...
@@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex):
),
),
],
],
))
))
def
search_by_full_text_index
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
# milvus/zilliz doesn't support bm25 search
return
[]
api/core/index/vector_index/qdrant_vector_index.py
View file @
4588831b
...
@@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex):
...
@@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return
True
return
True
return
False
return
False
def
search_by_full_text_index
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
from
qdrant_client.http
import
models
return
vector_store
.
similarity_search_by_bm25
(
models
.
Filter
(
must
=
[
models
.
FieldCondition
(
key
=
"group_id"
,
match
=
models
.
MatchValue
(
value
=
self
.
dataset
.
id
),
),
models
.
FieldCondition
(
key
=
"page_content"
,
match
=
models
.
MatchText
(
text
=
query
),
)
],
),
kwargs
.
get
(
'top_k'
,
2
))
api/core/index/vector_index/weaviate_vector_index.py
View file @
4588831b
from
typing
import
Optional
,
cast
from
typing
import
Optional
,
cast
,
Any
,
List
import
requests
import
requests
import
weaviate
import
weaviate
...
@@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel):
...
@@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel):
class
WeaviateVectorIndex
(
BaseVectorIndex
):
class
WeaviateVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
WeaviateConfig
,
embeddings
:
Embeddings
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
WeaviateConfig
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
,
embeddings
)
super
()
.
__init__
(
dataset
,
embeddings
)
self
.
_client
=
self
.
_init_client
(
config
)
self
.
_client
=
self
.
_init_client
(
config
)
...
@@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex):
...
@@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex):
return
True
return
True
return
False
return
False
def
search_by_full_text_index
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
return
vector_store
.
similarity_search_by_bm25
(
query
,
kwargs
.
get
(
'top_k'
,
2
),
**
kwargs
)
api/core/indexing_runner.py
View file @
4588831b
...
@@ -49,14 +49,14 @@ class IndexingRunner:
...
@@ -49,14 +49,14 @@ class IndexingRunner:
if
not
dataset
:
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
raise
ValueError
(
"no dataset found"
)
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
# get the process rule
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
dataset_document
.
dataset_process_rule_id
)
.
\
filter
(
DatasetProcessRule
.
id
==
dataset_document
.
dataset_process_rule_id
)
.
\
first
()
first
()
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
# get splitter
# get splitter
splitter
=
self
.
_get_splitter
(
processing_rule
)
splitter
=
self
.
_get_splitter
(
processing_rule
)
...
@@ -380,7 +380,7 @@ class IndexingRunner:
...
@@ -380,7 +380,7 @@ class IndexingRunner:
"preview"
:
preview_texts
"preview"
:
preview_texts
}
}
def
_load_data
(
self
,
dataset_document
:
DatasetDocument
)
->
List
[
Document
]:
def
_load_data
(
self
,
dataset_document
:
DatasetDocument
,
automatic
:
bool
=
False
)
->
List
[
Document
]:
# load file
# load file
if
dataset_document
.
data_source_type
not
in
[
"upload_file"
,
"notion_import"
]:
if
dataset_document
.
data_source_type
not
in
[
"upload_file"
,
"notion_import"
]:
return
[]
return
[]
...
@@ -396,7 +396,7 @@ class IndexingRunner:
...
@@ -396,7 +396,7 @@ class IndexingRunner:
one_or_none
()
one_or_none
()
if
file_detail
:
if
file_detail
:
text_docs
=
FileExtractor
.
load
(
file_detail
)
text_docs
=
FileExtractor
.
load
(
file_detail
,
is_automatic
=
False
)
elif
dataset_document
.
data_source_type
==
'notion_import'
:
elif
dataset_document
.
data_source_type
==
'notion_import'
:
loader
=
NotionLoader
.
from_document
(
dataset_document
)
loader
=
NotionLoader
.
from_document
(
dataset_document
)
text_docs
=
loader
.
load
()
text_docs
=
loader
.
load
()
...
...
api/core/model_providers/model_factory.py
View file @
4588831b
...
@@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding
...
@@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding
from
core.model_providers.models.entity.model_params
import
ModelKwargs
,
ModelType
from
core.model_providers.models.entity.model_params
import
ModelKwargs
,
ModelType
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.moderation.base
import
BaseModeration
from
core.model_providers.models.moderation.base
import
BaseModeration
from
core.model_providers.models.reranking.base
import
BaseReranking
from
core.model_providers.models.speech2text.base
import
BaseSpeech2Text
from
core.model_providers.models.speech2text.base
import
BaseSpeech2Text
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.provider
import
TenantDefaultModel
from
models.provider
import
TenantDefaultModel
...
@@ -140,6 +141,44 @@ class ModelFactory:
...
@@ -140,6 +141,44 @@ class ModelFactory:
name
=
model_name
name
=
model_name
)
)
@
classmethod
def
get_reranking_model
(
cls
,
tenant_id
:
str
,
model_provider_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
)
->
Optional
[
BaseReranking
]:
"""
get reranking model.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:param model_name:
:return:
"""
if
model_provider_name
is
None
and
model_name
is
None
:
default_model
=
cls
.
get_default_model
(
tenant_id
,
ModelType
.
RERANKING
)
if
not
default_model
:
raise
LLMBadRequestError
(
f
"Default model is not available. "
f
"Please configure a Default Reranking Model "
f
"in the Settings -> Model Provider."
)
model_provider_name
=
default_model
.
provider_name
model_name
=
default_model
.
model_name
# get model provider
model_provider
=
ModelProviderFactory
.
get_preferred_model_provider
(
tenant_id
,
model_provider_name
)
if
not
model_provider
:
raise
ProviderTokenNotInitError
(
f
"Model {model_name} provider credentials is not initialized."
)
# init reranking model
model_class
=
model_provider
.
get_model_class
(
model_type
=
ModelType
.
RERANKING
)
return
model_class
(
model_provider
=
model_provider
,
name
=
model_name
)
@
classmethod
@
classmethod
def
get_speech2text_model
(
cls
,
def
get_speech2text_model
(
cls
,
tenant_id
:
str
,
tenant_id
:
str
,
...
...
api/core/model_providers/model_provider_factory.py
View file @
4588831b
...
@@ -72,6 +72,9 @@ class ModelProviderFactory:
...
@@ -72,6 +72,9 @@ class ModelProviderFactory:
elif
provider_name
==
'localai'
:
elif
provider_name
==
'localai'
:
from
core.model_providers.providers.localai_provider
import
LocalAIProvider
from
core.model_providers.providers.localai_provider
import
LocalAIProvider
return
LocalAIProvider
return
LocalAIProvider
elif
provider_name
==
'cohere'
:
from
core.model_providers.providers.cohere_provider
import
CohereProvider
return
CohereProvider
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
...
api/core/model_providers/models/entity/model_params.py
View file @
4588831b
...
@@ -17,7 +17,7 @@ class ModelType(enum.Enum):
...
@@ -17,7 +17,7 @@ class ModelType(enum.Enum):
IMAGE
=
'image'
IMAGE
=
'image'
VIDEO
=
'video'
VIDEO
=
'video'
MODERATION
=
'moderation'
MODERATION
=
'moderation'
RERANKING
=
'reranking'
@
staticmethod
@
staticmethod
def
value_of
(
value
):
def
value_of
(
value
):
for
member
in
ModelType
:
for
member
in
ModelType
:
...
...
api/core/model_providers/models/reranking/__init__.py
0 → 100644
View file @
4588831b
api/core/model_providers/models/reranking/base.py
0 → 100644
View file @
4588831b
from
abc
import
abstractmethod
from
typing
import
Any
,
Optional
,
List
from
langchain.schema
import
Document
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.providers.base
import
BaseModelProvider
import
logging
logger
=
logging
.
getLogger
(
__name__
)
class
BaseReranking
(
BaseProviderModel
):
name
:
str
type
:
ModelType
=
ModelType
.
RERANKING
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
client
:
Any
,
name
:
str
):
super
()
.
__init__
(
model_provider
,
client
)
self
.
name
=
name
@
property
def
base_model_name
(
self
)
->
str
:
"""
get base model name
:return: str
"""
return
self
.
name
@
abstractmethod
def
rerank
(
self
,
query
:
str
,
documents
:
List
[
Document
],
score_threshold
:
Optional
[
float
],
top_k
:
Optional
[
int
])
->
Optional
[
List
[
Document
]]:
raise
NotImplementedError
@
abstractmethod
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
raise
NotImplementedError
api/core/model_providers/models/reranking/cohere_reranking.py
0 → 100644
View file @
4588831b
import
logging
from
typing
import
Optional
,
List
import
cohere
import
openai
from
langchain.schema
import
Document
from
core.model_providers.error
import
LLMBadRequestError
,
LLMAPIConnectionError
,
LLMAPIUnavailableError
,
\
LLMRateLimitError
,
LLMAuthorizationError
from
core.model_providers.models.reranking.base
import
BaseReranking
from
core.model_providers.providers.base
import
BaseModelProvider
class
CohereReranking
(
BaseReranking
):
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
name
:
str
):
self
.
credentials
=
model_provider
.
get_model_credentials
(
model_name
=
name
,
model_type
=
self
.
type
)
client
=
cohere
.
Client
(
self
.
credentials
.
get
(
'api_key'
))
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
rerank
(
self
,
query
:
str
,
documents
:
List
[
Document
],
score_threshold
:
Optional
[
float
],
top_k
:
Optional
[
int
])
->
Optional
[
List
[
Document
]]:
docs
=
[]
doc_id
=
[]
for
document
in
documents
:
if
document
.
metadata
[
'doc_id'
]
not
in
doc_id
:
doc_id
.
append
(
document
.
metadata
[
'doc_id'
])
docs
.
append
(
document
.
page_content
)
results
=
self
.
client
.
rerank
(
query
=
query
,
documents
=
docs
,
model
=
self
.
name
,
top_n
=
top_k
)
rerank_documents
=
[]
for
idx
,
result
in
enumerate
(
results
):
# format document
rerank_document
=
Document
(
page_content
=
result
.
document
[
'text'
],
metadata
=
{
"doc_id"
:
documents
[
result
.
index
]
.
metadata
[
'doc_id'
],
"doc_hash"
:
documents
[
result
.
index
]
.
metadata
[
'doc_hash'
],
"document_id"
:
documents
[
result
.
index
]
.
metadata
[
'document_id'
],
"dataset_id"
:
documents
[
result
.
index
]
.
metadata
[
'dataset_id'
],
'score'
:
result
.
relevance_score
}
)
# score threshold check
if
score_threshold
is
not
None
:
if
result
.
relevance_score
>=
score_threshold
:
rerank_documents
.
append
(
rerank_document
)
else
:
rerank_documents
.
append
(
rerank_document
)
return
rerank_documents
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
if
isinstance
(
ex
,
openai
.
error
.
InvalidRequestError
):
logging
.
warning
(
"Invalid request to OpenAI API."
)
return
LLMBadRequestError
(
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
APIConnectionError
):
logging
.
warning
(
"Failed to connect to OpenAI API."
)
return
LLMAPIConnectionError
(
ex
.
__class__
.
__name__
+
":"
+
str
(
ex
))
elif
isinstance
(
ex
,
(
openai
.
error
.
APIError
,
openai
.
error
.
ServiceUnavailableError
,
openai
.
error
.
Timeout
)):
logging
.
warning
(
"OpenAI service unavailable."
)
return
LLMAPIUnavailableError
(
ex
.
__class__
.
__name__
+
":"
+
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
RateLimitError
):
return
LLMRateLimitError
(
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
AuthenticationError
):
return
LLMAuthorizationError
(
str
(
ex
))
elif
isinstance
(
ex
,
openai
.
error
.
OpenAIError
):
return
LLMBadRequestError
(
ex
.
__class__
.
__name__
+
":"
+
str
(
ex
))
else
:
return
ex
api/core/model_providers/providers/cohere_provider.py
0 → 100644
View file @
4588831b
import
json
from
json
import
JSONDecodeError
from
typing
import
Type
from
langchain.schema
import
HumanMessage
from
core.helper
import
encrypter
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelKwargsRules
,
KwargRule
,
ModelType
,
ModelMode
from
core.model_providers.models.reranking.cohere_reranking
import
CohereReranking
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
models.provider
import
ProviderType
class
CohereProvider
(
BaseModelProvider
):
@
property
def
provider_name
(
self
):
"""
Returns the name of a provider.
"""
return
'cohere'
def
_get_text_generation_model_mode
(
self
,
model_name
)
->
str
:
return
ModelMode
.
CHAT
.
value
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
if
model_type
==
ModelType
.
RERANKING
:
return
[
{
'id'
:
'rerank-english-v2.0'
,
'name'
:
'rerank-english-v2.0'
},
{
'id'
:
'rerank-multilingual-v2.0'
,
'name'
:
'rerank-multilingual-v2.0'
}
]
else
:
return
[]
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
"""
Returns the model class.
:param model_type:
:return:
"""
if
model_type
==
ModelType
.
RERANKING
:
model_class
=
CohereReranking
else
:
raise
NotImplementedError
return
model_class
def
get_model_parameter_rules
(
self
,
model_name
:
str
,
model_type
:
ModelType
)
->
ModelKwargsRules
:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.3
,
precision
=
2
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
0.99
,
default
=
0.85
,
precision
=
2
),
presence_penalty
=
KwargRule
[
float
](
enabled
=
False
),
frequency_penalty
=
KwargRule
[
float
](
enabled
=
False
),
max_tokens
=
KwargRule
[
int
](
enabled
=
False
),
)
@
classmethod
def
is_provider_credentials_valid_or_raise
(
cls
,
credentials
:
dict
):
"""
Validates the given credentials.
"""
if
'api_key'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Cohere api_key must be provided.'
)
try
:
credential_kwargs
=
{
'api_key'
:
credentials
[
'api_key'
],
}
# todo validate
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
@
classmethod
def
encrypt_provider_credentials
(
cls
,
tenant_id
:
str
,
credentials
:
dict
)
->
dict
:
credentials
[
'api_key'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'api_key'
])
return
credentials
def
get_provider_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
dict
:
if
self
.
provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
:
try
:
credentials
=
json
.
loads
(
self
.
provider
.
encrypted_config
)
except
JSONDecodeError
:
credentials
=
{
'api_key'
:
None
,
}
if
credentials
[
'api_key'
]:
credentials
[
'api_key'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
credentials
[
'api_key'
]
)
if
obfuscated
:
credentials
[
'api_key'
]
=
encrypter
.
obfuscated_token
(
credentials
[
'api_key'
])
return
credentials
else
:
return
{}
def
should_deduct_quota
(
self
):
return
True
@
classmethod
def
is_model_credentials_valid_or_raise
(
cls
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@
classmethod
def
encrypt_model_credentials
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
)
->
dict
:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return
{}
def
get_model_credentials
(
self
,
model_name
:
str
,
model_type
:
ModelType
,
obfuscated
:
bool
=
False
)
->
dict
:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return
self
.
get_provider_credentials
(
obfuscated
)
api/core/model_providers/rules/_providers.json
View file @
4588831b
...
@@ -13,5 +13,6 @@
...
@@ -13,5 +13,6 @@
"huggingface_hub"
,
"huggingface_hub"
,
"xinference"
,
"xinference"
,
"openllm"
,
"openllm"
,
"localai"
"localai"
,
"cohere"
]
]
api/core/model_providers/rules/cohere.json
0 → 100644
View file @
4588831b
{
"support_provider_types"
:
[
"custom"
],
"system_config"
:
null
,
"model_flexibility"
:
"fixed"
}
\ No newline at end of file
api/core/orchestrator_rule_parser.py
View file @
4588831b
from
typing
import
Optional
import
json
import
threading
from
typing
import
Optional
,
List
from
flask
import
Flask
from
langchain
import
WikipediaAPIWrapper
from
langchain
import
WikipediaAPIWrapper
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.tools
import
BaseTool
,
Tool
,
WikipediaQueryRun
from
langchain.tools
import
BaseTool
,
Tool
,
WikipediaQueryRun
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
from
core.agent.agent.multi_dataset_router_agent
import
MultiDatasetRouterAgent
from
core.agent.agent.output_parser.structured_chat
import
StructuredChatOutputParser
from
core.agent.agent.structed_multi_dataset_router_agent
import
StructuredMultiDatasetRouterAgent
from
core.agent.agent_executor
import
AgentExecutor
,
PlanningStrategy
,
AgentConfiguration
from
core.agent.agent_executor
import
AgentExecutor
,
PlanningStrategy
,
AgentConfiguration
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
...
@@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory
...
@@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory
from
core.model_providers.models.entity.model_params
import
ModelKwargs
,
ModelMode
from
core.model_providers.models.entity.model_params
import
ModelKwargs
,
ModelMode
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.tool.current_datetime_tool
import
DatetimeTool
from
core.tool.current_datetime_tool
import
DatetimeTool
from
core.tool.dataset_multi_retriever_tool
import
DatasetMultiRetrieverTool
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
from
core.tool.provider.serpapi_provider
import
SerpAPIToolProvider
from
core.tool.provider.serpapi_provider
import
SerpAPIToolProvider
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
,
OptimizedSerpAPIInput
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
,
OptimizedSerpAPIInput
...
@@ -25,6 +32,16 @@ from extensions.ext_database import db
...
@@ -25,6 +32,16 @@ from extensions.ext_database import db
from
models.dataset
import
Dataset
,
DatasetProcessRule
from
models.dataset
import
Dataset
,
DatasetProcessRule
from
models.model
import
AppModelConfig
from
models.model
import
AppModelConfig
default_retrieval_model
=
{
'search_method'
:
'semantic_search'
,
'reranking_enable'
:
False
,
'reranking_model'
:
{
'reranking_provider_name'
:
''
,
'reranking_model_name'
:
''
},
'top_k'
:
2
,
'score_threshold_enable'
:
False
}
class
OrchestratorRuleParser
:
class
OrchestratorRuleParser
:
"""Parse the orchestrator rule to entities."""
"""Parse the orchestrator rule to entities."""
...
@@ -34,7 +51,7 @@ class OrchestratorRuleParser:
...
@@ -34,7 +51,7 @@ class OrchestratorRuleParser:
self
.
app_model_config
=
app_model_config
self
.
app_model_config
=
app_model_config
def
to_agent_executor
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
def
to_agent_executor
(
self
,
conversation_message_task
:
ConversationMessageTask
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
chain_callback
:
MainChainGatherCallbackHandler
,
rest_tokens
:
int
,
chain_callback
:
MainChainGatherCallbackHandler
,
tenant_id
:
str
,
retriever_from
:
str
=
'dev'
)
->
Optional
[
AgentExecutor
]:
retriever_from
:
str
=
'dev'
)
->
Optional
[
AgentExecutor
]:
if
not
self
.
app_model_config
.
agent_mode_dict
:
if
not
self
.
app_model_config
.
agent_mode_dict
:
return
None
return
None
...
@@ -101,7 +118,8 @@ class OrchestratorRuleParser:
...
@@ -101,7 +118,8 @@ class OrchestratorRuleParser:
rest_tokens
=
rest_tokens
,
rest_tokens
=
rest_tokens
,
return_resource
=
return_resource
,
return_resource
=
return_resource
,
retriever_from
=
retriever_from
,
retriever_from
=
retriever_from
,
dataset_configs
=
dataset_configs
dataset_configs
=
dataset_configs
,
tenant_id
=
tenant_id
)
)
if
len
(
tools
)
==
0
:
if
len
(
tools
)
==
0
:
...
@@ -123,7 +141,7 @@ class OrchestratorRuleParser:
...
@@ -123,7 +141,7 @@ class OrchestratorRuleParser:
return
chain
return
chain
def
to_tools
(
self
,
tool_configs
:
list
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
->
list
[
BaseTool
]:
def
to_tools
(
self
,
tool_configs
:
list
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
->
list
[
BaseTool
]:
"""
"""
Convert app agent tool configs to tools
Convert app agent tool configs to tools
...
@@ -132,6 +150,7 @@ class OrchestratorRuleParser:
...
@@ -132,6 +150,7 @@ class OrchestratorRuleParser:
:return:
:return:
"""
"""
tools
=
[]
tools
=
[]
dataset_tools
=
[]
for
tool_config
in
tool_configs
:
for
tool_config
in
tool_configs
:
tool_type
=
list
(
tool_config
.
keys
())[
0
]
tool_type
=
list
(
tool_config
.
keys
())[
0
]
tool_val
=
list
(
tool_config
.
values
())[
0
]
tool_val
=
list
(
tool_config
.
values
())[
0
]
...
@@ -140,7 +159,7 @@ class OrchestratorRuleParser:
...
@@ -140,7 +159,7 @@ class OrchestratorRuleParser:
tool
=
None
tool
=
None
if
tool_type
==
"dataset"
:
if
tool_type
==
"dataset"
:
tool
=
self
.
to_dataset_retriever_tool
(
tool_config
=
tool_val
,
**
kwargs
)
dataset_tools
.
append
(
tool_config
)
elif
tool_type
==
"web_reader"
:
elif
tool_type
==
"web_reader"
:
tool
=
self
.
to_web_reader_tool
(
tool_config
=
tool_val
,
**
kwargs
)
tool
=
self
.
to_web_reader_tool
(
tool_config
=
tool_val
,
**
kwargs
)
elif
tool_type
==
"google_search"
:
elif
tool_type
==
"google_search"
:
...
@@ -156,57 +175,81 @@ class OrchestratorRuleParser:
...
@@ -156,57 +175,81 @@ class OrchestratorRuleParser:
else
:
else
:
tool
.
callbacks
=
callbacks
tool
.
callbacks
=
callbacks
tools
.
append
(
tool
)
tools
.
append
(
tool
)
# format dataset tool
if
len
(
dataset_tools
)
>
0
:
dataset_retriever_tools
=
self
.
to_dataset_retriever_tool
(
tool_configs
=
dataset_tools
,
**
kwargs
)
if
dataset_retriever_tools
:
tools
.
extend
(
dataset_retriever_tools
)
return
tools
return
tools
def
to_dataset_retriever_tool
(
self
,
tool_config
:
dict
,
conversation_message_task
:
ConversationMessageTask
,
def
to_dataset_retriever_tool
(
self
,
tool_configs
:
List
,
conversation_message_task
:
ConversationMessageTask
,
dataset_configs
:
dict
,
rest_tokens
:
int
,
return_resource
:
bool
=
False
,
retriever_from
:
str
=
'dev'
,
return_resource
:
bool
=
False
,
retriever_from
:
str
=
'dev'
,
**
kwargs
)
\
**
kwargs
)
\
->
Optional
[
BaseTool
]:
->
Optional
[
List
[
BaseTool
]
]:
"""
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_configs:
:param tool_config:
:param dataset_configs:
:param conversation_message_task:
:param conversation_message_task:
:param return_resource:
:param return_resource:
:param retriever_from:
:param retriever_from:
:return:
:return:
"""
"""
# get dataset from dataset id
dataset_configs
=
kwargs
[
'dataset_configs'
]
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
retrieval_model
=
dataset_configs
.
get
(
'retrieval_model'
,
'single'
)
Dataset
.
tenant_id
==
self
.
tenant_id
,
tools
=
[]
Dataset
.
id
==
tool_config
.
get
(
"id"
)
dataset_ids
=
[]
)
.
first
()
tenant_id
=
None
for
tool_config
in
tool_configs
:
if
not
dataset
:
# get dataset from dataset id
return
None
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
self
.
tenant_id
,
if
dataset
and
dataset
.
available_document_count
==
0
and
dataset
.
available_document_count
==
0
:
Dataset
.
id
==
tool_config
.
get
(
'dataset'
)
.
get
(
"id"
)
return
None
)
.
first
()
top_k
=
dataset_configs
.
get
(
"top_k"
,
2
)
# dynamically adjust top_k when the remaining token number is not enough to support top_k
top_k
=
self
.
_dynamic_calc_retrieve_k
(
dataset
=
dataset
,
top_k
=
top_k
,
rest_tokens
=
rest_tokens
)
score_threshold
=
None
if
not
dataset
:
score_threshold_config
=
dataset_configs
.
get
(
"score_threshold"
)
return
None
if
score_threshold_config
and
score_threshold_config
.
get
(
"enable"
):
score_threshold
=
score_threshold_config
.
get
(
"value"
)
tool
=
DatasetRetrieverTool
.
from_dataset
(
if
dataset
and
dataset
.
available_document_count
==
0
and
dataset
.
available_document_count
==
0
:
dataset
=
dataset
,
return
None
top_k
=
top_k
,
dataset_ids
.
append
(
dataset
.
id
)
score_threshold
=
score_threshold
,
if
retrieval_model
==
'single'
:
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
)],
retrieval_model
=
dataset
.
retrieval_model
if
dataset
.
retrieval_model
else
default_retrieval_model
conversation_message_task
=
conversation_message_task
,
top_k
=
retrieval_model
[
'top_k'
]
return_resource
=
return_resource
,
retriever_from
=
retriever_from
# dynamically adjust top_k when the remaining token number is not enough to support top_k
)
# top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold
=
None
score_threshold_enable
=
retrieval_model
.
get
(
"score_threshold_enable"
)
if
score_threshold_enable
:
score_threshold
=
retrieval_model
.
get
(
"score_threshold"
)
tool
=
DatasetRetrieverTool
.
from_dataset
(
dataset
=
dataset
,
top_k
=
top_k
,
score_threshold
=
score_threshold
,
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
)],
conversation_message_task
=
conversation_message_task
,
return_resource
=
return_resource
,
retriever_from
=
retriever_from
)
tools
.
append
(
tool
)
if
retrieval_model
==
'multiple'
:
tool
=
DatasetMultiRetrieverTool
.
from_dataset
(
dataset_ids
=
dataset_ids
,
tenant_id
=
kwargs
[
'tenant_id'
],
top_k
=
dataset_configs
.
get
(
'top_k'
,
2
),
score_threshold
=
dataset_configs
.
get
(
'score_threshold'
,
0.5
)
if
dataset_configs
.
get
(
'score_threshold_enable'
,
False
)
else
None
,
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
)],
conversation_message_task
=
conversation_message_task
,
return_resource
=
return_resource
,
retriever_from
=
retriever_from
,
reranking_provider_name
=
dataset_configs
.
get
(
'reranking_model'
)
.
get
(
'reranking_provider_name'
),
reranking_model_name
=
dataset_configs
.
get
(
'reranking_model'
)
.
get
(
'reranking_model_name'
)
)
tools
.
append
(
tool
)
return
tool
return
tool
s
def
to_web_reader_tool
(
self
,
tool_config
:
dict
,
agent_model_instance
:
BaseLLM
,
**
kwargs
)
->
Optional
[
BaseTool
]:
def
to_web_reader_tool
(
self
,
tool_config
:
dict
,
agent_model_instance
:
BaseLLM
,
**
kwargs
)
->
Optional
[
BaseTool
]:
"""
"""
...
...
api/core/tool/dataset_multi_retriever_tool.py
0 → 100644
View file @
4588831b
import
json
import
threading
from
typing
import
Type
,
Optional
,
List
from
flask
import
current_app
,
Flask
from
langchain.tools
import
BaseTool
from
pydantic
import
Field
,
BaseModel
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.conversation_message_task
import
ConversationMessageTask
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.model_providers.error
import
LLMBadRequestError
,
ProviderTokenNotInitError
from
core.model_providers.model_factory
import
ModelFactory
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DocumentSegment
,
Document
from
services.retrieval_service
import
RetrievalService
default_retrieval_model
=
{
'search_method'
:
'semantic_search'
,
'reranking_enable'
:
False
,
'reranking_model'
:
{
'reranking_provider_name'
:
''
,
'reranking_model_name'
:
''
},
'top_k'
:
2
,
'score_threshold_enable'
:
False
}
class
DatasetMultiRetrieverToolInput
(
BaseModel
):
query
:
str
=
Field
(
...
,
description
=
"dataset multi retriever and rerank"
)
class
DatasetMultiRetrieverTool
(
BaseTool
):
"""Tool for querying multi dataset."""
name
:
str
=
"dataset-"
args_schema
:
Type
[
BaseModel
]
=
DatasetMultiRetrieverToolInput
description
:
str
=
"dataset multi retriever and rerank. "
tenant_id
:
str
dataset_ids
:
List
[
str
]
top_k
:
int
=
2
score_threshold
:
Optional
[
float
]
=
None
reranking_provider_name
:
str
reranking_model_name
:
str
conversation_message_task
:
ConversationMessageTask
return_resource
:
bool
retriever_from
:
str
@
classmethod
def
from_dataset
(
cls
,
dataset_ids
:
List
[
str
],
tenant_id
:
str
,
**
kwargs
):
return
cls
(
name
=
f
'dataset-{tenant_id}'
,
tenant_id
=
tenant_id
,
dataset_ids
=
dataset_ids
,
**
kwargs
)
def
_run
(
self
,
query
:
str
)
->
str
:
threads
=
[]
all_documents
=
[]
for
dataset_id
in
self
.
dataset_ids
:
retrieval_thread
=
threading
.
Thread
(
target
=
self
.
_retriever
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset_id'
:
dataset_id
,
'query'
:
query
,
'all_documents'
:
all_documents
})
threads
.
append
(
retrieval_thread
)
retrieval_thread
.
start
()
for
thread
in
threads
:
thread
.
join
()
# do rerank for searched documents
rerank
=
ModelFactory
.
get_reranking_model
(
tenant_id
=
self
.
tenant_id
,
model_provider_name
=
self
.
reranking_provider_name
,
model_name
=
self
.
reranking_model_name
)
all_documents
=
rerank
.
rerank
(
query
,
all_documents
,
self
.
score_threshold
,
self
.
top_k
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
conversation_message_task
)
hit_callback
.
on_tool_end
(
all_documents
)
document_context_list
=
[]
index_node_ids
=
[
document
.
metadata
[
'doc_id'
]
for
document
in
all_documents
]
segments
=
DocumentSegment
.
query
.
filter
(
DocumentSegment
.
completed_at
.
isnot
(
None
),
DocumentSegment
.
status
==
'completed'
,
DocumentSegment
.
enabled
==
True
,
DocumentSegment
.
index_node_id
.
in_
(
index_node_ids
)
)
.
all
()
if
segments
:
index_node_id_to_position
=
{
id
:
position
for
position
,
id
in
enumerate
(
index_node_ids
)}
sorted_segments
=
sorted
(
segments
,
key
=
lambda
segment
:
index_node_id_to_position
.
get
(
segment
.
index_node_id
,
float
(
'inf'
)))
for
segment
in
sorted_segments
:
if
segment
.
answer
:
document_context_list
.
append
(
f
'question:{segment.content} answer:{segment.answer}'
)
else
:
document_context_list
.
append
(
segment
.
content
)
if
self
.
return_resource
:
context_list
=
[]
resource_number
=
1
for
segment
in
sorted_segments
:
dataset
=
Dataset
.
query
.
filter_by
(
id
=
segment
.
dataset_id
)
.
first
()
document
=
Document
.
query
.
filter
(
Document
.
id
==
segment
.
document_id
,
Document
.
enabled
==
True
,
Document
.
archived
==
False
,
)
.
first
()
if
dataset
and
document
:
source
=
{
'position'
:
resource_number
,
'dataset_id'
:
dataset
.
id
,
'dataset_name'
:
dataset
.
name
,
'document_id'
:
document
.
id
,
'document_name'
:
document
.
name
,
'data_source_type'
:
document
.
data_source_type
,
'segment_id'
:
segment
.
id
,
'retriever_from'
:
self
.
retriever_from
}
if
self
.
retriever_from
==
'dev'
:
source
[
'hit_count'
]
=
segment
.
hit_count
source
[
'word_count'
]
=
segment
.
word_count
source
[
'segment_position'
]
=
segment
.
position
source
[
'index_node_hash'
]
=
segment
.
index_node_hash
if
segment
.
answer
:
source
[
'content'
]
=
f
'question:{segment.content}
\n
answer:{segment.answer}'
else
:
source
[
'content'
]
=
segment
.
content
context_list
.
append
(
source
)
resource_number
+=
1
hit_callback
.
return_retriever_resource_info
(
context_list
)
return
str
(
"
\n
"
.
join
(
document_context_list
))
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
raise
NotImplementedError
()
def
_retriever
(
self
,
flask_app
:
Flask
,
dataset_id
:
str
,
query
:
str
,
all_documents
:
List
):
with
flask_app
.
app_context
():
dataset
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
tenant_id
==
self
.
tenant_id
,
Dataset
.
id
==
dataset_id
)
.
first
()
if
not
dataset
:
return
[]
# get retrieval model , if the model is not setting , using default
retrieval_model
=
dataset
.
retrieval_model
if
dataset
.
retrieval_model
else
default_retrieval_model
if
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
kw_table_index
=
KeywordTableIndex
(
dataset
=
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
5
)
)
documents
=
kw_table_index
.
search
(
query
,
search_kwargs
=
{
'k'
:
self
.
top_k
})
if
documents
:
all_documents
.
extend
(
documents
)
else
:
try
:
embedding_model
=
ModelFactory
.
get_embedding_model
(
tenant_id
=
dataset
.
tenant_id
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
)
except
LLMBadRequestError
:
return
[]
except
ProviderTokenNotInitError
:
return
[]
embeddings
=
CacheEmbedding
(
embedding_model
)
documents
=
[]
threads
=
[]
if
self
.
top_k
>
0
:
# retrieval_model source with semantic
if
retrieval_model
[
'search_method'
]
==
'semantic_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
embedding_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
embedding_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset'
:
dataset
,
'query'
:
query
,
'top_k'
:
self
.
top_k
,
'score_threshold'
:
self
.
score_threshold
,
'reranking_model'
:
None
,
'all_documents'
:
documents
,
'search_method'
:
'hybrid_search'
,
'embeddings'
:
embeddings
})
threads
.
append
(
embedding_thread
)
embedding_thread
.
start
()
# retrieval_model source with full text
if
retrieval_model
[
'search_method'
]
==
'full_text_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
full_text_index_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
full_text_index_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset'
:
dataset
,
'query'
:
query
,
'search_method'
:
'hybrid_search'
,
'embeddings'
:
embeddings
,
'score_threshold'
:
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
'score_threshold_enable'
]
else
None
,
'top_k'
:
self
.
top_k
,
'reranking_model'
:
retrieval_model
[
'reranking_model'
]
if
retrieval_model
[
'reranking_enable'
]
else
None
,
'all_documents'
:
documents
})
threads
.
append
(
full_text_index_thread
)
full_text_index_thread
.
start
()
for
thread
in
threads
:
thread
.
join
()
all_documents
.
extend
(
documents
)
api/core/tool/dataset_retriever_tool.py
View file @
4588831b
import
json
import
json
from
typing
import
Type
,
Optional
import
threading
from
typing
import
Type
,
Optional
,
List
from
flask
import
current_app
from
flask
import
current_app
from
langchain.tools
import
BaseTool
from
langchain.tools
import
BaseTool
...
@@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE
...
@@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.model_factory
import
ModelFactory
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DocumentSegment
,
Document
from
models.dataset
import
Dataset
,
DocumentSegment
,
Document
from
services.retrieval_service
import
RetrievalService
default_retrieval_model
=
{
'search_method'
:
'semantic_search'
,
'reranking_enable'
:
False
,
'reranking_model'
:
{
'reranking_provider_name'
:
''
,
'reranking_model_name'
:
''
},
'top_k'
:
2
,
'score_threshold_enable'
:
False
}
class
DatasetRetrieverToolInput
(
BaseModel
):
class
DatasetRetrieverToolInput
(
BaseModel
):
...
@@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool):
...
@@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool):
)
.
first
()
)
.
first
()
if
not
dataset
:
if
not
dataset
:
return
f
'[{self.name} failed to find dataset with id {self.dataset_id}.]'
return
''
# get retrieval model , if the model is not setting , using default
retrieval_model
=
dataset
.
retrieval_model
if
dataset
.
retrieval_model
else
default_retrieval_model
if
dataset
.
indexing_technique
==
"economy"
:
if
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
# use keyword table query
...
@@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool):
...
@@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool):
return
''
return
''
embeddings
=
CacheEmbedding
(
embedding_model
)
embeddings
=
CacheEmbedding
(
embedding_model
)
vector_index
=
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
[]
threads
=
[]
if
self
.
top_k
>
0
:
if
self
.
top_k
>
0
:
documents
=
vector_index
.
search
(
# retrieval source with semantic
query
,
if
retrieval_model
[
'search_method'
]
==
'semantic_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
search_type
=
'similarity_score_threshold'
,
embedding_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
embedding_search
,
kwargs
=
{
search_kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'k'
:
self
.
top_k
,
'dataset'
:
dataset
,
'score_threshold'
:
self
.
score_threshold
,
'query'
:
query
,
'filter'
:
{
'top_k'
:
self
.
top_k
,
'group_id'
:
[
dataset
.
id
]
'score_threshold'
:
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
}
'score_threshold_enable'
]
else
None
,
}
'reranking_model'
:
retrieval_model
[
'reranking_model'
]
if
retrieval_model
[
)
'reranking_enable'
]
else
None
,
'all_documents'
:
documents
,
'search_method'
:
retrieval_model
[
'search_method'
],
'embeddings'
:
embeddings
})
threads
.
append
(
embedding_thread
)
embedding_thread
.
start
()
# retrieval_model source with full text
if
retrieval_model
[
'search_method'
]
==
'full_text_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
full_text_index_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
full_text_index_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset'
:
dataset
,
'query'
:
query
,
'search_method'
:
retrieval_model
[
'search_method'
],
'embeddings'
:
embeddings
,
'score_threshold'
:
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
'score_threshold_enable'
]
else
None
,
'top_k'
:
self
.
top_k
,
'reranking_model'
:
retrieval_model
[
'reranking_model'
]
if
retrieval_model
[
'reranking_enable'
]
else
None
,
'all_documents'
:
documents
})
threads
.
append
(
full_text_index_thread
)
full_text_index_thread
.
start
()
for
thread
in
threads
:
thread
.
join
()
# hybrid search: rerank after all documents have been searched
if
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
hybrid_rerank
=
ModelFactory
.
get_reranking_model
(
tenant_id
=
dataset
.
tenant_id
,
model_provider_name
=
retrieval_model
[
'reranking_model'
][
'reranking_provider_name'
],
model_name
=
retrieval_model
[
'reranking_model'
][
'reranking_model_name'
]
)
documents
=
hybrid_rerank
.
rerank
(
query
,
documents
,
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
'score_threshold_enable'
]
else
None
,
self
.
top_k
)
else
:
else
:
documents
=
[]
documents
=
[]
hit_callback
=
DatasetIndexToolCallbackHandler
(
dataset
.
id
,
self
.
conversation_message_task
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
conversation_message_task
)
hit_callback
.
on_tool_end
(
documents
)
hit_callback
.
on_tool_end
(
documents
)
document_score_list
=
{}
document_score_list
=
{}
if
dataset
.
indexing_technique
!=
"economy"
:
if
dataset
.
indexing_technique
!=
"economy"
:
...
...
api/core/vector_store/milvus_vector_store.py
View file @
4588831b
from
core.
index.vector_index
.milvus
import
Milvus
from
core.
vector_store.vector
.milvus
import
Milvus
class
MilvusVectorStore
(
Milvus
):
class
MilvusVectorStore
(
Milvus
):
...
...
api/core/vector_store/qdrant_vector_store.py
View file @
4588831b
...
@@ -4,7 +4,7 @@ from langchain.schema import Document
...
@@ -4,7 +4,7 @@ from langchain.schema import Document
from
qdrant_client.http.models
import
Filter
,
PointIdsList
,
FilterSelector
from
qdrant_client.http.models
import
Filter
,
PointIdsList
,
FilterSelector
from
qdrant_client.local.qdrant_local
import
QdrantLocal
from
qdrant_client.local.qdrant_local
import
QdrantLocal
from
core.
index.vector_index
.qdrant
import
Qdrant
from
core.
vector_store.vector
.qdrant
import
Qdrant
class
QdrantVectorStore
(
Qdrant
):
class
QdrantVectorStore
(
Qdrant
):
...
@@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant):
...
@@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant):
if
isinstance
(
self
.
client
,
QdrantLocal
):
if
isinstance
(
self
.
client
,
QdrantLocal
):
self
.
client
=
cast
(
QdrantLocal
,
self
.
client
)
self
.
client
=
cast
(
QdrantLocal
,
self
.
client
)
self
.
client
.
_load
()
self
.
client
.
_load
()
api/core/
index/vector_index
/milvus.py
→
api/core/
vector_store/vector
/milvus.py
View file @
4588831b
File moved
api/core/
index/vector_index
/qdrant.py
→
api/core/
vector_store/vector
/qdrant.py
View file @
4588831b
...
@@ -28,7 +28,7 @@ from langchain.docstore.document import Document
...
@@ -28,7 +28,7 @@ from langchain.docstore.document import Document
from
langchain.embeddings.base
import
Embeddings
from
langchain.embeddings.base
import
Embeddings
from
langchain.vectorstores
import
VectorStore
from
langchain.vectorstores
import
VectorStore
from
langchain.vectorstores.utils
import
maximal_marginal_relevance
from
langchain.vectorstores.utils
import
maximal_marginal_relevance
from
qdrant_client.http.models
import
PayloadSchemaType
from
qdrant_client.http.models
import
PayloadSchemaType
,
FilterSelector
,
TextIndexParams
,
TokenizerType
,
TextIndexType
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
qdrant_client
import
grpc
# noqa
from
qdrant_client
import
grpc
# noqa
...
@@ -189,14 +189,25 @@ class Qdrant(VectorStore):
...
@@ -189,14 +189,25 @@ class Qdrant(VectorStore):
texts
,
metadatas
,
ids
,
batch_size
texts
,
metadatas
,
ids
,
batch_size
):
):
self
.
client
.
upsert
(
self
.
client
.
upsert
(
collection_name
=
self
.
collection_name
,
points
=
points
,
**
kwargs
collection_name
=
self
.
collection_name
,
points
=
points
)
)
added_ids
.
extend
(
batch_ids
)
added_ids
.
extend
(
batch_ids
)
# if is new collection, create payload index on group_id
# if is new collection, create payload index on group_id
if
self
.
is_new_collection
:
if
self
.
is_new_collection
:
# create payload index
self
.
client
.
create_payload_index
(
self
.
collection_name
,
self
.
group_payload_key
,
self
.
client
.
create_payload_index
(
self
.
collection_name
,
self
.
group_payload_key
,
field_schema
=
PayloadSchemaType
.
KEYWORD
,
field_schema
=
PayloadSchemaType
.
KEYWORD
,
field_type
=
PayloadSchemaType
.
KEYWORD
)
field_type
=
PayloadSchemaType
.
KEYWORD
)
# creat full text index
text_index_params
=
TextIndexParams
(
type
=
TextIndexType
.
TEXT
,
tokenizer
=
TokenizerType
.
MULTILINGUAL
,
min_token_len
=
2
,
max_token_len
=
20
,
lowercase
=
True
)
self
.
client
.
create_payload_index
(
self
.
collection_name
,
self
.
content_payload_key
,
field_schema
=
text_index_params
)
return
added_ids
return
added_ids
@
sync_call_fallback
@
sync_call_fallback
...
@@ -600,7 +611,7 @@ class Qdrant(VectorStore):
...
@@ -600,7 +611,7 @@ class Qdrant(VectorStore):
limit
=
k
,
limit
=
k
,
offset
=
offset
,
offset
=
offset
,
with_payload
=
True
,
with_payload
=
True
,
with_vectors
=
True
,
# Langchain does not expect vectors to be returned
with_vectors
=
True
,
score_threshold
=
score_threshold
,
score_threshold
=
score_threshold
,
consistency
=
consistency
,
consistency
=
consistency
,
**
kwargs
,
**
kwargs
,
...
@@ -615,6 +626,39 @@ class Qdrant(VectorStore):
...
@@ -615,6 +626,39 @@ class Qdrant(VectorStore):
for
result
in
results
for
result
in
results
]
]
def
similarity_search_by_bm25
(
self
,
filter
:
Optional
[
MetadataFilter
]
=
None
,
k
:
int
=
4
)
->
List
[
Document
]:
"""Return docs most similar by bm25.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
search_params: Additional search params
Returns:
List of documents most similar to the query text and distance for each.
"""
response
=
self
.
client
.
scroll
(
collection_name
=
self
.
collection_name
,
scroll_filter
=
filter
,
limit
=
k
,
with_payload
=
True
,
with_vectors
=
True
)
results
=
response
[
0
]
documents
=
[]
for
result
in
results
:
if
result
:
documents
.
append
(
self
.
_document_from_scored_point
(
result
,
self
.
content_payload_key
,
self
.
metadata_payload_key
))
return
documents
@
sync_call_fallback
@
sync_call_fallback
async
def
asimilarity_search_with_score_by_vector
(
async
def
asimilarity_search_with_score_by_vector
(
self
,
self
,
...
...
api/core/vector_store/vector/weaviate.py
0 → 100644
View file @
4588831b
"""Wrapper around weaviate vector database."""
from
__future__
import
annotations
import
datetime
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
from
uuid
import
uuid4
import
numpy
as
np
from
langchain.docstore.document
import
Document
from
langchain.embeddings.base
import
Embeddings
from
langchain.utils
import
get_from_dict_or_env
from
langchain.vectorstores.base
import
VectorStore
from
langchain.vectorstores.utils
import
maximal_marginal_relevance
def
_default_schema
(
index_name
:
str
)
->
Dict
:
return
{
"class"
:
index_name
,
"properties"
:
[
{
"name"
:
"text"
,
"dataType"
:
[
"text"
],
}
],
}
def
_create_weaviate_client
(
**
kwargs
:
Any
)
->
Any
:
client
=
kwargs
.
get
(
"client"
)
if
client
is
not
None
:
return
client
weaviate_url
=
get_from_dict_or_env
(
kwargs
,
"weaviate_url"
,
"WEAVIATE_URL"
)
try
:
# the weaviate api key param should not be mandatory
weaviate_api_key
=
get_from_dict_or_env
(
kwargs
,
"weaviate_api_key"
,
"WEAVIATE_API_KEY"
,
None
)
except
ValueError
:
weaviate_api_key
=
None
try
:
import
weaviate
except
ImportError
:
raise
ValueError
(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`"
)
auth
=
(
weaviate
.
auth
.
AuthApiKey
(
api_key
=
weaviate_api_key
)
if
weaviate_api_key
is
not
None
else
None
)
client
=
weaviate
.
Client
(
weaviate_url
,
auth_client_secret
=
auth
)
return
client
def
_default_score_normalizer
(
val
:
float
)
->
float
:
return
1
-
1
/
(
1
+
np
.
exp
(
val
))
def
_json_serializable
(
value
:
Any
)
->
Any
:
if
isinstance
(
value
,
datetime
.
datetime
):
return
value
.
isoformat
()
return
value
class
Weaviate
(
VectorStore
):
"""Wrapper around Weaviate vector database.
To use, you should have the ``weaviate-client`` python package installed.
Example:
.. code-block:: python
import weaviate
from langchain.vectorstores import Weaviate
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
weaviate = Weaviate(client, index_name, text_key)
"""
def
__init__
(
self
,
client
:
Any
,
index_name
:
str
,
text_key
:
str
,
embedding
:
Optional
[
Embeddings
]
=
None
,
attributes
:
Optional
[
List
[
str
]]
=
None
,
relevance_score_fn
:
Optional
[
Callable
[[
float
],
float
]
]
=
_default_score_normalizer
,
by_text
:
bool
=
True
,
):
"""Initialize with Weaviate client."""
try
:
import
weaviate
except
ImportError
:
raise
ValueError
(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
if
not
isinstance
(
client
,
weaviate
.
Client
):
raise
ValueError
(
f
"client should be an instance of weaviate.Client, got {type(client)}"
)
self
.
_client
=
client
self
.
_index_name
=
index_name
self
.
_embedding
=
embedding
self
.
_text_key
=
text_key
self
.
_query_attrs
=
[
self
.
_text_key
]
self
.
relevance_score_fn
=
relevance_score_fn
self
.
_by_text
=
by_text
if
attributes
is
not
None
:
self
.
_query_attrs
.
extend
(
attributes
)
@
property
def
embeddings
(
self
)
->
Optional
[
Embeddings
]:
return
self
.
_embedding
def
_select_relevance_score_fn
(
self
)
->
Callable
[[
float
],
float
]:
return
(
self
.
relevance_score_fn
if
self
.
relevance_score_fn
else
_default_score_normalizer
)
def
add_texts
(
self
,
texts
:
Iterable
[
str
],
metadatas
:
Optional
[
List
[
dict
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
str
]:
"""Upload texts with metadata (properties) to Weaviate."""
from
weaviate.util
import
get_valid_uuid
ids
=
[]
embeddings
:
Optional
[
List
[
List
[
float
]]]
=
None
if
self
.
_embedding
:
if
not
isinstance
(
texts
,
list
):
texts
=
list
(
texts
)
embeddings
=
self
.
_embedding
.
embed_documents
(
texts
)
with
self
.
_client
.
batch
as
batch
:
for
i
,
text
in
enumerate
(
texts
):
data_properties
=
{
self
.
_text_key
:
text
}
if
metadatas
is
not
None
:
for
key
,
val
in
metadatas
[
i
]
.
items
():
data_properties
[
key
]
=
_json_serializable
(
val
)
# Allow for ids (consistent w/ other methods)
# # Or uuids (backwards compatble w/ existing arg)
# If the UUID of one of the objects already exists
# then the existing object will be replaced by the new object.
_id
=
get_valid_uuid
(
uuid4
())
if
"uuids"
in
kwargs
:
_id
=
kwargs
[
"uuids"
][
i
]
elif
"ids"
in
kwargs
:
_id
=
kwargs
[
"ids"
][
i
]
batch
.
add_data_object
(
data_object
=
data_properties
,
class_name
=
self
.
_index_name
,
uuid
=
_id
,
vector
=
embeddings
[
i
]
if
embeddings
else
None
,
)
ids
.
append
(
_id
)
return
ids
def
similarity_search
(
self
,
query
:
str
,
k
:
int
=
4
,
**
kwargs
:
Any
)
->
List
[
Document
]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
if
self
.
_by_text
:
return
self
.
similarity_search_by_text
(
query
,
k
,
**
kwargs
)
else
:
if
self
.
_embedding
is
None
:
raise
ValueError
(
"_embedding cannot be None for similarity_search when "
"_by_text=False"
)
embedding
=
self
.
_embedding
.
embed_query
(
query
)
return
self
.
similarity_search_by_vector
(
embedding
,
k
,
**
kwargs
)
def
similarity_search_by_text
(
self
,
query
:
str
,
k
:
int
=
4
,
**
kwargs
:
Any
)
->
List
[
Document
]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
content
:
Dict
[
str
,
Any
]
=
{
"concepts"
:
[
query
]}
if
kwargs
.
get
(
"search_distance"
):
content
[
"certainty"
]
=
kwargs
.
get
(
"search_distance"
)
query_obj
=
self
.
_client
.
query
.
get
(
self
.
_index_name
,
self
.
_query_attrs
)
if
kwargs
.
get
(
"where_filter"
):
query_obj
=
query_obj
.
with_where
(
kwargs
.
get
(
"where_filter"
))
if
kwargs
.
get
(
"additional"
):
query_obj
=
query_obj
.
with_additional
(
kwargs
.
get
(
"additional"
))
result
=
query_obj
.
with_near_text
(
content
)
.
with_limit
(
k
)
.
do
()
if
"errors"
in
result
:
raise
ValueError
(
f
"Error during query: {result['errors']}"
)
docs
=
[]
for
res
in
result
[
"data"
][
"Get"
][
self
.
_index_name
]:
text
=
res
.
pop
(
self
.
_text_key
)
docs
.
append
(
Document
(
page_content
=
text
,
metadata
=
res
))
return
docs
def
similarity_search_by_bm25
(
self
,
query
:
str
,
k
:
int
=
4
,
**
kwargs
:
Any
)
->
List
[
Document
]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
content
:
Dict
[
str
,
Any
]
=
{
"concepts"
:
[
query
]}
if
kwargs
.
get
(
"search_distance"
):
content
[
"certainty"
]
=
kwargs
.
get
(
"search_distance"
)
query_obj
=
self
.
_client
.
query
.
get
(
self
.
_index_name
,
self
.
_query_attrs
)
if
kwargs
.
get
(
"where_filter"
):
query_obj
=
query_obj
.
with_where
(
kwargs
.
get
(
"where_filter"
))
if
kwargs
.
get
(
"additional"
):
query_obj
=
query_obj
.
with_additional
(
kwargs
.
get
(
"additional"
))
result
=
query_obj
.
with_bm25
(
query
=
content
)
.
with_limit
(
k
)
.
do
()
if
"errors"
in
result
:
raise
ValueError
(
f
"Error during query: {result['errors']}"
)
docs
=
[]
for
res
in
result
[
"data"
][
"Get"
][
self
.
_index_name
]:
text
=
res
.
pop
(
self
.
_text_key
)
docs
.
append
(
Document
(
page_content
=
text
,
metadata
=
res
))
return
docs
def
similarity_search_by_vector
(
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
**
kwargs
:
Any
)
->
List
[
Document
]:
"""Look up similar documents by embedding vector in Weaviate."""
vector
=
{
"vector"
:
embedding
}
query_obj
=
self
.
_client
.
query
.
get
(
self
.
_index_name
,
self
.
_query_attrs
)
if
kwargs
.
get
(
"where_filter"
):
query_obj
=
query_obj
.
with_where
(
kwargs
.
get
(
"where_filter"
))
if
kwargs
.
get
(
"additional"
):
query_obj
=
query_obj
.
with_additional
(
kwargs
.
get
(
"additional"
))
result
=
query_obj
.
with_near_vector
(
vector
)
.
with_limit
(
k
)
.
do
()
if
"errors"
in
result
:
raise
ValueError
(
f
"Error during query: {result['errors']}"
)
docs
=
[]
for
res
in
result
[
"data"
][
"Get"
][
self
.
_index_name
]:
text
=
res
.
pop
(
self
.
_text_key
)
docs
.
append
(
Document
(
page_content
=
text
,
metadata
=
res
))
return
docs
def
max_marginal_relevance_search
(
self
,
query
:
str
,
k
:
int
=
4
,
fetch_k
:
int
=
20
,
lambda_mult
:
float
=
0.5
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if
self
.
_embedding
is
not
None
:
embedding
=
self
.
_embedding
.
embed_query
(
query
)
else
:
raise
ValueError
(
"max_marginal_relevance_search requires a suitable Embeddings object"
)
return
self
.
max_marginal_relevance_search_by_vector
(
embedding
,
k
=
k
,
fetch_k
=
fetch_k
,
lambda_mult
=
lambda_mult
,
**
kwargs
)
def
max_marginal_relevance_search_by_vector
(
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
fetch_k
:
int
=
20
,
lambda_mult
:
float
=
0.5
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
vector
=
{
"vector"
:
embedding
}
query_obj
=
self
.
_client
.
query
.
get
(
self
.
_index_name
,
self
.
_query_attrs
)
if
kwargs
.
get
(
"where_filter"
):
query_obj
=
query_obj
.
with_where
(
kwargs
.
get
(
"where_filter"
))
results
=
(
query_obj
.
with_additional
(
"vector"
)
.
with_near_vector
(
vector
)
.
with_limit
(
fetch_k
)
.
do
()
)
payload
=
results
[
"data"
][
"Get"
][
self
.
_index_name
]
embeddings
=
[
result
[
"_additional"
][
"vector"
]
for
result
in
payload
]
mmr_selected
=
maximal_marginal_relevance
(
np
.
array
(
embedding
),
embeddings
,
k
=
k
,
lambda_mult
=
lambda_mult
)
docs
=
[]
for
idx
in
mmr_selected
:
text
=
payload
[
idx
]
.
pop
(
self
.
_text_key
)
payload
[
idx
]
.
pop
(
"_additional"
)
meta
=
payload
[
idx
]
docs
.
append
(
Document
(
page_content
=
text
,
metadata
=
meta
))
return
docs
def
similarity_search_with_score
(
self
,
query
:
str
,
k
:
int
=
4
,
**
kwargs
:
Any
)
->
List
[
Tuple
[
Document
,
float
]]:
"""
Return list of documents most similar to the query
text and cosine distance in float for each.
Lower score represents more similarity.
"""
if
self
.
_embedding
is
None
:
raise
ValueError
(
"_embedding cannot be None for similarity_search_with_score"
)
content
:
Dict
[
str
,
Any
]
=
{
"concepts"
:
[
query
]}
if
kwargs
.
get
(
"search_distance"
):
content
[
"certainty"
]
=
kwargs
.
get
(
"search_distance"
)
query_obj
=
self
.
_client
.
query
.
get
(
self
.
_index_name
,
self
.
_query_attrs
)
embedded_query
=
self
.
_embedding
.
embed_query
(
query
)
if
not
self
.
_by_text
:
vector
=
{
"vector"
:
embedded_query
}
result
=
(
query_obj
.
with_near_vector
(
vector
)
.
with_limit
(
k
)
.
with_additional
(
"vector"
)
.
do
()
)
else
:
result
=
(
query_obj
.
with_near_text
(
content
)
.
with_limit
(
k
)
.
with_additional
(
"vector"
)
.
do
()
)
if
"errors"
in
result
:
raise
ValueError
(
f
"Error during query: {result['errors']}"
)
docs_and_scores
=
[]
for
res
in
result
[
"data"
][
"Get"
][
self
.
_index_name
]:
text
=
res
.
pop
(
self
.
_text_key
)
score
=
np
.
dot
(
res
[
"_additional"
][
"vector"
],
embedded_query
)
docs_and_scores
.
append
((
Document
(
page_content
=
text
,
metadata
=
res
),
score
))
return
docs_and_scores
@
classmethod
def
from_texts
(
cls
:
Type
[
Weaviate
],
texts
:
List
[
str
],
embedding
:
Embeddings
,
metadatas
:
Optional
[
List
[
dict
]]
=
None
,
**
kwargs
:
Any
,
)
->
Weaviate
:
"""Construct Weaviate wrapper from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
2. Creates a new index for the embeddings in the Weaviate instance.
3. Adds the documents to the newly created Weaviate index.
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain.vectorstores.weaviate import Weaviate
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
weaviate = Weaviate.from_texts(
texts,
embeddings,
weaviate_url="http://localhost:8080"
)
"""
client
=
_create_weaviate_client
(
**
kwargs
)
from
weaviate.util
import
get_valid_uuid
index_name
=
kwargs
.
get
(
"index_name"
,
f
"LangChain_{uuid4().hex}"
)
embeddings
=
embedding
.
embed_documents
(
texts
)
if
embedding
else
None
text_key
=
"text"
schema
=
_default_schema
(
index_name
)
attributes
=
list
(
metadatas
[
0
]
.
keys
())
if
metadatas
else
None
# check whether the index already exists
if
not
client
.
schema
.
contains
(
schema
):
client
.
schema
.
create_class
(
schema
)
with
client
.
batch
as
batch
:
for
i
,
text
in
enumerate
(
texts
):
data_properties
=
{
text_key
:
text
,
}
if
metadatas
is
not
None
:
for
key
in
metadatas
[
i
]
.
keys
():
data_properties
[
key
]
=
metadatas
[
i
][
key
]
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if
"uuids"
in
kwargs
:
_id
=
kwargs
[
"uuids"
][
i
]
else
:
_id
=
get_valid_uuid
(
uuid4
())
# if an embedding strategy is not provided, we let
# weaviate create the embedding. Note that this will only
# work if weaviate has been installed with a vectorizer module
# like text2vec-contextionary for example
params
=
{
"uuid"
:
_id
,
"data_object"
:
data_properties
,
"class_name"
:
index_name
,
}
if
embeddings
is
not
None
:
params
[
"vector"
]
=
embeddings
[
i
]
batch
.
add_data_object
(
**
params
)
batch
.
flush
()
relevance_score_fn
=
kwargs
.
get
(
"relevance_score_fn"
)
by_text
:
bool
=
kwargs
.
get
(
"by_text"
,
False
)
return
cls
(
client
,
index_name
,
text_key
,
embedding
=
embedding
,
attributes
=
attributes
,
relevance_score_fn
=
relevance_score_fn
,
by_text
=
by_text
,
)
def
delete
(
self
,
ids
:
Optional
[
List
[
str
]]
=
None
,
**
kwargs
:
Any
)
->
None
:
"""Delete by vector IDs.
Args:
ids: List of ids to delete.
"""
if
ids
is
None
:
raise
ValueError
(
"No ids provided to delete."
)
# TODO: Check if this can be done in bulk
for
id
in
ids
:
self
.
_client
.
data_object
.
delete
(
uuid
=
id
)
api/fields/dataset_fields.py
View file @
4588831b
...
@@ -12,6 +12,21 @@ dataset_fields = {
...
@@ -12,6 +12,21 @@ dataset_fields = {
'created_at'
:
TimestampField
,
'created_at'
:
TimestampField
,
}
}
reranking_model_fields
=
{
'reranking_provider_name'
:
fields
.
String
,
'reranking_model_name'
:
fields
.
String
}
dataset_retrieval_model_fields
=
{
'search_method'
:
fields
.
String
,
'reranking_enable'
:
fields
.
Boolean
,
'reranking_model'
:
fields
.
Nested
(
reranking_model_fields
),
'top_k'
:
fields
.
Integer
,
'score_threshold_enable'
:
fields
.
Boolean
,
'score_threshold'
:
fields
.
Float
}
dataset_detail_fields
=
{
dataset_detail_fields
=
{
'id'
:
fields
.
String
,
'id'
:
fields
.
String
,
'name'
:
fields
.
String
,
'name'
:
fields
.
String
,
...
@@ -29,7 +44,8 @@ dataset_detail_fields = {
...
@@ -29,7 +44,8 @@ dataset_detail_fields = {
'updated_at'
:
TimestampField
,
'updated_at'
:
TimestampField
,
'embedding_model'
:
fields
.
String
,
'embedding_model'
:
fields
.
String
,
'embedding_model_provider'
:
fields
.
String
,
'embedding_model_provider'
:
fields
.
String
,
'embedding_available'
:
fields
.
Boolean
'embedding_available'
:
fields
.
Boolean
,
'retrieval_model_dict'
:
fields
.
Nested
(
dataset_retrieval_model_fields
)
}
}
dataset_query_detail_fields
=
{
dataset_query_detail_fields
=
{
...
@@ -41,3 +57,5 @@ dataset_query_detail_fields = {
...
@@ -41,3 +57,5 @@ dataset_query_detail_fields = {
"created_by"
:
fields
.
String
,
"created_by"
:
fields
.
String
,
"created_at"
:
TimestampField
"created_at"
:
TimestampField
}
}
api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py
0 → 100644
View file @
4588831b
"""add-dataset-retrival-model
Revision ID: fca025d3b60f
Revises: b3a09c049e8e
Create Date: 2023-11-03 13:08:23.246396
"""
from
alembic
import
op
import
sqlalchemy
as
sa
from
sqlalchemy.dialects
import
postgresql
# revision identifiers, used by Alembic.
revision
=
'fca025d3b60f'
down_revision
=
'8fe468ba0ca5'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
op
.
drop_table
(
'sessions'
)
with
op
.
batch_alter_table
(
'datasets'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'retrieval_model'
,
postgresql
.
JSONB
(
astext_type
=
sa
.
Text
()),
nullable
=
True
))
batch_op
.
create_index
(
'retrieval_model_idx'
,
[
'retrieval_model'
],
unique
=
False
,
postgresql_using
=
'gin'
)
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'datasets'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_index
(
'retrieval_model_idx'
,
postgresql_using
=
'gin'
)
batch_op
.
drop_column
(
'retrieval_model'
)
op
.
create_table
(
'sessions'
,
sa
.
Column
(
'id'
,
sa
.
INTEGER
(),
autoincrement
=
True
,
nullable
=
False
),
sa
.
Column
(
'session_id'
,
sa
.
VARCHAR
(
length
=
255
),
autoincrement
=
False
,
nullable
=
True
),
sa
.
Column
(
'data'
,
postgresql
.
BYTEA
(),
autoincrement
=
False
,
nullable
=
True
),
sa
.
Column
(
'expiry'
,
postgresql
.
TIMESTAMP
(),
autoincrement
=
False
,
nullable
=
True
),
sa
.
PrimaryKeyConstraint
(
'id'
,
name
=
'sessions_pkey'
),
sa
.
UniqueConstraint
(
'session_id'
,
name
=
'sessions_session_id_key'
)
)
# ### end Alembic commands ###
api/models/dataset.py
View file @
4588831b
...
@@ -3,7 +3,7 @@ import pickle
...
@@ -3,7 +3,7 @@ import pickle
from
json
import
JSONDecodeError
from
json
import
JSONDecodeError
from
sqlalchemy
import
func
from
sqlalchemy
import
func
from
sqlalchemy.dialects.postgresql
import
UUID
from
sqlalchemy.dialects.postgresql
import
UUID
,
JSONB
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.account
import
Account
from
models.account
import
Account
...
@@ -15,6 +15,7 @@ class Dataset(db.Model):
...
@@ -15,6 +15,7 @@ class Dataset(db.Model):
__table_args__
=
(
__table_args__
=
(
db
.
PrimaryKeyConstraint
(
'id'
,
name
=
'dataset_pkey'
),
db
.
PrimaryKeyConstraint
(
'id'
,
name
=
'dataset_pkey'
),
db
.
Index
(
'dataset_tenant_idx'
,
'tenant_id'
),
db
.
Index
(
'dataset_tenant_idx'
,
'tenant_id'
),
db
.
Index
(
'retrieval_model_idx'
,
"retrieval_model"
,
postgresql_using
=
'gin'
)
)
)
INDEXING_TECHNIQUE_LIST
=
[
'high_quality'
,
'economy'
]
INDEXING_TECHNIQUE_LIST
=
[
'high_quality'
,
'economy'
]
...
@@ -39,7 +40,7 @@ class Dataset(db.Model):
...
@@ -39,7 +40,7 @@ class Dataset(db.Model):
embedding_model
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
True
)
embedding_model
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
True
)
embedding_model_provider
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
True
)
embedding_model_provider
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
True
)
collection_binding_id
=
db
.
Column
(
UUID
,
nullable
=
True
)
collection_binding_id
=
db
.
Column
(
UUID
,
nullable
=
True
)
retrieval_model
=
db
.
Column
(
JSONB
,
nullable
=
True
)
@
property
@
property
def
dataset_keyword_table
(
self
):
def
dataset_keyword_table
(
self
):
...
@@ -93,6 +94,20 @@ class Dataset(db.Model):
...
@@ -93,6 +94,20 @@ class Dataset(db.Model):
return
Document
.
query
.
with_entities
(
func
.
coalesce
(
func
.
sum
(
Document
.
word_count
)))
\
return
Document
.
query
.
with_entities
(
func
.
coalesce
(
func
.
sum
(
Document
.
word_count
)))
\
.
filter
(
Document
.
dataset_id
==
self
.
id
)
.
scalar
()
.
filter
(
Document
.
dataset_id
==
self
.
id
)
.
scalar
()
@
property
def
retrieval_model_dict
(
self
):
default_retrieval_model
=
{
'search_method'
:
'semantic_search'
,
'reranking_enable'
:
False
,
'reranking_model'
:
{
'reranking_provider_name'
:
''
,
'reranking_model_name'
:
''
},
'top_k'
:
2
,
'score_threshold_enable'
:
False
}
return
self
.
retrieval_model
if
self
.
retrieval_model
else
default_retrieval_model
class
DatasetProcessRule
(
db
.
Model
):
class
DatasetProcessRule
(
db
.
Model
):
__tablename__
=
'dataset_process_rules'
__tablename__
=
'dataset_process_rules'
...
@@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model):
...
@@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model):
],
],
'segmentation'
:
{
'segmentation'
:
{
'delimiter'
:
'
\n
'
,
'delimiter'
:
'
\n
'
,
'max_tokens'
:
1000
'max_tokens'
:
512
}
}
}
}
...
@@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model):
...
@@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model):
model_name
=
db
.
Column
(
db
.
String
(
40
),
nullable
=
False
)
model_name
=
db
.
Column
(
db
.
String
(
40
),
nullable
=
False
)
collection_name
=
db
.
Column
(
db
.
String
(
64
),
nullable
=
False
)
collection_name
=
db
.
Column
(
db
.
String
(
64
),
nullable
=
False
)
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
api/models/model.py
View file @
4588831b
...
@@ -160,7 +160,13 @@ class AppModelConfig(db.Model):
...
@@ -160,7 +160,13 @@ class AppModelConfig(db.Model):
@
property
@
property
def
dataset_configs_dict
(
self
)
->
dict
:
def
dataset_configs_dict
(
self
)
->
dict
:
return
json
.
loads
(
self
.
dataset_configs
)
if
self
.
dataset_configs
else
{
"top_k"
:
2
,
"score_threshold"
:
{
"enable"
:
False
}}
if
self
.
dataset_configs
:
dataset_configs
=
json
.
loads
(
self
.
dataset_configs
)
if
'retrieval_model'
not
in
dataset_configs
:
return
{
'retrieval_model'
:
'single'
}
else
:
return
dataset_configs
return
{
'retrieval_model'
:
'single'
}
@
property
@
property
def
file_upload_dict
(
self
)
->
dict
:
def
file_upload_dict
(
self
)
->
dict
:
...
...
api/requirements.txt
View file @
4588831b
...
@@ -23,7 +23,6 @@ boto3==1.28.17
...
@@ -23,7 +23,6 @@ boto3==1.28.17
tenacity==8.2.2
tenacity==8.2.2
cachetools~=5.3.0
cachetools~=5.3.0
weaviate-client~=3.21.0
weaviate-client~=3.21.0
qdrant_client~=1.1.6
mailchimp-transactional~=1.0.50
mailchimp-transactional~=1.0.50
scikit-learn==1.2.2
scikit-learn==1.2.2
sentry-sdk[flask]~=1.21.1
sentry-sdk[flask]~=1.21.1
...
@@ -53,4 +52,6 @@ xinference-client~=0.5.4
...
@@ -53,4 +52,6 @@ xinference-client~=0.5.4
safetensors==0.3.2
safetensors==0.3.2
zhipuai==1.0.7
zhipuai==1.0.7
werkzeug==2.3.7
werkzeug==2.3.7
pymilvus==2.3.0
pymilvus==2.3.0
\ No newline at end of file
qdrant-client==1.6.4
cohere~=4.32
\ No newline at end of file
api/services/app_model_config_service.py
View file @
4588831b
...
@@ -470,7 +470,16 @@ class AppModelConfigService:
...
@@ -470,7 +470,16 @@ class AppModelConfigService:
# dataset_configs
# dataset_configs
if
'dataset_configs'
not
in
config
or
not
config
[
"dataset_configs"
]:
if
'dataset_configs'
not
in
config
or
not
config
[
"dataset_configs"
]:
config
[
"dataset_configs"
]
=
{
"top_k"
:
2
,
"score_threshold"
:
{
"enable"
:
False
}}
config
[
"dataset_configs"
]
=
{
'retrieval_model'
:
'single'
}
if
not
isinstance
(
config
[
"dataset_configs"
],
dict
):
raise
ValueError
(
"dataset_configs must be of object type"
)
if
config
[
"dataset_configs"
][
'retrieval_model'
]
==
'multiple'
:
if
not
config
[
"dataset_configs"
][
'reranking_model'
]:
raise
ValueError
(
"reranking_model has not been set"
)
if
not
isinstance
(
config
[
"dataset_configs"
][
'reranking_model'
],
dict
):
raise
ValueError
(
"reranking_model must be of object type"
)
if
not
isinstance
(
config
[
"dataset_configs"
],
dict
):
if
not
isinstance
(
config
[
"dataset_configs"
],
dict
):
raise
ValueError
(
"dataset_configs must be of object type"
)
raise
ValueError
(
"dataset_configs must be of object type"
)
...
...
api/services/dataset_service.py
View file @
4588831b
...
@@ -173,6 +173,9 @@ class DatasetService:
...
@@ -173,6 +173,9 @@ class DatasetService:
filtered_data
[
'updated_by'
]
=
user
.
id
filtered_data
[
'updated_by'
]
=
user
.
id
filtered_data
[
'updated_at'
]
=
datetime
.
datetime
.
now
()
filtered_data
[
'updated_at'
]
=
datetime
.
datetime
.
now
()
# update Retrieval model
filtered_data
[
'retrieval_model'
]
=
data
[
'retrieval_model'
]
dataset
.
query
.
filter_by
(
id
=
dataset_id
)
.
update
(
filtered_data
)
dataset
.
query
.
filter_by
(
id
=
dataset_id
)
.
update
(
filtered_data
)
db
.
session
.
commit
()
db
.
session
.
commit
()
...
@@ -473,7 +476,19 @@ class DocumentService:
...
@@ -473,7 +476,19 @@ class DocumentService:
embedding_model
.
name
embedding_model
.
name
)
)
dataset
.
collection_binding_id
=
dataset_collection_binding
.
id
dataset
.
collection_binding_id
=
dataset_collection_binding
.
id
if
not
dataset
.
retrieval_model
:
default_retrieval_model
=
{
'search_method'
:
'semantic_search'
,
'reranking_enable'
:
False
,
'reranking_model'
:
{
'reranking_provider_name'
:
''
,
'reranking_model_name'
:
''
},
'top_k'
:
2
,
'score_threshold_enable'
:
False
}
dataset
.
retrieval_model
=
document_data
.
get
(
'retrieval_model'
)
if
document_data
.
get
(
'retrieval_model'
)
else
default_retrieval_model
documents
=
[]
documents
=
[]
batch
=
time
.
strftime
(
'
%
Y
%
m
%
d
%
H
%
M
%
S'
)
+
str
(
random
.
randint
(
100000
,
999999
))
batch
=
time
.
strftime
(
'
%
Y
%
m
%
d
%
H
%
M
%
S'
)
+
str
(
random
.
randint
(
100000
,
999999
))
...
@@ -733,6 +748,7 @@ class DocumentService:
...
@@ -733,6 +748,7 @@ class DocumentService:
raise
ValueError
(
f
"All your documents have overed limit {tenant_document_count}."
)
raise
ValueError
(
f
"All your documents have overed limit {tenant_document_count}."
)
embedding_model
=
None
embedding_model
=
None
dataset_collection_binding_id
=
None
dataset_collection_binding_id
=
None
retrieval_model
=
None
if
document_data
[
'indexing_technique'
]
==
'high_quality'
:
if
document_data
[
'indexing_technique'
]
==
'high_quality'
:
embedding_model
=
ModelFactory
.
get_embedding_model
(
embedding_model
=
ModelFactory
.
get_embedding_model
(
tenant_id
=
tenant_id
tenant_id
=
tenant_id
...
@@ -742,6 +758,20 @@ class DocumentService:
...
@@ -742,6 +758,20 @@ class DocumentService:
embedding_model
.
name
embedding_model
.
name
)
)
dataset_collection_binding_id
=
dataset_collection_binding
.
id
dataset_collection_binding_id
=
dataset_collection_binding
.
id
if
'retrieval_model'
in
document_data
and
document_data
[
'retrieval_model'
]:
retrieval_model
=
document_data
[
'retrieval_model'
]
else
:
default_retrieval_model
=
{
'search_method'
:
'semantic_search'
,
'reranking_enable'
:
False
,
'reranking_model'
:
{
'reranking_provider_name'
:
''
,
'reranking_model_name'
:
''
},
'top_k'
:
2
,
'score_threshold_enable'
:
False
}
retrieval_model
=
default_retrieval_model
# save dataset
# save dataset
dataset
=
Dataset
(
dataset
=
Dataset
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
...
@@ -751,7 +781,8 @@ class DocumentService:
...
@@ -751,7 +781,8 @@ class DocumentService:
created_by
=
account
.
id
,
created_by
=
account
.
id
,
embedding_model
=
embedding_model
.
name
if
embedding_model
else
None
,
embedding_model
=
embedding_model
.
name
if
embedding_model
else
None
,
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
if
embedding_model
else
None
,
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
if
embedding_model
else
None
,
collection_binding_id
=
dataset_collection_binding_id
collection_binding_id
=
dataset_collection_binding_id
,
retrieval_model
=
retrieval_model
)
)
db
.
session
.
add
(
dataset
)
db
.
session
.
add
(
dataset
)
...
@@ -768,7 +799,7 @@ class DocumentService:
...
@@ -768,7 +799,7 @@ class DocumentService:
return
dataset
,
documents
,
batch
return
dataset
,
documents
,
batch
@
classmethod
@
classmethod
def
document_create_args_validate
(
cls
,
args
:
dict
):
def
document_create_args_validate
(
cls
,
args
:
dict
):
if
'original_document_id'
not
in
args
or
not
args
[
'original_document_id'
]:
if
'original_document_id'
not
in
args
or
not
args
[
'original_document_id'
]:
DocumentService
.
data_source_args_validate
(
args
)
DocumentService
.
data_source_args_validate
(
args
)
DocumentService
.
process_rule_args_validate
(
args
)
DocumentService
.
process_rule_args_validate
(
args
)
...
...
api/services/hit_testing_service.py
View file @
4588831b
import
json
import
logging
import
logging
import
threading
import
time
import
time
from
typing
import
List
from
typing
import
List
...
@@ -9,16 +11,26 @@ from langchain.schema import Document
...
@@ -9,16 +11,26 @@ from langchain.schema import Document
from
sklearn.manifold
import
TSNE
from
sklearn.manifold
import
TSNE
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.model_factory
import
ModelFactory
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.account
import
Account
from
models.account
import
Account
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetQuery
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetQuery
from
services.retrieval_service
import
RetrievalService
default_retrieval_model
=
{
'search_method'
:
'semantic_search'
,
'reranking_enable'
:
False
,
'reranking_model'
:
{
'reranking_provider_name'
:
''
,
'reranking_model_name'
:
''
},
'top_k'
:
2
,
'score_threshold_enable'
:
False
}
class
HitTestingService
:
class
HitTestingService
:
@
classmethod
@
classmethod
def
retrieve
(
cls
,
dataset
:
Dataset
,
query
:
str
,
account
:
Account
,
limit
:
int
=
10
)
->
dict
:
def
retrieve
(
cls
,
dataset
:
Dataset
,
query
:
str
,
account
:
Account
,
retrieval_model
:
dict
,
limit
:
int
=
10
)
->
dict
:
if
dataset
.
available_document_count
==
0
or
dataset
.
available_segment_count
==
0
:
if
dataset
.
available_document_count
==
0
or
dataset
.
available_segment_count
==
0
:
return
{
return
{
"query"
:
{
"query"
:
{
...
@@ -28,31 +40,68 @@ class HitTestingService:
...
@@ -28,31 +40,68 @@ class HitTestingService:
"records"
:
[]
"records"
:
[]
}
}
start
=
time
.
perf_counter
()
# get retrieval model , if the model is not setting , using default
if
not
retrieval_model
:
retrieval_model
=
dataset
.
retrieval_model
if
dataset
.
retrieval_model
else
default_retrieval_model
# get embedding model
embedding_model
=
ModelFactory
.
get_embedding_model
(
embedding_model
=
ModelFactory
.
get_embedding_model
(
tenant_id
=
dataset
.
tenant_id
,
tenant_id
=
dataset
.
tenant_id
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
model_name
=
dataset
.
embedding_model
)
)
embeddings
=
CacheEmbedding
(
embedding_model
)
embeddings
=
CacheEmbedding
(
embedding_model
)
vector_index
=
VectorIndex
(
all_documents
=
[]
dataset
=
dataset
,
threads
=
[]
config
=
current_app
.
config
,
embeddings
=
embeddings
# retrieval_model source with semantic
)
if
retrieval_model
[
'search_method'
]
==
'semantic_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
embedding_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
embedding_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset'
:
dataset
,
'query'
:
query
,
'top_k'
:
retrieval_model
[
'top_k'
],
'score_threshold'
:
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
'score_threshold_enable'
]
else
None
,
'reranking_model'
:
retrieval_model
[
'reranking_model'
]
if
retrieval_model
[
'reranking_enable'
]
else
None
,
'all_documents'
:
all_documents
,
'search_method'
:
retrieval_model
[
'search_method'
],
'embeddings'
:
embeddings
})
threads
.
append
(
embedding_thread
)
embedding_thread
.
start
()
# retrieval source with full text
if
retrieval_model
[
'search_method'
]
==
'full_text_search'
or
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
full_text_index_thread
=
threading
.
Thread
(
target
=
RetrievalService
.
full_text_index_search
,
kwargs
=
{
'flask_app'
:
current_app
.
_get_current_object
(),
'dataset'
:
dataset
,
'query'
:
query
,
'search_method'
:
retrieval_model
[
'search_method'
],
'embeddings'
:
embeddings
,
'score_threshold'
:
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
'score_threshold_enable'
]
else
None
,
'top_k'
:
retrieval_model
[
'top_k'
],
'reranking_model'
:
retrieval_model
[
'reranking_model'
]
if
retrieval_model
[
'reranking_enable'
]
else
None
,
'all_documents'
:
all_documents
})
threads
.
append
(
full_text_index_thread
)
full_text_index_thread
.
start
()
for
thread
in
threads
:
thread
.
join
()
if
retrieval_model
[
'search_method'
]
==
'hybrid_search'
:
hybrid_rerank
=
ModelFactory
.
get_reranking_model
(
tenant_id
=
dataset
.
tenant_id
,
model_provider_name
=
retrieval_model
[
'reranking_model'
][
'reranking_provider_name'
],
model_name
=
retrieval_model
[
'reranking_model'
][
'reranking_model_name'
]
)
all_documents
=
hybrid_rerank
.
rerank
(
query
,
all_documents
,
retrieval_model
[
'score_threshold'
]
if
retrieval_model
[
'score_threshold_enable'
]
else
None
,
retrieval_model
[
'top_k'
])
start
=
time
.
perf_counter
()
documents
=
vector_index
.
search
(
query
,
search_type
=
'similarity_score_threshold'
,
search_kwargs
=
{
'k'
:
10
,
'filter'
:
{
'group_id'
:
[
dataset
.
id
]
}
}
)
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
logging
.
debug
(
f
"Hit testing retrieve in {end - start:0.4f} seconds"
)
logging
.
debug
(
f
"Hit testing retrieve in {end - start:0.4f} seconds"
)
...
@@ -67,7 +116,7 @@ class HitTestingService:
...
@@ -67,7 +116,7 @@ class HitTestingService:
db
.
session
.
add
(
dataset_query
)
db
.
session
.
add
(
dataset_query
)
db
.
session
.
commit
()
db
.
session
.
commit
()
return
cls
.
compact_retrieve_response
(
dataset
,
embeddings
,
query
,
documents
)
return
cls
.
compact_retrieve_response
(
dataset
,
embeddings
,
query
,
all_
documents
)
@
classmethod
@
classmethod
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
embeddings
:
Embeddings
,
query
:
str
,
documents
:
List
[
Document
]):
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
embeddings
:
Embeddings
,
query
:
str
,
documents
:
List
[
Document
]):
...
@@ -99,7 +148,7 @@ class HitTestingService:
...
@@ -99,7 +148,7 @@ class HitTestingService:
record
=
{
record
=
{
"segment"
:
segment
,
"segment"
:
segment
,
"score"
:
document
.
metadata
[
'score'
]
,
"score"
:
document
.
metadata
.
get
(
'score'
,
None
)
,
"tsne_position"
:
tsne_position_data
[
i
]
"tsne_position"
:
tsne_position_data
[
i
]
}
}
...
@@ -136,3 +185,11 @@ class HitTestingService:
...
@@ -136,3 +185,11 @@ class HitTestingService:
tsne_position_data
.
append
({
'x'
:
float
(
data_tsne
[
i
][
0
]),
'y'
:
float
(
data_tsne
[
i
][
1
])})
tsne_position_data
.
append
({
'x'
:
float
(
data_tsne
[
i
][
0
]),
'y'
:
float
(
data_tsne
[
i
][
1
])})
return
tsne_position_data
return
tsne_position_data
@
classmethod
def
hit_testing_args_check
(
cls
,
args
):
query
=
args
[
'query'
]
if
not
query
or
len
(
query
)
>
250
:
raise
ValueError
(
'Query is required and cannot exceed 250 characters'
)
api/services/retrieval_service.py
0 → 100644
View file @
4588831b
from
typing
import
Optional
from
flask
import
current_app
,
Flask
from
langchain.embeddings.base
import
Embeddings
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.model_providers.model_factory
import
ModelFactory
from
models.dataset
import
Dataset
default_retrieval_model
=
{
'search_method'
:
'semantic_search'
,
'reranking_enable'
:
False
,
'reranking_model'
:
{
'reranking_provider_name'
:
''
,
'reranking_model_name'
:
''
},
'top_k'
:
2
,
'score_threshold_enable'
:
False
}
class
RetrievalService
:
@
classmethod
def
embedding_search
(
cls
,
flask_app
:
Flask
,
dataset
:
Dataset
,
query
:
str
,
top_k
:
int
,
score_threshold
:
Optional
[
float
],
reranking_model
:
Optional
[
dict
],
all_documents
:
list
,
search_method
:
str
,
embeddings
:
Embeddings
):
with
flask_app
.
app_context
():
vector_index
=
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
vector_index
.
search
(
query
,
search_type
=
'similarity_score_threshold'
,
search_kwargs
=
{
'k'
:
top_k
,
'score_threshold'
:
score_threshold
,
'filter'
:
{
'group_id'
:
[
dataset
.
id
]
}
}
)
if
documents
:
if
reranking_model
and
search_method
==
'semantic_search'
:
rerank
=
ModelFactory
.
get_reranking_model
(
tenant_id
=
dataset
.
tenant_id
,
model_provider_name
=
reranking_model
[
'reranking_provider_name'
],
model_name
=
reranking_model
[
'reranking_model_name'
]
)
all_documents
.
extend
(
rerank
.
rerank
(
query
,
documents
,
score_threshold
,
len
(
documents
)))
else
:
all_documents
.
extend
(
documents
)
@
classmethod
def
full_text_index_search
(
cls
,
flask_app
:
Flask
,
dataset
:
Dataset
,
query
:
str
,
top_k
:
int
,
score_threshold
:
Optional
[
float
],
reranking_model
:
Optional
[
dict
],
all_documents
:
list
,
search_method
:
str
,
embeddings
:
Embeddings
):
with
flask_app
.
app_context
():
vector_index
=
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
vector_index
.
search_by_full_text_index
(
query
,
search_type
=
'similarity_score_threshold'
,
top_k
=
top_k
)
if
documents
:
if
reranking_model
and
search_method
==
'full_text_search'
:
rerank
=
ModelFactory
.
get_reranking_model
(
tenant_id
=
dataset
.
tenant_id
,
model_provider_name
=
reranking_model
[
'reranking_provider_name'
],
model_name
=
reranking_model
[
'reranking_model_name'
]
)
all_documents
.
extend
(
rerank
.
rerank
(
query
,
documents
,
score_threshold
,
len
(
documents
)))
else
:
all_documents
.
extend
(
documents
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment