Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
0578c1b6
Commit
0578c1b6
authored
Jun 19, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: replace using new index builder
parent
fb5118f0
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
27 changed files
with
379 additions
and
344 deletions
+379
-344
llm_router_chain.py
api/core/chain/llm_router_chain.py
+4
-2
sensitive_word_avoidance_chain.py
api/core/chain/sensitive_word_avoidance_chain.py
+7
-2
tool_chain.py
api/core/chain/tool_chain.py
+12
-3
base.py
api/core/index/base.py
+2
-2
index.py
api/core/index/index.py
+41
-0
keyword_table_index.py
api/core/index/keyword_table_index/keyword_table_index.py
+2
-2
base.py
api/core/index/vector_index/base.py
+4
-2
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+1
-1
vector_index.py
api/core/index/vector_index/vector_index.py
+3
-3
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+1
-1
indexing_runner.py
api/core/indexing_runner.py
+132
-84
streamable_azure_chat_open_ai.py
api/core/llm/streamable_azure_chat_open_ai.py
+11
-2
streamable_chat_open_ai.py
api/core/llm/streamable_chat_open_ai.py
+11
-2
prompts.py
api/core/prompt/prompts.py
+0
-19
app_model_config_service.py
api/services/app_model_config_service.py
+0
-1
add_document_to_index_task.py
api/tasks/add_document_to_index_task.py
+33
-48
add_segment_to_index_task.py
api/tasks/add_segment_to_index_task.py
+16
-32
clean_dataset_task.py
api/tasks/clean_dataset_task.py
+7
-8
clean_document_task.py
api/tasks/clean_document_task.py
+7
-6
clean_notion_document_task.py
api/tasks/clean_notion_document_task.py
+7
-6
deal_dataset_vector_index_task.py
api/tasks/deal_dataset_vector_index_task.py
+35
-40
document_indexing_sync_task.py
api/tasks/document_indexing_sync_task.py
+11
-20
document_indexing_task.py
api/tasks/document_indexing_task.py
+6
-16
document_indexing_update_task.py
api/tasks/document_indexing_update_task.py
+11
-20
recover_document_indexing_task.py
api/tasks/recover_document_indexing_task.py
+4
-9
remove_document_from_index_task.py
api/tasks/remove_document_from_index_task.py
+5
-6
remove_segment_from_index_task.py
api/tasks/remove_segment_from_index_task.py
+6
-7
No files found.
api/core/chain/llm_router_chain.py
View file @
0578c1b6
...
@@ -4,6 +4,7 @@ from __future__ import annotations
...
@@ -4,6 +4,7 @@ from __future__ import annotations
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
langchain.base_language
import
BaseLanguageModel
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
pydantic
import
root_validator
from
pydantic
import
root_validator
...
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
...
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
raise
ValueError
raise
ValueError
def
_call
(
def
_call
(
self
,
self
,
inputs
:
Dict
[
str
,
Any
]
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
output
=
cast
(
output
=
cast
(
Dict
[
str
,
Any
],
Dict
[
str
,
Any
],
...
...
api/core/chain/sensitive_word_avoidance_chain.py
View file @
0578c1b6
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
...
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
...
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return
self
.
canned_response
return
self
.
canned_response
return
text
return
text
def
_call
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
text
=
inputs
[
self
.
input_key
]
text
=
inputs
[
self
.
input_key
]
output
=
self
.
_check_sensitive_word
(
text
)
output
=
self
.
_check_sensitive_word
(
text
)
return
{
self
.
output_key
:
output
}
return
{
self
.
output_key
:
output
}
api/core/chain/tool_chain.py
View file @
0578c1b6
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
,
AsyncCallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
langchain.tools
import
BaseTool
from
langchain.tools
import
BaseTool
...
@@ -30,12 +31,20 @@ class ToolChain(Chain):
...
@@ -30,12 +31,20 @@ class ToolChain(Chain):
"""
"""
return
[
self
.
output_key
]
return
[
self
.
output_key
]
def
_call
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
input
=
inputs
[
self
.
input_key
]
input
=
inputs
[
self
.
input_key
]
output
=
self
.
tool
.
run
(
input
,
self
.
verbose
)
output
=
self
.
tool
.
run
(
input
,
self
.
verbose
)
return
{
self
.
output_key
:
output
}
return
{
self
.
output_key
:
output
}
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
AsyncCallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Run the logic of this chain and return the output."""
"""Run the logic of this chain and return the output."""
input
=
inputs
[
self
.
input_key
]
input
=
inputs
[
self
.
input_key
]
output
=
await
self
.
tool
.
arun
(
input
,
self
.
verbose
)
output
=
await
self
.
tool
.
arun
(
input
,
self
.
verbose
)
...
...
api/core/index/base.py
View file @
0578c1b6
...
@@ -7,11 +7,11 @@ from langchain.schema import Document, BaseRetriever
...
@@ -7,11 +7,11 @@ from langchain.schema import Document, BaseRetriever
class
BaseIndex
(
ABC
):
class
BaseIndex
(
ABC
):
@
abstractmethod
@
abstractmethod
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
def
create
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
)
->
BaseIndex
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
add_texts
(
self
,
texts
:
list
[
Document
]):
def
add_texts
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
...
...
api/core/index/index.py
0 → 100644
View file @
0578c1b6
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
class
IndexBuilder
:
@
classmethod
def
get_index
(
cls
,
dataset
:
Dataset
,
high_quality
:
str
):
if
high_quality
==
"high_quality"
:
if
dataset
.
indexing_technique
!=
'high_quality'
:
return
None
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
return
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
elif
high_quality
==
"economy"
:
return
KeywordTableIndex
(
dataset
=
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
10
)
)
else
:
raise
ValueError
(
'Unknown indexing technique'
)
\ No newline at end of file
api/core/index/keyword_table_index/keyword_table_index.py
View file @
0578c1b6
...
@@ -20,7 +20,7 @@ class KeywordTableIndex(BaseIndex):
...
@@ -20,7 +20,7 @@ class KeywordTableIndex(BaseIndex):
self
.
_dataset
=
dataset
self
.
_dataset
=
dataset
self
.
_config
=
config
self
.
_config
=
config
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
def
create
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
)
->
BaseIndex
:
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
{}
keyword_table
=
{}
for
text
in
texts
:
for
text
in
texts
:
...
@@ -37,7 +37,7 @@ class KeywordTableIndex(BaseIndex):
...
@@ -37,7 +37,7 @@ class KeywordTableIndex(BaseIndex):
return
self
return
self
def
add_texts
(
self
,
texts
:
list
[
Document
]):
def
add_texts
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
):
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
self
.
_get_dataset_keyword_table
()
keyword_table
=
self
.
_get_dataset_keyword_table
()
...
...
api/core/index/vector_index/base.py
View file @
0578c1b6
...
@@ -67,11 +67,13 @@ class BaseVectorIndex(BaseIndex):
...
@@ -67,11 +67,13 @@ class BaseVectorIndex(BaseIndex):
return
vector_store
.
as_retriever
(
**
kwargs
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
def
add_texts
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
if
kwargs
.
get
(
'duplicate_check'
,
False
):
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
...
...
api/core/index/vector_index/qdrant_vector_index.py
View file @
0578c1b6
...
@@ -53,7 +53,7 @@ class QdrantVectorIndex(BaseVectorIndex):
...
@@ -53,7 +53,7 @@ class QdrantVectorIndex(BaseVectorIndex):
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
())}
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
())}
}
}
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
def
create
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
QdrantVectorStore
.
from_documents
(
self
.
_vector_store
=
QdrantVectorStore
.
from_documents
(
texts
,
texts
,
...
...
api/core/index/vector_index/vector_index.py
View file @
0578c1b6
...
@@ -51,14 +51,14 @@ class VectorIndex:
...
@@ -51,14 +51,14 @@ class VectorIndex:
else
:
else
:
raise
ValueError
(
f
"Vector store {config.get('VECTOR_STORE')} is not supported."
)
raise
ValueError
(
f
"Vector store {config.get('VECTOR_STORE')} is not supported."
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
def
add_texts
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
):
if
not
self
.
_dataset
.
index_struct_dict
:
if
not
self
.
_dataset
.
index_struct_dict
:
self
.
_vector_index
.
create
(
texts
)
self
.
_vector_index
.
create
(
texts
,
**
kwargs
)
self
.
_dataset
.
index_struct
=
json
.
dumps
(
self
.
_vector_index
.
to_index_struct
())
self
.
_dataset
.
index_struct
=
json
.
dumps
(
self
.
_vector_index
.
to_index_struct
())
db
.
session
.
commit
()
db
.
session
.
commit
()
return
return
self
.
_vector_index
.
add_texts
(
texts
)
self
.
_vector_index
.
add_texts
(
texts
,
**
kwargs
)
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
):
if
self
.
_vector_index
is
not
None
:
if
self
.
_vector_index
is
not
None
:
...
...
api/core/index/vector_index/weaviate_vector_index.py
View file @
0578c1b6
...
@@ -62,7 +62,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
...
@@ -62,7 +62,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
())}
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
())}
}
}
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
def
create
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
texts
,
texts
,
...
...
api/core/indexing_runner.py
View file @
0578c1b6
import
datetime
import
datetime
import
json
import
json
import
logging
import
re
import
re
import
time
import
time
import
uuid
import
uuid
...
@@ -15,8 +16,10 @@ from core.data_loader.file_extractor import FileExtractor
...
@@ -15,8 +16,10 @@ from core.data_loader.file_extractor import FileExtractor
from
core.data_loader.loader.notion
import
NotionLoader
from
core.data_loader.loader.notion
import
NotionLoader
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.index
import
IndexBuilder
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.error
import
ProviderTokenNotInitError
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.llm_builder
import
LLMBuilder
from
core.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.llm.token_calculator
import
TokenCalculator
from
core.llm.token_calculator
import
TokenCalculator
...
@@ -39,6 +42,58 @@ class IndexingRunner:
...
@@ -39,6 +42,58 @@ class IndexingRunner:
def
run
(
self
,
dataset_documents
:
List
[
DatasetDocument
]):
def
run
(
self
,
dataset_documents
:
List
[
DatasetDocument
]):
"""Run the indexing process."""
"""Run the indexing process."""
for
dataset_document
in
dataset_documents
:
for
dataset_document
in
dataset_documents
:
try
:
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
dataset_document
.
dataset_process_rule_id
)
.
\
first
()
# get splitter
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
documents
=
self
.
_step_split
(
text_docs
=
text_docs
,
splitter
=
splitter
,
dataset
=
dataset
,
dataset_document
=
dataset_document
,
processing_rule
=
processing_rule
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
dataset_document
=
dataset_document
,
documents
=
documents
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
run_in_splitting_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is splitting."""
try
:
# get dataset
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
id
=
dataset_document
.
dataset_id
...
@@ -47,6 +102,15 @@ class IndexingRunner:
...
@@ -47,6 +102,15 @@ class IndexingRunner:
if
not
dataset
:
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
)
.
all
()
db
.
session
.
delete
(
document_segments
)
db
.
session
.
commit
()
# load file
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
text_docs
=
self
.
_load_data
(
dataset_document
)
...
@@ -73,92 +137,73 @@ class IndexingRunner:
...
@@ -73,92 +137,73 @@ class IndexingRunner:
dataset_document
=
dataset_document
,
dataset_document
=
dataset_document
,
documents
=
documents
documents
=
documents
)
)
except
DocumentIsPausedException
:
def
run_in_splitting_status
(
self
,
dataset_document
:
DatasetDocument
):
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
"""Run the indexing process when the index_status is splitting."""
except
ProviderTokenNotInitError
as
e
:
# get dataset
dataset_document
.
indexing_status
=
'error'
dataset
=
Dataset
.
query
.
filter_by
(
dataset_document
.
error
=
str
(
e
.
description
)
id
=
dataset_document
.
dataset_id
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
)
.
first
()
db
.
session
.
commit
()
except
Exception
as
e
:
if
not
dataset
:
logging
.
exception
(
"consume document failed"
)
raise
ValueError
(
"no dataset found"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
# get exist document_segment list and delete
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
document_segments
=
DocumentSegment
.
query
.
filter_by
(
db
.
session
.
commit
()
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
)
.
all
()
db
.
session
.
delete
(
document_segments
)
db
.
session
.
commit
()
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
dataset_document
.
dataset_process_rule_id
)
.
\
first
()
# get splitter
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
documents
=
self
.
_step_split
(
text_docs
=
text_docs
,
splitter
=
splitter
,
dataset
=
dataset
,
dataset_document
=
dataset_document
,
processing_rule
=
processing_rule
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
dataset_document
=
dataset_document
,
documents
=
documents
)
def
run_in_indexing_status
(
self
,
dataset_document
:
DatasetDocument
):
def
run_in_indexing_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is indexing."""
"""Run the indexing process when the index_status is indexing."""
# get dataset
try
:
dataset
=
Dataset
.
query
.
filter_by
(
# get dataset
id
=
dataset_document
.
dataset_id
dataset
=
Dataset
.
query
.
filter_by
(
)
.
first
()
id
=
dataset_document
.
dataset_id
)
.
first
()
if
not
dataset
:
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
document_id
=
dataset_document
.
id
)
.
all
()
)
.
all
()
documents
=
[]
documents
=
[]
if
document_segments
:
if
document_segments
:
for
document_segment
in
document_segments
:
for
document_segment
in
document_segments
:
# transform segment to node
# transform segment to node
if
document_segment
.
status
!=
"completed"
:
if
document_segment
.
status
!=
"completed"
:
document
=
Document
(
document
=
Document
(
page_content
=
document_segment
.
content
,
page_content
=
document_segment
.
content
,
metadata
=
{
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
}
)
)
documents
.
append
(
document
)
documents
.
append
(
document
)
# build index
# build index
self
.
_build_index
(
self
.
_build_index
(
dataset
=
dataset
,
dataset
=
dataset
,
dataset_document
=
dataset_document
,
dataset_document
=
dataset_document
,
documents
=
documents
documents
=
documents
)
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
file_indexing_estimate
(
self
,
file_details
:
List
[
UploadFile
],
tmp_processing_rule
:
dict
)
->
dict
:
def
file_indexing_estimate
(
self
,
file_details
:
List
[
UploadFile
],
tmp_processing_rule
:
dict
)
->
dict
:
"""
"""
...
@@ -481,11 +526,14 @@ class IndexingRunner:
...
@@ -481,11 +526,14 @@ class IndexingRunner:
)
)
# save vector index
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
vector_index
.
add_texts
(
chunk_documents
)
if
index
:
index
.
add_texts
(
chunk_documents
)
# save keyword index
# save keyword index
keyword_table_index
.
add_texts
(
chunk_documents
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
(
chunk_documents
)
document_ids
=
[
document
.
metadata
[
'doc_id'
]
for
document
in
chunk_documents
]
document_ids
=
[
document
.
metadata
[
'doc_id'
]
for
document
in
chunk_documents
]
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
...
...
api/core/llm/streamable_azure_chat_open_ai.py
View file @
0578c1b6
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
AsyncCallbackManagerForLLMRun
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.chat_models
import
AzureChatOpenAI
from
langchain.chat_models
import
AzureChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
from
typing
import
Optional
,
List
,
Dict
,
Any
...
@@ -69,7 +70,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
...
@@ -69,7 +70,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return
message_tokens
return
message_tokens
def
_generate
(
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
)
->
ChatResult
:
self
.
callback_manager
.
on_llm_start
(
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
...
@@ -87,7 +92,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
...
@@ -87,7 +92,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return
chat_result
return
chat_result
async
def
_agenerate
(
async
def
_agenerate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
)
->
ChatResult
:
if
self
.
callback_manager
.
is_async
:
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_start
(
await
self
.
callback_manager
.
on_llm_start
(
...
...
api/core/llm/streamable_chat_open_ai.py
View file @
0578c1b6
import
os
import
os
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.chat_models
import
ChatOpenAI
from
langchain.chat_models
import
ChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
from
typing
import
Optional
,
List
,
Dict
,
Any
...
@@ -71,7 +72,11 @@ class StreamableChatOpenAI(ChatOpenAI):
...
@@ -71,7 +72,11 @@ class StreamableChatOpenAI(ChatOpenAI):
return
message_tokens
return
message_tokens
def
_generate
(
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
)
->
ChatResult
:
self
.
callback_manager
.
on_llm_start
(
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
...
@@ -88,7 +93,11 @@ class StreamableChatOpenAI(ChatOpenAI):
...
@@ -88,7 +93,11 @@ class StreamableChatOpenAI(ChatOpenAI):
return
chat_result
return
chat_result
async
def
_agenerate
(
async
def
_agenerate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
)
->
ChatResult
:
if
self
.
callback_manager
.
is_async
:
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_start
(
await
self
.
callback_manager
.
on_llm_start
(
...
...
api/core/prompt/prompts.py
View file @
0578c1b6
from
llama_index
import
QueryKeywordExtractPrompt
CONVERSATION_TITLE_PROMPT
=
(
CONVERSATION_TITLE_PROMPT
=
(
"Human:{query}
\n
-----
\n
"
"Human:{query}
\n
-----
\n
"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.
\n
"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.
\n
"
...
@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
...
@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[
\"
question1
\"
,
\"
question2
\"
,
\"
question3
\"
]
\n
"
"[
\"
question1
\"
,
\"
question2
\"
,
\"
question3
\"
]
\n
"
)
)
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
=
(
"A question is provided below. Given the question, extract up to {max_keywords} "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords."
"I am not sure which language the following question is in. "
"If the user asked the question in Chinese, please return the keywords in Chinese. "
"If the user asked the question in English, please return the keywords in English.
\n
"
"---------------------
\n
"
"{question}
\n
"
"---------------------
\n
"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'
\n
"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE
=
QueryKeywordExtractPrompt
(
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
)
RULE_CONFIG_GENERATE_TEMPLATE
=
"""Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select
\
RULE_CONFIG_GENERATE_TEMPLATE
=
"""Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select
\
the model prompt that best suits the input.
the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement.
You will be provided with the prompt, variables, and an opening statement.
...
...
api/services/app_model_config_service.py
View file @
0578c1b6
...
@@ -4,7 +4,6 @@ import uuid
...
@@ -4,7 +4,6 @@ import uuid
from
core.constant
import
llm_constant
from
core.constant
import
llm_constant
from
models.account
import
Account
from
models.account
import
Account
from
services.dataset_service
import
DatasetService
from
services.dataset_service
import
DatasetService
from
services.errors.account
import
NoPermissionError
class
AppModelConfigService
:
class
AppModelConfigService
:
...
...
api/tasks/add_document_to_index_task.py
View file @
0578c1b6
...
@@ -4,96 +4,81 @@ import time
...
@@ -4,96 +4,81 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
llama_index.data_structs
import
Node
from
langchain.schema
import
Document
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
,
Document
from
models.dataset
import
DocumentSegment
from
models.dataset
import
Document
as
DatasetDocument
@
shared_task
@
shared_task
def
add_document_to_index_task
(
document_id
:
str
):
def
add_document_to_index_task
(
d
ataset_d
ocument_id
:
str
):
"""
"""
Async Add document to index
Async Add document to index
:param document_id:
:param document_id:
Usage: add_document_to_index.delay(document_id)
Usage: add_document_to_index.delay(document_id)
"""
"""
logging
.
info
(
click
.
style
(
'Start add document to index: {}'
.
format
(
document_id
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Start add document to index: {}'
.
format
(
d
ataset_d
ocument_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
start_at
=
time
.
perf_counter
()
d
ocument
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
)
.
first
()
d
ataset_document
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
DatasetDocument
.
id
==
dataset_
document_id
)
.
first
()
if
not
document
:
if
not
d
ataset_d
ocument
:
raise
NotFound
(
'Document not found'
)
raise
NotFound
(
'Document not found'
)
if
document
.
indexing_status
!=
'completed'
:
if
d
ataset_d
ocument
.
indexing_status
!=
'completed'
:
return
return
indexing_cache_key
=
'document_{}_indexing'
.
format
(
document
.
id
)
indexing_cache_key
=
'document_{}_indexing'
.
format
(
d
ataset_d
ocument
.
id
)
try
:
try
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
document_id
==
d
ataset_d
ocument
.
id
,
DocumentSegment
.
enabled
==
True
DocumentSegment
.
enabled
==
True
)
\
)
\
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
nodes
=
[]
documents
=
[]
previous_node
=
None
for
segment
in
segments
:
for
segment
in
segments
:
relationships
=
{
document
=
Document
(
DocumentRelationship
.
SOURCE
:
document
.
id
page_content
=
segment
.
content
,
}
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
if
previous_node
:
"doc_hash"
:
segment
.
index_node_hash
,
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_node
.
doc_id
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
previous_node
.
relationships
[
DocumentRelationship
.
NEXT
]
=
segment
.
index_node_id
}
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
)
)
previous_node
=
node
documents
.
append
(
document
)
nodes
.
append
(
node
)
dataset
=
dataset_document
.
dataset
dataset
=
document
.
dataset
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
# save vector index
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
vector_index
.
add_nodes
(
if
index
:
nodes
=
nodes
,
index
.
add_texts
(
documents
)
duplicate_check
=
True
)
# save keyword index
# save keyword index
keyword_table_index
.
add_nodes
(
nodes
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
(
documents
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
logging
.
info
(
click
.
style
(
'Document added to index: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
click
.
style
(
'Document added to index: {} latency: {}'
.
format
(
d
ataset_d
ocument
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
exception
(
"add document to index failed"
)
logging
.
exception
(
"add document to index failed"
)
document
.
enabled
=
False
d
ataset_d
ocument
.
enabled
=
False
document
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
d
ataset_d
ocument
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
document
.
status
=
'error'
d
ataset_d
ocument
.
status
=
'error'
document
.
error
=
str
(
e
)
d
ataset_d
ocument
.
error
=
str
(
e
)
db
.
session
.
commit
()
db
.
session
.
commit
()
finally
:
finally
:
redis_client
.
delete
(
indexing_cache_key
)
redis_client
.
delete
(
indexing_cache_key
)
api/tasks/add_segment_to_index_task.py
View file @
0578c1b6
...
@@ -4,12 +4,10 @@ import time
...
@@ -4,12 +4,10 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
llama_index.data_structs
import
Node
from
langchain.schema
import
Document
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
from
models.dataset
import
DocumentSegment
...
@@ -36,25 +34,14 @@ def add_segment_to_index_task(segment_id: str):
...
@@ -36,25 +34,14 @@ def add_segment_to_index_task(segment_id: str):
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
try
:
try
:
relationships
=
{
document
=
Document
(
DocumentRelationship
.
SOURCE
:
segment
.
document_id
,
page_content
=
segment
.
content
,
}
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
previous_segment
=
segment
.
previous_segment
"doc_hash"
:
segment
.
index_node_hash
,
if
previous_segment
:
"document_id"
:
segment
.
document_id
,
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_segment
.
index_node_id
"dataset_id"
:
segment
.
dataset_id
,
}
next_segment
=
segment
.
next_segment
if
next_segment
:
relationships
[
DocumentRelationship
.
NEXT
]
=
next_segment
.
index_node_id
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
)
)
dataset
=
segment
.
dataset
dataset
=
segment
.
dataset
...
@@ -62,18 +49,15 @@ def add_segment_to_index_task(segment_id: str):
...
@@ -62,18 +49,15 @@ def add_segment_to_index_task(segment_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Segment has no dataset'
)
raise
Exception
(
'Segment has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
# save vector index
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
vector_index
.
add_nodes
(
if
index
:
nodes
=
[
node
],
index
.
add_texts
([
document
],
duplicate_check
=
True
)
duplicate_check
=
True
)
# save keyword index
# save keyword index
keyword_table_index
.
add_nodes
([
node
])
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
([
document
])
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment added to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Segment added to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
api/tasks/clean_dataset_task.py
View file @
0578c1b6
...
@@ -4,8 +4,7 @@ import time
...
@@ -4,8 +4,7 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
,
DatasetKeywordTable
,
DatasetQuery
,
DatasetProcessRule
,
\
from
models.dataset
import
DocumentSegment
,
Dataset
,
DatasetKeywordTable
,
DatasetQuery
,
DatasetProcessRule
,
\
AppDatasetJoin
AppDatasetJoin
...
@@ -33,19 +32,19 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
...
@@ -33,19 +32,19 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct
=
index_struct
index_struct
=
index_struct
)
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
documents
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
documents
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
index_doc_ids
=
[
document
.
id
for
document
in
documents
]
index_doc_ids
=
[
document
.
id
for
document
in
documents
]
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
kw_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
# delete from vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
if
vector_index
:
for
index_doc_id
in
index_doc_ids
:
for
index_doc_id
in
index_doc_ids
:
try
:
try
:
vector_index
.
del
_doc
(
index_doc_id
)
vector_index
.
del
ete_by_document_id
(
index_doc_id
)
except
Exception
:
except
Exception
:
logging
.
exception
(
"Delete doc index failed when dataset deleted."
)
logging
.
exception
(
"Delete doc index failed when dataset deleted."
)
continue
continue
...
@@ -53,7 +52,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
...
@@ -53,7 +52,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
# delete from keyword index
# delete from keyword index
if
index_node_ids
:
if
index_node_ids
:
try
:
try
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
except
Exception
:
except
Exception
:
logging
.
exception
(
"Delete nodes index failed when dataset deleted."
)
logging
.
exception
(
"Delete nodes index failed when dataset deleted."
)
...
...
api/tasks/clean_document_task.py
View file @
0578c1b6
...
@@ -4,8 +4,7 @@ import time
...
@@ -4,8 +4,7 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
from
models.dataset
import
DocumentSegment
,
Dataset
...
@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
...
@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
# delete from keyword index
if
index_node_ids
:
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
delete
(
segment
)
db
.
session
.
commit
()
db
.
session
.
commit
()
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
logging
.
info
(
...
...
api/tasks/clean_notion_document_task.py
View file @
0578c1b6
...
@@ -5,8 +5,7 @@ from typing import List
...
@@ -5,8 +5,7 @@ from typing import List
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
,
Document
from
models.dataset
import
DocumentSegment
,
Dataset
,
Document
...
@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
...
@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
for
document_id
in
document_ids
:
for
document_id
in
document_ids
:
document
=
db
.
session
.
query
(
Document
)
.
filter
(
document
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
Document
.
id
==
document_id
)
.
first
()
)
.
first
()
db
.
session
.
delete
(
document
)
db
.
session
.
delete
(
document
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
# delete from keyword index
if
index_node_ids
:
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
delete
(
segment
)
...
...
api/tasks/deal_dataset_vector_index_task.py
View file @
0578c1b6
...
@@ -3,10 +3,12 @@ import time
...
@@ -3,10 +3,12 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
llama_index.data_structs.node_v2
import
DocumentRelationship
,
Node
from
langchain.schema
import
Document
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Document
,
Dataset
from
models.dataset
import
DocumentSegment
,
Dataset
from
models.dataset
import
Document
as
DatasetDocument
@
shared_task
@
shared_task
...
@@ -26,48 +28,41 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
...
@@ -26,48 +28,41 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
)
.
first
()
)
.
first
()
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
raise
Exception
(
'Dataset not found'
)
documents
=
Document
.
query
.
filter_by
(
dataset_id
=
dataset_id
)
.
all
()
dataset_documents
=
DatasetDocument
.
query
.
filter_by
(
dataset_id
=
dataset_id
)
.
all
()
if
documents
:
if
dataset_documents
:
vector_index
=
VectorIndex
(
dataset
=
dataset
)
# save vector index
for
document
in
documents
:
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
# delete from vector index
if
index
:
if
action
==
"remove"
:
for
dataset_document
in
dataset_documents
:
vector_index
.
del_doc
(
document
.
id
)
# delete from vector index
elif
action
==
"add"
:
if
action
==
"remove"
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
index
.
delete_by_document_id
(
dataset_document
.
id
)
DocumentSegment
.
document_id
==
document
.
id
,
elif
action
==
"add"
:
DocumentSegment
.
enabled
==
True
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
)
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
DocumentSegment
.
document_id
==
dataset_document
.
id
,
DocumentSegment
.
enabled
==
True
nodes
=
[]
)
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
previous_node
=
None
for
segment
in
segments
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
document
.
id
}
if
previous_node
:
documents
=
[]
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_node
.
doc_id
for
segment
in
segments
:
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
previous_node
.
relationships
[
DocumentRelationship
.
NEXT
]
=
segment
.
index_node_id
documents
.
append
(
document
)
node
=
Node
(
# save vector index
doc_id
=
segment
.
index_node_id
,
index
.
add_texts
(
doc_hash
=
segment
.
index_node_hash
,
documents
,
text
=
segment
.
content
,
duplicate_check
=
True
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
)
)
previous_node
=
node
nodes
.
append
(
node
)
# save vector index
vector_index
.
add_nodes
(
nodes
=
nodes
,
duplicate_check
=
True
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
logging
.
info
(
click
.
style
(
'Deal dataset vector index: {} latency: {}'
.
format
(
dataset_id
,
end_at
-
start_at
),
fg
=
'green'
))
click
.
style
(
'Deal dataset vector index: {} latency: {}'
.
format
(
dataset_id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
api/tasks/document_indexing_sync_task.py
View file @
0578c1b6
...
@@ -7,10 +7,8 @@ from celery import shared_task
...
@@ -7,10 +7,8 @@ from celery import shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.data_loader.loader.notion
import
NotionLoader
from
core.data_loader.loader.notion
import
NotionLoader
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
from
models.source
import
DataSourceBinding
from
models.source
import
DataSourceBinding
...
@@ -77,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
...
@@ -77,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
raise
Exception
(
'Dataset not found'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
# delete from keyword index
if
index_node_ids
:
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
delete
(
segment
)
...
@@ -98,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
...
@@ -98,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
:
except
Exception
:
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
try
:
try
:
indexing_runner
=
IndexingRunner
()
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
([
document
])
indexing_runner
.
run
([
document
])
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
'Document update paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
except
Exception
:
document
.
indexing_status
=
'error'
pass
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume update document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
api/tasks/document_indexing_task.py
View file @
0578c1b6
...
@@ -7,7 +7,6 @@ from celery import shared_task
...
@@ -7,7 +7,6 @@ from celery import shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
Document
from
models.dataset
import
Document
...
@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
...
@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
Usage: document_indexing_task.delay(dataset_id, document_id)
Usage: document_indexing_task.delay(dataset_id, document_id)
"""
"""
documents
=
[]
documents
=
[]
start_at
=
time
.
perf_counter
()
for
document_id
in
document_ids
:
for
document_id
in
document_ids
:
logging
.
info
(
click
.
style
(
'Start process document: {}'
.
format
(
document_id
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Start process document: {}'
.
format
(
document_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
document
=
db
.
session
.
query
(
Document
)
.
filter
(
document
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
,
Document
.
id
==
document_id
,
...
@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
...
@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner
=
IndexingRunner
()
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
(
documents
)
indexing_runner
.
run
(
documents
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Processed dataset: {} latency: {}'
.
format
(
dataset_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
'Document paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
except
Exception
:
document
.
indexing_status
=
'error'
pass
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
api/tasks/document_indexing_update_task.py
View file @
0578c1b6
...
@@ -6,10 +6,8 @@ import click
...
@@ -6,10 +6,8 @@ import click
from
celery
import
shared_task
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
...
@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
...
@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
raise
Exception
(
'Dataset not found'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_ids
(
index_node_ids
)
# delete from keyword index
# delete from keyword index
if
index_node_ids
:
if
index_node_ids
:
keyword_table_index
.
del_node
s
(
index_node_ids
)
vector_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
delete
(
segment
)
...
@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
...
@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
:
except
Exception
:
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
try
:
try
:
indexing_runner
=
IndexingRunner
()
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
([
document
])
indexing_runner
.
run
([
document
])
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
'Document update paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
except
Exception
:
document
.
indexing_status
=
'error'
pass
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume update document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
api/tasks/recover_document_indexing_task.py
View file @
0578c1b6
import
datetime
import
logging
import
logging
import
time
import
time
...
@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
...
@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner
.
run_in_indexing_status
(
document
)
indexing_runner
.
run_in_indexing_status
(
document
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
'Document paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
as
e
:
except
Exception
:
logging
.
exception
(
"consume document failed"
)
pass
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
api/tasks/remove_document_from_index_task.py
View file @
0578c1b6
...
@@ -5,8 +5,7 @@ import click
...
@@ -5,8 +5,7 @@ import click
from
celery
import
shared_task
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
,
Document
from
models.dataset
import
DocumentSegment
,
Document
...
@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
...
@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
# delete from vector index
vector_index
.
del
_doc
(
document
.
id
)
vector_index
.
del
ete_by_document_id
(
document
.
id
)
# delete from keyword index
# delete from keyword index
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
if
index_node_ids
:
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
logging
.
info
(
...
...
api/tasks/remove_segment_from_index_task.py
View file @
0578c1b6
...
@@ -5,8 +5,7 @@ import click
...
@@ -5,8 +5,7 @@ import click
from
celery
import
shared_task
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
from
models.dataset
import
DocumentSegment
...
@@ -38,15 +37,15 @@ def remove_segment_from_index_task(segment_id: str):
...
@@ -38,15 +37,15 @@ def remove_segment_from_index_task(segment_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Segment has no dataset'
)
raise
Exception
(
'Segment has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
# delete from vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
if
vector_index
:
vector_index
.
del
_node
s
([
segment
.
index_node_id
])
vector_index
.
del
ete_by_id
s
([
segment
.
index_node_id
])
# delete from keyword index
# delete from keyword index
k
eyword_table_index
.
del_node
s
([
segment
.
index_node_id
])
k
w_index
.
delete_by_id
s
([
segment
.
index_node_id
])
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment removed from index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Segment removed from index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment