Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
2b150ffd
Commit
2b150ffd
authored
Jul 18, 2023
by
jyong
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add update segment and support qa segment
parent
018511b8
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
260 additions
and
69 deletions
+260
-69
datasets_document.py
api/controllers/console/datasets/datasets_document.py
+2
-0
datasets_segments.py
api/controllers/console/datasets/datasets_segments.py
+87
-3
dataset_docstore.py
api/core/docstore/dataset_docstore.py
+6
-1
llm_generator.py
api/core/generator/llm_generator.py
+1
-1
indexing_runner.py
api/core/indexing_runner.py
+61
-52
dataset.py
api/models/dataset.py
+3
-0
dataset_service.py
api/services/dataset_service.py
+93
-5
deal_dataset_vector_index_task.py
api/tasks/deal_dataset_vector_index_task.py
+1
-1
enable_segment_to_index_task.py
api/tasks/enable_segment_to_index_task.py
+6
-6
No files found.
api/controllers/console/datasets/datasets_document.py
View file @
2b150ffd
...
@@ -60,6 +60,7 @@ document_fields = {
...
@@ -60,6 +60,7 @@ document_fields = {
'display_status'
:
fields
.
String
,
'display_status'
:
fields
.
String
,
'word_count'
:
fields
.
Integer
,
'word_count'
:
fields
.
Integer
,
'hit_count'
:
fields
.
Integer
,
'hit_count'
:
fields
.
Integer
,
'doc_form'
:
fields
.
String
,
}
}
document_with_segments_fields
=
{
document_with_segments_fields
=
{
...
@@ -86,6 +87,7 @@ document_with_segments_fields = {
...
@@ -86,6 +87,7 @@ document_with_segments_fields = {
'total_segments'
:
fields
.
Integer
'total_segments'
:
fields
.
Integer
}
}
class
DocumentResource
(
Resource
):
class
DocumentResource
(
Resource
):
def
get_document
(
self
,
dataset_id
:
str
,
document_id
:
str
)
->
Document
:
def
get_document
(
self
,
dataset_id
:
str
,
document_id
:
str
)
->
Document
:
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
...
...
api/controllers/console/datasets/datasets_segments.py
View file @
2b150ffd
...
@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client
...
@@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client
from
models.dataset
import
DocumentSegment
from
models.dataset
import
DocumentSegment
from
libs.helper
import
TimestampField
from
libs.helper
import
TimestampField
from
services.dataset_service
import
DatasetService
,
DocumentService
from
services.dataset_service
import
DatasetService
,
DocumentService
,
SegmentService
from
tasks.
add_segment_to_index_task
import
add
_segment_to_index_task
from
tasks.
enable_segment_to_index_task
import
enable
_segment_to_index_task
from
tasks.remove_segment_from_index_task
import
remove_segment_from_index_task
from
tasks.remove_segment_from_index_task
import
remove_segment_from_index_task
segment_fields
=
{
segment_fields
=
{
...
@@ -24,6 +24,7 @@ segment_fields = {
...
@@ -24,6 +24,7 @@ segment_fields = {
'position'
:
fields
.
Integer
,
'position'
:
fields
.
Integer
,
'document_id'
:
fields
.
String
,
'document_id'
:
fields
.
String
,
'content'
:
fields
.
String
,
'content'
:
fields
.
String
,
'answer'
:
fields
.
String
,
'word_count'
:
fields
.
Integer
,
'word_count'
:
fields
.
Integer
,
'tokens'
:
fields
.
Integer
,
'tokens'
:
fields
.
Integer
,
'keywords'
:
fields
.
List
(
fields
.
String
),
'keywords'
:
fields
.
List
(
fields
.
String
),
...
@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource):
...
@@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource):
return
{
return
{
'data'
:
marshal
(
segments
,
segment_fields
),
'data'
:
marshal
(
segments
,
segment_fields
),
'doc_form'
:
document
.
doc_form
,
'has_more'
:
has_more
,
'has_more'
:
has_more
,
'limit'
:
limit
,
'limit'
:
limit
,
'total'
:
total
'total'
:
total
...
@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
...
@@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times
# Set cache to prevent indexing the same segment multiple times
redis_client
.
setex
(
indexing_cache_key
,
600
,
1
)
redis_client
.
setex
(
indexing_cache_key
,
600
,
1
)
add
_segment_to_index_task
.
delay
(
segment
.
id
)
enable
_segment_to_index_task
.
delay
(
segment
.
id
)
return
{
'result'
:
'success'
},
200
return
{
'result'
:
'success'
},
200
elif
action
==
"disable"
:
elif
action
==
"disable"
:
...
@@ -202,7 +204,89 @@ class DatasetDocumentSegmentApi(Resource):
...
@@ -202,7 +204,89 @@ class DatasetDocumentSegmentApi(Resource):
raise
InvalidActionError
()
raise
InvalidActionError
()
class
DatasetDocumentSegmentAddApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
post
(
self
,
dataset_id
,
document_id
):
# check dataset
dataset_id
=
str
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
if
not
dataset
:
raise
NotFound
(
'Dataset not found.'
)
# check document
document_id
=
str
(
document_id
)
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
if
not
document
:
raise
NotFound
(
'Document not found.'
)
# The role of the current user in the ta table must be admin or owner
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
raise
Forbidden
()
try
:
DatasetService
.
check_dataset_permission
(
dataset
,
current_user
)
except
services
.
errors
.
account
.
NoPermissionError
as
e
:
raise
Forbidden
(
str
(
e
))
# validate args
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'content'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'answer'
,
type
=
str
,
required
=
False
,
nullable
=
True
,
location
=
'json'
)
args
=
parser
.
parse_args
()
SegmentService
.
segment_create_args_validate
(
args
,
document
)
segment
=
SegmentService
.
create_segment
(
args
,
document
)
return
{
'data'
:
marshal
(
segment
,
segment_fields
),
'doc_form'
:
document
.
doc_form
},
200
class
DatasetDocumentSegmentUpdateApi
(
Resource
):
@
setup_required
@
login_required
@
account_initialization_required
def
patch
(
self
,
dataset_id
,
document_id
,
segment_id
):
# check dataset
dataset_id
=
str
(
dataset_id
)
dataset
=
DatasetService
.
get_dataset
(
dataset_id
)
if
not
dataset
:
raise
NotFound
(
'Dataset not found.'
)
# check document
document_id
=
str
(
document_id
)
document
=
DocumentService
.
get_document
(
dataset_id
,
document_id
)
if
not
document
:
raise
NotFound
(
'Document not found.'
)
# check segment
segment_id
=
str
(
segment_id
)
segment
=
DocumentSegment
.
query
.
filter
(
DocumentSegment
.
id
==
str
(
segment_id
),
DocumentSegment
.
tenant_id
==
current_user
.
current_tenant_id
)
.
first
()
if
not
segment
:
raise
NotFound
(
'Segment not found.'
)
# The role of the current user in the ta table must be admin or owner
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
raise
Forbidden
()
try
:
DatasetService
.
check_dataset_permission
(
dataset
,
current_user
)
except
services
.
errors
.
account
.
NoPermissionError
as
e
:
raise
Forbidden
(
str
(
e
))
# validate args
parser
=
reqparse
.
RequestParser
()
parser
.
add_argument
(
'content'
,
type
=
str
,
required
=
True
,
nullable
=
False
,
location
=
'json'
)
parser
.
add_argument
(
'answer'
,
type
=
str
,
required
=
False
,
nullable
=
True
,
location
=
'json'
)
args
=
parser
.
parse_args
()
SegmentService
.
segment_create_args_validate
(
args
,
document
)
segment
=
SegmentService
.
update_segment
(
args
,
segment
,
document
)
return
{
'data'
:
marshal
(
segment
,
segment_fields
),
'doc_form'
:
document
.
doc_form
},
200
api
.
add_resource
(
DatasetDocumentSegmentListApi
,
api
.
add_resource
(
DatasetDocumentSegmentListApi
,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments'
)
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments'
)
api
.
add_resource
(
DatasetDocumentSegmentApi
,
api
.
add_resource
(
DatasetDocumentSegmentApi
,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>'
)
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>'
)
api
.
add_resource
(
DatasetDocumentSegmentApi
,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>'
)
api
.
add_resource
(
DatasetDocumentSegmentApi
,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>'
)
api/core/docstore/dataset_docstore.py
View file @
2b150ffd
...
@@ -68,7 +68,7 @@ class DatesetDocumentStore:
...
@@ -68,7 +68,7 @@ class DatesetDocumentStore:
self
,
docs
:
Sequence
[
Document
],
allow_update
:
bool
=
True
self
,
docs
:
Sequence
[
Document
],
allow_update
:
bool
=
True
)
->
None
:
)
->
None
:
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
DocumentSegment
.
document
==
self
.
_document_id
DocumentSegment
.
document
_id
==
self
.
_document_id
)
.
scalar
()
)
.
scalar
()
if
max_position
is
None
:
if
max_position
is
None
:
...
@@ -105,9 +105,14 @@ class DatesetDocumentStore:
...
@@ -105,9 +105,14 @@ class DatesetDocumentStore:
tokens
=
tokens
,
tokens
=
tokens
,
created_by
=
self
.
_user_id
,
created_by
=
self
.
_user_id
,
)
)
if
'answer'
in
doc
.
metadata
and
doc
.
metadata
[
'answer'
]:
segment_document
.
answer
=
doc
.
metadata
.
pop
(
'answer'
,
''
)
db
.
session
.
add
(
segment_document
)
db
.
session
.
add
(
segment_document
)
else
:
else
:
segment_document
.
content
=
doc
.
page_content
segment_document
.
content
=
doc
.
page_content
if
'answer'
in
doc
.
metadata
and
doc
.
metadata
[
'answer'
]:
segment_document
.
answer
=
doc
.
metadata
.
pop
(
'answer'
,
''
)
segment_document
.
index_node_hash
=
doc
.
metadata
[
'doc_hash'
]
segment_document
.
index_node_hash
=
doc
.
metadata
[
'doc_hash'
]
segment_document
.
word_count
=
len
(
doc
.
page_content
)
segment_document
.
word_count
=
len
(
doc
.
page_content
)
segment_document
.
tokens
=
tokens
segment_document
.
tokens
=
tokens
...
...
api/core/generator/llm_generator.py
View file @
2b150ffd
...
@@ -193,7 +193,7 @@ class LLMGenerator:
...
@@ -193,7 +193,7 @@ class LLMGenerator:
llm
:
StreamableOpenAI
=
LLMBuilder
.
to_llm
(
llm
:
StreamableOpenAI
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
model_name
=
'gpt-3.5-turbo'
,
max_tokens
=
100
max_tokens
=
100
0
)
)
if
isinstance
(
llm
,
BaseChatModel
):
if
isinstance
(
llm
,
BaseChatModel
):
...
...
api/core/indexing_runner.py
View file @
2b150ffd
...
@@ -100,21 +100,21 @@ class IndexingRunner:
...
@@ -100,21 +100,21 @@ class IndexingRunner:
db
.
session
.
commit
()
db
.
session
.
commit
()
def
format_split_text
(
self
,
text
):
def
format_split_text
(
self
,
text
):
regex
=
r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
# 匹配Q和A的正则表达式
regex
=
r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
matches
=
re
.
findall
(
regex
,
text
,
re
.
MULTILINE
)
# 获取所有匹配到的结果
matches
=
re
.
findall
(
regex
,
text
,
re
.
MULTILINE
)
result
=
[]
# 存储最终的结果
result
=
[]
for
match
in
matches
:
for
match
in
matches
:
q
=
match
[
0
]
q
=
match
[
0
]
a
=
match
[
1
]
a
=
match
[
1
]
if
q
and
a
:
if
q
and
a
:
# 如果Q和A都存在,就将其添加到结果中
result
.
append
({
result
.
append
({
"question"
:
q
,
"question"
:
q
,
"answer"
:
re
.
sub
(
r"\n\s*"
,
"
\n
"
,
a
.
strip
())
"answer"
:
re
.
sub
(
r"\n\s*"
,
"
\n
"
,
a
.
strip
())
})
})
return
result
return
result
def
run_in_splitting_status
(
self
,
dataset_document
:
DatasetDocument
):
def
run_in_splitting_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is splitting."""
"""Run the indexing process when the index_status is splitting."""
try
:
try
:
...
@@ -249,11 +249,10 @@ class IndexingRunner:
...
@@ -249,11 +249,10 @@ class IndexingRunner:
splitter
=
self
.
_get_splitter
(
processing_rule
)
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
# split to documents
documents
=
self
.
_split_to_documents
(
documents
=
self
.
_split_to_documents
_for_estimate
(
text_docs
=
text_docs
,
text_docs
=
text_docs
,
splitter
=
splitter
,
splitter
=
splitter
,
processing_rule
=
processing_rule
,
processing_rule
=
processing_rule
tenant_id
=
'84b2202c-c359-46b7-a810-bce50feaa4d1'
)
)
total_segments
+=
len
(
documents
)
total_segments
+=
len
(
documents
)
for
document
in
documents
:
for
document
in
documents
:
...
@@ -310,11 +309,10 @@ class IndexingRunner:
...
@@ -310,11 +309,10 @@ class IndexingRunner:
splitter
=
self
.
_get_splitter
(
processing_rule
)
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
# split to documents
documents
=
self
.
_split_to_documents
(
documents
=
self
.
_split_to_documents
_for_estimate
(
text_docs
=
documents
,
text_docs
=
documents
,
splitter
=
splitter
,
splitter
=
splitter
,
processing_rule
=
processing_rule
,
processing_rule
=
processing_rule
tenant_id
=
'84b2202c-c359-46b7-a810-bce50feaa4d1'
)
)
total_segments
+=
len
(
documents
)
total_segments
+=
len
(
documents
)
for
document
in
documents
:
for
document
in
documents
:
...
@@ -418,7 +416,8 @@ class IndexingRunner:
...
@@ -418,7 +416,8 @@ class IndexingRunner:
text_docs
=
text_docs
,
text_docs
=
text_docs
,
splitter
=
splitter
,
splitter
=
splitter
,
processing_rule
=
processing_rule
,
processing_rule
=
processing_rule
,
tenant_id
=
dataset
.
tenant_id
tenant_id
=
dataset
.
tenant_id
,
document_form
=
dataset_document
.
doc_form
)
)
# save node to document segment
# save node to document segment
...
@@ -455,7 +454,7 @@ class IndexingRunner:
...
@@ -455,7 +454,7 @@ class IndexingRunner:
return
documents
return
documents
def
_split_to_documents
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitter
,
def
_split_to_documents
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitter
,
processing_rule
:
DatasetProcessRule
,
tenant_id
)
->
List
[
Document
]:
processing_rule
:
DatasetProcessRule
,
tenant_id
:
str
,
document_form
:
str
)
->
List
[
Document
]:
"""
"""
Split the text documents into nodes.
Split the text documents into nodes.
"""
"""
...
@@ -472,51 +471,59 @@ class IndexingRunner:
...
@@ -472,51 +471,59 @@ class IndexingRunner:
for
document
in
documents
:
for
document
in
documents
:
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
continue
continue
#
if
document_form
==
'text_model'
:
response
=
LLMGenerator
.
generate_qa_document
(
tenant_id
,
document
.
page_content
)
# text model document
document_qa_list
=
self
.
format_split_text
(
response
)
# CONVERSATION_PROMPT = (
# "你是出题人.\n"
# "用户会发送一段长文本.\n请一步一步思考"
# 'Step1:了解并总结这段文本的主要内容\n'
# 'Step2:这段文本提到了哪些关键信息或概念\n'
# 'Step3:可分解或结合多个信息与概念\n'
# 'Step4:将这些关键信息与概念生成 10 个问题与答案,问题描述清楚并且详细完整,答案详细完整.\n'
# "按格式回答: Q1:\nA1:\nQ2:\nA2:...\n"
# )
# openai.api_key = "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"
# response = openai.ChatCompletion.create(
# model='gpt-3.5-turbo',
# messages=[
# {
# 'role': 'system',
# 'content': CONVERSATION_PROMPT
# },
# {
# 'role': 'user',
# 'content': document.page_content
# }
# ],
# temperature=0,
# stream=False, # this time, we set stream=True
#
# n=1,
# top_p=1,
# frequency_penalty=0,
# presence_penalty=0
# )
# # response = LLMGenerator.generate_qa_document('84b2202c-c359-46b7-a810-bce50feaa4d1', doc.page_content)
# document_qa_list = self.format_split_text(response['choices'][0]['message']['content'])
qa_documents
=
[]
for
result
in
document_qa_list
:
document
=
Document
(
page_content
=
result
[
'question'
],
metadata
=
{
'source'
:
result
[
'answer'
]})
doc_id
=
str
(
uuid
.
uuid4
())
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
document
.
metadata
[
'doc_id'
]
=
doc_id
document
.
metadata
[
'doc_id'
]
=
doc_id
document
.
metadata
[
'doc_hash'
]
=
hash
document
.
metadata
[
'doc_hash'
]
=
hash
qa_documents
.
append
(
document
)
split_documents
.
extend
(
qa_documents
)
split_documents
.
append
(
document
)
elif
document_form
==
'qa_model'
:
# qa model document
response
=
LLMGenerator
.
generate_qa_document
(
tenant_id
,
document
.
page_content
)
document_qa_list
=
self
.
format_split_text
(
response
)
qa_documents
=
[]
for
result
in
document_qa_list
:
qa_document
=
Document
(
page_content
=
result
[
'question'
],
metadata
=
document
.
metadata
)
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
qa_document
.
metadata
[
'answer'
]
=
result
[
'answer'
]
qa_document
.
metadata
[
'doc_id'
]
=
doc_id
qa_document
.
metadata
[
'doc_hash'
]
=
hash
qa_documents
.
append
(
qa_document
)
split_documents
.
extend
(
qa_documents
)
all_documents
.
extend
(
split_documents
)
return
all_documents
def
_split_to_documents_for_estimate
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitter
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Document
]:
"""
Split the text documents into nodes.
"""
all_documents
=
[]
for
text_doc
in
text_docs
:
# document clean
document_text
=
self
.
_document_clean
(
text_doc
.
page_content
,
processing_rule
)
text_doc
.
page_content
=
document_text
# parse document to nodes
documents
=
splitter
.
split_documents
([
text_doc
])
split_documents
=
[]
for
document
in
documents
:
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
continue
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
document
.
metadata
[
'doc_id'
]
=
doc_id
document
.
metadata
[
'doc_hash'
]
=
hash
split_documents
.
append
(
document
)
all_documents
.
extend
(
split_documents
)
all_documents
.
extend
(
split_documents
)
...
@@ -550,6 +557,7 @@ class IndexingRunner:
...
@@ -550,6 +557,7 @@ class IndexingRunner:
text
=
re
.
sub
(
pattern
,
''
,
text
)
text
=
re
.
sub
(
pattern
,
''
,
text
)
return
text
return
text
def
format_split_text
(
self
,
text
):
def
format_split_text
(
self
,
text
):
regex
=
r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
# 匹配Q和A的正则表达式
regex
=
r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
# 匹配Q和A的正则表达式
matches
=
re
.
findall
(
regex
,
text
,
re
.
MULTILINE
)
# 获取所有匹配到的结果
matches
=
re
.
findall
(
regex
,
text
,
re
.
MULTILINE
)
# 获取所有匹配到的结果
...
@@ -566,6 +574,7 @@ class IndexingRunner:
...
@@ -566,6 +574,7 @@ class IndexingRunner:
})
})
return
result
return
result
def
_build_index
(
self
,
dataset
:
Dataset
,
dataset_document
:
DatasetDocument
,
documents
:
List
[
Document
])
->
None
:
def
_build_index
(
self
,
dataset
:
Dataset
,
dataset_document
:
DatasetDocument
,
documents
:
List
[
Document
])
->
None
:
"""
"""
Build the index for the document.
Build the index for the document.
...
...
api/models/dataset.py
View file @
2b150ffd
...
@@ -330,6 +330,9 @@ class DocumentSegment(db.Model):
...
@@ -330,6 +330,9 @@ class DocumentSegment(db.Model):
created_by
=
db
.
Column
(
UUID
,
nullable
=
False
)
created_by
=
db
.
Column
(
UUID
,
nullable
=
False
)
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_by
=
db
.
Column
(
UUID
,
nullable
=
True
)
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
indexing_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
True
)
indexing_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
True
)
completed_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
True
)
completed_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
True
)
error
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
error
=
db
.
Column
(
db
.
Text
,
nullable
=
True
)
...
...
api/services/dataset_service.py
View file @
2b150ffd
...
@@ -3,16 +3,21 @@ import logging
...
@@ -3,16 +3,21 @@ import logging
import
datetime
import
datetime
import
time
import
time
import
random
import
random
import
uuid
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
from
flask
import
current_app
from
flask
import
current_app
from
sqlalchemy
import
func
from
controllers.console.datasets.error
import
InvalidActionError
from
core.llm.token_calculator
import
TokenCalculator
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
flask_login
import
current_user
from
flask_login
import
current_user
from
events.dataset_event
import
dataset_was_deleted
from
events.dataset_event
import
dataset_was_deleted
from
events.document_event
import
document_was_deleted
from
events.document_event
import
document_was_deleted
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
libs
import
helper
from
models.account
import
Account
from
models.account
import
Account
from
models.dataset
import
Dataset
,
Document
,
DatasetQuery
,
DatasetProcessRule
,
AppDatasetJoin
,
DocumentSegment
from
models.dataset
import
Dataset
,
Document
,
DatasetQuery
,
DatasetProcessRule
,
AppDatasetJoin
,
DocumentSegment
from
models.model
import
UploadFile
from
models.model
import
UploadFile
...
@@ -25,6 +30,9 @@ from tasks.clean_notion_document_task import clean_notion_document_task
...
@@ -25,6 +30,9 @@ from tasks.clean_notion_document_task import clean_notion_document_task
from
tasks.deal_dataset_vector_index_task
import
deal_dataset_vector_index_task
from
tasks.deal_dataset_vector_index_task
import
deal_dataset_vector_index_task
from
tasks.document_indexing_task
import
document_indexing_task
from
tasks.document_indexing_task
import
document_indexing_task
from
tasks.document_indexing_update_task
import
document_indexing_update_task
from
tasks.document_indexing_update_task
import
document_indexing_update_task
from
tasks.create_segment_to_index_task
import
create_segment_to_index_task
from
tasks.update_segment_index_task
import
update_segment_index_task
class
DatasetService
:
class
DatasetService
:
...
@@ -308,6 +316,7 @@ class DocumentService:
...
@@ -308,6 +316,7 @@ class DocumentService:
)
.
all
()
)
.
all
()
return
documents
return
documents
@
staticmethod
@
staticmethod
def
get_document_file_detail
(
file_id
:
str
):
def
get_document_file_detail
(
file_id
:
str
):
file_detail
=
db
.
session
.
query
(
UploadFile
)
.
\
file_detail
=
db
.
session
.
query
(
UploadFile
)
.
\
...
@@ -440,6 +449,7 @@ class DocumentService:
...
@@ -440,6 +449,7 @@ class DocumentService:
}
}
document
=
DocumentService
.
save_document
(
dataset
,
dataset_process_rule
.
id
,
document
=
DocumentService
.
save_document
(
dataset
,
dataset_process_rule
.
id
,
document_data
[
"data_source"
][
"type"
],
document_data
[
"data_source"
][
"type"
],
document_data
[
"doc_form"
],
data_source_info
,
created_from
,
position
,
data_source_info
,
created_from
,
position
,
account
,
file_name
,
batch
)
account
,
file_name
,
batch
)
db
.
session
.
add
(
document
)
db
.
session
.
add
(
document
)
...
@@ -484,6 +494,7 @@ class DocumentService:
...
@@ -484,6 +494,7 @@ class DocumentService:
}
}
document
=
DocumentService
.
save_document
(
dataset
,
dataset_process_rule
.
id
,
document
=
DocumentService
.
save_document
(
dataset
,
dataset_process_rule
.
id
,
document_data
[
"data_source"
][
"type"
],
document_data
[
"data_source"
][
"type"
],
document_data
[
"doc_form"
],
data_source_info
,
created_from
,
position
,
data_source_info
,
created_from
,
position
,
account
,
page
[
'page_name'
],
batch
)
account
,
page
[
'page_name'
],
batch
)
# if page['type'] == 'database':
# if page['type'] == 'database':
...
@@ -514,8 +525,9 @@ class DocumentService:
...
@@ -514,8 +525,9 @@ class DocumentService:
return
documents
,
batch
return
documents
,
batch
@
staticmethod
@
staticmethod
def
save_document
(
dataset
:
Dataset
,
process_rule_id
:
str
,
data_source_type
:
str
,
data_source_info
:
dict
,
def
save_document
(
dataset
:
Dataset
,
process_rule_id
:
str
,
data_source_type
:
str
,
document_form
:
str
,
created_from
:
str
,
position
:
int
,
account
:
Account
,
name
:
str
,
batch
:
str
):
data_source_info
:
dict
,
created_from
:
str
,
position
:
int
,
account
:
Account
,
name
:
str
,
batch
:
str
):
document
=
Document
(
document
=
Document
(
tenant_id
=
dataset
.
tenant_id
,
tenant_id
=
dataset
.
tenant_id
,
dataset_id
=
dataset
.
id
,
dataset_id
=
dataset
.
id
,
...
@@ -527,6 +539,7 @@ class DocumentService:
...
@@ -527,6 +539,7 @@ class DocumentService:
name
=
name
,
name
=
name
,
created_from
=
created_from
,
created_from
=
created_from
,
created_by
=
account
.
id
,
created_by
=
account
.
id
,
doc_form
=
document_form
)
)
return
document
return
document
...
@@ -618,6 +631,7 @@ class DocumentService:
...
@@ -618,6 +631,7 @@ class DocumentService:
document
.
splitting_completed_at
=
None
document
.
splitting_completed_at
=
None
document
.
updated_at
=
datetime
.
datetime
.
utcnow
()
document
.
updated_at
=
datetime
.
datetime
.
utcnow
()
document
.
created_from
=
created_from
document
.
created_from
=
created_from
document
.
doc_form
=
document_data
[
'doc_form'
]
db
.
session
.
add
(
document
)
db
.
session
.
add
(
document
)
db
.
session
.
commit
()
db
.
session
.
commit
()
# update document segment
# update document segment
...
@@ -667,7 +681,7 @@ class DocumentService:
...
@@ -667,7 +681,7 @@ class DocumentService:
DocumentService
.
data_source_args_validate
(
args
)
DocumentService
.
data_source_args_validate
(
args
)
DocumentService
.
process_rule_args_validate
(
args
)
DocumentService
.
process_rule_args_validate
(
args
)
else
:
else
:
if
(
'data_source'
not
in
args
and
not
args
[
'data_source'
])
\
if
(
'data_source'
not
in
args
and
not
args
[
'data_source'
])
\
and
(
'process_rule'
not
in
args
and
not
args
[
'process_rule'
]):
and
(
'process_rule'
not
in
args
and
not
args
[
'process_rule'
]):
raise
ValueError
(
"Data source or Process rule is required"
)
raise
ValueError
(
"Data source or Process rule is required"
)
else
:
else
:
...
@@ -694,10 +708,12 @@ class DocumentService:
...
@@ -694,10 +708,12 @@ class DocumentService:
raise
ValueError
(
"Data source info is required"
)
raise
ValueError
(
"Data source info is required"
)
if
args
[
'data_source'
][
'type'
]
==
'upload_file'
:
if
args
[
'data_source'
][
'type'
]
==
'upload_file'
:
if
'file_info_list'
not
in
args
[
'data_source'
][
'info_list'
]
or
not
args
[
'data_source'
][
'info_list'
][
'file_info_list'
]:
if
'file_info_list'
not
in
args
[
'data_source'
][
'info_list'
]
or
not
args
[
'data_source'
][
'info_list'
][
'file_info_list'
]:
raise
ValueError
(
"File source info is required"
)
raise
ValueError
(
"File source info is required"
)
if
args
[
'data_source'
][
'type'
]
==
'notion_import'
:
if
args
[
'data_source'
][
'type'
]
==
'notion_import'
:
if
'notion_info_list'
not
in
args
[
'data_source'
][
'info_list'
]
or
not
args
[
'data_source'
][
'info_list'
][
'notion_info_list'
]:
if
'notion_info_list'
not
in
args
[
'data_source'
][
'info_list'
]
or
not
args
[
'data_source'
][
'info_list'
][
'notion_info_list'
]:
raise
ValueError
(
"Notion source info is required"
)
raise
ValueError
(
"Notion source info is required"
)
@
classmethod
@
classmethod
...
@@ -843,3 +859,75 @@ class DocumentService:
...
@@ -843,3 +859,75 @@ class DocumentService:
if
not
isinstance
(
args
[
'process_rule'
][
'rules'
][
'segmentation'
][
'max_tokens'
],
int
):
if
not
isinstance
(
args
[
'process_rule'
][
'rules'
][
'segmentation'
][
'max_tokens'
],
int
):
raise
ValueError
(
"Process rule segmentation max_tokens is invalid"
)
raise
ValueError
(
"Process rule segmentation max_tokens is invalid"
)
class
SegmentService
:
@
classmethod
def
segment_create_args_validate
(
cls
,
args
:
dict
,
document
:
Document
):
if
document
.
doc_form
==
'qa_model'
:
if
'answer'
not
in
args
or
not
args
[
'answer'
]:
raise
ValueError
(
"Answer is required"
)
@
classmethod
def
create_segment
(
cls
,
args
:
dict
,
document
:
Document
):
content
=
args
[
'content'
]
doc_id
=
str
(
uuid
.
uuid4
())
segment_hash
=
helper
.
generate_text_hash
(
content
)
# calc embedding use tokens
tokens
=
TokenCalculator
.
get_num_tokens
(
'text-embedding-ada-002'
,
content
)
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
)
.
scalar
()
segment_document
=
DocumentSegment
(
tenant_id
=
current_user
.
current_tenant_id
,
dataset_id
=
document
.
dataset_id
,
document_id
=
document
.
id
,
index_node_id
=
doc_id
,
index_node_hash
=
segment_hash
,
position
=
max_position
+
1
if
max_position
else
1
,
content
=
content
,
word_count
=
len
(
content
),
tokens
=
tokens
,
created_by
=
current_user
.
id
)
if
document
.
doc_form
==
'qa_model'
:
segment_document
.
answer
=
args
[
'answer'
]
db
.
session
.
add
(
segment_document
)
db
.
session
.
commit
()
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment_document
.
id
)
redis_client
.
setex
(
indexing_cache_key
,
600
,
1
)
create_segment_to_index_task
.
delay
(
segment_document
.
id
)
return
segment_document
@
classmethod
def
update_segment
(
cls
,
args
:
dict
,
segment
:
DocumentSegment
,
document
:
Document
):
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
cache_result
=
redis_client
.
get
(
indexing_cache_key
)
if
cache_result
is
not
None
:
raise
InvalidActionError
(
"Segment is indexing, please try again later"
)
content
=
args
[
'content'
]
if
segment
.
content
==
content
:
if
document
.
doc_form
==
'qa_model'
:
segment
.
answer
=
args
[
'answer'
]
db
.
session
.
add
(
segment
)
db
.
session
.
commit
()
else
:
segment_hash
=
helper
.
generate_text_hash
(
content
)
# calc embedding use tokens
tokens
=
TokenCalculator
.
get_num_tokens
(
'text-embedding-ada-002'
,
content
)
segment
.
content
=
content
segment
.
index_node_hash
=
segment_hash
segment
.
word_count
=
len
(
content
)
segment
.
tokens
=
tokens
segment
.
status
=
'updating'
segment
.
updated_by
=
current_user
.
id
segment
.
updated_at
=
datetime
.
datetime
.
utcnow
()
if
document
.
doc_form
==
'qa_model'
:
segment
.
answer
=
args
[
'answer'
]
db
.
session
.
add
(
segment
)
db
.
session
.
commit
()
# update segment index task
redis_client
.
setex
(
indexing_cache_key
,
600
,
1
)
update_segment_index_task
.
delay
(
segment
.
id
)
return
segment
api/tasks/deal_dataset_vector_index_task.py
View file @
2b150ffd
...
@@ -49,7 +49,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
...
@@ -49,7 +49,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
dataset_document
.
id
,
DocumentSegment
.
document_id
==
dataset_document
.
id
,
DocumentSegment
.
enabled
==
True
DocumentSegment
.
enabled
==
True
)
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
)
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
documents
=
[]
documents
=
[]
for
segment
in
segments
:
for
segment
in
segments
:
...
...
api/tasks/
add
_segment_to_index_task.py
→
api/tasks/
enable
_segment_to_index_task.py
View file @
2b150ffd
...
@@ -14,14 +14,14 @@ from models.dataset import DocumentSegment
...
@@ -14,14 +14,14 @@ from models.dataset import DocumentSegment
@
shared_task
@
shared_task
def
add
_segment_to_index_task
(
segment_id
:
str
):
def
enable
_segment_to_index_task
(
segment_id
:
str
):
"""
"""
Async
Add
segment to index
Async
enable
segment to index
:param segment_id:
:param segment_id:
Usage:
add_segment_to_index
.delay(segment_id)
Usage:
enable_segment_to_index_task
.delay(segment_id)
"""
"""
logging
.
info
(
click
.
style
(
'Start
add
segment to index: {}'
.
format
(
segment_id
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Start
enable
segment to index: {}'
.
format
(
segment_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
start_at
=
time
.
perf_counter
()
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
id
==
segment_id
)
.
first
()
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
id
==
segment_id
)
.
first
()
...
@@ -71,9 +71,9 @@ def add_segment_to_index_task(segment_id: str):
...
@@ -71,9 +71,9 @@ def add_segment_to_index_task(segment_id: str):
index
.
add_texts
([
document
])
index
.
add_texts
([
document
])
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment
add
ed to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Segment
enabl
ed to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
exception
(
"
add
segment to index failed"
)
logging
.
exception
(
"
enable
segment to index failed"
)
segment
.
enabled
=
False
segment
.
enabled
=
False
segment
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
segment
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
segment
.
status
=
'error'
segment
.
status
=
'error'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment