Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
b615aa79
Commit
b615aa79
authored
Jul 21, 2023
by
jyong
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'feat/milvus-support' into deploy/dev
# Conflicts: # api/core/tool/dataset_index_tool.py
parents
b7695ffc
4f1b4b73
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
832 additions
and
26 deletions
+832
-26
datasets_document.py
api/controllers/console/datasets/datasets_document.py
+6
-0
datasets_segments.py
api/controllers/console/datasets/datasets_segments.py
+87
-3
hit_testing.py
api/controllers/console/datasets/hit_testing.py
+1
-0
excel.py
api/core/data_loader/loader/excel.py
+1
-1
dataset_docstore.py
api/core/docstore/dataset_docstore.py
+6
-1
llm_generator.py
api/core/generator/llm_generator.py
+22
-4
test-embedding.py
api/core/index/vector_index/test-embedding.py
+123
-0
indexing_runner.py
api/core/indexing_runner.py
+93
-5
prompts.py
api/core/prompt/prompts.py
+11
-0
dataset_index_tool.py
api/core/tool/dataset_index_tool.py
+102
-0
8d2d099ceb74_add_qa_model_support.py
api/migrations/versions/8d2d099ceb74_add_qa_model_support.py
+42
-0
dataset.py
api/models/dataset.py
+7
-1
completion_service.py
api/services/completion_service.py
+1
-0
dataset_service.py
api/services/dataset_service.py
+92
-5
create_segment_to_index_task.py
api/tasks/create_segment_to_index_task.py
+98
-0
enable_segment_to_index_task.py
api/tasks/enable_segment_to_index_task.py
+6
-6
generate_test_task.py
api/tasks/generate_test_task.py
+24
-0
update_segment_index_task.py
api/tasks/update_segment_index_task.py
+110
-0
No files found.
api/controllers/console/datasets/datasets_document.py
View file @
b615aa79
...
...
@@ -60,6 +60,7 @@ document_fields = {
'display_status'
:
fields
.
String
,
'word_count'
:
fields
.
Integer
,
'hit_count'
:
fields
.
Integer
,
'doc_form'
:
fields
.
String
,
}
document_with_segments_fields
=
{
...
...
@@ -86,6 +87,7 @@ document_with_segments_fields = {
'total_segments'
:
fields
.
Integer
}
class
DocumentResource
(
Resource
):
def
get_document
(
self
,
dataset_id
:
str
,
document_id
:
str
)
->
Document
:
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
...
...
@@ -269,6 +271,7 @@ class DatasetDocumentListApi(Resource):
parser
.
add_argument
(
'process_rule'
,
type
=
dict
,
required
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'duplicate'
,
type
=
bool
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'original_document_id'
,
type
=
str
,
required
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
if
not
dataset
.
indexing_technique
and
not
args
[
'indexing_technique'
]:
...
...
@@ -313,6 +316,7 @@ class DatasetInitApi(Resource):
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'data_source'
,
type
=
dict
,
required
=
True
,
nullable
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'process_rule'
,
type
=
dict
,
required
=
True
,
nullable
=
True
,
location
=
'json'
)
parser
.
add_argument
(
'doc_form'
,
type
=
str
,
default
=
'text_model'
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
args
=
parser
.
parse_args
()
# validate args
...
...
@@ -488,6 +492,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
DocumentSegment
.
status
!=
're_segment'
)
.
count
()
document
.
completed_segments
=
completed_segments
document
.
total_segments
=
total_segments
if
document
.
is_paused
:
document
.
indexing_status
=
'paused'
documents_status
.
append
(
marshal
(
document
,
self
.
document_status_fields
))
data
=
{
'data'
:
documents_status
...
...
api/controllers/console/datasets/datasets_segments.py
View file @
b615aa79
...
...
@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client
from
models.dataset
import
DocumentSegment
from
libs.helper
import
TimestampField
from
services.dataset_service
import
DatasetService
,
DocumentService
from
tasks.
add_segment_to_index_task
import
add
_segment_to_index_task
from
services.dataset_service
import
DatasetService
,
DocumentService
,
SegmentService
from
tasks.
enable_segment_to_index_task
import
enable
_segment_to_index_task
from
tasks.remove_segment_from_index_task
import
remove_segment_from_index_task
segment_fields
=
{
...
...
@@ -24,6 +24,7 @@ segment_fields = {
'position'
:
fields
.
Integer
,
'document_id'
:
fields
.
String
,
'content'
:
fields
.
String
,
'answer'
:
fields
.
String
,
'word_count'
:
fields
.
Integer
,
'tokens'
:
fields
.
Integer
,
'keywords'
:
fields
.
List
(
fields
.
String
),
...
...
@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource):
return
{
'data'
:
marshal
(
segments
,
segment_fields
),
'doc_form'
:
document
.
doc_form
,
'has_more'
:
has_more
,
'limit'
:
limit
,
'total'
:
total
...
...
@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times
redis_client
.
setex
(
indexing_cache_key
,
600
,
1
)
add
_segment_to_index_task
.
delay
(
segment
.
id
)
enable
_segment_to_index_task
.
delay
(
segment
.
id
)
return
{
'result'
:
'success'
},
200
elif
action
==
"disable"
:
...
...
@@ -202,7 +204,89 @@ class DatasetDocumentSegmentApi(Resource):
raise
InvalidActionError
()
class
DatasetDocumentSegmentAddApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
post
(
self
,
dataset_id
,
document_id
):
# check dataset
dataset_id
=
str
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
if
not
dataset
:
raise
NotFound
(
'Dataset not found.'
)
# check document
document_id
=
str
(
document_id
)
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
if
not
document
:
raise
NotFound
(
'Document not found.'
)
# The role of the current user in the ta table must be admin or owner
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
raise
Forbidden
()
try
:
DatasetService
.
check_dataset_permission
(
dataset
,
current_user
)
except
services
.
errors
.
account
.
NoPermissionError
as
e
:
raise
Forbidden
(
str
(
e
))
# validate args
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'content'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'answer'
,
type
=
str
,
required
=
False
,
nullable
=
True
,
location
=
'json'
)
args
=
parser
.
parse_args
()
SegmentService
.
segment_create_args_validate
(
args
,
document
)
segment
=
SegmentService
.
create_segment
(
args
,
document
)
return
{
'data'
:
marshal
(
segment
,
segment_fields
),
'doc_form'
:
document
.
doc_form
},
200
class
DatasetDocumentSegmentUpdateApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
patch
(
self
,
dataset_id
,
document_id
,
segment_id
):
# check dataset
dataset_id
=
str
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
if
not
dataset
:
raise
NotFound
(
'Dataset not found.'
)
# check document
document_id
=
str
(
document_id
)
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
if
not
document
:
raise
NotFound
(
'Document not found.'
)
# check segment
segment_id
=
str
(
segment_id
)
segment
=
DocumentSegment
.
query
.
filter
(
DocumentSegment
.
id
==
str
(
segment_id
),
DocumentSegment
.
tenant_id
==
current_user
.
current_tenant_id
)
.
first
()
if
not
segment
:
raise
NotFound
(
'Segment not found.'
)
# The role of the current user in the ta table must be admin or owner
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
raise
Forbidden
()
try
:
DatasetService
.
check_dataset_permission
(
dataset
,
current_user
)
except
services
.
errors
.
account
.
NoPermissionError
as
e
:
raise
Forbidden
(
str
(
e
))
# validate args
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'content'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'answer'
,
type
=
str
,
required
=
False
,
nullable
=
True
,
location
=
'json'
)
args
=
parser
.
parse_args
()
SegmentService
.
segment_create_args_validate
(
args
,
document
)
segment
=
SegmentService
.
update_segment
(
args
,
segment
,
document
)
return
{
'data'
:
marshal
(
segment
,
segment_fields
),
'doc_form'
:
document
.
doc_form
},
200
api
.
add_resource
(
DatasetDocumentSegmentListApi
,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments'
)
api
.
add_resource
(
DatasetDocumentSegmentApi
,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>'
)
api
.
add_resource
(
DatasetDocumentSegmentAddApi
,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment'
)
api
.
add_resource
(
DatasetDocumentSegmentUpdateApi
,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>'
)
api/controllers/console/datasets/hit_testing.py
View file @
b615aa79
...
...
@@ -28,6 +28,7 @@ segment_fields = {
'position'
:
fields
.
Integer
,
'document_id'
:
fields
.
String
,
'content'
:
fields
.
String
,
'answer'
:
fields
.
String
,
'word_count'
:
fields
.
Integer
,
'tokens'
:
fields
.
Integer
,
'keywords'
:
fields
.
List
(
fields
.
String
),
...
...
api/core/data_loader/loader/excel.py
View file @
b615aa79
...
...
@@ -39,7 +39,7 @@ class ExcelLoader(BaseLoader):
row_dict
=
dict
(
zip
(
keys
,
list
(
map
(
str
,
row
))))
row_dict
=
{
k
:
v
for
k
,
v
in
row_dict
.
items
()
if
v
}
item
=
''
.
join
(
f
'{k}:{v}
\n
'
for
k
,
v
in
row_dict
.
items
())
document
=
Document
(
page_content
=
item
)
document
=
Document
(
page_content
=
item
,
metadata
=
{
'source'
:
self
.
_file_path
}
)
data
.
append
(
document
)
return
data
api/core/docstore/dataset_docstore.py
View file @
b615aa79
...
...
@@ -68,7 +68,7 @@ class DatesetDocumentStore:
self
,
docs
:
Sequence
[
Document
],
allow_update
:
bool
=
True
)
->
None
:
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
DocumentSegment
.
document
==
self
.
_document_id
DocumentSegment
.
document
_id
==
self
.
_document_id
)
.
scalar
()
if
max_position
is
None
:
...
...
@@ -105,9 +105,14 @@ class DatesetDocumentStore:
tokens
=
tokens
,
created_by
=
self
.
_user_id
,
)
if
'answer'
in
doc
.
metadata
and
doc
.
metadata
[
'answer'
]:
segment_document
.
answer
=
doc
.
metadata
.
pop
(
'answer'
,
''
)
db
.
session
.
add
(
segment_document
)
else
:
segment_document
.
content
=
doc
.
page_content
if
'answer'
in
doc
.
metadata
and
doc
.
metadata
[
'answer'
]:
segment_document
.
answer
=
doc
.
metadata
.
pop
(
'answer'
,
''
)
segment_document
.
index_node_hash
=
doc
.
metadata
[
'doc_hash'
]
segment_document
.
word_count
=
len
(
doc
.
page_content
)
segment_document
.
tokens
=
tokens
...
...
api/core/generator/llm_generator.py
View file @
b615aa79
...
...
@@ -2,7 +2,7 @@ import logging
from
langchain
import
PromptTemplate
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.schema
import
HumanMessage
,
OutputParserException
,
BaseMessage
from
langchain.schema
import
HumanMessage
,
OutputParserException
,
BaseMessage
,
SystemMessage
from
core.constant
import
llm_constant
from
core.llm.llm_builder
import
LLMBuilder
...
...
@@ -12,8 +12,8 @@ from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorO
from
core.prompt.output_parser.suggested_questions_after_answer
import
SuggestedQuestionsAfterAnswerOutputParser
from
core.prompt.prompt_template
import
JinjaPromptTemplate
,
OutLinePromptTemplate
from
core.prompt.prompts
import
CONVERSATION_TITLE_PROMPT
,
CONVERSATION_SUMMARY_PROMPT
,
INTRODUCTION_GENERATE_PROMPT
from
core.prompt.prompts
import
CONVERSATION_TITLE_PROMPT
,
CONVERSATION_SUMMARY_PROMPT
,
INTRODUCTION_GENERATE_PROMPT
,
\
GENERATOR_QA_PROMPT
# gpt-3.5-turbo works not well
generate_base_model
=
'text-davinci-003'
...
...
@@ -31,7 +31,8 @@ class LLMGenerator:
llm
:
StreamableOpenAI
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
max_tokens
=
50
max_tokens
=
50
,
timeout
=
600
)
if
isinstance
(
llm
,
BaseChatModel
):
...
...
@@ -185,3 +186,20 @@ class LLMGenerator:
}
return
rule_config
@
classmethod
def
generate_qa_document
(
cls
,
tenant_id
:
str
,
query
):
prompt
=
GENERATOR_QA_PROMPT
llm
:
StreamableOpenAI
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
max_tokens
=
2000
)
if
isinstance
(
llm
,
BaseChatModel
):
prompt
=
[
SystemMessage
(
content
=
prompt
),
HumanMessage
(
content
=
query
)]
response
=
llm
.
generate
([
prompt
])
answer
=
response
.
generations
[
0
][
0
]
.
text
total_token
=
response
.
llm_output
[
'token_usage'
][
'total_tokens'
]
return
answer
.
strip
()
api/core/index/vector_index/test-embedding.py
0 → 100644
View file @
b615aa79
import
numpy
as
np
import
sklearn.decomposition
import
pickle
import
time
# Apply 'Algorithm 1' to the ada-002 embeddings to make them isotropic, taken from the paper:
# ALL-BUT-THE-TOP: SIMPLE AND EFFECTIVE POST- PROCESSING FOR WORD REPRESENTATIONS
# Jiaqi Mu, Pramod Viswanath
# This uses Principal Component Analysis (PCA) to 'evenly distribute' the embedding vectors (make them isotropic)
# For more information on PCA, see https://jamesmccaffrey.wordpress.com/2021/07/16/computing-pca-using-numpy-without-scikit/
# get the file pointer of the pickle containing the embeddings
fp
=
open
(
'/path/to/your/data/Embedding-Latest.pkl'
,
'rb'
)
# the embedding data here is a dict consisting of key / value pairs
# the key is the hash of the message (SHA3-256), the value is the embedding from ada-002 (array of dimension 1536)
# the hash can be used to lookup the orignal text in a database
E
=
pickle
.
load
(
fp
)
# load the data into memory
# seperate the keys (hashes) and values (embeddings) into seperate vectors
K
=
list
(
E
.
keys
())
# vector of all the hash values
X
=
np
.
array
(
list
(
E
.
values
()))
# vector of all the embeddings, converted to numpy arrays
# list the total number of embeddings
# this can be truncated if there are too many embeddings to do PCA on
print
(
f
"Total number of embeddings: {len(X)}"
)
# get dimension of embeddings, used later
Dim
=
len
(
X
[
0
])
# flash out the first few embeddings
print
(
"First two embeddings are: "
)
print
(
X
[
0
])
print
(
f
"First embedding length: {len(X[0])}"
)
print
(
X
[
1
])
print
(
f
"Second embedding length: {len(X[1])}"
)
# compute the mean of all the embeddings, and flash the result
mu
=
np
.
mean
(
X
,
axis
=
0
)
# same as mu in paper
print
(
f
"Mean embedding vector: {mu}"
)
print
(
f
"Mean embedding vector length: {len(mu)}"
)
# subtract the mean vector from each embedding vector ... vectorized in numpy
X_tilde
=
X
-
mu
# same as v_tilde(w) in paper
# do the heavy lifting of extracting the principal components
# note that this is a function of the embeddings you currently have here, and this set may grow over time
# therefore the PCA basis vectors may change over time, and your final isotropic embeddings may drift over time
# but the drift should stabilize after you have extracted enough embedding data to characterize the nature of the embedding engine
print
(
f
"Performing PCA on the normalized embeddings ..."
)
pca
=
sklearn
.
decomposition
.
PCA
()
# new object
TICK
=
time
.
time
()
# start timer
pca
.
fit
(
X_tilde
)
# do the heavy lifting!
TOCK
=
time
.
time
()
# end timer
DELTA
=
TOCK
-
TICK
print
(
f
"PCA finished in {DELTA} seconds ..."
)
# dimensional reduction stage (the only hyperparameter)
# pick max dimension of PCA components to express embddings
# in general this is some integer less than or equal to the dimension of your embeddings
# it could be set as a high percentile, say 95th percentile of pca.explained_variance_ratio_
# but just hardcoding a constant here
D
=
15
# hyperparameter on dimension (out of 1536 for ada-002), paper recommeds D = Dim/100
# form the set of v_prime(w), which is the final embedding
# this could be vectorized in numpy to speed it up, but coding it directly here in a double for-loop to avoid errors and to be transparent
E_prime
=
dict
()
# output dict of the new embeddings
N
=
len
(
X_tilde
)
N10
=
round
(
N
/
10
)
U
=
pca
.
components_
# set of PCA basis vectors, sorted by most significant to least significant
print
(
f
"Shape of full set of PCA componenents {U.shape}"
)
U
=
U
[
0
:
D
,:]
# take the top D dimensions (or take them all if D is the size of the embedding vector)
print
(
f
"Shape of downselected PCA componenents {U.shape}"
)
for
ii
in
range
(
N
):
v_tilde
=
X_tilde
[
ii
]
v
=
X
[
ii
]
v_projection
=
np
.
zeros
(
Dim
)
# start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for
jj
in
range
(
D
):
u_jj
=
U
[
jj
,:]
# vector
v_jj
=
np
.
dot
(
u_jj
,
v
)
# scaler
v_projection
+=
v_jj
*
u_jj
# vector
v_prime
=
v_tilde
-
v_projection
# final embedding vector
v_prime
=
v_prime
/
np
.
linalg
.
norm
(
v_prime
)
# create unit vector
E_prime
[
K
[
ii
]]
=
v_prime
if
(
ii
%
N10
==
0
)
or
(
ii
==
N
-
1
):
print
(
f
"Finished with {ii+1} embeddings out of {N} ({round(100*ii/N)}
%
done)"
)
# save as new pickle
print
(
"Saving new pickle ..."
)
embeddingName
=
'/path/to/your/data/Embedding-Latest-Isotropic.pkl'
with
open
(
embeddingName
,
'wb'
)
as
f
:
# Python 3: open(..., 'wb')
pickle
.
dump
([
E_prime
,
mu
,
U
],
f
)
print
(
embeddingName
)
print
(
"Done!"
)
# When working with live data with a new embedding from ada-002, be sure to tranform it first with this function before comparing it
#
def
projectEmbedding
(
v
,
mu
,
U
):
v
=
np
.
array
(
v
)
v_tilde
=
v
-
mu
v_projection
=
np
.
zeros
(
len
(
v
))
# start to build the projection
# project the original embedding onto the PCA basis vectors, use only first D dimensions
for
u
in
U
:
v_jj
=
np
.
dot
(
u
,
v
)
# scaler
v_projection
+=
v_jj
*
u
# vector
v_prime
=
v_tilde
-
v_projection
# final embedding vector
v_prime
=
v_prime
/
np
.
linalg
.
norm
(
v_prime
)
# create unit vector
return
v_prime
\ No newline at end of file
api/core/indexing_runner.py
View file @
b615aa79
...
...
@@ -6,6 +6,7 @@ import time
import
uuid
from
typing
import
Optional
,
List
,
cast
import
openai
from
flask
import
current_app
from
flask_login
import
current_user
from
langchain.embeddings
import
OpenAIEmbeddings
...
...
@@ -16,6 +17,7 @@ from core.data_loader.file_extractor import FileExtractor
from
core.data_loader.loader.notion
import
NotionLoader
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.generator.llm_generator
import
LLMGenerator
from
core.index.index
import
IndexBuilder
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
...
...
@@ -70,7 +72,13 @@ class IndexingRunner:
dataset_document
=
dataset_document
,
processing_rule
=
processing_rule
)
# new_documents = []
# for document in documents:
# response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content)
# document_qa_list = self.format_split_text(response)
# for result in document_qa_list:
# document = Document(page_content=result['question'], metadata={'source': result['answer']})
# new_documents.append(document)
# build index
self
.
_build_index
(
dataset
=
dataset
,
...
...
@@ -91,6 +99,22 @@ class IndexingRunner:
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
format_split_text
(
self
,
text
):
regex
=
r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
matches
=
re
.
findall
(
regex
,
text
,
re
.
MULTILINE
)
result
=
[]
for
match
in
matches
:
q
=
match
[
0
]
a
=
match
[
1
]
if
q
and
a
:
result
.
append
({
"question"
:
q
,
"answer"
:
re
.
sub
(
r"\n\s*"
,
"
\n
"
,
a
.
strip
())
})
return
result
def
run_in_splitting_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is splitting."""
try
:
...
...
@@ -225,7 +249,7 @@ class IndexingRunner:
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
documents
=
self
.
_split_to_documents
(
documents
=
self
.
_split_to_documents
_for_estimate
(
text_docs
=
text_docs
,
splitter
=
splitter
,
processing_rule
=
processing_rule
...
...
@@ -285,7 +309,7 @@ class IndexingRunner:
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
documents
=
self
.
_split_to_documents
(
documents
=
self
.
_split_to_documents
_for_estimate
(
text_docs
=
documents
,
splitter
=
splitter
,
processing_rule
=
processing_rule
...
...
@@ -391,7 +415,9 @@ class IndexingRunner:
documents
=
self
.
_split_to_documents
(
text_docs
=
text_docs
,
splitter
=
splitter
,
processing_rule
=
processing_rule
processing_rule
=
processing_rule
,
tenant_id
=
dataset
.
tenant_id
,
document_form
=
dataset_document
.
doc_form
)
# save node to document segment
...
...
@@ -428,7 +454,7 @@ class IndexingRunner:
return
documents
def
_split_to_documents
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitter
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Document
]:
processing_rule
:
DatasetProcessRule
,
tenant_id
:
str
,
document_form
:
str
)
->
List
[
Document
]:
"""
Split the text documents into nodes.
"""
...
...
@@ -445,7 +471,52 @@ class IndexingRunner:
for
document
in
documents
:
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
continue
if
document_form
==
'text_model'
:
# text model document
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
document
.
metadata
[
'doc_id'
]
=
doc_id
document
.
metadata
[
'doc_hash'
]
=
hash
split_documents
.
append
(
document
)
elif
document_form
==
'qa_model'
:
# qa model document
response
=
LLMGenerator
.
generate_qa_document
(
tenant_id
,
document
.
page_content
)
document_qa_list
=
self
.
format_split_text
(
response
)
qa_documents
=
[]
for
result
in
document_qa_list
:
qa_document
=
Document
(
page_content
=
result
[
'question'
],
metadata
=
document
.
metadata
.
copy
())
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
result
[
'question'
])
qa_document
.
metadata
[
'answer'
]
=
result
[
'answer'
]
qa_document
.
metadata
[
'doc_id'
]
=
doc_id
qa_document
.
metadata
[
'doc_hash'
]
=
hash
qa_documents
.
append
(
qa_document
)
split_documents
.
extend
(
qa_documents
)
all_documents
.
extend
(
split_documents
)
return
all_documents
def
_split_to_documents_for_estimate
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitter
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Document
]:
"""
Split the text documents into nodes.
"""
all_documents
=
[]
for
text_doc
in
text_docs
:
# document clean
document_text
=
self
.
_document_clean
(
text_doc
.
page_content
,
processing_rule
)
text_doc
.
page_content
=
document_text
# parse document to nodes
documents
=
splitter
.
split_documents
([
text_doc
])
split_documents
=
[]
for
document
in
documents
:
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
continue
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
...
...
@@ -487,6 +558,23 @@ class IndexingRunner:
return
text
def
format_split_text
(
self
,
text
):
regex
=
r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
# 匹配Q和A的正则表达式
matches
=
re
.
findall
(
regex
,
text
,
re
.
MULTILINE
)
# 获取所有匹配到的结果
result
=
[]
# 存储最终的结果
for
match
in
matches
:
q
=
match
[
0
]
a
=
match
[
1
]
if
q
and
a
:
# 如果Q和A都存在,就将其添加到结果中
result
.
append
({
"question"
:
q
,
"answer"
:
re
.
sub
(
r"\n\s*"
,
"
\n
"
,
a
.
strip
())
})
return
result
def
_build_index
(
self
,
dataset
:
Dataset
,
dataset_document
:
DatasetDocument
,
documents
:
List
[
Document
])
->
None
:
"""
Build the index for the document.
...
...
api/core/prompt/prompts.py
View file @
b615aa79
...
...
@@ -43,6 +43,17 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[
\"
question1
\"
,
\"
question2
\"
,
\"
question3
\"
]
\n
"
)
GENERATOR_QA_PROMPT
=
(
"You are the questioner.
\n
"
"The user will send a long text.
\n
Please think step by step."
'Step 1: Understand and summarize the main content of this text.
\n
'
'Step 2: What key information or concepts are mentioned in this text?
\n
'
'Step 3: Decompose or combine multiple pieces of information and concepts.
\n
'
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
'The questions should be clear and detailed, and the answers should be detailed and complete.
\n
'
"Answer in the following format: Q1:
\n
A1:
\n
Q2:
\n
A2:...
\n
"
)
RULE_CONFIG_GENERATE_TEMPLATE
=
"""Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select
\
the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement.
...
...
api/core/tool/dataset_index_tool.py
0 → 100644
View file @
b615aa79
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.tools
import
BaseTool
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
,
DocumentSegment
class
DatasetTool
(
BaseTool
):
"""Tool for querying a Dataset."""
dataset
:
Dataset
k
:
int
=
2
def
_run
(
self
,
tool_input
:
str
)
->
str
:
if
self
.
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
kw_table_index
=
KeywordTableIndex
(
dataset
=
self
.
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
5
)
)
documents
=
kw_table_index
.
search
(
tool_input
,
search_kwargs
=
{
'k'
:
self
.
k
})
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
else
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
,
'text-embedding-ada-002'
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
vector_index
.
search
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
self
.
k
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
document_context_list
=
[]
index_node_ids
=
[
document
.
metadata
[
'doc_id'
]
for
document
in
documents
]
segments
=
DocumentSegment
.
query
.
filter
(
DocumentSegment
.
completed_at
.
isnot
(
None
),
DocumentSegment
.
status
==
'completed'
,
DocumentSegment
.
enabled
==
True
,
DocumentSegment
.
index_node_id
.
in_
(
index_node_ids
)
)
.
all
()
if
segments
:
for
segment
in
segments
:
if
segment
.
answer
:
document_context_list
.
append
(
segment
.
answer
)
else
:
document_context_list
.
append
(
segment
.
content
)
return
str
(
"
\n
"
.
join
(
document_context_list
))
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
,
'text-embedding-ada-002'
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
await
vector_index
.
asearch
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
10
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
api/migrations/versions/8d2d099ceb74_add_qa_model_support.py
0 → 100644
View file @
b615aa79
"""add_qa_model_support
Revision ID: 8d2d099ceb74
Revises: a5b56fb053ef
Create Date: 2023-07-18 15:25:15.293438
"""
from
alembic
import
op
import
sqlalchemy
as
sa
from
sqlalchemy.dialects
import
postgresql
# revision identifiers, used by Alembic.
revision
=
'8d2d099ceb74'
down_revision
=
'a5b56fb053ef'
branch_labels
=
None
depends_on
=
None
def
upgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'document_segments'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'answer'
,
sa
.
Text
(),
nullable
=
True
))
batch_op
.
add_column
(
sa
.
Column
(
'updated_by'
,
postgresql
.
UUID
(),
nullable
=
True
))
batch_op
.
add_column
(
sa
.
Column
(
'updated_at'
,
sa
.
DateTime
(),
server_default
=
sa
.
text
(
'CURRENT_TIMESTAMP(0)'
),
nullable
=
False
))
with
op
.
batch_alter_table
(
'documents'
,
schema
=
None
)
as
batch_op
:
batch_op
.
add_column
(
sa
.
Column
(
'doc_form'
,
sa
.
String
(
length
=
255
),
server_default
=
sa
.
text
(
"'text_model'::character varying"
),
nullable
=
False
))
# ### end Alembic commands ###
def
downgrade
():
# ### commands auto generated by Alembic - please adjust! ###
with
op
.
batch_alter_table
(
'documents'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'doc_form'
)
with
op
.
batch_alter_table
(
'document_segments'
,
schema
=
None
)
as
batch_op
:
batch_op
.
drop_column
(
'updated_at'
)
batch_op
.
drop_column
(
'updated_by'
)
batch_op
.
drop_column
(
'answer'
)
# ### end Alembic commands ###
api/models/dataset.py
View file @
b615aa79
...
...
@@ -206,6 +206,8 @@ class Document(db.Model):
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
doc_type
=
db
.
Column
(
db
.
String
(
40
),
nullable
=
True
)
doc_metadata
=
db
.
Column
(
db
.
JSON
,
nullable
=
True
)
doc_form
=
db
.
Column
(
db
.
String
(
255
),
nullable
=
False
,
server_default
=
db
.
text
(
"'text_model'::character varying"
))
DATA_SOURCES
=
[
'upload_file'
,
'notion_import'
]
...
...
@@ -308,6 +310,7 @@ class DocumentSegment(db.Model):
document_id
=
db
.
Column
(
UUID
,
nullable
=
False
)
position
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
)
content
=
db
.
Column
(
db
.
Text
,
nullable
=
False
)
answer
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
word_count
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
)
tokens
=
db
.
Column
(
db
.
Integer
,
nullable
=
False
)
...
...
@@ -327,6 +330,9 @@ class DocumentSegment(db.Model):
created_by
=
db
.
Column
(
UUID
,
nullable
=
False
)
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_by
=
db
.
Column
(
UUID
,
nullable
=
True
)
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
indexing_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
True
)
completed_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
True
)
error
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
...
...
@@ -442,4 +448,4 @@ class Embedding(db.Model):
self
.
embedding
=
pickle
.
dumps
(
embedding_data
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
def
get_embedding
(
self
)
->
list
[
float
]:
return
pickle
.
loads
(
self
.
embedding
)
return
pickle
.
loads
(
self
.
embedding
)
\ No newline at end of file
api/services/completion_service.py
View file @
b615aa79
...
...
@@ -199,6 +199,7 @@ class CompletionService:
conversation
=
db
.
session
.
query
(
Conversation
)
.
filter_by
(
id
=
conversation
.
id
)
.
first
()
# run
Completion
.
generate
(
task_id
=
generate_task_id
,
app
=
app_model
,
...
...
api/services/dataset_service.py
View file @
b615aa79
...
...
@@ -3,16 +3,20 @@ import logging
import
datetime
import
time
import
random
import
uuid
from
typing
import
Optional
,
List
from
flask
import
current_app
from
sqlalchemy
import
func
from
core.llm.token_calculator
import
TokenCalculator
from
extensions.ext_redis
import
redis_client
from
flask_login
import
current_user
from
events.dataset_event
import
dataset_was_deleted
from
events.document_event
import
document_was_deleted
from
extensions.ext_database
import
db
from
libs
import
helper
from
models.account
import
Account
from
models.dataset
import
Dataset
,
Document
,
DatasetQuery
,
DatasetProcessRule
,
AppDatasetJoin
,
DocumentSegment
from
models.model
import
UploadFile
...
...
@@ -25,6 +29,9 @@ from tasks.clean_notion_document_task import clean_notion_document_task
from
tasks.deal_dataset_vector_index_task
import
deal_dataset_vector_index_task
from
tasks.document_indexing_task
import
document_indexing_task
from
tasks.document_indexing_update_task
import
document_indexing_update_task
from
tasks.create_segment_to_index_task
import
create_segment_to_index_task
from
tasks.update_segment_index_task
import
update_segment_index_task
class
DatasetService
:
...
...
@@ -308,6 +315,7 @@ class DocumentService:
)
.
all
()
return
documents
@
staticmethod
def
get_document_file_detail
(
file_id
:
str
):
file_detail
=
db
.
session
.
query
(
UploadFile
)
.
\
...
...
@@ -440,6 +448,7 @@ class DocumentService:
}
document
=
DocumentService
.
save_document
(
dataset
,
dataset_process_rule
.
id
,
document_data
[
"data_source"
][
"type"
],
document_data
[
"doc_form"
],
data_source_info
,
created_from
,
position
,
account
,
file_name
,
batch
)
db
.
session
.
add
(
document
)
...
...
@@ -484,6 +493,7 @@ class DocumentService:
}
document
=
DocumentService
.
save_document
(
dataset
,
dataset_process_rule
.
id
,
document_data
[
"data_source"
][
"type"
],
document_data
[
"doc_form"
],
data_source_info
,
created_from
,
position
,
account
,
page
[
'page_name'
],
batch
)
# if page['type'] == 'database':
...
...
@@ -514,8 +524,9 @@ class DocumentService:
return
documents
,
batch
@
staticmethod
def
save_document
(
dataset
:
Dataset
,
process_rule_id
:
str
,
data_source_type
:
str
,
data_source_info
:
dict
,
created_from
:
str
,
position
:
int
,
account
:
Account
,
name
:
str
,
batch
:
str
):
def
save_document
(
dataset
:
Dataset
,
process_rule_id
:
str
,
data_source_type
:
str
,
document_form
:
str
,
data_source_info
:
dict
,
created_from
:
str
,
position
:
int
,
account
:
Account
,
name
:
str
,
batch
:
str
):
document
=
Document
(
tenant_id
=
dataset
.
tenant_id
,
dataset_id
=
dataset
.
id
,
...
...
@@ -527,6 +538,7 @@ class DocumentService:
name
=
name
,
created_from
=
created_from
,
created_by
=
account
.
id
,
doc_form
=
document_form
)
return
document
...
...
@@ -618,6 +630,7 @@ class DocumentService:
document
.
splitting_completed_at
=
None
document
.
updated_at
=
datetime
.
datetime
.
utcnow
()
document
.
created_from
=
created_from
document
.
doc_form
=
document_data
[
'doc_form'
]
db
.
session
.
add
(
document
)
db
.
session
.
commit
()
# update document segment
...
...
@@ -667,7 +680,7 @@ class DocumentService:
DocumentService
.
data_source_args_validate
(
args
)
DocumentService
.
process_rule_args_validate
(
args
)
else
:
if
(
'data_source'
not
in
args
and
not
args
[
'data_source'
])
\
if
(
'data_source'
not
in
args
and
not
args
[
'data_source'
])
\
and
(
'process_rule'
not
in
args
and
not
args
[
'process_rule'
]):
raise
ValueError
(
"Data source or Process rule is required"
)
else
:
...
...
@@ -694,10 +707,12 @@ class DocumentService:
raise
ValueError
(
"Data source info is required"
)
if
args
[
'data_source'
][
'type'
]
==
'upload_file'
:
if
'file_info_list'
not
in
args
[
'data_source'
][
'info_list'
]
or
not
args
[
'data_source'
][
'info_list'
][
'file_info_list'
]:
if
'file_info_list'
not
in
args
[
'data_source'
][
'info_list'
]
or
not
args
[
'data_source'
][
'info_list'
][
'file_info_list'
]:
raise
ValueError
(
"File source info is required"
)
if
args
[
'data_source'
][
'type'
]
==
'notion_import'
:
if
'notion_info_list'
not
in
args
[
'data_source'
][
'info_list'
]
or
not
args
[
'data_source'
][
'info_list'
][
'notion_info_list'
]:
if
'notion_info_list'
not
in
args
[
'data_source'
][
'info_list'
]
or
not
args
[
'data_source'
][
'info_list'
][
'notion_info_list'
]:
raise
ValueError
(
"Notion source info is required"
)
@
classmethod
...
...
@@ -843,3 +858,75 @@ class DocumentService:
if
not
isinstance
(
args
[
'process_rule'
][
'rules'
][
'segmentation'
][
'max_tokens'
],
int
):
raise
ValueError
(
"Process rule segmentation max_tokens is invalid"
)
class
SegmentService
:
@
classmethod
def
segment_create_args_validate
(
cls
,
args
:
dict
,
document
:
Document
):
if
document
.
doc_form
==
'qa_model'
:
if
'answer'
not
in
args
or
not
args
[
'answer'
]:
raise
ValueError
(
"Answer is required"
)
@
classmethod
def
create_segment
(
cls
,
args
:
dict
,
document
:
Document
):
content
=
args
[
'content'
]
doc_id
=
str
(
uuid
.
uuid4
())
segment_hash
=
helper
.
generate_text_hash
(
content
)
# calc embedding use tokens
tokens
=
TokenCalculator
.
get_num_tokens
(
'text-embedding-ada-002'
,
content
)
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
)
.
scalar
()
segment_document
=
DocumentSegment
(
tenant_id
=
current_user
.
current_tenant_id
,
dataset_id
=
document
.
dataset_id
,
document_id
=
document
.
id
,
index_node_id
=
doc_id
,
index_node_hash
=
segment_hash
,
position
=
max_position
+
1
if
max_position
else
1
,
content
=
content
,
word_count
=
len
(
content
),
tokens
=
tokens
,
created_by
=
current_user
.
id
)
if
document
.
doc_form
==
'qa_model'
:
segment_document
.
answer
=
args
[
'answer'
]
db
.
session
.
add
(
segment_document
)
db
.
session
.
commit
()
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment_document
.
id
)
redis_client
.
setex
(
indexing_cache_key
,
600
,
1
)
create_segment_to_index_task
.
delay
(
segment_document
.
id
)
return
segment_document
@
classmethod
def
update_segment
(
cls
,
args
:
dict
,
segment
:
DocumentSegment
,
document
:
Document
):
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
cache_result
=
redis_client
.
get
(
indexing_cache_key
)
if
cache_result
is
not
None
:
raise
ValueError
(
"Segment is indexing, please try again later"
)
content
=
args
[
'content'
]
if
segment
.
content
==
content
:
if
document
.
doc_form
==
'qa_model'
:
segment
.
answer
=
args
[
'answer'
]
db
.
session
.
add
(
segment
)
db
.
session
.
commit
()
else
:
segment_hash
=
helper
.
generate_text_hash
(
content
)
# calc embedding use tokens
tokens
=
TokenCalculator
.
get_num_tokens
(
'text-embedding-ada-002'
,
content
)
segment
.
content
=
content
segment
.
index_node_hash
=
segment_hash
segment
.
word_count
=
len
(
content
)
segment
.
tokens
=
tokens
segment
.
status
=
'updating'
segment
.
updated_by
=
current_user
.
id
segment
.
updated_at
=
datetime
.
datetime
.
utcnow
()
if
document
.
doc_form
==
'qa_model'
:
segment
.
answer
=
args
[
'answer'
]
db
.
session
.
add
(
segment
)
db
.
session
.
commit
()
# update segment index task
redis_client
.
setex
(
indexing_cache_key
,
600
,
1
)
update_segment_index_task
.
delay
(
segment
.
id
)
return
segment
api/tasks/create_segment_to_index_task.py
0 → 100644
View file @
b615aa79
import
datetime
import
logging
import
time
import
click
from
celery
import
shared_task
from
langchain.schema
import
Document
from
werkzeug.exceptions
import
NotFound
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
@
shared_task
def
create_segment_to_index_task
(
segment_id
:
str
):
"""
Async create segment to index
:param segment_id:
Usage: create_segment_to_index_task.delay(segment_id)
"""
logging
.
info
(
click
.
style
(
'Start create segment to index: {}'
.
format
(
segment_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
id
==
segment_id
)
.
first
()
if
not
segment
:
raise
NotFound
(
'Segment not found'
)
if
segment
.
status
!=
'waiting'
:
return
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
try
:
# update segment status to indexing
update_params
=
{
DocumentSegment
.
status
:
"indexing"
,
DocumentSegment
.
indexing_at
:
datetime
.
datetime
.
utcnow
()
}
DocumentSegment
.
query
.
filter_by
(
id
=
segment
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
dataset
=
segment
.
dataset
if
not
dataset
:
logging
.
info
(
click
.
style
(
'Segment {} has no dataset, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
dataset_document
=
segment
.
document
if
not
dataset_document
:
logging
.
info
(
click
.
style
(
'Segment {} has no document, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
if
not
dataset_document
.
enabled
or
dataset_document
.
archived
or
dataset_document
.
indexing_status
!=
'completed'
:
logging
.
info
(
click
.
style
(
'Segment {} document status is invalid, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
# save vector index
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
:
index
.
add_texts
([
document
],
duplicate_check
=
True
)
# save keyword index
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
([
document
])
# update segment to completed
update_params
=
{
DocumentSegment
.
status
:
"completed"
,
DocumentSegment
.
completed_at
:
datetime
.
datetime
.
utcnow
()
}
DocumentSegment
.
query
.
filter_by
(
id
=
segment
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment created to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
as
e
:
logging
.
exception
(
"create segment to index failed"
)
segment
.
enabled
=
False
segment
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
segment
.
status
=
'error'
segment
.
error
=
str
(
e
)
db
.
session
.
commit
()
finally
:
redis_client
.
delete
(
indexing_cache_key
)
api/tasks/
add
_segment_to_index_task.py
→
api/tasks/
enable
_segment_to_index_task.py
View file @
b615aa79
...
...
@@ -14,14 +14,14 @@ from models.dataset import DocumentSegment
@
shared_task
def
add
_segment_to_index_task
(
segment_id
:
str
):
def
enable
_segment_to_index_task
(
segment_id
:
str
):
"""
Async
Add
segment to index
Async
enable
segment to index
:param segment_id:
Usage:
add_segment_to_index
.delay(segment_id)
Usage:
enable_segment_to_index_task
.delay(segment_id)
"""
logging
.
info
(
click
.
style
(
'Start
add
segment to index: {}'
.
format
(
segment_id
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Start
enable
segment to index: {}'
.
format
(
segment_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
id
==
segment_id
)
.
first
()
...
...
@@ -71,9 +71,9 @@ def add_segment_to_index_task(segment_id: str):
index
.
add_texts
([
document
])
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment
add
ed to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Segment
enabl
ed to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
as
e
:
logging
.
exception
(
"
add
segment to index failed"
)
logging
.
exception
(
"
enable
segment to index failed"
)
segment
.
enabled
=
False
segment
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
segment
.
status
=
'error'
...
...
api/tasks/generate_test_task.py
0 → 100644
View file @
b615aa79
import
logging
import
time
import
click
import
requests
from
celery
import
shared_task
from
core.generator.llm_generator
import
LLMGenerator
@
shared_task
def
generate_test_task
():
logging
.
info
(
click
.
style
(
'Start generate test'
,
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
try
:
#res = requests.post('https://api.openai.com/v1/chat/completions')
answer
=
LLMGenerator
.
generate_conversation_name
(
'84b2202c-c359-46b7-a810-bce50feaa4d1'
,
'avb'
,
'ccc'
)
print
(
f
'answer: {answer}'
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Conversation test, latency: {}'
.
format
(
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
:
logging
.
exception
(
"generate test failed"
)
api/tasks/update_segment_index_task.py
0 → 100644
View file @
b615aa79
import
datetime
import
logging
import
time
import
click
from
celery
import
shared_task
from
langchain.schema
import
Document
from
werkzeug.exceptions
import
NotFound
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
@
shared_task
def
update_segment_index_task
(
segment_id
:
str
):
"""
Async update segment index
:param segment_id:
Usage: update_segment_index_task.delay(segment_id)
"""
logging
.
info
(
click
.
style
(
'Start update segment index: {}'
.
format
(
segment_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
id
==
segment_id
)
.
first
()
if
not
segment
:
raise
NotFound
(
'Segment not found'
)
if
segment
.
status
!=
'updating'
:
return
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
try
:
dataset
=
segment
.
dataset
if
not
dataset
:
logging
.
info
(
click
.
style
(
'Segment {} has no dataset, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
dataset_document
=
segment
.
document
if
not
dataset_document
:
logging
.
info
(
click
.
style
(
'Segment {} has no document, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
if
not
dataset_document
.
enabled
or
dataset_document
.
archived
or
dataset_document
.
indexing_status
!=
'completed'
:
logging
.
info
(
click
.
style
(
'Segment {} document status is invalid, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
# update segment status to indexing
update_params
=
{
DocumentSegment
.
status
:
"indexing"
,
DocumentSegment
.
indexing_at
:
datetime
.
datetime
.
utcnow
()
}
DocumentSegment
.
query
.
filter_by
(
id
=
segment
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
kw_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
if
vector_index
:
vector_index
.
delete_by_ids
([
segment
.
index_node_id
])
# delete from keyword index
kw_index
.
delete_by_ids
([
segment
.
index_node_id
])
# add new index
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
# save vector index
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
:
index
.
add_texts
([
document
],
duplicate_check
=
True
)
# save keyword index
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
([
document
])
# update segment to completed
update_params
=
{
DocumentSegment
.
status
:
"completed"
,
DocumentSegment
.
completed_at
:
datetime
.
datetime
.
utcnow
()
}
DocumentSegment
.
query
.
filter_by
(
id
=
segment
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment update index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
as
e
:
logging
.
exception
(
"update segment index failed"
)
segment
.
enabled
=
False
segment
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
segment
.
status
=
'error'
segment
.
error
=
str
(
e
)
db
.
session
.
commit
()
finally
:
redis_client
.
delete
(
indexing_cache_key
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment