Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
269a465f
Unverified
Commit
269a465f
authored
Sep 18, 2023
by
Jyong
Committed by
GitHub
Sep 18, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Feat/improve vector database logic (#1193)
Co-authored-by:
jyong
<
jyong@dify.ai
>
parent
60e0bbd7
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
463 additions
and
46 deletions
+463
-46
commands.py
api/commands.py
+124
-17
base.py
api/core/index/base.py
+8
-0
keyword_table_index.py
api/core/index/keyword_table_index/keyword_table_index.py
+32
-0
base.py
api/core/index/vector_index/base.py
+57
-1
milvus_vector_index.py
api/core/index/vector_index/milvus_vector_index.py
+13
-0
qdrant.py
api/core/index/vector_index/qdrant.py
+38
-2
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+58
-20
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+14
-0
dataset_retriever_tool.py
api/core/tool/dataset_retriever_tool.py
+4
-2
qdrant_vector_store.py
api/core/vector_store/qdrant_vector_store.py
+5
-0
6e2cfb077b04_add_dataset_collection_binding.py
...s/versions/6e2cfb077b04_add_dataset_collection_binding.py
+47
-0
dataset.py
api/models/dataset.py
+18
-0
dataset_service.py
api/services/dataset_service.py
+41
-3
hit_testing_service.py
api/services/hit_testing_service.py
+4
-1
No files found.
api/commands.py
View file @
269a465f
This diff is collapsed.
Click to expand it.
api/core/index/base.py
View file @
269a465f
...
...
@@ -16,6 +16,10 @@ class BaseIndex(ABC):
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
raise
NotImplementedError
@
abstractmethod
def
create_with_collection_name
(
self
,
texts
:
list
[
Document
],
collection_name
:
str
,
**
kwargs
)
->
BaseIndex
:
raise
NotImplementedError
@
abstractmethod
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
raise
NotImplementedError
...
...
@@ -28,6 +32,10 @@ class BaseIndex(ABC):
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
raise
NotImplementedError
@
abstractmethod
def
delete_by_group_id
(
self
,
group_id
:
str
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
delete_by_document_id
(
self
,
document_id
:
str
):
raise
NotImplementedError
...
...
api/core/index/keyword_table_index/keyword_table_index.py
View file @
269a465f
...
...
@@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex):
return
self
def
create_with_collection_name
(
self
,
texts
:
list
[
Document
],
collection_name
:
str
,
**
kwargs
)
->
BaseIndex
:
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
{}
for
text
in
texts
:
keywords
=
keyword_table_handler
.
extract_keywords
(
text
.
page_content
,
self
.
_config
.
max_keywords_per_chunk
)
self
.
_update_segment_keywords
(
self
.
dataset
.
id
,
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
keyword_table
=
self
.
_add_text_to_keyword_table
(
keyword_table
,
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
dataset
.
id
,
keyword_table
=
json
.
dumps
({
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
{}
}
},
cls
=
SetEncoder
)
)
db
.
session
.
add
(
dataset_keyword_table
)
db
.
session
.
commit
()
self
.
_save_dataset_keyword_table
(
keyword_table
)
return
self
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
keyword_table_handler
=
JiebaKeywordTableHandler
()
...
...
@@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
db
.
session
.
delete
(
dataset_keyword_table
)
db
.
session
.
commit
()
def
delete_by_group_id
(
self
,
group_id
:
str
)
->
None
:
dataset_keyword_table
=
self
.
dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
db
.
session
.
delete
(
dataset_keyword_table
)
db
.
session
.
commit
()
def
_save_dataset_keyword_table
(
self
,
keyword_table
):
keyword_table_dict
=
{
'__type__'
:
'keyword_table'
,
...
...
api/core/index/vector_index/base.py
View file @
269a465f
...
...
@@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
from
core.index.base
import
BaseIndex
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DocumentSegment
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetCollectionBinding
from
models.dataset
import
Document
as
DatasetDocument
...
...
@@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex):
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
def
delete_by_group_id
(
self
,
group_id
:
str
)
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
.
delete
()
def
delete
(
self
)
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
...
...
@@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex):
raise
e
logging
.
info
(
f
"Dataset {dataset.id} recreate successfully."
)
def
restore_dataset_in_one
(
self
,
dataset
:
Dataset
,
dataset_collection_binding
:
DatasetCollectionBinding
):
logging
.
info
(
f
"restore dataset in_one,_dataset {dataset.id}"
)
dataset_documents
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
DatasetDocument
.
dataset_id
==
dataset
.
id
,
DatasetDocument
.
indexing_status
==
'completed'
,
DatasetDocument
.
enabled
==
True
,
DatasetDocument
.
archived
==
False
,
)
.
all
()
documents
=
[]
for
dataset_document
in
dataset_documents
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
dataset_document
.
id
,
DocumentSegment
.
status
==
'completed'
,
DocumentSegment
.
enabled
==
True
)
.
all
()
for
segment
in
segments
:
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
documents
.
append
(
document
)
if
documents
:
try
:
self
.
create_with_collection_name
(
documents
,
dataset_collection_binding
.
collection_name
)
except
Exception
as
e
:
raise
e
logging
.
info
(
f
"Dataset {dataset.id} recreate successfully."
)
def
delete_original_collection
(
self
,
dataset
:
Dataset
,
dataset_collection_binding
:
DatasetCollectionBinding
):
logging
.
info
(
f
"delete original collection: {dataset.id}"
)
self
.
delete
()
dataset
.
collection_binding_id
=
dataset_collection_binding
.
id
db
.
session
.
add
(
dataset
)
db
.
session
.
commit
()
logging
.
info
(
f
"Dataset {dataset.id} recreate successfully."
)
api/core/index/vector_index/milvus_vector_index.py
View file @
269a465f
...
...
@@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
return
self
def
create_with_collection_name
(
self
,
texts
:
list
[
Document
],
collection_name
:
str
,
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
client
=
self
.
_client
,
index_name
=
collection_name
,
uuids
=
uuids
,
by_text
=
False
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
...
...
api/core/index/vector_index/qdrant.py
View file @
269a465f
...
...
@@ -28,6 +28,7 @@ from langchain.docstore.document import Document
from
langchain.embeddings.base
import
Embeddings
from
langchain.vectorstores
import
VectorStore
from
langchain.vectorstores.utils
import
maximal_marginal_relevance
from
qdrant_client.http.models
import
PayloadSchemaType
if
TYPE_CHECKING
:
from
qdrant_client
import
grpc
# noqa
...
...
@@ -84,6 +85,7 @@ class Qdrant(VectorStore):
CONTENT_KEY
=
"page_content"
METADATA_KEY
=
"metadata"
GROUP_KEY
=
"group_id"
VECTOR_NAME
=
None
def
__init__
(
...
...
@@ -93,9 +95,12 @@ class Qdrant(VectorStore):
embeddings
:
Optional
[
Embeddings
]
=
None
,
content_payload_key
:
str
=
CONTENT_KEY
,
metadata_payload_key
:
str
=
METADATA_KEY
,
group_payload_key
:
str
=
GROUP_KEY
,
group_id
:
str
=
None
,
distance_strategy
:
str
=
"COSINE"
,
vector_name
:
Optional
[
str
]
=
VECTOR_NAME
,
embedding_function
:
Optional
[
Callable
]
=
None
,
# deprecated
is_new_collection
:
bool
=
False
):
"""Initialize with necessary components."""
try
:
...
...
@@ -129,7 +134,10 @@ class Qdrant(VectorStore):
self
.
collection_name
=
collection_name
self
.
content_payload_key
=
content_payload_key
or
self
.
CONTENT_KEY
self
.
metadata_payload_key
=
metadata_payload_key
or
self
.
METADATA_KEY
self
.
group_payload_key
=
group_payload_key
or
self
.
GROUP_KEY
self
.
vector_name
=
vector_name
or
self
.
VECTOR_NAME
self
.
group_id
=
group_id
self
.
is_new_collection
=
is_new_collection
if
embedding_function
is
not
None
:
warnings
.
warn
(
...
...
@@ -170,6 +178,8 @@ class Qdrant(VectorStore):
batch_size:
How many vectors upload per-request.
Default: 64
group_id:
collection group
Returns:
List of ids from adding the texts into the vectorstore.
...
...
@@ -182,7 +192,11 @@ class Qdrant(VectorStore):
collection_name
=
self
.
collection_name
,
points
=
points
,
**
kwargs
)
added_ids
.
extend
(
batch_ids
)
# if is new collection, create payload index on group_id
if
self
.
is_new_collection
:
self
.
client
.
create_payload_index
(
self
.
collection_name
,
self
.
group_payload_key
,
field_schema
=
PayloadSchemaType
.
KEYWORD
,
field_type
=
PayloadSchemaType
.
KEYWORD
)
return
added_ids
@
sync_call_fallback
...
...
@@ -970,6 +984,8 @@ class Qdrant(VectorStore):
distance_func
:
str
=
"Cosine"
,
content_payload_key
:
str
=
CONTENT_KEY
,
metadata_payload_key
:
str
=
METADATA_KEY
,
group_payload_key
:
str
=
GROUP_KEY
,
group_id
:
str
=
None
,
vector_name
:
Optional
[
str
]
=
VECTOR_NAME
,
batch_size
:
int
=
64
,
shard_number
:
Optional
[
int
]
=
None
,
...
...
@@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
metadata_payload_key:
A payload key used to store the metadata of the document.
Default: "metadata"
group_payload_key:
A payload key used to store the content of the document.
Default: "group_id"
group_id:
collection group id
vector_name:
Name of the vector to be used internally in Qdrant.
Default: None
...
...
@@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
distance_func
,
content_payload_key
,
metadata_payload_key
,
group_payload_key
,
group_id
,
vector_name
,
shard_number
,
replication_factor
,
...
...
@@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
distance_func
:
str
=
"Cosine"
,
content_payload_key
:
str
=
CONTENT_KEY
,
metadata_payload_key
:
str
=
METADATA_KEY
,
group_payload_key
:
str
=
GROUP_KEY
,
group_id
:
str
=
None
,
vector_name
:
Optional
[
str
]
=
VECTOR_NAME
,
shard_number
:
Optional
[
int
]
=
None
,
replication_factor
:
Optional
[
int
]
=
None
,
...
...
@@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
vector_size
=
len
(
partial_embeddings
[
0
])
collection_name
=
collection_name
or
uuid
.
uuid4
()
.
hex
distance_func
=
distance_func
.
upper
()
is_new_collection
=
False
client
=
qdrant_client
.
QdrantClient
(
location
=
location
,
url
=
url
,
...
...
@@ -1454,6 +1480,7 @@ class Qdrant(VectorStore):
init_from
=
init_from
,
timeout
=
timeout
,
# type: ignore[arg-type]
)
is_new_collection
=
True
qdrant
=
cls
(
client
=
client
,
collection_name
=
collection_name
,
...
...
@@ -1462,6 +1489,9 @@ class Qdrant(VectorStore):
metadata_payload_key
=
metadata_payload_key
,
distance_strategy
=
distance_func
,
vector_name
=
vector_name
,
group_id
=
group_id
,
group_payload_key
=
group_payload_key
,
is_new_collection
=
is_new_collection
)
return
qdrant
...
...
@@ -1516,6 +1546,8 @@ class Qdrant(VectorStore):
metadatas
:
Optional
[
List
[
dict
]],
content_payload_key
:
str
,
metadata_payload_key
:
str
,
group_id
:
str
,
group_payload_key
:
str
)
->
List
[
dict
]:
payloads
=
[]
for
i
,
text
in
enumerate
(
texts
):
...
...
@@ -1529,6 +1561,7 @@ class Qdrant(VectorStore):
{
content_payload_key
:
text
,
metadata_payload_key
:
metadata
,
group_payload_key
:
group_id
}
)
...
...
@@ -1578,7 +1611,7 @@ class Qdrant(VectorStore):
else
:
out
.
append
(
rest
.
FieldCondition
(
key
=
f
"{self.metadata_payload_key}.{key}"
,
key
=
key
,
match
=
rest
.
MatchValue
(
value
=
value
),
)
)
...
...
@@ -1654,6 +1687,7 @@ class Qdrant(VectorStore):
metadatas
:
Optional
[
List
[
dict
]]
=
None
,
ids
:
Optional
[
Sequence
[
str
]]
=
None
,
batch_size
:
int
=
64
,
group_id
:
Optional
[
str
]
=
None
,
)
->
Generator
[
Tuple
[
List
[
str
],
List
[
rest
.
PointStruct
]],
None
,
None
]:
from
qdrant_client.http
import
models
as
rest
...
...
@@ -1684,6 +1718,8 @@ class Qdrant(VectorStore):
batch_metadatas
,
self
.
content_payload_key
,
self
.
metadata_payload_key
,
self
.
group_id
,
self
.
group_payload_key
),
)
]
...
...
api/core/index/vector_index/qdrant_vector_index.py
View file @
269a465f
...
...
@@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
from
qdrant_client.http.models
import
HnswConfigDiff
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.qdrant_vector_store
import
QdrantVectorStore
from
models.dataset
import
Dataset
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DatasetCollectionBinding
class
QdrantConfig
(
BaseModel
):
endpoint
:
str
api_key
:
Optional
[
str
]
root_path
:
Optional
[
str
]
def
to_qdrant_params
(
self
):
if
self
.
endpoint
and
self
.
endpoint
.
startswith
(
'path:'
):
path
=
self
.
endpoint
.
replace
(
'path:'
,
''
)
...
...
@@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return
'qdrant'
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
class_prefix
+=
'_Node'
return
class_prefix
if
dataset
.
collection_binding_id
:
dataset_collection_binding
=
db
.
session
.
query
(
DatasetCollectionBinding
)
.
\
filter
(
DatasetCollectionBinding
.
id
==
dataset
.
collection_binding_id
)
.
\
one_or_none
()
if
dataset_collection_binding
:
return
dataset_collection_binding
.
collection_name
else
:
raise
ValueError
(
'Dataset Collection Bindings is not exist!'
)
else
:
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
return
class_prefix
dataset_id
=
dataset
.
id
return
"Vector_index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
+
'_Node'
dataset_id
=
dataset
.
id
return
"Vector_index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
+
'_Node'
def
to_index_struct
(
self
)
->
dict
:
return
{
...
...
@@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
collection_name
=
self
.
get_index_name
(
self
.
dataset
),
ids
=
uuids
,
content_payload_key
=
'page_content'
,
group_id
=
self
.
dataset
.
id
,
group_payload_key
=
'group_id'
,
hnsw_config
=
HnswConfigDiff
(
m
=
0
,
payload_m
=
16
,
ef_construct
=
100
,
full_scan_threshold
=
10000
,
max_indexing_threads
=
0
,
on_disk
=
False
),
**
self
.
_client_config
.
to_qdrant_params
()
)
return
self
def
create_with_collection_name
(
self
,
texts
:
list
[
Document
],
collection_name
:
str
,
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
QdrantVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
collection_name
=
collection_name
,
ids
=
uuids
,
content_payload_key
=
'page_content'
,
group_id
=
self
.
dataset
.
id
,
group_payload_key
=
'group_id'
,
hnsw_config
=
HnswConfigDiff
(
m
=
0
,
payload_m
=
16
,
ef_construct
=
100
,
full_scan_threshold
=
10000
,
max_indexing_threads
=
0
,
on_disk
=
False
),
**
self
.
_client_config
.
to_qdrant_params
()
)
...
...
@@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
if
self
.
_vector_store
:
return
self
.
_vector_store
attributes
=
[
'doc_id'
,
'dataset_id'
,
'document_id'
]
if
self
.
_is_origin
():
attributes
=
[
'doc_id'
]
client
=
qdrant_client
.
QdrantClient
(
**
self
.
_client_config
.
to_qdrant_params
()
)
...
...
@@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
client
=
client
,
collection_name
=
self
.
get_index_name
(
self
.
dataset
),
embeddings
=
self
.
_embeddings
,
content_payload_key
=
'page_content'
content_payload_key
=
'page_content'
,
group_id
=
self
.
dataset
.
id
,
group_payload_key
=
'group_id'
)
def
_get_vector_store_class
(
self
)
->
type
:
return
QdrantVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
...
...
@@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex):
))
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
...
...
@@ -132,6 +154,22 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def
delete_by_group_id
(
self
,
group_id
:
str
)
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
from
qdrant_client.http
import
models
vector_store
.
del_texts
(
models
.
Filter
(
must
=
[
models
.
FieldCondition
(
key
=
"group_id"
,
match
=
models
.
MatchValue
(
value
=
group_id
),
),
],
))
def
_is_origin
(
self
):
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
...
...
api/core/index/vector_index/weaviate_vector_index.py
View file @
269a465f
...
...
@@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
return
self
def
create_with_collection_name
(
self
,
texts
:
list
[
Document
],
collection_name
:
str
,
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
dataset
),
uuids
=
uuids
,
by_text
=
False
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
...
...
api/core/tool/dataset_retriever_tool.py
View file @
269a465f
...
...
@@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool):
return_resource
:
str
retriever_from
:
str
@
classmethod
def
from_dataset
(
cls
,
dataset
:
Dataset
,
**
kwargs
):
description
=
dataset
.
description
...
...
@@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool):
query
,
search_type
=
'similarity_score_threshold'
,
search_kwargs
=
{
'k'
:
self
.
k
'k'
:
self
.
k
,
'filter'
:
{
'group_id'
:
[
dataset
.
id
]
}
}
)
else
:
...
...
api/core/vector_store/qdrant_vector_store.py
View file @
269a465f
...
...
@@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
self
.
client
.
delete_collection
(
collection_name
=
self
.
collection_name
)
def
delete_group
(
self
):
self
.
_reload_if_needed
()
self
.
client
.
delete_collection
(
collection_name
=
self
.
collection_name
)
@
classmethod
def
_document_from_scored_point
(
cls
,
...
...
api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py
0 → 100644
View file @
269a465f
"""add_dataset_collection_binding
Revision ID: 6e2cfb077b04
Revises: 77e83833755c
Create Date: 2023-09-13 22:16:48.027810
"""
from
alembic
import
op
import
sqlalchemy
as
sa
from
sqlalchemy.dialects
import
postgresql
# revision identifiers, used by Alembic.
revision
=
'6e2cfb077b04'
down_revision
=
'77e83833755c'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
op
.
create_table
(
'dataset_collection_bindings'
,
sa
.
Column
(
'id'
,
postgresql
.
UUID
(),
server_default
=
sa
.
text
(
'uuid_generate_v4()'
),
nullable
=
False
),
sa
.
Column
(
'provider_name'
,
sa
.
String
(
length
=
40
),
nullable
=
False
),
sa
.
Column
(
'model_name'
,
sa
.
String
(
length
=
40
),
nullable
=
False
),
sa
.
Column
(
'collection_name'
,
sa
.
String
(
length
=
64
),
nullable
=
False
),
sa
.
Column
(
'created_at'
,
sa
.
DateTime
(),
server_default
=
sa
.
text
(
'CURRENT_TIMESTAMP(0)'
),
nullable
=
False
),
sa
.
PrimaryKeyConstraint
(
'id'
,
name
=
'dataset_collection_bindings_pkey'
)
)
with
op
.
batch_alter_table
(
'dataset_collection_bindings'
,
schema
=
None
)
as
batch_op
:
batch_op
.
create_index
(
'provider_model_name_idx'
,
[
'provider_name'
,
'model_name'
],
unique
=
False
)
with
op
.
batch_alter_table
(
'datasets'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'collection_binding_id'
,
postgresql
.
UUID
(),
nullable
=
True
))
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'datasets'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'collection_binding_id'
)
with
op
.
batch_alter_table
(
'dataset_collection_bindings'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_index
(
'provider_model_name_idx'
)
op
.
drop_table
(
'dataset_collection_bindings'
)
# ### end Alembic commands ###
api/models/dataset.py
View file @
269a465f
...
...
@@ -38,6 +38,8 @@ class Dataset(db.Model):
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
embedding_model
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
True
)
embedding_model_provider
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
True
)
collection_binding_id
=
db
.
Column
(
UUID
,
nullable
=
True
)
@
property
def
dataset_keyword_table
(
self
):
...
...
@@ -445,3 +447,19 @@ class Embedding(db.Model):
def
get_embedding
(
self
)
->
list
[
float
]:
return
pickle
.
loads
(
self
.
embedding
)
class
DatasetCollectionBinding
(
db
.
Model
):
__tablename__
=
'dataset_collection_bindings'
__table_args__
=
(
db
.
PrimaryKeyConstraint
(
'id'
,
name
=
'dataset_collection_bindings_pkey'
),
db
.
Index
(
'provider_model_name_idx'
,
'provider_name'
,
'model_name'
)
)
id
=
db
.
Column
(
UUID
,
primary_key
=
True
,
server_default
=
db
.
text
(
'uuid_generate_v4()'
))
provider_name
=
db
.
Column
(
db
.
String
(
40
),
nullable
=
False
)
model_name
=
db
.
Column
(
db
.
String
(
40
),
nullable
=
False
)
collection_name
=
db
.
Column
(
db
.
String
(
64
),
nullable
=
False
)
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
api/services/dataset_service.py
View file @
269a465f
...
...
@@ -20,7 +20,8 @@ from events.document_event import document_was_deleted
from
extensions.ext_database
import
db
from
libs
import
helper
from
models.account
import
Account
from
models.dataset
import
Dataset
,
Document
,
DatasetQuery
,
DatasetProcessRule
,
AppDatasetJoin
,
DocumentSegment
from
models.dataset
import
Dataset
,
Document
,
DatasetQuery
,
DatasetProcessRule
,
AppDatasetJoin
,
DocumentSegment
,
\
DatasetCollectionBinding
from
models.model
import
UploadFile
from
models.source
import
DataSourceBinding
from
services.errors.account
import
NoPermissionError
...
...
@@ -147,6 +148,7 @@ class DatasetService:
action
=
'remove'
filtered_data
[
'embedding_model'
]
=
None
filtered_data
[
'embedding_model_provider'
]
=
None
filtered_data
[
'collection_binding_id'
]
=
None
elif
data
[
'indexing_technique'
]
==
'high_quality'
:
action
=
'add'
# get embedding model setting
...
...
@@ -156,6 +158,11 @@ class DatasetService:
)
filtered_data
[
'embedding_model'
]
=
embedding_model
.
name
filtered_data
[
'embedding_model_provider'
]
=
embedding_model
.
model_provider
.
provider_name
dataset_collection_binding
=
DatasetCollectionBindingService
.
get_dataset_collection_binding
(
embedding_model
.
model_provider
.
provider_name
,
embedding_model
.
name
)
filtered_data
[
'collection_binding_id'
]
=
dataset_collection_binding
.
id
except
LLMBadRequestError
:
raise
ValueError
(
f
"No Embedding Model available. Please configure a valid provider "
...
...
@@ -464,7 +471,11 @@ class DocumentService:
)
dataset
.
embedding_model
=
embedding_model
.
name
dataset
.
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
dataset_collection_binding
=
DatasetCollectionBindingService
.
get_dataset_collection_binding
(
embedding_model
.
model_provider
.
provider_name
,
embedding_model
.
name
)
dataset
.
collection_binding_id
=
dataset_collection_binding
.
id
documents
=
[]
batch
=
time
.
strftime
(
'
%
Y
%
m
%
d
%
H
%
M
%
S'
)
+
str
(
random
.
randint
(
100000
,
999999
))
...
...
@@ -720,10 +731,16 @@ class DocumentService:
if
total_count
>
tenant_document_count
:
raise
ValueError
(
f
"All your documents have overed limit {tenant_document_count}."
)
embedding_model
=
None
dataset_collection_binding_id
=
None
if
document_data
[
'indexing_technique'
]
==
'high_quality'
:
embedding_model
=
ModelFactory
.
get_embedding_model
(
tenant_id
=
tenant_id
)
dataset_collection_binding
=
DatasetCollectionBindingService
.
get_dataset_collection_binding
(
embedding_model
.
model_provider
.
provider_name
,
embedding_model
.
name
)
dataset_collection_binding_id
=
dataset_collection_binding
.
id
# save dataset
dataset
=
Dataset
(
tenant_id
=
tenant_id
,
...
...
@@ -732,7 +749,8 @@ class DocumentService:
indexing_technique
=
document_data
[
"indexing_technique"
],
created_by
=
account
.
id
,
embedding_model
=
embedding_model
.
name
if
embedding_model
else
None
,
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
if
embedding_model
else
None
embedding_model_provider
=
embedding_model
.
model_provider
.
provider_name
if
embedding_model
else
None
,
collection_binding_id
=
dataset_collection_binding_id
)
db
.
session
.
add
(
dataset
)
...
...
@@ -1069,3 +1087,23 @@ class SegmentService:
delete_segment_from_index_task
.
delay
(
segment
.
id
,
segment
.
index_node_id
,
dataset
.
id
,
document
.
id
)
db
.
session
.
delete
(
segment
)
db
.
session
.
commit
()
class
DatasetCollectionBindingService
:
@
classmethod
def
get_dataset_collection_binding
(
cls
,
provider_name
:
str
,
model_name
:
str
)
->
DatasetCollectionBinding
:
dataset_collection_binding
=
db
.
session
.
query
(
DatasetCollectionBinding
)
.
\
filter
(
DatasetCollectionBinding
.
provider_name
==
provider_name
,
DatasetCollectionBinding
.
model_name
==
model_name
)
.
\
order_by
(
DatasetCollectionBinding
.
created_at
)
.
\
first
()
if
not
dataset_collection_binding
:
dataset_collection_binding
=
DatasetCollectionBinding
(
provider_name
=
provider_name
,
model_name
=
model_name
,
collection_name
=
"Vector_index_"
+
str
(
uuid
.
uuid4
())
.
replace
(
"-"
,
"_"
)
+
'_Node'
)
db
.
session
.
add
(
dataset_collection_binding
)
db
.
session
.
flush
()
return
dataset_collection_binding
api/services/hit_testing_service.py
View file @
269a465f
...
...
@@ -47,7 +47,10 @@ class HitTestingService:
query
,
search_type
=
'similarity_score_threshold'
,
search_kwargs
=
{
'k'
:
10
'k'
:
10
,
'filter'
:
{
'group_id'
:
[
dataset
.
id
]
}
}
)
end
=
time
.
perf_counter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment