Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
01baae87
Commit
01baae87
authored
Jun 20, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: recreate dataset when origin dataset format
parent
f33056f4
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
108 additions
and
64 deletions
+108
-64
base.py
api/core/index/base.py
+6
-0
keyword_table_index.py
api/core/index/keyword_table_index/keyword_table_index.py
+11
-11
base.py
api/core/index/vector_index/base.py
+60
-1
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+12
-11
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+13
-12
indexing_runner.py
api/core/indexing_runner.py
+5
-28
document_indexing_update_task.py
api/tasks/document_indexing_update_task.py
+1
-1
No files found.
api/core/index/base.py
View file @
01baae87
...
...
@@ -4,8 +4,14 @@ from typing import List, Any
from
langchain.schema
import
Document
,
BaseRetriever
from
models.dataset
import
Dataset
class
BaseIndex
(
ABC
):
def
__init__
(
self
,
dataset
:
Dataset
):
self
.
dataset
=
dataset
@
abstractmethod
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
raise
NotImplementedError
...
...
api/core/index/keyword_table_index/keyword_table_index.py
View file @
01baae87
...
...
@@ -17,7 +17,7 @@ class KeywordTableConfig(BaseModel):
class
KeywordTableIndex
(
BaseIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
KeywordTableConfig
=
KeywordTableConfig
()):
s
elf
.
_dataset
=
dataset
s
uper
()
.
__init__
(
dataset
)
self
.
_config
=
config
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
...
...
@@ -29,11 +29,11 @@ class KeywordTableIndex(BaseIndex):
keyword_table
=
self
.
_add_text_to_keyword_table
(
keyword_table
,
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
_
dataset
.
id
,
dataset_id
=
self
.
dataset
.
id
,
keyword_table
=
json
.
dumps
({
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
_
dataset
.
id
,
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
{}
}
...
...
@@ -70,7 +70,7 @@ class KeywordTableIndex(BaseIndex):
def
delete_by_document_id
(
self
,
document_id
:
str
):
# get segment ids by document_id
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_
dataset
.
id
,
DocumentSegment
.
dataset_id
==
self
.
dataset
.
id
,
DocumentSegment
.
document_id
==
document_id
)
.
all
()
...
...
@@ -98,7 +98,7 @@ class KeywordTableIndex(BaseIndex):
documents
=
[]
for
chunk_index
in
sorted_chunk_indices
:
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_
dataset
.
id
,
DocumentSegment
.
dataset_id
==
self
.
dataset
.
id
,
DocumentSegment
.
index_node_id
==
chunk_index
)
.
first
()
...
...
@@ -115,7 +115,7 @@ class KeywordTableIndex(BaseIndex):
return
documents
def
delete
(
self
)
->
None
:
dataset_keyword_table
=
self
.
_
dataset
.
dataset_keyword_table
dataset_keyword_table
=
self
.
dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
db
.
session
.
delete
(
dataset_keyword_table
)
db
.
session
.
commit
()
...
...
@@ -124,26 +124,26 @@ class KeywordTableIndex(BaseIndex):
keyword_table_dict
=
{
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
_
dataset
.
id
,
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
keyword_table
}
}
self
.
_
dataset
.
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
keyword_table_dict
,
cls
=
SetEncoder
)
self
.
dataset
.
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
keyword_table_dict
,
cls
=
SetEncoder
)
db
.
session
.
commit
()
def
_get_dataset_keyword_table
(
self
)
->
Optional
[
dict
]:
dataset_keyword_table
=
self
.
_
dataset
.
dataset_keyword_table
dataset_keyword_table
=
self
.
dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
if
dataset_keyword_table
.
keyword_table_dict
:
return
dataset_keyword_table
.
keyword_table_dict
[
'__data__'
][
'table'
]
else
:
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
_
dataset
.
id
,
dataset_id
=
self
.
dataset
.
id
,
keyword_table
=
json
.
dumps
({
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
_
dataset
.
id
,
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
{}
}
...
...
api/core/index/vector_index/base.py
View file @
01baae87
import
json
import
logging
from
abc
import
abstractmethod
from
typing
import
List
,
Any
,
Tuple
,
cast
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
core.index.base
import
BaseIndex
from
models.dataset
import
Dataset
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DocumentSegment
from
models.dataset
import
Document
as
DatasetDocument
class
BaseVectorIndex
(
BaseIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
)
self
.
_embeddings
=
embeddings
self
.
_vector_store
=
None
def
get_type
(
self
)
->
str
:
raise
NotImplementedError
...
...
@@ -69,6 +80,9 @@ class BaseVectorIndex(BaseIndex):
return
vector_store
.
as_retriever
(
**
kwargs
)
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
...
...
@@ -85,6 +99,9 @@ class BaseVectorIndex(BaseIndex):
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
...
...
@@ -96,3 +113,45 @@ class BaseVectorIndex(BaseIndex):
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
.
delete
()
def
_is_origin
(
self
):
return
False
def
recreate_dataset
(
self
,
dataset
:
Dataset
):
logging
.
debug
(
f
"Recreating dataset {dataset.id}"
)
self
.
delete
()
dataset_documents
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
DatasetDocument
.
dataset_id
==
dataset
.
id
,
DatasetDocument
.
indexing_status
==
'completed'
,
DatasetDocument
.
enabled
==
True
,
DatasetDocument
.
archived
==
False
,
)
.
all
()
documents
=
[]
for
dataset_document
in
dataset_documents
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
dataset_document
.
id
,
DocumentSegment
.
status
==
'completed'
,
DocumentSegment
.
enabled
==
True
)
.
all
()
for
segment
in
segments
:
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
documents
.
append
(
document
)
self
.
create
(
documents
)
dataset
.
index_struct
=
json
.
dumps
(
self
.
to_index_struct
())
db
.
session
.
commit
()
self
.
dataset
=
dataset
api/core/index/vector_index/qdrant_vector_index.py
View file @
01baae87
...
...
@@ -36,17 +36,15 @@ class QdrantConfig(BaseModel):
class
QdrantVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
QdrantConfig
,
embeddings
:
Embeddings
):
s
elf
.
_dataset
=
dataset
s
uper
()
.
__init__
(
dataset
,
embeddings
)
self
.
_client_config
=
config
self
.
_embeddings
=
embeddings
self
.
_vector_store
=
None
def
get_type
(
self
)
->
str
:
return
'qdrant'
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
if
self
.
_
dataset
.
index_struct_dict
:
return
self
.
_
dataset
.
index_struct_dict
[
'vector_store'
][
'collection_name'
]
if
self
.
dataset
.
index_struct_dict
:
return
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'collection_name'
]
dataset_id
=
dataset
.
id
return
"Index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
...
...
@@ -54,7 +52,7 @@ class QdrantVectorIndex(BaseVectorIndex):
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
_
dataset
)}
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
dataset
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
...
...
@@ -62,7 +60,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self
.
_vector_store
=
QdrantVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
collection_name
=
self
.
get_index_name
(
self
.
_
dataset
),
collection_name
=
self
.
get_index_name
(
self
.
dataset
),
ids
=
uuids
,
content_payload_key
=
'text'
,
**
self
.
_client_config
.
to_qdrant_params
()
...
...
@@ -81,7 +79,7 @@ class QdrantVectorIndex(BaseVectorIndex):
return
QdrantVectorStore
(
client
=
client
,
collection_name
=
self
.
get_index_name
(
self
.
_
dataset
),
collection_name
=
self
.
get_index_name
(
self
.
dataset
),
embeddings
=
self
.
_embeddings
,
content_payload_key
=
'text'
)
...
...
@@ -90,6 +88,9 @@ class QdrantVectorIndex(BaseVectorIndex):
return
QdrantVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
...
...
@@ -98,15 +99,15 @@ class QdrantVectorIndex(BaseVectorIndex):
vector_store
.
del_texts
(
models
.
Filter
(
must
=
[
models
.
FieldCondition
(
key
=
"
doc_id"
if
self
.
_is_origin
()
else
"
metadata.document_id"
,
key
=
"metadata.document_id"
,
match
=
models
.
MatchValue
(
value
=
document_id
),
),
],
))
def
_is_origin
(
self
):
if
self
.
_
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
_
dataset
.
index_struct_dict
[
'vector_store'
][
'collection_name'
]
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'collection_name'
]
if
not
class_prefix
.
strip
(
'Vector_'
):
# original class_prefix
return
True
...
...
api/core/index/vector_index/weaviate_vector_index.py
View file @
01baae87
...
...
@@ -26,10 +26,8 @@ class WeaviateConfig(BaseModel):
class
WeaviateVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
WeaviateConfig
,
embeddings
:
Embeddings
):
s
elf
.
_dataset
=
dataset
s
uper
()
.
__init__
(
dataset
,
embeddings
)
self
.
_client
=
self
.
_init_client
(
config
)
self
.
_embeddings
=
embeddings
self
.
_vector_store
=
None
def
_init_client
(
self
,
config
:
WeaviateConfig
)
->
weaviate
.
Client
:
auth_config
=
weaviate
.
auth
.
AuthApiKey
(
api_key
=
config
.
api_key
)
...
...
@@ -59,8 +57,8 @@ class WeaviateVectorIndex(BaseVectorIndex):
return
'weaviate'
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
if
self
.
_
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
_
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
class_prefix
+=
'_Node'
...
...
@@ -73,7 +71,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
_
dataset
)}
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
dataset
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
...
...
@@ -82,7 +80,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
texts
,
self
.
_embeddings
,
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
_
dataset
),
index_name
=
self
.
get_index_name
(
self
.
dataset
),
uuids
=
uuids
,
by_text
=
False
)
...
...
@@ -96,11 +94,11 @@ class WeaviateVectorIndex(BaseVectorIndex):
attributes
=
[
'doc_id'
,
'dataset_id'
,
'document_id'
]
if
self
.
_is_origin
():
attributes
=
[
'doc_id'
,
'ref_doc_id'
]
attributes
=
[
'doc_id'
]
return
WeaviateVectorStore
(
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
_
dataset
),
index_name
=
self
.
get_index_name
(
self
.
dataset
),
text_key
=
'text'
,
embedding
=
self
.
_embeddings
,
attributes
=
attributes
,
...
...
@@ -111,18 +109,21 @@ class WeaviateVectorIndex(BaseVectorIndex):
return
WeaviateVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
.
del_texts
({
"operator"
:
"Equal"
,
"path"
:
[
"doc
_id"
if
self
.
_is_origin
()
else
"doc
ument_id"
],
"path"
:
[
"document_id"
],
"valueText"
:
document_id
})
def
_is_origin
(
self
):
if
self
.
_
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
_
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
return
True
...
...
api/core/indexing_runner.py
View file @
01baae87
...
...
@@ -488,28 +488,8 @@ class IndexingRunner:
"""
Build the index for the document.
"""
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
10
)
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
keyword_table_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# chunk nodes by chunk size
indexing_start_at
=
time
.
perf_counter
()
...
...
@@ -526,14 +506,11 @@ class IndexingRunner:
)
# save vector index
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
:
index
.
add_texts
(
chunk_documents
)
if
vector_index
:
vector_index
.
add_texts
(
chunk_documents
)
# save keyword index
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
(
chunk_documents
)
keyword_table_index
.
add_texts
(
chunk_documents
)
document_ids
=
[
document
.
metadata
[
'doc_id'
]
for
document
in
chunk_documents
]
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
...
...
api/tasks/document_indexing_update_task.py
View file @
01baae87
...
...
@@ -54,7 +54,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
# delete from keyword index
if
index_node_ids
:
vector
_index
.
delete_by_ids
(
index_node_ids
)
kw
_index
.
delete_by_ids
(
index_node_ids
)
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment