Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
a55ba6e6
Unverified
Commit
a55ba6e6
authored
Aug 28, 2023
by
Jyong
Committed by
GitHub
Aug 28, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix/ignore economy dataset (#1043)
Co-authored-by:
jyong
<
jyong@dify.ai
>
parent
f9bec1ed
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
320 additions
and
205 deletions
+320
-205
datasets.py
api/controllers/console/datasets/datasets.py
+17
-14
datasets_document.py
api/controllers/console/datasets/datasets_document.py
+23
-23
datasets_segments.py
api/controllers/console/datasets/datasets_segments.py
+48
-53
dataset_docstore.py
api/core/docstore/dataset_docstore.py
+8
-7
index.py
api/core/index/index.py
+18
-1
indexing_runner.py
api/core/indexing_runner.py
+46
-38
create_document_index.py
api/events/event_handlers/create_document_index.py
+0
-1
4bcffcd64aa4_update_dataset_model_field_null_.py
...versions/4bcffcd64aa4_update_dataset_model_field_null_.py
+46
-0
dataset.py
api/models/dataset.py
+2
-4
dataset_service.py
api/services/dataset_service.py
+99
-55
batch_create_segment_to_index_task.py
api/tasks/batch_create_segment_to_index_task.py
+7
-5
clean_dataset_task.py
api/tasks/clean_dataset_task.py
+4
-2
deal_dataset_vector_index_task.py
api/tasks/deal_dataset_vector_index_task.py
+2
-2
No files found.
api/controllers/console/datasets/datasets.py
View file @
a55ba6e6
...
@@ -92,11 +92,14 @@ class DatasetListApi(Resource):
...
@@ -92,11 +92,14 @@ class DatasetListApi(Resource):
model_names
.
append
(
f
"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}"
)
model_names
.
append
(
f
"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}"
)
data
=
marshal
(
datasets
,
dataset_detail_fields
)
data
=
marshal
(
datasets
,
dataset_detail_fields
)
for
item
in
data
:
for
item
in
data
:
item_model
=
f
"{item['embedding_model']}:{item['embedding_model_provider']}"
if
item
[
'indexing_technique'
]
==
'high_quality'
:
if
item_model
in
model_names
:
item_model
=
f
"{item['embedding_model']}:{item['embedding_model_provider']}"
item
[
'embedding_available'
]
=
True
if
item_model
in
model_names
:
item
[
'embedding_available'
]
=
True
else
:
item
[
'embedding_available'
]
=
False
else
:
else
:
item
[
'embedding_available'
]
=
Fals
e
item
[
'embedding_available'
]
=
Tru
e
response
=
{
response
=
{
'data'
:
data
,
'data'
:
data
,
'has_more'
:
len
(
datasets
)
==
limit
,
'has_more'
:
len
(
datasets
)
==
limit
,
...
@@ -122,14 +125,6 @@ class DatasetListApi(Resource):
...
@@ -122,14 +125,6 @@ class DatasetListApi(Resource):
# The role of the current user in the ta table must be admin or owner
# The role of the current user in the ta table must be admin or owner
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
raise
Forbidden
()
raise
Forbidden
()
try
:
ModelFactory
.
get_embedding_model
(
tenant_id
=
current_user
.
current_tenant_id
)
except
LLMBadRequestError
:
raise
ProviderNotInitializeError
(
f
"No Embedding Model available. Please configure a valid provider "
f
"in the Settings -> Model Provider."
)
try
:
try
:
dataset
=
DatasetService
.
create_empty_dataset
(
dataset
=
DatasetService
.
create_empty_dataset
(
...
@@ -167,6 +162,11 @@ class DatasetApi(Resource):
...
@@ -167,6 +162,11 @@ class DatasetApi(Resource):
@
account_initialization_required
@
account_initialization_required
def
patch
(
self
,
dataset_id
):
def
patch
(
self
,
dataset_id
):
dataset_id_str
=
str
(
dataset_id
)
dataset_id_str
=
str
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id_str
)
if
dataset
is
None
:
raise
NotFound
(
"Dataset not found."
)
# check user's model setting
DatasetService
.
check_dataset_model_setting
(
dataset
)
parser
=
reqparse
.
RequestParser
()
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'name'
,
nullable
=
False
,
parser
.
add_argument
(
'name'
,
nullable
=
False
,
...
@@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource):
...
@@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource):
parser
=
reqparse
.
RequestParser
()
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'info_list'
,
type
=
dict
,
required
=
True
,
nullable
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'info_list'
,
type
=
dict
,
required
=
True
,
nullable
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'process_rule'
,
type
=
dict
,
required
=
True
,
nullable
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'process_rule'
,
type
=
dict
,
required
=
True
,
nullable
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'indexing_technique'
,
type
=
str
,
required
=
True
,
nullable
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'dataset_id'
,
type
=
str
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'dataset_id'
,
type
=
str
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
...
@@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource):
...
@@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource):
try
:
try
:
response
=
indexing_runner
.
file_indexing_estimate
(
current_user
.
current_tenant_id
,
file_details
,
response
=
indexing_runner
.
file_indexing_estimate
(
current_user
.
current_tenant_id
,
file_details
,
args
[
'process_rule'
],
args
[
'doc_form'
],
args
[
'process_rule'
],
args
[
'doc_form'
],
args
[
'doc_language'
],
args
[
'dataset_id'
])
args
[
'doc_language'
],
args
[
'dataset_id'
],
args
[
'indexing_technique'
])
except
LLMBadRequestError
:
except
LLMBadRequestError
:
raise
ProviderNotInitializeError
(
raise
ProviderNotInitializeError
(
f
"No Embedding Model available. Please configure a valid provider "
f
"No Embedding Model available. Please configure a valid provider "
...
@@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource):
...
@@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource):
response
=
indexing_runner
.
notion_indexing_estimate
(
current_user
.
current_tenant_id
,
response
=
indexing_runner
.
notion_indexing_estimate
(
current_user
.
current_tenant_id
,
args
[
'info_list'
][
'notion_info_list'
],
args
[
'info_list'
][
'notion_info_list'
],
args
[
'process_rule'
],
args
[
'doc_form'
],
args
[
'process_rule'
],
args
[
'doc_form'
],
args
[
'doc_language'
],
args
[
'dataset_id'
])
args
[
'doc_language'
],
args
[
'dataset_id'
],
args
[
'indexing_technique'
])
except
LLMBadRequestError
:
except
LLMBadRequestError
:
raise
ProviderNotInitializeError
(
raise
ProviderNotInitializeError
(
f
"No Embedding Model available. Please configure a valid provider "
f
"No Embedding Model available. Please configure a valid provider "
...
...
api/controllers/console/datasets/datasets_document.py
View file @
a55ba6e6
...
@@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource):
...
@@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource):
# validate args
# validate args
DocumentService
.
document_create_args_validate
(
args
)
DocumentService
.
document_create_args_validate
(
args
)
# check embedding model setting
try
:
ModelFactory
.
get_embedding_model
(
tenant_id
=
current_user
.
current_tenant_id
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
)
except
LLMBadRequestError
:
raise
ProviderNotInitializeError
(
f
"No Embedding Model available. Please configure a valid provider "
f
"in the Settings -> Model Provider."
)
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
try
:
try
:
documents
,
batch
=
DocumentService
.
save_document_with_dataset_id
(
dataset
,
args
,
current_user
)
documents
,
batch
=
DocumentService
.
save_document_with_dataset_id
(
dataset
,
args
,
current_user
)
except
ProviderTokenNotInitError
as
ex
:
except
ProviderTokenNotInitError
as
ex
:
...
@@ -339,15 +325,17 @@ class DatasetInitApi(Resource):
...
@@ -339,15 +325,17 @@ class DatasetInitApi(Resource):
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
parser
.
add_argument
(
'doc_language'
,
type
=
str
,
default
=
'English'
,
required
=
False
,
nullable
=
False
,
location
=
'json'
)
location
=
'json'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
[
'indexing_technique'
]
==
'high_quality'
:
try
:
try
:
ModelFactory
.
get_embedding_model
(
ModelFactory
.
get_embedding_model
(
tenant_id
=
current_user
.
current_tenant_id
tenant_id
=
current_user
.
current_tenant_id
)
)
except
LLMBadRequestError
:
except
LLMBadRequestError
:
raise
ProviderNotInitializeError
(
raise
ProviderNotInitializeError
(
f
"No Embedding Model available. Please configure a valid provider "
f
"No Embedding Model available. Please configure a valid provider "
f
"in the Settings -> Model Provider."
)
f
"in the Settings -> Model Provider."
)
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
# validate args
# validate args
DocumentService
.
document_create_args_validate
(
args
)
DocumentService
.
document_create_args_validate
(
args
)
...
@@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource):
...
@@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource):
def
delete
(
self
,
dataset_id
,
document_id
):
def
delete
(
self
,
dataset_id
,
document_id
):
dataset_id
=
str
(
dataset_id
)
dataset_id
=
str
(
dataset_id
)
document_id
=
str
(
document_id
)
document_id
=
str
(
document_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
if
dataset
is
None
:
raise
NotFound
(
"Dataset not found."
)
# check user's model setting
DatasetService
.
check_dataset_model_setting
(
dataset
)
document
=
self
.
get_document
(
dataset_id
,
document_id
)
document
=
self
.
get_document
(
dataset_id
,
document_id
)
try
:
try
:
...
@@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource):
...
@@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource):
def
patch
(
self
,
dataset_id
,
document_id
,
action
):
def
patch
(
self
,
dataset_id
,
document_id
,
action
):
dataset_id
=
str
(
dataset_id
)
dataset_id
=
str
(
dataset_id
)
document_id
=
str
(
document_id
)
document_id
=
str
(
document_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
if
dataset
is
None
:
raise
NotFound
(
"Dataset not found."
)
# check user's model setting
DatasetService
.
check_dataset_model_setting
(
dataset
)
document
=
self
.
get_document
(
dataset_id
,
document_id
)
document
=
self
.
get_document
(
dataset_id
,
document_id
)
# The role of the current user in the ta table must be admin or owner
# The role of the current user in the ta table must be admin or owner
...
...
api/controllers/console/datasets/datasets_segments.py
View file @
a55ba6e6
...
@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
...
@@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
if
not
dataset
:
if
not
dataset
:
raise
NotFound
(
'Dataset not found.'
)
raise
NotFound
(
'Dataset not found.'
)
# check user's model setting
DatasetService
.
check_dataset_model_setting
(
dataset
)
# The role of the current user in the ta table must be admin or owner
# The role of the current user in the ta table must be admin or owner
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
raise
Forbidden
()
raise
Forbidden
()
...
@@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
...
@@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService
.
check_dataset_permission
(
dataset
,
current_user
)
DatasetService
.
check_dataset_permission
(
dataset
,
current_user
)
except
services
.
errors
.
account
.
NoPermissionError
as
e
:
except
services
.
errors
.
account
.
NoPermissionError
as
e
:
raise
Forbidden
(
str
(
e
))
raise
Forbidden
(
str
(
e
))
if
dataset
.
indexing_technique
==
'high_quality'
:
# check embedding model setting
# check embedding model setting
try
:
try
:
ModelFactory
.
get_embedding_model
(
ModelFactory
.
get_embedding_model
(
tenant_id
=
current_user
.
current_tenant_id
,
tenant_id
=
current_user
.
current_tenant_id
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
model_name
=
dataset
.
embedding_model
)
)
except
LLMBadRequestError
:
except
LLMBadRequestError
:
raise
ProviderNotInitializeError
(
raise
ProviderNotInitializeError
(
f
"No Embedding Model available. Please configure a valid provider "
f
"No Embedding Model available. Please configure a valid provider "
f
"in the Settings -> Model Provider."
)
f
"in the Settings -> Model Provider."
)
except
ProviderTokenNotInitError
as
ex
:
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
raise
ProviderNotInitializeError
(
ex
.
description
)
segment
=
DocumentSegment
.
query
.
filter
(
segment
=
DocumentSegment
.
query
.
filter
(
DocumentSegment
.
id
==
str
(
segment_id
),
DocumentSegment
.
id
==
str
(
segment_id
),
...
@@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource):
...
@@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource):
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
raise
Forbidden
()
raise
Forbidden
()
# check embedding model setting
# check embedding model setting
try
:
if
dataset
.
indexing_technique
==
'high_quality'
:
ModelFactory
.
get_embedding_model
(
try
:
tenant_id
=
current_user
.
current_tenant_id
,
ModelFactory
.
get_embedding_model
(
model_provider_name
=
dataset
.
embedding_model_provider
,
tenant_id
=
current_user
.
current_tenant_id
,
model_name
=
dataset
.
embedding_model
model_provider_name
=
dataset
.
embedding_model_provider
,
)
model_name
=
dataset
.
embedding_model
except
LLMBadRequestError
:
)
raise
ProviderNotInitializeError
(
except
LLMBadRequestError
:
f
"No Embedding Model available. Please configure a valid provider "
raise
ProviderNotInitializeError
(
f
"in the Settings -> Model Provider."
)
f
"No Embedding Model available. Please configure a valid provider "
except
ProviderTokenNotInitError
as
ex
:
f
"in the Settings -> Model Provider."
)
raise
ProviderNotInitializeError
(
ex
.
description
)
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
try
:
try
:
DatasetService
.
check_dataset_permission
(
dataset
,
current_user
)
DatasetService
.
check_dataset_permission
(
dataset
,
current_user
)
except
services
.
errors
.
account
.
NoPermissionError
as
e
:
except
services
.
errors
.
account
.
NoPermissionError
as
e
:
...
@@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
...
@@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
if
not
dataset
:
if
not
dataset
:
raise
NotFound
(
'Dataset not found.'
)
raise
NotFound
(
'Dataset not found.'
)
# check user's model setting
DatasetService
.
check_dataset_model_setting
(
dataset
)
# check document
# check document
document_id
=
str
(
document_id
)
document_id
=
str
(
document_id
)
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
if
not
document
:
if
not
document
:
raise
NotFound
(
'Document not found.'
)
raise
NotFound
(
'Document not found.'
)
# check embedding model setting
if
dataset
.
indexing_technique
==
'high_quality'
:
try
:
# check embedding model setting
ModelFactory
.
get_embedding_model
(
try
:
tenant_id
=
current_user
.
current_tenant_id
,
ModelFactory
.
get_embedding_model
(
model_provider_name
=
dataset
.
embedding_model_provider
,
tenant_id
=
current_user
.
current_tenant_id
,
model_name
=
dataset
.
embedding_model
model_provider_name
=
dataset
.
embedding_model_provider
,
)
model_name
=
dataset
.
embedding_model
except
LLMBadRequestError
:
)
raise
ProviderNotInitializeError
(
except
LLMBadRequestError
:
f
"No Embedding Model available. Please configure a valid provider "
raise
ProviderNotInitializeError
(
f
"in the Settings -> Model Provider."
)
f
"No Embedding Model available. Please configure a valid provider "
except
ProviderTokenNotInitError
as
ex
:
f
"in the Settings -> Model Provider."
)
raise
ProviderNotInitializeError
(
ex
.
description
)
except
ProviderTokenNotInitError
as
ex
:
# check segment
raise
ProviderNotInitializeError
(
ex
.
description
)
# check segment
segment_id
=
str
(
segment_id
)
segment_id
=
str
(
segment_id
)
segment
=
DocumentSegment
.
query
.
filter
(
segment
=
DocumentSegment
.
query
.
filter
(
DocumentSegment
.
id
==
str
(
segment_id
),
DocumentSegment
.
id
==
str
(
segment_id
),
...
@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
...
@@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
if
not
dataset
:
if
not
dataset
:
raise
NotFound
(
'Dataset not found.'
)
raise
NotFound
(
'Dataset not found.'
)
# check user's model setting
DatasetService
.
check_dataset_model_setting
(
dataset
)
# check document
# check document
document_id
=
str
(
document_id
)
document_id
=
str
(
document_id
)
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
...
@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
...
@@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
if
not
document
:
if
not
document
:
raise
NotFound
(
'Document not found.'
)
raise
NotFound
(
'Document not found.'
)
try
:
ModelFactory
.
get_embedding_model
(
tenant_id
=
current_user
.
current_tenant_id
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
)
except
LLMBadRequestError
:
raise
ProviderNotInitializeError
(
f
"No Embedding Model available. Please configure a valid provider "
f
"in the Settings -> Model Provider."
)
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
# get file from request
# get file from request
file
=
request
.
files
[
'file'
]
file
=
request
.
files
[
'file'
]
# check file
# check file
...
...
api/core/docstore/dataset_docstore.py
View file @
a55ba6e6
...
@@ -67,12 +67,13 @@ class DatesetDocumentStore:
...
@@ -67,12 +67,13 @@ class DatesetDocumentStore:
if
max_position
is
None
:
if
max_position
is
None
:
max_position
=
0
max_position
=
0
embedding_model
=
None
embedding_model
=
ModelFactory
.
get_embedding_model
(
if
self
.
_dataset
.
indexing_technique
==
'high_quality'
:
tenant_id
=
self
.
_dataset
.
tenant_id
,
embedding_model
=
ModelFactory
.
get_embedding_model
(
model_provider_name
=
self
.
_dataset
.
embedding_model_provider
,
tenant_id
=
self
.
_dataset
.
tenant_id
,
model_name
=
self
.
_dataset
.
embedding_model
model_provider_name
=
self
.
_dataset
.
embedding_model_provider
,
)
model_name
=
self
.
_dataset
.
embedding_model
)
for
doc
in
docs
:
for
doc
in
docs
:
if
not
isinstance
(
doc
,
Document
):
if
not
isinstance
(
doc
,
Document
):
...
@@ -88,7 +89,7 @@ class DatesetDocumentStore:
...
@@ -88,7 +89,7 @@ class DatesetDocumentStore:
)
)
# calc embedding use tokens
# calc embedding use tokens
tokens
=
embedding_model
.
get_num_tokens
(
doc
.
page_content
)
tokens
=
embedding_model
.
get_num_tokens
(
doc
.
page_content
)
if
embedding_model
else
0
if
not
segment_document
:
if
not
segment_document
:
max_position
+=
1
max_position
+=
1
...
...
api/core/index/index.py
View file @
a55ba6e6
import
json
from
flask
import
current_app
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.models.embedding.openai_embedding
import
OpenAIEmbedding
from
core.model_providers.models.entity.model_params
import
ModelKwargs
from
core.model_providers.models.llm.openai_model
import
OpenAIModel
from
core.model_providers.providers.openai_provider
import
OpenAIProvider
from
models.dataset
import
Dataset
from
models.dataset
import
Dataset
from
models.provider
import
Provider
,
ProviderType
class
IndexBuilder
:
class
IndexBuilder
:
...
@@ -35,4 +43,13 @@ class IndexBuilder:
...
@@ -35,4 +43,13 @@ class IndexBuilder:
)
)
)
)
else
:
else
:
raise
ValueError
(
'Unknown indexing technique'
)
raise
ValueError
(
'Unknown indexing technique'
)
\ No newline at end of file
@
classmethod
def
get_default_high_quality_index
(
cls
,
dataset
:
Dataset
):
embeddings
=
OpenAIEmbeddings
(
openai_api_key
=
' '
)
return
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
api/core/indexing_runner.py
View file @
a55ba6e6
...
@@ -217,25 +217,29 @@ class IndexingRunner:
...
@@ -217,25 +217,29 @@ class IndexingRunner:
db
.
session
.
commit
()
db
.
session
.
commit
()
def
file_indexing_estimate
(
self
,
tenant_id
:
str
,
file_details
:
List
[
UploadFile
],
tmp_processing_rule
:
dict
,
def
file_indexing_estimate
(
self
,
tenant_id
:
str
,
file_details
:
List
[
UploadFile
],
tmp_processing_rule
:
dict
,
doc_form
:
str
=
None
,
doc_language
:
str
=
'English'
,
dataset_id
:
str
=
None
)
->
dict
:
doc_form
:
str
=
None
,
doc_language
:
str
=
'English'
,
dataset_id
:
str
=
None
,
indexing_technique
:
str
=
'economy'
)
->
dict
:
"""
"""
Estimate the indexing for the document.
Estimate the indexing for the document.
"""
"""
embedding_model
=
None
if
dataset_id
:
if
dataset_id
:
dataset
=
Dataset
.
query
.
filter_by
(
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_id
id
=
dataset_id
)
.
first
()
)
.
first
()
if
not
dataset
:
if
not
dataset
:
raise
ValueError
(
'Dataset not found.'
)
raise
ValueError
(
'Dataset not found.'
)
embedding_model
=
ModelFactory
.
get_embedding_model
(
if
dataset
.
indexing_technique
==
'high_quality'
or
indexing_technique
==
'high_quality'
:
tenant_id
=
dataset
.
tenant_id
,
embedding_model
=
ModelFactory
.
get_embedding_model
(
model_provider_name
=
dataset
.
embedding_model_provider
,
tenant_id
=
dataset
.
tenant_id
,
model_name
=
dataset
.
embedding_model
model_provider_name
=
dataset
.
embedding_model_provider
,
)
model_name
=
dataset
.
embedding_model
)
else
:
else
:
embedding_model
=
ModelFactory
.
get_embedding_model
(
if
indexing_technique
==
'high_quality'
:
tenant_id
=
tenant_id
embedding_model
=
ModelFactory
.
get_embedding_model
(
)
tenant_id
=
tenant_id
)
tokens
=
0
tokens
=
0
preview_texts
=
[]
preview_texts
=
[]
total_segments
=
0
total_segments
=
0
...
@@ -263,8 +267,8 @@ class IndexingRunner:
...
@@ -263,8 +267,8 @@ class IndexingRunner:
for
document
in
documents
:
for
document
in
documents
:
if
len
(
preview_texts
)
<
5
:
if
len
(
preview_texts
)
<
5
:
preview_texts
.
append
(
document
.
page_content
)
preview_texts
.
append
(
document
.
page_content
)
if
indexing_technique
==
'high_quality'
or
embedding_model
:
tokens
+=
embedding_model
.
get_num_tokens
(
self
.
filter_string
(
document
.
page_content
))
tokens
+=
embedding_model
.
get_num_tokens
(
self
.
filter_string
(
document
.
page_content
))
if
doc_form
and
doc_form
==
'qa_model'
:
if
doc_form
and
doc_form
==
'qa_model'
:
text_generation_model
=
ModelFactory
.
get_text_generation_model
(
text_generation_model
=
ModelFactory
.
get_text_generation_model
(
...
@@ -286,32 +290,35 @@ class IndexingRunner:
...
@@ -286,32 +290,35 @@ class IndexingRunner:
return
{
return
{
"total_segments"
:
total_segments
,
"total_segments"
:
total_segments
,
"tokens"
:
tokens
,
"tokens"
:
tokens
,
"total_price"
:
'{:f}'
.
format
(
embedding_model
.
calc_tokens_price
(
tokens
)),
"total_price"
:
'{:f}'
.
format
(
embedding_model
.
calc_tokens_price
(
tokens
))
if
embedding_model
else
0
,
"currency"
:
embedding_model
.
get_currency
(),
"currency"
:
embedding_model
.
get_currency
()
if
embedding_model
else
'USD'
,
"preview"
:
preview_texts
"preview"
:
preview_texts
}
}
def
notion_indexing_estimate
(
self
,
tenant_id
:
str
,
notion_info_list
:
list
,
tmp_processing_rule
:
dict
,
def
notion_indexing_estimate
(
self
,
tenant_id
:
str
,
notion_info_list
:
list
,
tmp_processing_rule
:
dict
,
doc_form
:
str
=
None
,
doc_language
:
str
=
'English'
,
dataset_id
:
str
=
None
)
->
dict
:
doc_form
:
str
=
None
,
doc_language
:
str
=
'English'
,
dataset_id
:
str
=
None
,
indexing_technique
:
str
=
'economy'
)
->
dict
:
"""
"""
Estimate the indexing for the document.
Estimate the indexing for the document.
"""
"""
embedding_model
=
None
if
dataset_id
:
if
dataset_id
:
dataset
=
Dataset
.
query
.
filter_by
(
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_id
id
=
dataset_id
)
.
first
()
)
.
first
()
if
not
dataset
:
if
not
dataset
:
raise
ValueError
(
'Dataset not found.'
)
raise
ValueError
(
'Dataset not found.'
)
embedding_model
=
ModelFactory
.
get_embedding_model
(
if
dataset
.
indexing_technique
==
'high_quality'
or
indexing_technique
==
'high_quality'
:
tenant_id
=
dataset
.
tenant_id
,
embedding_model
=
ModelFactory
.
get_embedding_model
(
model_provider_name
=
dataset
.
embedding_model_provider
,
tenant_id
=
dataset
.
tenant_id
,
model_name
=
dataset
.
embedding_model
model_provider_name
=
dataset
.
embedding_model_provider
,
)
model_name
=
dataset
.
embedding_model
)
else
:
else
:
embedding_model
=
ModelFactory
.
get_embedding_model
(
if
indexing_technique
==
'high_quality'
:
tenant_id
=
tenant_id
embedding_model
=
ModelFactory
.
get_embedding_model
(
)
tenant_id
=
tenant_id
)
# load data from notion
# load data from notion
tokens
=
0
tokens
=
0
preview_texts
=
[]
preview_texts
=
[]
...
@@ -356,8 +363,8 @@ class IndexingRunner:
...
@@ -356,8 +363,8 @@ class IndexingRunner:
for
document
in
documents
:
for
document
in
documents
:
if
len
(
preview_texts
)
<
5
:
if
len
(
preview_texts
)
<
5
:
preview_texts
.
append
(
document
.
page_content
)
preview_texts
.
append
(
document
.
page_content
)
if
indexing_technique
==
'high_quality'
or
embedding_model
:
tokens
+=
embedding_model
.
get_num_tokens
(
document
.
page_content
)
tokens
+=
embedding_model
.
get_num_tokens
(
document
.
page_content
)
if
doc_form
and
doc_form
==
'qa_model'
:
if
doc_form
and
doc_form
==
'qa_model'
:
text_generation_model
=
ModelFactory
.
get_text_generation_model
(
text_generation_model
=
ModelFactory
.
get_text_generation_model
(
...
@@ -379,8 +386,8 @@ class IndexingRunner:
...
@@ -379,8 +386,8 @@ class IndexingRunner:
return
{
return
{
"total_segments"
:
total_segments
,
"total_segments"
:
total_segments
,
"tokens"
:
tokens
,
"tokens"
:
tokens
,
"total_price"
:
'{:f}'
.
format
(
embedding_model
.
calc_tokens_price
(
tokens
)),
"total_price"
:
'{:f}'
.
format
(
embedding_model
.
calc_tokens_price
(
tokens
))
if
embedding_model
else
0
,
"currency"
:
embedding_model
.
get_currency
(),
"currency"
:
embedding_model
.
get_currency
()
if
embedding_model
else
'USD'
,
"preview"
:
preview_texts
"preview"
:
preview_texts
}
}
...
@@ -657,12 +664,13 @@ class IndexingRunner:
...
@@ -657,12 +664,13 @@ class IndexingRunner:
"""
"""
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
keyword_table_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
keyword_table_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
embedding_model
=
None
embedding_model
=
ModelFactory
.
get_embedding_model
(
if
dataset
.
indexing_technique
==
'high_quality'
:
tenant_id
=
dataset
.
tenant_id
,
embedding_model
=
ModelFactory
.
get_embedding_model
(
model_provider_name
=
dataset
.
embedding_model_provider
,
tenant_id
=
dataset
.
tenant_id
,
model_name
=
dataset
.
embedding_model
model_provider_name
=
dataset
.
embedding_model_provider
,
)
model_name
=
dataset
.
embedding_model
)
# chunk nodes by chunk size
# chunk nodes by chunk size
indexing_start_at
=
time
.
perf_counter
()
indexing_start_at
=
time
.
perf_counter
()
...
@@ -672,11 +680,11 @@ class IndexingRunner:
...
@@ -672,11 +680,11 @@ class IndexingRunner:
# check document is paused
# check document is paused
self
.
_check_document_paused_status
(
dataset_document
.
id
)
self
.
_check_document_paused_status
(
dataset_document
.
id
)
chunk_documents
=
documents
[
i
:
i
+
chunk_size
]
chunk_documents
=
documents
[
i
:
i
+
chunk_size
]
if
dataset
.
indexing_technique
==
'high_quality'
or
embedding_model
:
tokens
+=
sum
(
tokens
+=
sum
(
embedding_model
.
get_num_tokens
(
document
.
page_content
)
embedding_model
.
get_num_tokens
(
document
.
page_content
)
for
document
in
chunk_documents
for
document
in
chunk_documents
)
)
# save vector index
# save vector index
if
vector_index
:
if
vector_index
:
...
...
api/events/event_handlers/create_document_index.py
View file @
a55ba6e6
from
events.dataset_event
import
dataset_was_deleted
from
events.dataset_event
import
dataset_was_deleted
from
events.event_handlers.document_index_event
import
document_index_created
from
events.event_handlers.document_index_event
import
document_index_created
from
tasks.clean_dataset_task
import
clean_dataset_task
import
datetime
import
datetime
import
logging
import
logging
import
time
import
time
...
...
api/migrations/versions/4bcffcd64aa4_update_dataset_model_field_null_.py
0 → 100644
View file @
a55ba6e6
"""update_dataset_model_field_null_available
Revision ID: 4bcffcd64aa4
Revises: 853f9b9cd3b6
Create Date: 2023-08-28 20:58:50.077056
"""
from
alembic
import
op
import
sqlalchemy
as
sa
# revision identifiers, used by Alembic.
revision
=
'4bcffcd64aa4'
down_revision
=
'853f9b9cd3b6'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'datasets'
,
schema
=
None
)
as
batch_op
:
batch_op
.
alter_column
(
'embedding_model'
,
existing_type
=
sa
.
VARCHAR
(
length
=
255
),
nullable
=
True
,
existing_server_default
=
sa
.
text
(
"'text-embedding-ada-002'::character varying"
))
batch_op
.
alter_column
(
'embedding_model_provider'
,
existing_type
=
sa
.
VARCHAR
(
length
=
255
),
nullable
=
True
,
existing_server_default
=
sa
.
text
(
"'openai'::character varying"
))
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'datasets'
,
schema
=
None
)
as
batch_op
:
batch_op
.
alter_column
(
'embedding_model_provider'
,
existing_type
=
sa
.
VARCHAR
(
length
=
255
),
nullable
=
False
,
existing_server_default
=
sa
.
text
(
"'openai'::character varying"
))
batch_op
.
alter_column
(
'embedding_model'
,
existing_type
=
sa
.
VARCHAR
(
length
=
255
),
nullable
=
False
,
existing_server_default
=
sa
.
text
(
"'text-embedding-ada-002'::character varying"
))
# ### end Alembic commands ###
api/models/dataset.py
View file @
a55ba6e6
...
@@ -36,10 +36,8 @@ class Dataset(db.Model):
...
@@ -36,10 +36,8 @@ class Dataset(db.Model):
updated_by
=
db
.
Column
(
UUID
,
nullable
=
True
)
updated_by
=
db
.
Column
(
UUID
,
nullable
=
True
)
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
embedding_model
=
db
.
Column
(
db
.
String
(
embedding_model
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
True
)
255
),
nullable
=
False
,
server_default
=
db
.
text
(
"'text-embedding-ada-002'::character varying"
))
embedding_model_provider
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
True
)
embedding_model_provider
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
False
,
server_default
=
db
.
text
(
"'openai'::character varying"
))
@
property
@
property
def
dataset_keyword_table
(
self
):
def
dataset_keyword_table
(
self
):
...
...
api/services/dataset_service.py
View file @
a55ba6e6
...
@@ -10,6 +10,7 @@ from flask import current_app
...
@@ -10,6 +10,7 @@ from flask import current_app
from
sqlalchemy
import
func
from
sqlalchemy
import
func
from
core.index.index
import
IndexBuilder
from
core.index.index
import
IndexBuilder
from
core.model_providers.error
import
LLMBadRequestError
,
ProviderTokenNotInitError
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.model_factory
import
ModelFactory
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
flask_login
import
current_user
from
flask_login
import
current_user
...
@@ -91,16 +92,18 @@ class DatasetService:
...
@@ -91,16 +92,18 @@ class DatasetService:
if
Dataset
.
query
.
filter_by
(
name
=
name
,
tenant_id
=
tenant_id
)
.
first
():
if
Dataset
.
query
.
filter_by
(
name
=
name
,
tenant_id
=
tenant_id
)
.
first
():
raise
DatasetNameDuplicateError
(
raise
DatasetNameDuplicateError
(
f
'Dataset with name {name} already exists.'
)
f
'Dataset with name {name} already exists.'
)
embedding_model
=
ModelFactory
.
get_embedding_model
(
embedding_model
=
None
tenant_id
=
current_user
.
current_tenant_id
if
indexing_technique
==
'high_quality'
:
)
embedding_model
=
ModelFactory
.
get_embedding_model
(
tenant_id
=
current_user
.
current_tenant_id
)
dataset
=
Dataset
(
name
=
name
,
indexing_technique
=
indexing_technique
)
dataset
=
Dataset
(
name
=
name
,
indexing_technique
=
indexing_technique
)
# dataset = Dataset(name=name, provider=provider, config=config)
# dataset = Dataset(name=name, provider=provider, config=config)
dataset
.
created_by
=
account
.
id
dataset
.
created_by
=
account
.
id
dataset
.
updated_by
=
account
.
id
dataset
.
updated_by
=
account
.
id
dataset
.
tenant_id
=
tenant_id
dataset
.
tenant_id
=
tenant_id
dataset
.
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
dataset
.
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
if
embedding_model
else
None
dataset
.
embedding_model
=
embedding_model
.
name
dataset
.
embedding_model
=
embedding_model
.
name
if
embedding_model
else
None
db
.
session
.
add
(
dataset
)
db
.
session
.
add
(
dataset
)
db
.
session
.
commit
()
db
.
session
.
commit
()
return
dataset
return
dataset
...
@@ -115,6 +118,23 @@ class DatasetService:
...
@@ -115,6 +118,23 @@ class DatasetService:
else
:
else
:
return
dataset
return
dataset
@
staticmethod
def
check_dataset_model_setting
(
dataset
):
if
dataset
.
indexing_technique
==
'high_quality'
:
try
:
ModelFactory
.
get_embedding_model
(
tenant_id
=
dataset
.
tenant_id
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
)
except
LLMBadRequestError
:
raise
ValueError
(
f
"No Embedding Model available. Please configure a valid provider "
f
"in the Settings -> Model Provider."
)
except
ProviderTokenNotInitError
as
ex
:
raise
ValueError
(
f
"The dataset in unavailable, due to: "
f
"{ex.description}"
)
@
staticmethod
@
staticmethod
def
update_dataset
(
dataset_id
,
data
,
user
):
def
update_dataset
(
dataset_id
,
data
,
user
):
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
...
@@ -124,6 +144,19 @@ class DatasetService:
...
@@ -124,6 +144,19 @@ class DatasetService:
if
data
[
'indexing_technique'
]
==
'economy'
:
if
data
[
'indexing_technique'
]
==
'economy'
:
deal_dataset_vector_index_task
.
delay
(
dataset_id
,
'remove'
)
deal_dataset_vector_index_task
.
delay
(
dataset_id
,
'remove'
)
elif
data
[
'indexing_technique'
]
==
'high_quality'
:
elif
data
[
'indexing_technique'
]
==
'high_quality'
:
# check embedding model setting
try
:
ModelFactory
.
get_embedding_model
(
tenant_id
=
current_user
.
current_tenant_id
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
)
except
LLMBadRequestError
:
raise
ValueError
(
f
"No Embedding Model available. Please configure a valid provider "
f
"in the Settings -> Model Provider."
)
except
ProviderTokenNotInitError
as
ex
:
raise
ValueError
(
ex
.
description
)
deal_dataset_vector_index_task
.
delay
(
dataset_id
,
'add'
)
deal_dataset_vector_index_task
.
delay
(
dataset_id
,
'add'
)
filtered_data
=
{
k
:
v
for
k
,
v
in
data
.
items
()
if
v
is
not
None
or
k
==
'description'
}
filtered_data
=
{
k
:
v
for
k
,
v
in
data
.
items
()
if
v
is
not
None
or
k
==
'description'
}
...
@@ -397,23 +430,23 @@ class DocumentService:
...
@@ -397,23 +430,23 @@ class DocumentService:
# check document limit
# check document limit
if
current_app
.
config
[
'EDITION'
]
==
'CLOUD'
:
if
current_app
.
config
[
'EDITION'
]
==
'CLOUD'
:
count
=
0
if
'original_document_id'
not
in
document_data
or
not
document_data
[
'original_document_id'
]:
if
document_data
[
"data_source"
][
"type"
]
==
"upload_file"
:
count
=
0
upload_file_list
=
document_data
[
"data_source"
][
"info_list"
][
'file_info_list'
][
'file_ids'
]
if
document_data
[
"data_source"
][
"type"
]
==
"upload_file"
:
count
=
len
(
upload_file_list
)
upload_file_list
=
document_data
[
"data_source"
][
"info_list"
][
'file_info_list'
][
'file_ids'
]
elif
document_data
[
"data_source"
][
"type"
]
==
"notion_import"
:
count
=
len
(
upload_file_list
)
notion_info_list
=
document_data
[
"data_source"
][
'info_list'
][
'notion_info_list'
]
elif
document_data
[
"data_source"
][
"type"
]
==
"notion_import"
:
for
notion_info
in
notion_info_list
:
notion_info_list
=
document_data
[
"data_source"
][
'info_list'
][
'notion_info_list'
]
count
=
count
+
len
(
notion_info
[
'pages'
])
for
notion_info
in
notion_info_list
:
documents_count
=
DocumentService
.
get_tenant_documents_count
()
count
=
count
+
len
(
notion_info
[
'pages'
])
total_count
=
documents_count
+
count
documents_count
=
DocumentService
.
get_tenant_documents_count
()
tenant_document_count
=
int
(
current_app
.
config
[
'TENANT_DOCUMENT_COUNT'
])
total_count
=
documents_count
+
count
if
total_count
>
tenant_document_count
:
tenant_document_count
=
int
(
current_app
.
config
[
'TENANT_DOCUMENT_COUNT'
])
raise
ValueError
(
f
"over document limit {tenant_document_count}."
)
if
total_count
>
tenant_document_count
:
raise
ValueError
(
f
"over document limit {tenant_document_count}."
)
# if dataset is empty, update dataset data_source_type
# if dataset is empty, update dataset data_source_type
if
not
dataset
.
data_source_type
:
if
not
dataset
.
data_source_type
:
dataset
.
data_source_type
=
document_data
[
"data_source"
][
"type"
]
dataset
.
data_source_type
=
document_data
[
"data_source"
][
"type"
]
db
.
session
.
commit
()
if
not
dataset
.
indexing_technique
:
if
not
dataset
.
indexing_technique
:
if
'indexing_technique'
not
in
document_data
\
if
'indexing_technique'
not
in
document_data
\
...
@@ -421,6 +454,13 @@ class DocumentService:
...
@@ -421,6 +454,13 @@ class DocumentService:
raise
ValueError
(
"Indexing technique is required"
)
raise
ValueError
(
"Indexing technique is required"
)
dataset
.
indexing_technique
=
document_data
[
"indexing_technique"
]
dataset
.
indexing_technique
=
document_data
[
"indexing_technique"
]
if
document_data
[
"indexing_technique"
]
==
'high_quality'
:
embedding_model
=
ModelFactory
.
get_embedding_model
(
tenant_id
=
dataset
.
tenant_id
)
dataset
.
embedding_model
=
embedding_model
.
name
dataset
.
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
documents
=
[]
documents
=
[]
batch
=
time
.
strftime
(
'
%
Y
%
m
%
d
%
H
%
M
%
S'
)
+
str
(
random
.
randint
(
100000
,
999999
))
batch
=
time
.
strftime
(
'
%
Y
%
m
%
d
%
H
%
M
%
S'
)
+
str
(
random
.
randint
(
100000
,
999999
))
...
@@ -466,11 +506,11 @@ class DocumentService:
...
@@ -466,11 +506,11 @@ class DocumentService:
"upload_file_id"
:
file_id
,
"upload_file_id"
:
file_id
,
}
}
document
=
DocumentService
.
build_document
(
dataset
,
dataset_process_rule
.
id
,
document
=
DocumentService
.
build_document
(
dataset
,
dataset_process_rule
.
id
,
document_data
[
"data_source"
][
"type"
],
document_data
[
"data_source"
][
"type"
],
document_data
[
"doc_form"
],
document_data
[
"doc_form"
],
document_data
[
"doc_language"
],
document_data
[
"doc_language"
],
data_source_info
,
created_from
,
position
,
data_source_info
,
created_from
,
position
,
account
,
file_name
,
batch
)
account
,
file_name
,
batch
)
db
.
session
.
add
(
document
)
db
.
session
.
add
(
document
)
db
.
session
.
flush
()
db
.
session
.
flush
()
document_ids
.
append
(
document
.
id
)
document_ids
.
append
(
document
.
id
)
...
@@ -512,11 +552,11 @@ class DocumentService:
...
@@ -512,11 +552,11 @@ class DocumentService:
"type"
:
page
[
'type'
]
"type"
:
page
[
'type'
]
}
}
document
=
DocumentService
.
build_document
(
dataset
,
dataset_process_rule
.
id
,
document
=
DocumentService
.
build_document
(
dataset
,
dataset_process_rule
.
id
,
document_data
[
"data_source"
][
"type"
],
document_data
[
"data_source"
][
"type"
],
document_data
[
"doc_form"
],
document_data
[
"doc_form"
],
document_data
[
"doc_language"
],
document_data
[
"doc_language"
],
data_source_info
,
created_from
,
position
,
data_source_info
,
created_from
,
position
,
account
,
page
[
'page_name'
],
batch
)
account
,
page
[
'page_name'
],
batch
)
db
.
session
.
add
(
document
)
db
.
session
.
add
(
document
)
db
.
session
.
flush
()
db
.
session
.
flush
()
document_ids
.
append
(
document
.
id
)
document_ids
.
append
(
document
.
id
)
...
@@ -536,9 +576,9 @@ class DocumentService:
...
@@ -536,9 +576,9 @@ class DocumentService:
@
staticmethod
@
staticmethod
def
build_document
(
dataset
:
Dataset
,
process_rule_id
:
str
,
data_source_type
:
str
,
document_form
:
str
,
def
build_document
(
dataset
:
Dataset
,
process_rule_id
:
str
,
data_source_type
:
str
,
document_form
:
str
,
document_language
:
str
,
data_source_info
:
dict
,
created_from
:
str
,
position
:
int
,
document_language
:
str
,
data_source_info
:
dict
,
created_from
:
str
,
position
:
int
,
account
:
Account
,
account
:
Account
,
name
:
str
,
batch
:
str
):
name
:
str
,
batch
:
str
):
document
=
Document
(
document
=
Document
(
tenant_id
=
dataset
.
tenant_id
,
tenant_id
=
dataset
.
tenant_id
,
dataset_id
=
dataset
.
id
,
dataset_id
=
dataset
.
id
,
...
@@ -567,6 +607,7 @@ class DocumentService:
...
@@ -567,6 +607,7 @@ class DocumentService:
def
update_document_with_dataset_id
(
dataset
:
Dataset
,
document_data
:
dict
,
def
update_document_with_dataset_id
(
dataset
:
Dataset
,
document_data
:
dict
,
account
:
Account
,
dataset_process_rule
:
Optional
[
DatasetProcessRule
]
=
None
,
account
:
Account
,
dataset_process_rule
:
Optional
[
DatasetProcessRule
]
=
None
,
created_from
:
str
=
'web'
):
created_from
:
str
=
'web'
):
DatasetService
.
check_dataset_model_setting
(
dataset
)
document
=
DocumentService
.
get_document
(
dataset
.
id
,
document_data
[
"original_document_id"
])
document
=
DocumentService
.
get_document
(
dataset
.
id
,
document_data
[
"original_document_id"
])
if
document
.
display_status
!=
'available'
:
if
document
.
display_status
!=
'available'
:
raise
ValueError
(
"Document is not available"
)
raise
ValueError
(
"Document is not available"
)
...
@@ -674,9 +715,11 @@ class DocumentService:
...
@@ -674,9 +715,11 @@ class DocumentService:
tenant_document_count
=
int
(
current_app
.
config
[
'TENANT_DOCUMENT_COUNT'
])
tenant_document_count
=
int
(
current_app
.
config
[
'TENANT_DOCUMENT_COUNT'
])
if
total_count
>
tenant_document_count
:
if
total_count
>
tenant_document_count
:
raise
ValueError
(
f
"All your documents have overed limit {tenant_document_count}."
)
raise
ValueError
(
f
"All your documents have overed limit {tenant_document_count}."
)
embedding_model
=
ModelFactory
.
get_embedding_model
(
embedding_model
=
None
tenant_id
=
tenant_id
if
document_data
[
'indexing_technique'
]
==
'high_quality'
:
)
embedding_model
=
ModelFactory
.
get_embedding_model
(
tenant_id
=
tenant_id
)
# save dataset
# save dataset
dataset
=
Dataset
(
dataset
=
Dataset
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
...
@@ -684,8 +727,8 @@ class DocumentService:
...
@@ -684,8 +727,8 @@ class DocumentService:
data_source_type
=
document_data
[
"data_source"
][
"type"
],
data_source_type
=
document_data
[
"data_source"
][
"type"
],
indexing_technique
=
document_data
[
"indexing_technique"
],
indexing_technique
=
document_data
[
"indexing_technique"
],
created_by
=
account
.
id
,
created_by
=
account
.
id
,
embedding_model
=
embedding_model
.
name
,
embedding_model
=
embedding_model
.
name
if
embedding_model
else
None
,
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
if
embedding_model
else
None
)
)
db
.
session
.
add
(
dataset
)
db
.
session
.
add
(
dataset
)
...
@@ -903,15 +946,15 @@ class SegmentService:
...
@@ -903,15 +946,15 @@ class SegmentService:
content
=
args
[
'content'
]
content
=
args
[
'content'
]
doc_id
=
str
(
uuid
.
uuid4
())
doc_id
=
str
(
uuid
.
uuid4
())
segment_hash
=
helper
.
generate_text_hash
(
content
)
segment_hash
=
helper
.
generate_text_hash
(
content
)
tokens
=
0
embedding_model
=
ModelFactory
.
get_embedding_model
(
if
dataset
.
indexing_technique
==
'high_quality'
:
tenant_id
=
dataset
.
tenant_id
,
embedding_model
=
ModelFactory
.
get_embedding_model
(
model_provider_name
=
dataset
.
embedding_model_provider
,
tenant_id
=
dataset
.
tenant_id
,
model_name
=
dataset
.
embedding_model
model_provider_name
=
dataset
.
embedding_model_provider
,
)
model_name
=
dataset
.
embedding_model
)
# calc embedding use tokens
# calc embedding use tokens
tokens
=
embedding_model
.
get_num_tokens
(
content
)
tokens
=
embedding_model
.
get_num_tokens
(
content
)
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
DocumentSegment
.
document_id
==
document
.
id
)
.
scalar
()
)
.
scalar
()
...
@@ -973,15 +1016,16 @@ class SegmentService:
...
@@ -973,15 +1016,16 @@ class SegmentService:
kw_index
.
update_segment_keywords_index
(
segment
.
index_node_id
,
segment
.
keywords
)
kw_index
.
update_segment_keywords_index
(
segment
.
index_node_id
,
segment
.
keywords
)
else
:
else
:
segment_hash
=
helper
.
generate_text_hash
(
content
)
segment_hash
=
helper
.
generate_text_hash
(
content
)
tokens
=
0
if
dataset
.
indexing_technique
==
'high_quality'
:
embedding_model
=
ModelFactory
.
get_embedding_model
(
tenant_id
=
dataset
.
tenant_id
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
)
embedding_model
=
ModelFactory
.
get_embedding_model
(
# calc embedding use tokens
tenant_id
=
dataset
.
tenant_id
,
tokens
=
embedding_model
.
get_num_tokens
(
content
)
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
)
# calc embedding use tokens
tokens
=
embedding_model
.
get_num_tokens
(
content
)
segment
.
content
=
content
segment
.
content
=
content
segment
.
index_node_hash
=
segment_hash
segment
.
index_node_hash
=
segment_hash
segment
.
word_count
=
len
(
content
)
segment
.
word_count
=
len
(
content
)
...
@@ -1013,7 +1057,7 @@ class SegmentService:
...
@@ -1013,7 +1057,7 @@ class SegmentService:
cache_result
=
redis_client
.
get
(
indexing_cache_key
)
cache_result
=
redis_client
.
get
(
indexing_cache_key
)
if
cache_result
is
not
None
:
if
cache_result
is
not
None
:
raise
ValueError
(
"Segment is deleting."
)
raise
ValueError
(
"Segment is deleting."
)
# enabled segment need to delete index
# enabled segment need to delete index
if
segment
.
enabled
:
if
segment
.
enabled
:
# send delete segment index task
# send delete segment index task
...
...
api/tasks/batch_create_segment_to_index_task.py
View file @
a55ba6e6
...
@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
...
@@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
if
not
dataset_document
.
enabled
or
dataset_document
.
archived
or
dataset_document
.
indexing_status
!=
'completed'
:
if
not
dataset_document
.
enabled
or
dataset_document
.
archived
or
dataset_document
.
indexing_status
!=
'completed'
:
raise
ValueError
(
'Document is not available.'
)
raise
ValueError
(
'Document is not available.'
)
document_segments
=
[]
document_segments
=
[]
for
segment
in
content
:
embedding_model
=
None
content
=
segment
[
'content'
]
if
dataset
.
indexing_technique
==
'high_quality'
:
doc_id
=
str
(
uuid
.
uuid4
())
segment_hash
=
helper
.
generate_text_hash
(
content
)
embedding_model
=
ModelFactory
.
get_embedding_model
(
embedding_model
=
ModelFactory
.
get_embedding_model
(
tenant_id
=
dataset
.
tenant_id
,
tenant_id
=
dataset
.
tenant_id
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_provider_name
=
dataset
.
embedding_model_provider
,
model_name
=
dataset
.
embedding_model
model_name
=
dataset
.
embedding_model
)
)
for
segment
in
content
:
content
=
segment
[
'content'
]
doc_id
=
str
(
uuid
.
uuid4
())
segment_hash
=
helper
.
generate_text_hash
(
content
)
# calc embedding use tokens
# calc embedding use tokens
tokens
=
embedding_model
.
get_num_tokens
(
content
)
tokens
=
embedding_model
.
get_num_tokens
(
content
)
if
embedding_model
else
0
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
DocumentSegment
.
document_id
==
dataset_document
.
id
DocumentSegment
.
document_id
==
dataset_document
.
id
)
.
scalar
()
)
.
scalar
()
...
...
api/tasks/clean_dataset_task.py
View file @
a55ba6e6
...
@@ -3,8 +3,10 @@ import time
...
@@ -3,8 +3,10 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
flask
import
current_app
from
core.index.index
import
IndexBuilder
from
core.index.index
import
IndexBuilder
from
core.index.vector_index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
,
DatasetKeywordTable
,
DatasetQuery
,
DatasetProcessRule
,
\
from
models.dataset
import
DocumentSegment
,
Dataset
,
DatasetKeywordTable
,
DatasetQuery
,
DatasetProcessRule
,
\
AppDatasetJoin
,
Document
AppDatasetJoin
,
Document
...
@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
...
@@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
documents
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
dataset_id
==
dataset_id
)
.
all
()
documents
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
dataset_id
==
dataset_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
kw_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
kw_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
# delete from vector index
if
vector_index
:
if
dataset
.
indexing_technique
==
'high_quality'
:
vector_index
=
IndexBuilder
.
get_default_high_quality_index
(
dataset
)
try
:
try
:
vector_index
.
delete
()
vector_index
.
delete
()
except
Exception
:
except
Exception
:
...
...
api/tasks/deal_dataset_vector_index_task.py
View file @
a55ba6e6
...
@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
...
@@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
raise
Exception
(
'Dataset not found'
)
raise
Exception
(
'Dataset not found'
)
if
action
==
"remove"
:
if
action
==
"remove"
:
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
,
ignore_high_quality_check
=
Tru
e
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
,
ignore_high_quality_check
=
Fals
e
)
index
.
delete
()
index
.
delete
()
elif
action
==
"add"
:
elif
action
==
"add"
:
dataset_documents
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
dataset_documents
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
...
@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
...
@@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if
dataset_documents
:
if
dataset_documents
:
# save vector index
# save vector index
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
,
ignore_high_quality_check
=
Tru
e
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
,
ignore_high_quality_check
=
Fals
e
)
documents
=
[]
documents
=
[]
for
dataset_document
in
dataset_documents
:
for
dataset_document
in
dataset_documents
:
# delete from vector index
# delete from vector index
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment