Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
0578c1b6
Commit
0578c1b6
authored
Jun 19, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: replace using new index builder
parent
fb5118f0
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
27 changed files
with
379 additions
and
344 deletions
+379
-344
llm_router_chain.py
api/core/chain/llm_router_chain.py
+4
-2
sensitive_word_avoidance_chain.py
api/core/chain/sensitive_word_avoidance_chain.py
+7
-2
tool_chain.py
api/core/chain/tool_chain.py
+12
-3
base.py
api/core/index/base.py
+2
-2
index.py
api/core/index/index.py
+41
-0
keyword_table_index.py
api/core/index/keyword_table_index/keyword_table_index.py
+2
-2
base.py
api/core/index/vector_index/base.py
+4
-2
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+1
-1
vector_index.py
api/core/index/vector_index/vector_index.py
+3
-3
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+1
-1
indexing_runner.py
api/core/indexing_runner.py
+132
-84
streamable_azure_chat_open_ai.py
api/core/llm/streamable_azure_chat_open_ai.py
+11
-2
streamable_chat_open_ai.py
api/core/llm/streamable_chat_open_ai.py
+11
-2
prompts.py
api/core/prompt/prompts.py
+0
-19
app_model_config_service.py
api/services/app_model_config_service.py
+0
-1
add_document_to_index_task.py
api/tasks/add_document_to_index_task.py
+33
-48
add_segment_to_index_task.py
api/tasks/add_segment_to_index_task.py
+16
-32
clean_dataset_task.py
api/tasks/clean_dataset_task.py
+7
-8
clean_document_task.py
api/tasks/clean_document_task.py
+7
-6
clean_notion_document_task.py
api/tasks/clean_notion_document_task.py
+7
-6
deal_dataset_vector_index_task.py
api/tasks/deal_dataset_vector_index_task.py
+35
-40
document_indexing_sync_task.py
api/tasks/document_indexing_sync_task.py
+11
-20
document_indexing_task.py
api/tasks/document_indexing_task.py
+6
-16
document_indexing_update_task.py
api/tasks/document_indexing_update_task.py
+11
-20
recover_document_indexing_task.py
api/tasks/recover_document_indexing_task.py
+4
-9
remove_document_from_index_task.py
api/tasks/remove_document_from_index_task.py
+5
-6
remove_segment_from_index_task.py
api/tasks/remove_segment_from_index_task.py
+6
-7
No files found.
api/core/chain/llm_router_chain.py
View file @
0578c1b6
...
...
@@ -4,6 +4,7 @@ from __future__ import annotations
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
pydantic
import
root_validator
...
...
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
raise
ValueError
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
]
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
output
=
cast
(
Dict
[
str
,
Any
],
...
...
api/core/chain/sensitive_word_avoidance_chain.py
View file @
0578c1b6
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
...
...
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return
self
.
canned_response
return
text
def
_call
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
text
=
inputs
[
self
.
input_key
]
output
=
self
.
_check_sensitive_word
(
text
)
return
{
self
.
output_key
:
output
}
api/core/chain/tool_chain.py
View file @
0578c1b6
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
,
AsyncCallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.tools
import
BaseTool
...
...
@@ -30,12 +31,20 @@ class ToolChain(Chain):
"""
return
[
self
.
output_key
]
def
_call
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
input
=
inputs
[
self
.
input_key
]
output
=
self
.
tool
.
run
(
input
,
self
.
verbose
)
return
{
self
.
output_key
:
output
}
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
AsyncCallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Run the logic of this chain and return the output."""
input
=
inputs
[
self
.
input_key
]
output
=
await
self
.
tool
.
arun
(
input
,
self
.
verbose
)
...
...
api/core/index/base.py
View file @
0578c1b6
...
...
@@ -7,11 +7,11 @@ from langchain.schema import Document, BaseRetriever
class
BaseIndex
(
ABC
):
@
abstractmethod
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
def
create
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
)
->
BaseIndex
:
raise
NotImplementedError
@
abstractmethod
def
add_texts
(
self
,
texts
:
list
[
Document
]):
def
add_texts
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
):
raise
NotImplementedError
@
abstractmethod
...
...
api/core/index/index.py
0 → 100644
View file @
0578c1b6
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
class
IndexBuilder
:
@
classmethod
def
get_index
(
cls
,
dataset
:
Dataset
,
high_quality
:
str
):
if
high_quality
==
"high_quality"
:
if
dataset
.
indexing_technique
!=
'high_quality'
:
return
None
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
return
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
elif
high_quality
==
"economy"
:
return
KeywordTableIndex
(
dataset
=
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
10
)
)
else
:
raise
ValueError
(
'Unknown indexing technique'
)
\ No newline at end of file
api/core/index/keyword_table_index/keyword_table_index.py
View file @
0578c1b6
...
...
@@ -20,7 +20,7 @@ class KeywordTableIndex(BaseIndex):
self
.
_dataset
=
dataset
self
.
_config
=
config
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
def
create
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
)
->
BaseIndex
:
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
{}
for
text
in
texts
:
...
...
@@ -37,7 +37,7 @@ class KeywordTableIndex(BaseIndex):
return
self
def
add_texts
(
self
,
texts
:
list
[
Document
]):
def
add_texts
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
):
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
self
.
_get_dataset_keyword_table
()
...
...
api/core/index/vector_index/base.py
View file @
0578c1b6
...
...
@@ -67,11 +67,13 @@ class BaseVectorIndex(BaseIndex):
return
vector_store
.
as_retriever
(
**
kwargs
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
def
add_texts
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
if
kwargs
.
get
(
'duplicate_check'
,
False
):
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
...
...
api/core/index/vector_index/qdrant_vector_index.py
View file @
0578c1b6
...
...
@@ -53,7 +53,7 @@ class QdrantVectorIndex(BaseVectorIndex):
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
())}
}
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
def
create
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
QdrantVectorStore
.
from_documents
(
texts
,
...
...
api/core/index/vector_index/vector_index.py
View file @
0578c1b6
...
...
@@ -51,14 +51,14 @@ class VectorIndex:
else
:
raise
ValueError
(
f
"Vector store {config.get('VECTOR_STORE')} is not supported."
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
def
add_texts
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
):
if
not
self
.
_dataset
.
index_struct_dict
:
self
.
_vector_index
.
create
(
texts
)
self
.
_vector_index
.
create
(
texts
,
**
kwargs
)
self
.
_dataset
.
index_struct
=
json
.
dumps
(
self
.
_vector_index
.
to_index_struct
())
db
.
session
.
commit
()
return
self
.
_vector_index
.
add_texts
(
texts
)
self
.
_vector_index
.
add_texts
(
texts
,
**
kwargs
)
def
__getattr__
(
self
,
name
):
if
self
.
_vector_index
is
not
None
:
...
...
api/core/index/vector_index/weaviate_vector_index.py
View file @
0578c1b6
...
...
@@ -62,7 +62,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
())}
}
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
def
create
(
self
,
texts
:
list
[
Document
]
,
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
texts
,
...
...
api/core/indexing_runner.py
View file @
0578c1b6
import
datetime
import
json
import
logging
import
re
import
time
import
uuid
...
...
@@ -15,8 +16,10 @@ from core.data_loader.file_extractor import FileExtractor
from
core.data_loader.loader.notion
import
NotionLoader
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.index
import
IndexBuilder
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.error
import
ProviderTokenNotInitError
from
core.llm.llm_builder
import
LLMBuilder
from
core.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.llm.token_calculator
import
TokenCalculator
...
...
@@ -39,6 +42,58 @@ class IndexingRunner:
def
run
(
self
,
dataset_documents
:
List
[
DatasetDocument
]):
"""Run the indexing process."""
for
dataset_document
in
dataset_documents
:
try
:
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
dataset_document
.
dataset_process_rule_id
)
.
\
first
()
# get splitter
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
documents
=
self
.
_step_split
(
text_docs
=
text_docs
,
splitter
=
splitter
,
dataset
=
dataset
,
dataset_document
=
dataset_document
,
processing_rule
=
processing_rule
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
dataset_document
=
dataset_document
,
documents
=
documents
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
run_in_splitting_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is splitting."""
try
:
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
...
...
@@ -47,6 +102,15 @@ class IndexingRunner:
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
)
.
all
()
db
.
session
.
delete
(
document_segments
)
db
.
session
.
commit
()
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
...
...
@@ -73,92 +137,73 @@ class IndexingRunner:
dataset_document
=
dataset_document
,
documents
=
documents
)
def
run_in_splitting_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is splitting."""
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
)
.
all
()
db
.
session
.
delete
(
document_segments
)
db
.
session
.
commit
()
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
dataset_document
.
dataset_process_rule_id
)
.
\
first
()
# get splitter
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
documents
=
self
.
_step_split
(
text_docs
=
text_docs
,
splitter
=
splitter
,
dataset
=
dataset
,
dataset_document
=
dataset_document
,
processing_rule
=
processing_rule
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
dataset_document
=
dataset_document
,
documents
=
documents
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
run_in_indexing_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is indexing."""
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
)
.
first
()
try
:
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
)
.
all
()
documents
=
[]
if
document_segments
:
for
document_segment
in
document_segments
:
# transform segment to node
if
document_segment
.
status
!=
"completed"
:
document
=
Document
(
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
documents
.
append
(
document
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
dataset_document
=
dataset_document
,
documents
=
documents
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
)
.
all
()
documents
=
[]
if
document_segments
:
for
document_segment
in
document_segments
:
# transform segment to node
if
document_segment
.
status
!=
"completed"
:
document
=
Document
(
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
documents
.
append
(
document
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
dataset_document
=
dataset_document
,
documents
=
documents
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
file_indexing_estimate
(
self
,
file_details
:
List
[
UploadFile
],
tmp_processing_rule
:
dict
)
->
dict
:
"""
...
...
@@ -481,11 +526,14 @@ class IndexingRunner:
)
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
vector_index
.
add_texts
(
chunk_documents
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
:
index
.
add_texts
(
chunk_documents
)
# save keyword index
keyword_table_index
.
add_texts
(
chunk_documents
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
(
chunk_documents
)
document_ids
=
[
document
.
metadata
[
'doc_id'
]
for
document
in
chunk_documents
]
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
...
...
api/core/llm/streamable_azure_chat_open_ai.py
View file @
0578c1b6
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
AsyncCallbackManagerForLLMRun
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.chat_models
import
AzureChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
...
...
@@ -69,7 +70,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return
message_tokens
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
...
...
@@ -87,7 +92,11 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return
chat_result
async
def
_agenerate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_start
(
...
...
api/core/llm/streamable_chat_open_ai.py
View file @
0578c1b6
import
os
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.chat_models
import
ChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
...
...
@@ -71,7 +72,11 @@ class StreamableChatOpenAI(ChatOpenAI):
return
message_tokens
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
...
...
@@ -88,7 +93,11 @@ class StreamableChatOpenAI(ChatOpenAI):
return
chat_result
async
def
_agenerate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_start
(
...
...
api/core/prompt/prompts.py
View file @
0578c1b6
from
llama_index
import
QueryKeywordExtractPrompt
CONVERSATION_TITLE_PROMPT
=
(
"Human:{query}
\n
-----
\n
"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.
\n
"
...
...
@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[
\"
question1
\"
,
\"
question2
\"
,
\"
question3
\"
]
\n
"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
=
(
"A question is provided below. Given the question, extract up to {max_keywords} "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords."
"I am not sure which language the following question is in. "
"If the user asked the question in Chinese, please return the keywords in Chinese. "
"If the user asked the question in English, please return the keywords in English.
\n
"
"---------------------
\n
"
"{question}
\n
"
"---------------------
\n
"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'
\n
"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE
=
QueryKeywordExtractPrompt
(
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
)
RULE_CONFIG_GENERATE_TEMPLATE
=
"""Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select
\
the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement.
...
...
api/services/app_model_config_service.py
View file @
0578c1b6
...
...
@@ -4,7 +4,6 @@ import uuid
from
core.constant
import
llm_constant
from
models.account
import
Account
from
services.dataset_service
import
DatasetService
from
services.errors.account
import
NoPermissionError
class
AppModelConfigService
:
...
...
api/tasks/add_document_to_index_task.py
View file @
0578c1b6
...
...
@@ -4,96 +4,81 @@ import time
import
click
from
celery
import
shared_task
from
llama_index.data_structs
import
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
langchain.schema
import
Document
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
,
Document
from
models.dataset
import
DocumentSegment
from
models.dataset
import
Document
as
DatasetDocument
@
shared_task
def
add_document_to_index_task
(
document_id
:
str
):
def
add_document_to_index_task
(
d
ataset_d
ocument_id
:
str
):
"""
Async Add document to index
:param document_id:
Usage: add_document_to_index.delay(document_id)
"""
logging
.
info
(
click
.
style
(
'Start add document to index: {}'
.
format
(
document_id
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Start add document to index: {}'
.
format
(
d
ataset_d
ocument_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
d
ocument
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
)
.
first
()
if
not
document
:
d
ataset_document
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
DatasetDocument
.
id
==
dataset_
document_id
)
.
first
()
if
not
d
ataset_d
ocument
:
raise
NotFound
(
'Document not found'
)
if
document
.
indexing_status
!=
'completed'
:
if
d
ataset_d
ocument
.
indexing_status
!=
'completed'
:
return
indexing_cache_key
=
'document_{}_indexing'
.
format
(
document
.
id
)
indexing_cache_key
=
'document_{}_indexing'
.
format
(
d
ataset_d
ocument
.
id
)
try
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
document_id
==
d
ataset_d
ocument
.
id
,
DocumentSegment
.
enabled
==
True
)
\
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
nodes
=
[]
previous_node
=
None
documents
=
[]
for
segment
in
segments
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
document
.
id
}
if
previous_node
:
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_node
.
doc_id
previous_node
.
relationships
[
DocumentRelationship
.
NEXT
]
=
segment
.
index_node_id
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
previous_node
=
node
documents
.
append
(
document
)
nodes
.
append
(
node
)
dataset
=
document
.
dataset
dataset
=
dataset_document
.
dataset
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
vector_index
.
add_nodes
(
nodes
=
nodes
,
duplicate_check
=
True
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
:
index
.
add_texts
(
documents
)
# save keyword index
keyword_table_index
.
add_nodes
(
nodes
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
(
documents
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Document added to index: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
click
.
style
(
'Document added to index: {} latency: {}'
.
format
(
d
ataset_d
ocument
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
as
e
:
logging
.
exception
(
"add document to index failed"
)
document
.
enabled
=
False
document
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
document
.
status
=
'error'
document
.
error
=
str
(
e
)
d
ataset_d
ocument
.
enabled
=
False
d
ataset_d
ocument
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
d
ataset_d
ocument
.
status
=
'error'
d
ataset_d
ocument
.
error
=
str
(
e
)
db
.
session
.
commit
()
finally
:
redis_client
.
delete
(
indexing_cache_key
)
api/tasks/add_segment_to_index_task.py
View file @
0578c1b6
...
...
@@ -4,12 +4,10 @@ import time
import
click
from
celery
import
shared_task
from
llama_index.data_structs
import
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
langchain.schema
import
Document
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
...
...
@@ -36,25 +34,14 @@ def add_segment_to_index_task(segment_id: str):
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
try
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
segment
.
document_id
,
}
previous_segment
=
segment
.
previous_segment
if
previous_segment
:
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_segment
.
index_node_id
next_segment
=
segment
.
next_segment
if
next_segment
:
relationships
[
DocumentRelationship
.
NEXT
]
=
next_segment
.
index_node_id
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
dataset
=
segment
.
dataset
...
...
@@ -62,18 +49,15 @@ def add_segment_to_index_task(segment_id: str):
if
not
dataset
:
raise
Exception
(
'Segment has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
vector_index
.
add_nodes
(
nodes
=
[
node
],
duplicate_check
=
True
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
:
index
.
add_texts
([
document
],
duplicate_check
=
True
)
# save keyword index
keyword_table_index
.
add_nodes
([
node
])
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
([
document
])
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment added to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
api/tasks/clean_dataset_task.py
View file @
0578c1b6
...
...
@@ -4,8 +4,7 @@ import time
import
click
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
,
DatasetKeywordTable
,
DatasetQuery
,
DatasetProcessRule
,
\
AppDatasetJoin
...
...
@@ -33,19 +32,19 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct
=
index_struct
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
documents
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
index_doc_ids
=
[
document
.
id
for
document
in
documents
]
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
kw_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
if
vector_index
:
for
index_doc_id
in
index_doc_ids
:
try
:
vector_index
.
del
_doc
(
index_doc_id
)
vector_index
.
del
ete_by_document_id
(
index_doc_id
)
except
Exception
:
logging
.
exception
(
"Delete doc index failed when dataset deleted."
)
continue
...
...
@@ -53,7 +52,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
# delete from keyword index
if
index_node_ids
:
try
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
except
Exception
:
logging
.
exception
(
"Delete nodes index failed when dataset deleted."
)
...
...
api/tasks/clean_document_task.py
View file @
0578c1b6
...
...
@@ -4,8 +4,7 @@ import time
import
click
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
...
...
@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
commit
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
...
...
api/tasks/clean_notion_document_task.py
View file @
0578c1b6
...
...
@@ -5,8 +5,7 @@ from typing import List
import
click
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
,
Document
...
...
@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
for
document_id
in
document_ids
:
document
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
)
.
first
()
db
.
session
.
delete
(
document
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
...
...
api/tasks/deal_dataset_vector_index_task.py
View file @
0578c1b6
...
...
@@ -3,10 +3,12 @@ import time
import
click
from
celery
import
shared_task
from
llama_index.data_structs.node_v2
import
DocumentRelationship
,
Node
from
core.index.vector_index
import
VectorIndex
from
langchain.schema
import
Document
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Document
,
Dataset
from
models.dataset
import
DocumentSegment
,
Dataset
from
models.dataset
import
Document
as
DatasetDocument
@
shared_task
...
...
@@ -26,48 +28,41 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
)
.
first
()
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
documents
=
Document
.
query
.
filter_by
(
dataset_id
=
dataset_id
)
.
all
()
if
documents
:
vector_index
=
VectorIndex
(
dataset
=
dataset
)
for
document
in
documents
:
# delete from vector index
if
action
==
"remove"
:
vector_index
.
del_doc
(
document
.
id
)
elif
action
==
"add"
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
enabled
==
True
)
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
nodes
=
[]
previous_node
=
None
for
segment
in
segments
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
document
.
id
}
dataset_documents
=
DatasetDocument
.
query
.
filter_by
(
dataset_id
=
dataset_id
)
.
all
()
if
dataset_documents
:
# save vector index
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
:
for
dataset_document
in
dataset_documents
:
# delete from vector index
if
action
==
"remove"
:
index
.
delete_by_document_id
(
dataset_document
.
id
)
elif
action
==
"add"
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
dataset_document
.
id
,
DocumentSegment
.
enabled
==
True
)
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
if
previous_node
:
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_node
.
doc_id
documents
=
[]
for
segment
in
segments
:
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
previous_node
.
relationships
[
DocumentRelationship
.
NEXT
]
=
segment
.
index_node_id
documents
.
append
(
document
)
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
# save vector index
index
.
add_texts
(
documents
,
duplicate_check
=
True
)
previous_node
=
node
nodes
.
append
(
node
)
# save vector index
vector_index
.
add_nodes
(
nodes
=
nodes
,
duplicate_check
=
True
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Deal dataset vector index: {} latency: {}'
.
format
(
dataset_id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
api/tasks/document_indexing_sync_task.py
View file @
0578c1b6
...
...
@@ -7,10 +7,8 @@ from celery import shared_task
from
werkzeug.exceptions
import
NotFound
from
core.data_loader.loader.notion
import
NotionLoader
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
from
models.source
import
DataSourceBinding
...
...
@@ -77,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
...
...
@@ -98,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
:
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
try
:
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
([
document
])
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
logging
.
info
(
click
.
style
(
'Document update paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume update document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
:
pass
api/tasks/document_indexing_task.py
View file @
0578c1b6
...
...
@@ -7,7 +7,6 @@ from celery import shared_task
from
werkzeug.exceptions
import
NotFound
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
models.dataset
import
Document
...
...
@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
Usage: document_indexing_task.delay(dataset_id, document_id)
"""
documents
=
[]
start_at
=
time
.
perf_counter
()
for
document_id
in
document_ids
:
logging
.
info
(
click
.
style
(
'Start process document: {}'
.
format
(
document_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
document
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
,
...
...
@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
(
documents
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
logging
.
info
(
click
.
style
(
'Document paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
logging
.
info
(
click
.
style
(
'Processed dataset: {} latency: {}'
.
format
(
dataset_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
:
pass
api/tasks/document_indexing_update_task.py
View file @
0578c1b6
...
...
@@ -6,10 +6,8 @@ import click
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
...
...
@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_ids
(
index_node_ids
)
# delete from keyword index
if
index_node_ids
:
keyword_table_index
.
del_node
s
(
index_node_ids
)
vector_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
...
...
@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
:
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
try
:
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
([
document
])
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
logging
.
info
(
click
.
style
(
'Document update paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume update document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
:
pass
api/tasks/recover_document_indexing_task.py
View file @
0578c1b6
import
datetime
import
logging
import
time
...
...
@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner
.
run_in_indexing_status
(
document
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
logging
.
info
(
click
.
style
(
'Document paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
:
pass
api/tasks/remove_document_from_index_task.py
View file @
0578c1b6
...
...
@@ -5,8 +5,7 @@ import click
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
,
Document
...
...
@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
vector_index
.
del
_doc
(
document
.
id
)
vector_index
.
del
ete_by_document_id
(
document
.
id
)
# delete from keyword index
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
...
...
api/tasks/remove_segment_from_index_task.py
View file @
0578c1b6
...
...
@@ -5,8 +5,7 @@ import click
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
...
...
@@ -38,15 +37,15 @@ def remove_segment_from_index_task(segment_id: str):
if
not
dataset
:
raise
Exception
(
'Segment has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
vector_index
.
del
_node
s
([
segment
.
index_node_id
])
if
vector_index
:
vector_index
.
del
ete_by_id
s
([
segment
.
index_node_id
])
# delete from keyword index
k
eyword_table_index
.
del_node
s
([
segment
.
index_node_id
])
k
w_index
.
delete_by_id
s
([
segment
.
index_node_id
])
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment removed from index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment