Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
9e9d15ec
Commit
9e9d15ec
authored
Jun 18, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: replace llama-index to langchain in index build
parent
23ef2262
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
43 changed files
with
1692 additions
and
1995 deletions
+1692
-1995
config.py
api/config.py
+2
-0
data_source.py
api/controllers/console/datasets/data_source.py
+10
-10
file.py
api/controllers/console/datasets/file.py
+2
-28
__init__.py
api/core/__init__.py
+0
-18
file_extractor.py
api/core/data_loader/file_extractor.py
+42
-0
csv.py
api/core/data_loader/loader/csv.py
+67
-0
excel.py
api/core/data_loader/loader/excel.py
+46
-0
html.py
api/core/data_loader/loader/html.py
+36
-0
markdown.py
api/core/data_loader/loader/markdown.py
+134
-0
notion.py
api/core/data_loader/loader/notion.py
+240
-235
pdf.py
api/core/data_loader/loader/pdf.py
+69
-0
dataset_docstore.py
api/core/docstore/dataset_docstore.py
+35
-45
cached_embedding.py
api/core/embedding/cached_embedding.py
+72
-0
openai_embedding.py
api/core/embedding/openai_embedding.py
+0
-214
base.py
api/core/index/base.py
+50
-0
index_builder.py
api/core/index/index_builder.py
+0
-60
jieba_keyword_table.py
api/core/index/keyword_table/jieba_keyword_table.py
+0
-159
keyword_table_index.py
api/core/index/keyword_table_index.py
+0
-135
jieba_keyword_table_handler.py
.../index/keyword_table_index/jieba_keyword_table_handler.py
+33
-0
keyword_table_index.py
api/core/index/keyword_table_index/keyword_table_index.py
+186
-0
stopwords.py
api/core/index/keyword_table_index/stopwords.py
+0
-0
synthesizer.py
api/core/index/query/synthesizer.py
+0
-79
html_parser.py
api/core/index/readers/html_parser.py
+0
-22
xlsx_parser.py
api/core/index/readers/xlsx_parser.py
+0
-33
vector_index.py
api/core/index/vector_index.py
+0
-136
base.py
api/core/index/vector_index/base.py
+24
-0
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+141
-0
vector_index.py
api/core/index/vector_index/vector_index.py
+70
-0
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+145
-0
indexing_runner.py
api/core/indexing_runner.py
+186
-232
azure_provider.py
api/core/llm/provider/azure_provider.py
+1
-1
fixed_text_splitter.py
api/core/spiltter/fixed_text_splitter.py
+0
-0
base.py
api/core/vector_store/base.py
+0
-34
qdrant_vector_store.py
api/core/vector_store/qdrant_vector_store.py
+45
-0
qdrant_vector_store_client.py
api/core/vector_store/qdrant_vector_store_client.py
+0
-147
vector_store.py
api/core/vector_store/vector_store.py
+0
-62
vector_store_index_query.py
api/core/vector_store/vector_store_index_query.py
+0
-66
weaviate_vector_store.py
api/core/vector_store/weaviate_vector_store.py
+35
-0
weaviate_vector_store_client.py
api/core/vector_store/weaviate_vector_store_client.py
+0
-270
helper.py
api/libs/helper.py
+6
-0
requirements.txt
api/requirements.txt
+3
-3
dataset_service.py
api/services/dataset_service.py
+0
-3
document_indexing_sync_task.py
api/tasks/document_indexing_sync_task.py
+12
-3
No files found.
api/config.py
View file @
9e9d15ec
...
...
@@ -187,11 +187,13 @@ class Config:
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self
.
DEFAULT_LLM_PROVIDER
=
get_env
(
'DEFAULT_LLM_PROVIDER'
)
# notion import setting
self
.
NOTION_CLIENT_ID
=
get_env
(
'NOTION_CLIENT_ID'
)
self
.
NOTION_CLIENT_SECRET
=
get_env
(
'NOTION_CLIENT_SECRET'
)
self
.
NOTION_INTEGRATION_TYPE
=
get_env
(
'NOTION_INTEGRATION_TYPE'
)
self
.
NOTION_INTERNAL_SECRET
=
get_env
(
'NOTION_INTERNAL_SECRET'
)
self
.
NOTION_INTEGRATION_TOKEN
=
get_env
(
'NOTION_INTEGRATION_TOKEN'
)
class
CloudEditionConfig
(
Config
):
...
...
api/controllers/console/datasets/data_source.py
View file @
9e9d15ec
...
...
@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
from
controllers.console
import
api
from
controllers.console.setup
import
setup_required
from
controllers.console.wraps
import
account_initialization_required
from
core.data_
source.notion
import
NotionPageRe
ader
from
core.data_
loader.loader.notion
import
NotionLo
ader
from
core.indexing_runner
import
IndexingRunner
from
extensions.ext_database
import
db
from
libs.helper
import
TimestampField
from
libs.oauth_data_source
import
NotionOAuth
from
models.dataset
import
Document
from
models.source
import
DataSourceBinding
from
services.dataset_service
import
DatasetService
,
DocumentService
...
...
@@ -232,15 +231,16 @@ class DataSourceNotionApi(Resource):
)
.
first
()
if
not
data_source_binding
:
raise
NotFound
(
'Data source binding not found.'
)
reader
=
NotionPageReader
(
integration_token
=
data_source_binding
.
access_token
)
if
page_type
==
'page'
:
page_content
=
reader
.
read_page
(
page_id
)
elif
page_type
==
'database'
:
page_content
=
reader
.
query_database_data
(
page_id
)
else
:
page_content
=
""
loader
=
NotionLoader
(
notion_access_token
=
data_source_binding
.
access_token
,
notion_workspace_id
=
workspace_id
,
notion_obj_id
=
page_id
,
notion_page_type
=
page_type
)
return
{
'content'
:
page_content
'content'
:
loader
.
load_as_text
()
},
200
@
setup_required
...
...
api/controllers/console/datasets/file.py
View file @
9e9d15ec
...
...
@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
UnsupportedFileTypeError
from
controllers.console.setup
import
setup_required
from
controllers.console.wraps
import
account_initialization_required
from
core.index.readers.html_parser
import
HTMLParser
from
core.index.readers.pdf_parser
import
PDFParser
from
core.index.readers.xlsx_parser
import
XLSXParser
from
core.data_loader.file_extractor
import
FileExtractor
from
extensions.ext_storage
import
storage
from
libs.helper
import
TimestampField
from
extensions.ext_database
import
db
...
...
@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
if
extension
not
in
ALLOWED_EXTENSIONS
:
raise
UnsupportedFileTypeError
()
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
upload_file
.
key
)
.
suffix
filepath
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage
.
download
(
upload_file
.
key
,
filepath
)
if
extension
==
'pdf'
:
parser
=
PDFParser
({
'upload_file'
:
upload_file
})
text
=
parser
.
parse_file
(
Path
(
filepath
))
elif
extension
in
[
'html'
,
'htm'
]:
# Use BeautifulSoup to extract text
parser
=
HTMLParser
()
text
=
parser
.
parse_file
(
Path
(
filepath
))
elif
extension
==
'xlsx'
:
parser
=
XLSXParser
()
text
=
parser
.
parse_file
(
filepath
)
else
:
# ['txt', 'markdown', 'md']
with
open
(
filepath
,
"rb"
)
as
fp
:
data
=
fp
.
read
()
encoding
=
chardet
.
detect
(
data
)[
'encoding'
]
if
encoding
:
text
=
data
.
decode
(
encoding
=
encoding
)
.
strip
()
if
data
else
''
else
:
text
=
data
.
decode
(
encoding
=
'utf-8'
)
.
strip
()
if
data
else
''
text
=
FileExtractor
.
load
(
upload_file
,
return_text
=
True
)
text
=
text
[
0
:
PREVIEW_WORDS_LIMIT
]
if
text
else
''
return
{
'content'
:
text
}
...
...
api/core/__init__.py
View file @
9e9d15ec
...
...
@@ -3,19 +3,12 @@ from typing import Optional
import
langchain
from
flask
import
Flask
from
jieba.analyse
import
default_tfidf
from
langchain
import
set_handler
from
langchain.prompts.base
import
DEFAULT_FORMATTER_MAPPING
from
llama_index
import
IndexStructType
,
QueryMode
from
llama_index.indices.registry
import
INDEX_STRUT_TYPE_TO_QUERY_MAP
from
pydantic
import
BaseModel
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.index.keyword_table.jieba_keyword_table
import
GPTJIEBAKeywordTableIndex
from
core.index.keyword_table.stopwords
import
STOPWORDS
from
core.prompt.prompt_template
import
OneLineFormatter
from
core.vector_store.vector_store
import
VectorStore
from
core.vector_store.vector_store_index_query
import
EnhanceGPTVectorStoreIndexQuery
class
HostedOpenAICredential
(
BaseModel
):
...
...
@@ -32,17 +25,6 @@ hosted_llm_credentials = HostedLLMCredentials()
def
init_app
(
app
:
Flask
):
formatter
=
OneLineFormatter
()
DEFAULT_FORMATTER_MAPPING
[
'f-string'
]
=
formatter
.
format
INDEX_STRUT_TYPE_TO_QUERY_MAP
[
IndexStructType
.
KEYWORD_TABLE
]
=
GPTJIEBAKeywordTableIndex
.
get_query_map
()
INDEX_STRUT_TYPE_TO_QUERY_MAP
[
IndexStructType
.
WEAVIATE
]
=
{
QueryMode
.
DEFAULT
:
EnhanceGPTVectorStoreIndexQuery
,
QueryMode
.
EMBEDDING
:
EnhanceGPTVectorStoreIndexQuery
,
}
INDEX_STRUT_TYPE_TO_QUERY_MAP
[
IndexStructType
.
QDRANT
]
=
{
QueryMode
.
DEFAULT
:
EnhanceGPTVectorStoreIndexQuery
,
QueryMode
.
EMBEDDING
:
EnhanceGPTVectorStoreIndexQuery
,
}
default_tfidf
.
stop_words
=
STOPWORDS
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
langchain
.
verbose
=
True
...
...
api/core/data_loader/file_extractor.py
0 → 100644
View file @
9e9d15ec
import
tempfile
from
pathlib
import
Path
from
typing
import
List
,
Union
from
langchain.document_loaders
import
TextLoader
,
Docx2txtLoader
from
langchain.schema
import
Document
from
core.data_loader.loader.csv
import
CSVLoader
from
core.data_loader.loader.excel
import
ExcelLoader
from
core.data_loader.loader.html
import
HTMLLoader
from
core.data_loader.loader.markdown
import
MarkdownLoader
from
core.data_loader.loader.pdf
import
PdfLoader
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
class
FileExtractor
:
@
classmethod
def
load
(
cls
,
upload_file
:
UploadFile
,
return_text
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
upload_file
.
key
)
.
suffix
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage
.
download
(
upload_file
.
key
,
file_path
)
input_file
=
Path
(
file_path
)
if
input_file
.
suffix
==
'.xlxs'
:
loader
=
ExcelLoader
(
file_path
)
elif
input_file
.
suffix
==
'.pdf'
:
loader
=
PdfLoader
(
file_path
,
upload_file
=
upload_file
)
elif
input_file
.
suffix
in
[
'.md'
,
'.markdown'
]:
loader
=
MarkdownLoader
(
file_path
,
autodetect_encoding
=
True
)
elif
input_file
.
suffix
in
[
'.htm'
,
'.html'
]:
loader
=
HTMLLoader
(
file_path
)
elif
input_file
.
suffix
==
'.docx'
:
loader
=
Docx2txtLoader
(
file_path
)
elif
input_file
.
suffix
==
'.csv'
:
loader
=
CSVLoader
(
file_path
,
autodetect_encoding
=
True
)
else
:
# txt
loader
=
TextLoader
(
file_path
,
autodetect_encoding
=
True
)
return
loader
.
load_as_text
()
if
return_text
else
loader
.
load
()
api/core/data_loader/loader/csv.py
0 → 100644
View file @
9e9d15ec
import
logging
from
typing
import
Optional
,
Dict
,
List
from
langchain.document_loaders
import
CSVLoader
as
LCCSVLoader
from
langchain.document_loaders.helpers
import
detect_file_encodings
from
models.dataset
import
Document
logger
=
logging
.
getLogger
(
__name__
)
class
CSVLoader
(
LCCSVLoader
):
def
__init__
(
self
,
file_path
:
str
,
source_column
:
Optional
[
str
]
=
None
,
csv_args
:
Optional
[
Dict
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
autodetect_encoding
:
bool
=
True
,
):
self
.
file_path
=
file_path
self
.
source_column
=
source_column
self
.
encoding
=
encoding
self
.
csv_args
=
csv_args
or
{}
self
.
autodetect_encoding
=
autodetect_encoding
def
load
(
self
)
->
List
[
Document
]:
"""Load data into document objects."""
try
:
with
open
(
self
.
file_path
,
newline
=
""
,
encoding
=
self
.
encoding
)
as
csvfile
:
docs
=
self
.
_read_from_file
(
csvfile
)
except
UnicodeDecodeError
as
e
:
if
self
.
autodetect_encoding
:
detected_encodings
=
detect_file_encodings
(
self
.
file_path
)
for
encoding
in
detected_encodings
:
logger
.
debug
(
"Trying encoding: "
,
encoding
.
encoding
)
try
:
with
open
(
self
.
file_path
,
newline
=
""
,
encoding
=
encoding
.
encoding
)
as
csvfile
:
docs
=
self
.
_read_from_file
(
csvfile
)
break
except
UnicodeDecodeError
:
continue
else
:
raise
RuntimeError
(
f
"Error loading {self.file_path}"
)
from
e
return
docs
def
_read_from_file
(
self
,
csvfile
):
docs
=
[]
csv_reader
=
csv
.
DictReader
(
csvfile
,
**
self
.
csv_args
)
# type: ignore
for
i
,
row
in
enumerate
(
csv_reader
):
content
=
"
\n
"
.
join
(
f
"{k.strip()}: {v.strip()}"
for
k
,
v
in
row
.
items
())
try
:
source
=
(
row
[
self
.
source_column
]
if
self
.
source_column
is
not
None
else
self
.
file_path
)
except
KeyError
:
raise
ValueError
(
f
"Source column '{self.source_column}' not found in CSV file."
)
metadata
=
{
"source"
:
source
,
"row"
:
i
}
doc
=
Document
(
page_content
=
content
,
metadata
=
metadata
)
docs
.
append
(
doc
)
return
docs
api/core/data_loader/loader/excel.py
0 → 100644
View file @
9e9d15ec
import
json
import
logging
from
typing
import
List
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
from
openpyxl.reader.excel
import
load_workbook
logger
=
logging
.
getLogger
(
__name__
)
class
ExcelLoader
(
BaseLoader
):
"""Load xlxs files.
Args:
file_path: Path to the file to load.
"""
def
__init__
(
self
,
file_path
:
str
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
def
load
(
self
)
->
List
[
Document
]:
data
=
[]
keys
=
[]
wb
=
load_workbook
(
filename
=
self
.
_file_path
,
read_only
=
True
)
# loop over all sheets
for
sheet
in
wb
:
for
row
in
sheet
.
iter_rows
(
values_only
=
True
):
if
all
(
v
is
None
for
v
in
row
):
continue
if
keys
==
[]:
keys
=
list
(
map
(
str
,
row
))
else
:
data
.
append
(
json
.
dumps
(
dict
(
zip
(
keys
,
list
(
map
(
str
,
row
)))),
ensure_ascii
=
False
))
metadata
=
{
"source"
:
self
.
_file_path
}
return
[
Document
(
page_content
=
'
\n\n
'
.
join
(
data
),
metadata
=
metadata
)]
def
load_as_text
(
self
)
->
str
:
documents
=
self
.
load
()
return
''
.
join
([
document
.
page_content
for
document
in
documents
])
api/core/data_loader/loader/html.py
0 → 100644
View file @
9e9d15ec
import
logging
from
typing
import
List
from
bs4
import
BeautifulSoup
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
logger
=
logging
.
getLogger
(
__name__
)
class
HTMLLoader
(
BaseLoader
):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def
__init__
(
self
,
file_path
:
str
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
def
load
(
self
)
->
List
[
Document
]:
metadata
=
{
"source"
:
self
.
_file_path
}
return
[
Document
(
page_content
=
self
.
load_as_text
(),
metadata
=
metadata
)]
def
load_as_text
(
self
)
->
str
:
with
open
(
self
.
_file_path
,
"rb"
)
as
fp
:
soup
=
BeautifulSoup
(
fp
,
'html.parser'
)
text
=
soup
.
get_text
()
text
=
text
.
strip
()
if
text
else
''
return
text
api/core/
index/readers/markdown_parser
.py
→
api/core/
data_loader/loader/markdown
.py
View file @
9e9d15ec
"""Markdown parser.
import
logging
import
re
from
typing
import
Optional
,
List
,
Tuple
,
cast
Contains parser for md files.
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.document_loaders.helpers
import
detect_file_encodings
from
langchain.schema
import
Document
logger
=
logging
.
getLogger
(
__name__
)
class
MarkdownLoader
(
BaseLoader
):
"""Load md files.
"""
import
re
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
cast
from
llama_index.readers.file.base_parser
import
BaseParser
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
class
MarkdownParser
(
BaseParser
):
"""Markdown parser.
remove_images: Whether to remove images from the text.
Extract text from markdown files.
Returns dictionary with keys as headers and values as the text between headers
.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding
.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
def
__init__
(
self
,
*
args
:
Any
,
file_path
:
str
,
remove_hyperlinks
:
bool
=
True
,
remove_images
:
bool
=
True
,
**
kwargs
:
Any
,
)
->
None
:
"""Init params."""
super
()
.
__init__
(
*
args
,
**
kwargs
)
encoding
:
Optional
[
str
]
=
None
,
autodetect_encoding
:
bool
=
True
,
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
self
.
_remove_hyperlinks
=
remove_hyperlinks
self
.
_remove_images
=
remove_images
self
.
_encoding
=
encoding
self
.
_autodetect_encoding
=
autodetect_encoding
def
load
(
self
)
->
List
[
Document
]:
tups
=
self
.
parse_tups
(
self
.
_file_path
)
documents
=
[]
metadata
=
{
"source"
:
self
.
_file_path
}
for
header
,
value
in
tups
:
if
header
is
None
:
documents
.
append
(
Document
(
page_content
=
value
,
metadata
=
metadata
))
else
:
documents
.
append
(
Document
(
page_content
=
f
"
\n\n
{header}
\n
{value}"
,
metadata
=
metadata
))
return
documents
def
markdown_to_tups
(
self
,
markdown_text
:
str
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
"""Convert a markdown file to a dictionary.
...
...
@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser):
content
=
re
.
sub
(
pattern
,
r"\1"
,
content
)
return
content
def
_init_parser
(
self
)
->
Dict
:
"""Initialize the parser with the config."""
return
{}
def
parse_tups
(
self
,
filepath
:
Path
,
errors
:
str
=
"ignore"
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
def
parse_tups
(
self
,
filepath
:
str
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
"""Parse file into tuples."""
with
open
(
filepath
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
content
=
f
.
read
()
content
=
""
try
:
with
open
(
filepath
,
"r"
,
encoding
=
self
.
_encoding
)
as
f
:
content
=
f
.
read
()
except
UnicodeDecodeError
as
e
:
if
self
.
_autodetect_encoding
:
detected_encodings
=
detect_file_encodings
(
filepath
)
for
encoding
in
detected_encodings
:
logger
.
debug
(
"Trying encoding: "
,
encoding
.
encoding
)
try
:
with
open
(
filepath
,
encoding
=
encoding
.
encoding
)
as
f
:
content
=
f
.
read
()
break
except
UnicodeDecodeError
:
continue
else
:
raise
RuntimeError
(
f
"Error loading {filepath}"
)
from
e
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error loading {filepath}"
)
from
e
if
self
.
_remove_hyperlinks
:
content
=
self
.
remove_hyperlinks
(
content
)
if
self
.
_remove_images
:
content
=
self
.
remove_images
(
content
)
markdown_tups
=
self
.
markdown_to_tups
(
content
)
return
markdown_tups
def
parse_file
(
self
,
filepath
:
Path
,
errors
:
str
=
"ignore"
)
->
Union
[
str
,
List
[
str
]]:
"""Parse file into string."""
tups
=
self
.
parse_tups
(
filepath
,
errors
=
errors
)
results
=
[]
# TODO: don't include headers right now
for
header
,
value
in
tups
:
if
header
is
None
:
results
.
append
(
value
)
else
:
results
.
append
(
f
"
\n\n
{header}
\n
{value}"
)
return
results
return
self
.
markdown_to_tups
(
content
)
api/core/data_
source
/notion.py
→
api/core/data_
loader/loader
/notion.py
View file @
9e9d15ec
"""Notion reader."""
import
json
import
logging
import
os
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
List
,
Dict
,
Any
,
Optional
import
requests
# type: ignore
import
requests
from
flask
import
current_app
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
from
llama_index.readers.base
import
BaseReader
from
llama_index.readers.schema.base
import
Document
from
extensions.ext_database
import
db
from
models.dataset
import
Document
as
DocumentModel
from
models.source
import
DataSourceBinding
logger
=
logging
.
getLogger
(
__name__
)
INTEGRATION_TOKEN_NAME
=
"NOTION_INTEGRATION_TOKEN"
BLOCK_CHILD_URL_TMPL
=
"https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL
=
"https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL
=
"https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL
=
"https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL
=
"https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE
=
[
'heading_1'
,
'heading_2'
,
'heading_3'
]
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Notion DB reader coming soon!
class
NotionPageReader
(
BaseReader
):
"""Notion Page reader.
Reads a set of Notion pages.
Args:
integration_token (str): Notion integration token.
"""
def
__init__
(
self
,
integration_token
:
Optional
[
str
]
=
None
)
->
None
:
"""Initialize with parameters."""
if
integration_token
is
None
:
integration_token
=
os
.
getenv
(
INTEGRATION_TOKEN_NAME
)
class
NotionLoader
(
BaseLoader
):
def
__init__
(
self
,
notion_access_token
:
str
,
notion_workspace_id
:
str
,
notion_obj_id
:
str
,
notion_page_type
:
str
,
document_model
:
Optional
[
DocumentModel
]
=
None
):
self
.
_document_model
=
document_model
self
.
_notion_workspace_id
=
notion_workspace_id
self
.
_notion_obj_id
=
notion_obj_id
self
.
_notion_page_type
=
notion_page_type
self
.
_notion_access_token
=
notion_access_token
if
not
self
.
_notion_access_token
:
integration_token
=
current_app
.
config
.
get
(
'NOTION_INTEGRATION_TOKEN'
)
if
integration_token
is
None
:
raise
ValueError
(
"Must specify `integration_token` or set environment "
"variable `NOTION_INTEGRATION_TOKEN`."
)
self
.
token
=
integration_token
self
.
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
}
def
_read_block
(
self
,
block_id
:
str
,
num_tabs
:
int
=
0
)
->
str
:
"""Read a block."""
done
=
False
self
.
_notion_access_token
=
integration_token
@
classmethod
def
from_document
(
cls
,
document_model
:
DocumentModel
):
data_source_info
=
document_model
.
data_source_info_dict
if
not
data_source_info
or
'notion_page_id'
not
in
data_source_info
\
or
'notion_workspace_id'
not
in
data_source_info
:
raise
ValueError
(
"no notion page found"
)
notion_workspace_id
=
data_source_info
[
'notion_workspace_id'
]
notion_obj_id
=
data_source_info
[
'notion_page_id'
]
notion_page_type
=
data_source_info
[
'type'
]
notion_access_token
=
cls
.
_get_access_token
(
document_model
.
tenant_id
,
notion_workspace_id
)
return
cls
(
notion_access_token
=
notion_access_token
,
notion_workspace_id
=
notion_workspace_id
,
notion_obj_id
=
notion_obj_id
,
notion_page_type
=
notion_page_type
,
document_model
=
document_model
)
def
load
(
self
)
->
List
[
Document
]:
self
.
update_last_edited_time
(
self
.
_document_model
)
text_docs
=
self
.
_load_data_as_documents
(
self
.
_notion_obj_id
,
self
.
_notion_page_type
)
return
text_docs
def
load_as_text
(
self
)
->
str
:
text_docs
=
self
.
_load_data_as_documents
(
self
.
_notion_obj_id
,
self
.
_notion_page_type
)
text
=
"
\n
"
.
join
([
doc
.
page_content
for
doc
in
text_docs
])
return
text
def
_load_data_as_documents
(
self
,
notion_obj_id
:
str
,
notion_page_type
:
str
)
->
List
[
Document
]:
docs
=
[]
if
notion_page_type
==
'database'
:
# get all the pages in the database
page_text
=
self
.
_get_notion_database_data
(
notion_obj_id
)
docs
.
append
(
Document
(
page_content
=
page_text
))
elif
notion_page_type
==
'page'
:
page_text_list
=
self
.
_get_notion_block_data
(
notion_obj_id
)
for
page_text
in
page_text_list
:
docs
.
append
(
Document
(
page_content
=
page_text
))
else
:
raise
ValueError
(
"notion page type not supported"
)
return
docs
def
_get_notion_database_data
(
self
,
database_id
:
str
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
)
->
str
:
"""Get all the pages from a Notion database."""
res
=
requests
.
post
(
DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
),
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
,
)
data
=
res
.
json
()
database_content_list
=
[]
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
return
""
for
result
in
data
[
"results"
]:
properties
=
result
[
'properties'
]
data
=
{}
for
property_name
,
property_value
in
properties
.
items
():
type
=
property_value
[
'type'
]
if
type
==
'multi_select'
:
value
=
[]
multi_select_list
=
property_value
[
type
]
for
multi_select
in
multi_select_list
:
value
.
append
(
multi_select
[
'name'
])
elif
type
==
'rich_text'
or
type
==
'title'
:
if
len
(
property_value
[
type
])
>
0
:
value
=
property_value
[
type
][
0
][
'plain_text'
]
else
:
value
=
''
elif
type
==
'select'
or
type
==
'status'
:
if
property_value
[
type
]:
value
=
property_value
[
type
][
'name'
]
else
:
value
=
''
else
:
value
=
property_value
[
type
]
data
[
property_name
]
=
value
database_content_list
.
append
(
json
.
dumps
(
data
))
return
"
\n\n
"
.
join
(
database_content_list
)
def
_get_notion_block_data
(
self
,
page_id
:
str
)
->
List
[
str
]:
result_lines_arr
=
[]
cur_block_id
=
block
_id
while
not
don
e
:
cur_block_id
=
page
_id
while
Tru
e
:
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
self
.
headers
,
json
=
query_dict
"GET"
,
block_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
data
=
res
.
json
()
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
done
=
True
break
# current block's heading
heading
=
''
for
result
in
data
[
"results"
]:
result_type
=
result
[
"type"
]
...
...
@@ -71,6 +170,7 @@ class NotionPageReader(BaseReader):
if
result_type
==
'table'
:
result_block_id
=
result
[
"id"
]
text
=
self
.
_read_table_rows
(
result_block_id
)
text
+=
"
\n\n
"
result_lines_arr
.
append
(
text
)
else
:
if
"rich_text"
in
result_obj
:
...
...
@@ -78,10 +178,10 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object
if
"text"
in
rich_text
:
text
=
rich_text
[
"text"
][
"content"
]
prefix
=
"
\t
"
*
num_tabs
cur_result_text_arr
.
append
(
prefix
+
text
)
cur_result_text_arr
.
append
(
text
)
if
result_type
in
HEADING_TYPE
:
heading
=
text
result_block_id
=
result
[
"id"
]
has_children
=
result
[
"has_children"
]
block_type
=
result
[
"type"
]
...
...
@@ -92,77 +192,39 @@ class NotionPageReader(BaseReader):
cur_result_text_arr
.
append
(
children_text
)
cur_result_text
=
"
\n
"
.
join
(
cur_result_text_arr
)
cur_result_text
+=
"
\n\n
"
if
result_type
in
HEADING_TYPE
:
result_lines_arr
.
append
(
cur_result_text
)
else
:
result_lines_arr
.
append
(
f
'{heading}
\n
{cur_result_text}'
)
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
else
:
cur_block_id
=
data
[
"next_cursor"
]
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
def
_read_table_rows
(
self
,
block_id
:
str
)
->
str
:
"""Read table rows."""
done
=
False
result_lines_arr
=
[]
cur_block_id
=
block_id
while
not
done
:
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
self
.
headers
,
json
=
query_dict
)
data
=
res
.
json
()
# get table headers text
table_header_cell_texts
=
[]
tabel_header_cells
=
data
[
"results"
][
0
][
'table_row'
][
'cells'
]
for
tabel_header_cell
in
tabel_header_cells
:
if
tabel_header_cell
:
for
table_header_cell_text
in
tabel_header_cell
:
text
=
table_header_cell_text
[
"text"
][
"content"
]
table_header_cell_texts
.
append
(
text
)
# get table columns text and format
results
=
data
[
"results"
]
for
i
in
range
(
len
(
results
)
-
1
):
column_texts
=
[]
tabel_column_cells
=
data
[
"results"
][
i
+
1
][
'table_row'
][
'cells'
]
for
j
in
range
(
len
(
tabel_column_cells
)):
if
tabel_column_cells
[
j
]:
for
table_column_cell_text
in
tabel_column_cells
[
j
]:
column_text
=
table_column_cell_text
[
"text"
][
"content"
]
column_texts
.
append
(
f
'{table_header_cell_texts[j]}:{column_text}'
)
cur_result_text
=
"
\n
"
.
join
(
column_texts
)
result_lines_arr
.
append
(
cur_result_text
)
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
else
:
cur_block_id
=
data
[
"next_cursor"
]
return
result_lines_arr
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
def
_read_parent_blocks
(
self
,
block_id
:
str
,
num_tabs
:
int
=
0
)
->
List
[
str
]:
def
_read_block
(
self
,
block_id
:
str
,
num_tabs
:
int
=
0
)
->
str
:
"""Read a block."""
done
=
False
result_lines_arr
=
[]
cur_block_id
=
block_id
while
not
don
e
:
while
Tru
e
:
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
self
.
headers
,
json
=
query_dict
"GET"
,
block_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
data
=
res
.
json
()
# current block's heading
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
break
heading
=
''
for
result
in
data
[
"results"
]:
result_type
=
result
[
"type"
]
...
...
@@ -171,7 +233,6 @@ class NotionPageReader(BaseReader):
if
result_type
==
'table'
:
result_block_id
=
result
[
"id"
]
text
=
self
.
_read_table_rows
(
result_block_id
)
text
+=
"
\n\n
"
result_lines_arr
.
append
(
text
)
else
:
if
"rich_text"
in
result_obj
:
...
...
@@ -179,10 +240,10 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object
if
"text"
in
rich_text
:
text
=
rich_text
[
"text"
][
"content"
]
cur_result_text_arr
.
append
(
text
)
prefix
=
"
\t
"
*
num_tabs
cur_result_text_arr
.
append
(
prefix
+
text
)
if
result_type
in
HEADING_TYPE
:
heading
=
text
result_block_id
=
result
[
"id"
]
has_children
=
result
[
"has_children"
]
block_type
=
result
[
"type"
]
...
...
@@ -193,177 +254,121 @@ class NotionPageReader(BaseReader):
cur_result_text_arr
.
append
(
children_text
)
cur_result_text
=
"
\n
"
.
join
(
cur_result_text_arr
)
cur_result_text
+=
"
\n\n
"
if
result_type
in
HEADING_TYPE
:
result_lines_arr
.
append
(
cur_result_text
)
else
:
result_lines_arr
.
append
(
f
'{heading}
\n
{cur_result_text}'
)
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
else
:
cur_block_id
=
data
[
"next_cursor"
]
return
result_lines_arr
def
read_page
(
self
,
page_id
:
str
)
->
str
:
"""Read a page."""
return
self
.
_read_block
(
page_id
)
def
read_page_as_documents
(
self
,
page_id
:
str
)
->
List
[
str
]:
"""Read a page as documents."""
return
self
.
_read_parent_blocks
(
page_id
)
def
query_database_data
(
self
,
database_id
:
str
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
)
->
str
:
"""Get all the pages from a Notion database."""
res
=
requests
.
post
\
(
DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
),
headers
=
self
.
headers
,
json
=
query_dict
,
)
data
=
res
.
json
()
database_content_list
=
[]
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
return
""
for
result
in
data
[
"results"
]:
properties
=
result
[
'properties'
]
data
=
{}
for
property_name
,
property_value
in
properties
.
items
():
type
=
property_value
[
'type'
]
if
type
==
'multi_select'
:
value
=
[]
multi_select_list
=
property_value
[
type
]
for
multi_select
in
multi_select_list
:
value
.
append
(
multi_select
[
'name'
])
elif
type
==
'rich_text'
or
type
==
'title'
:
if
len
(
property_value
[
type
])
>
0
:
value
=
property_value
[
type
][
0
][
'plain_text'
]
else
:
value
=
''
elif
type
==
'select'
or
type
==
'status'
:
if
property_value
[
type
]:
value
=
property_value
[
type
][
'name'
]
else
:
value
=
''
else
:
value
=
property_value
[
type
]
data
[
property_name
]
=
value
database_content_list
.
append
(
json
.
dumps
(
data
,
ensure_ascii
=
False
))
return
"
\n\n
"
.
join
(
database_content_list
)
def
query_database
(
self
,
database_id
:
str
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
)
->
List
[
str
]:
"""Get all the pages from a Notion database."""
res
=
requests
.
post
\
(
DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
),
headers
=
self
.
headers
,
json
=
query_dict
,
)
data
=
res
.
json
()
page_ids
=
[]
for
result
in
data
[
"results"
]:
page_id
=
result
[
"id"
]
page_ids
.
append
(
page_id
)
return
page_ids
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
def
search
(
self
,
query
:
str
)
->
List
[
str
]
:
"""
Search Notion page given a text query
."""
def
_read_table_rows
(
self
,
block_id
:
str
)
->
str
:
"""
Read table rows
."""
done
=
False
next_cursor
:
Optional
[
str
]
=
None
page_ids
=
[]
result_lines_arr
=
[]
cur_block_id
=
block_id
while
not
done
:
query_dict
=
{
"query"
:
query
,
}
if
next_cursor
is
not
None
:
query_dict
[
"start_cursor"
]
=
next_cursor
res
=
requests
.
post
(
SEARCH_URL
,
headers
=
self
.
headers
,
json
=
query_dict
)
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
data
=
res
.
json
()
for
result
in
data
[
"results"
]:
page_id
=
result
[
"id"
]
page_ids
.
append
(
page_id
)
# get table headers text
table_header_cell_texts
=
[]
tabel_header_cells
=
data
[
"results"
][
0
][
'table_row'
][
'cells'
]
for
tabel_header_cell
in
tabel_header_cells
:
if
tabel_header_cell
:
for
table_header_cell_text
in
tabel_header_cell
:
text
=
table_header_cell_text
[
"text"
][
"content"
]
table_header_cell_texts
.
append
(
text
)
# get table columns text and format
results
=
data
[
"results"
]
for
i
in
range
(
len
(
results
)
-
1
):
column_texts
=
[]
tabel_column_cells
=
data
[
"results"
][
i
+
1
][
'table_row'
][
'cells'
]
for
j
in
range
(
len
(
tabel_column_cells
)):
if
tabel_column_cells
[
j
]:
for
table_column_cell_text
in
tabel_column_cells
[
j
]:
column_text
=
table_column_cell_text
[
"text"
][
"content"
]
column_texts
.
append
(
f
'{table_header_cell_texts[j]}:{column_text}'
)
cur_result_text
=
"
\n
"
.
join
(
column_texts
)
result_lines_arr
.
append
(
cur_result_text
)
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
else
:
next_cursor
=
data
[
"next_cursor"
]
return
page_ids
cur_block_id
=
data
[
"next_cursor"
]
def
load_data
(
self
,
page_ids
:
List
[
str
]
=
[],
database_id
:
Optional
[
str
]
=
None
)
->
List
[
Document
]:
"""Load data from the input directory.
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
Args:
page_ids (List[str]): List of page ids to load.
def
update_last_edited_time
(
self
,
document_model
:
DocumentModel
):
if
not
document_model
:
return
Returns:
List[Document]: List of documents.
last_edited_time
=
self
.
get_notion_last_edited_time
()
data_source_info
=
document_model
.
data_source_info_dict
data_source_info
[
'last_edited_time'
]
=
last_edited_time
update_params
=
{
DocumentModel
.
data_source_info
:
json
.
dumps
(
data_source_info
)
}
"""
if
not
page_ids
and
not
database_id
:
raise
ValueError
(
"Must specify either `page_ids` or `database_id`."
)
docs
=
[]
if
database_id
is
not
None
:
# get all the pages in the database
page_ids
=
self
.
query_database
(
database_id
)
for
page_id
in
page_ids
:
page_text
=
self
.
read_page
(
page_id
)
docs
.
append
(
Document
(
page_text
))
else
:
for
page_id
in
page_ids
:
page_text
=
self
.
read_page
(
page_id
)
docs
.
append
(
Document
(
page_text
))
DocumentModel
.
query
.
filter_by
(
id
=
document_model
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
return
docs
def
load_data_as_documents
(
self
,
page_ids
:
List
[
str
]
=
[],
database_id
:
Optional
[
str
]
=
None
)
->
List
[
Document
]:
if
not
page_ids
and
not
database_id
:
raise
ValueError
(
"Must specify either `page_ids` or `database_id`."
)
docs
=
[]
if
database_id
is
not
None
:
# get all the pages in the database
page_text
=
self
.
query_database_data
(
database_id
)
docs
.
append
(
Document
(
page_text
))
def
get_notion_last_edited_time
(
self
)
->
str
:
obj_id
=
self
.
_notion_obj_id
page_type
=
self
.
_notion_page_type
if
page_type
==
'database'
:
retrieve_page_url
=
RETRIEVE_DATABASE_URL_TMPL
.
format
(
database_id
=
obj_id
)
else
:
for
page_id
in
page_ids
:
page_text_list
=
self
.
read_page_as_documents
(
page_id
)
for
page_text
in
page_text_list
:
docs
.
append
(
Document
(
page_text
))
retrieve_page_url
=
RETRIEVE_PAGE_URL_TMPL
.
format
(
page_id
=
obj_id
)
return
docs
def
get_page_last_edited_time
(
self
,
page_id
:
str
)
->
str
:
retrieve_page_url
=
RETRIEVE_PAGE_URL_TMPL
.
format
(
page_id
=
page_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
retrieve_page_url
,
headers
=
self
.
headers
,
json
=
query_dict
"GET"
,
retrieve_page_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
data
=
res
.
json
()
return
data
[
"last_edited_time"
]
def
get_database_last_edited_time
(
self
,
database_id
:
str
)
->
str
:
retrieve_page_url
=
RETRIEVE_DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
retrieve_page_url
,
headers
=
self
.
headers
,
json
=
query_dict
)
data
=
res
.
json
()
return
data
[
"last_edited_time"
]
@
classmethod
def
_get_access_token
(
cls
,
tenant_id
:
str
,
notion_workspace_id
:
str
)
->
str
:
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
db
.
and_
(
DataSourceBinding
.
tenant_id
==
tenant_id
,
DataSourceBinding
.
provider
==
'notion'
,
DataSourceBinding
.
disabled
==
False
,
DataSourceBinding
.
source_info
[
'workspace_id'
]
==
f
'"{notion_workspace_id}"'
)
)
.
first
()
if
not
data_source_binding
:
raise
Exception
(
f
'No notion data source binding found for tenant {tenant_id} '
f
'and notion workspace {notion_workspace_id}'
)
if
__name__
==
"__main__"
:
reader
=
NotionPageReader
()
logger
.
info
(
reader
.
search
(
"What I"
))
return
data_source_binding
.
access_token
api/core/
index/readers/pdf_parser
.py
→
api/core/
data_loader/loader/pdf
.py
View file @
9e9d15ec
from
pathlib
import
Path
from
typing
import
Dict
import
logging
from
typing
import
List
,
Optional
from
flask
import
current_app
from
l
lama_index.readers.file.base_parser
import
BaseParser
from
langchain.document_loaders.base
import
BaseLoader
from
l
angchain.schema
import
Document
from
pypdf
import
PdfReader
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
logger
=
logging
.
getLogger
(
__name__
)
class
PDFParser
(
BaseParser
):
"""PDF parser."""
def
_init_parser
(
self
)
->
Dict
:
"""Init parser."""
return
{}
class
PdfLoader
(
BaseLoader
):
"""Load pdf files.
def
parse_file
(
self
,
file
:
Path
,
errors
:
str
=
"ignore"
)
->
str
:
"""Parse file."""
if
not
current_app
.
config
.
get
(
'PDF_PREVIEW'
,
True
):
return
''
Args:
file_path: Path to the file to load.
"""
def
__init__
(
self
,
file_path
:
str
,
upload_file
:
Optional
[
UploadFile
]
=
None
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
self
.
_upload_file
=
upload_file
def
load
(
self
)
->
List
[
Document
]:
plaintext_file_key
=
''
plaintext_file_exists
=
False
if
self
.
_
parser_config
and
'upload_file'
in
self
.
_parser_config
and
self
.
_parser_config
[
'upload_file'
]
:
upload_file
:
UploadFile
=
self
.
_parser_config
[
'upload_file'
]
if
upload_file
.
hash
:
plaintext_file_key
=
'upload_files/'
+
upload_file
.
tenant_id
+
'/'
+
upload_file
.
hash
+
'.plaintext'
if
self
.
_
upload_file
:
if
self
.
_upload_file
.
hash
:
plaintext_file_key
=
'upload_files/'
+
self
.
_upload_file
.
tenant_id
+
'/'
\
+
self
.
_
upload_file
.
hash
+
'.plaintext'
try
:
text
=
storage
.
load
(
plaintext_file_key
)
.
decode
(
'utf-8'
)
plaintext_file_exists
=
True
...
...
@@ -35,7 +43,7 @@ class PDFParser(BaseParser):
pass
text_list
=
[]
with
open
(
file
,
"rb"
)
as
fp
:
with
open
(
self
.
_file_path
,
"rb"
)
as
fp
:
# Create a PDF object
pdf
=
PdfReader
(
fp
)
...
...
@@ -53,4 +61,9 @@ class PDFParser(BaseParser):
if
not
plaintext_file_exists
and
plaintext_file_key
:
storage
.
save
(
plaintext_file_key
,
text
.
encode
(
'utf-8'
))
return
text
metadata
=
{
"source"
:
self
.
_file_path
}
return
[
Document
(
page_content
=
text
,
metadata
=
metadata
)]
def
load_as_text
(
self
)
->
str
:
documents
=
self
.
load
()
return
'
\n
'
.
join
([
document
.
page_content
for
document
in
documents
])
api/core/docstore/dataset_docstore.py
View file @
9e9d15ec
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
import
tiktoken
from
llama_index.data_structs
import
Node
from
llama_index.docstore.types
import
BaseDocumentStore
from
llama_index.docstore.utils
import
json_to_doc
from
llama_index.schema
import
BaseDocument
from
langchain.schema
import
Document
from
sqlalchemy
import
func
from
core.llm.token_calculator
import
TokenCalculator
...
...
@@ -12,7 +8,7 @@ from extensions.ext_database import db
from
models.dataset
import
Dataset
,
DocumentSegment
class
DatesetDocumentStore
(
BaseDocumentStore
)
:
class
DatesetDocumentStore
:
def
__init__
(
self
,
dataset
:
Dataset
,
...
...
@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
return
self
.
_embedding_model_name
@
property
def
docs
(
self
)
->
Dict
[
str
,
Base
Document
]:
def
docs
(
self
)
->
Dict
[
str
,
Document
]:
document_segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
)
.
all
()
...
...
@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
output
=
{}
for
document_segment
in
document_segments
:
doc_id
=
document_segment
.
index_node_id
result
=
self
.
segment_to_dict
(
document_segment
)
output
[
doc_id
]
=
json_to_doc
(
result
)
output
[
doc_id
]
=
Document
(
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
return
output
def
add_documents
(
self
,
docs
:
Sequence
[
Base
Document
],
allow_update
:
bool
=
True
self
,
docs
:
Sequence
[
Document
],
allow_update
:
bool
=
True
)
->
None
:
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
DocumentSegment
.
document
==
self
.
_document_id
...
...
@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
max_position
=
0
for
doc
in
docs
:
if
doc
.
is_doc_id_none
:
raise
ValueError
(
"doc
_id not se
t"
)
if
not
isinstance
(
doc
,
Document
)
:
raise
ValueError
(
"doc
must be a Documen
t"
)
if
not
isinstance
(
doc
,
Node
):
raise
ValueError
(
"doc must be a Node"
)
segment_document
=
self
.
get_document
(
doc_id
=
doc
.
get_doc_id
(),
raise_error
=
False
)
segment_document
=
self
.
get_document
(
doc_id
=
doc
.
metadata
[
'doc_id'
],
raise_error
=
False
)
# NOTE: doc could already exist in the store, but we overwrite it
if
not
allow_update
and
segment_document
:
raise
ValueError
(
f
"doc_id {doc.
get_doc_id()
} already exists. "
f
"doc_id {doc.
metadata['doc_id']
} already exists. "
"Set allow_update to True to overwrite."
)
# calc embedding use tokens
tokens
=
TokenCalculator
.
get_num_tokens
(
self
.
_embedding_model_name
,
doc
.
get_text
()
)
tokens
=
TokenCalculator
.
get_num_tokens
(
self
.
_embedding_model_name
,
doc
.
page_content
)
if
not
segment_document
:
max_position
+=
1
...
...
@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
tenant_id
=
self
.
_dataset
.
tenant_id
,
dataset_id
=
self
.
_dataset
.
id
,
document_id
=
self
.
_document_id
,
index_node_id
=
doc
.
get_doc_id
()
,
index_node_hash
=
doc
.
get_doc_hash
()
,
index_node_id
=
doc
.
metadata
[
'doc_id'
]
,
index_node_hash
=
doc
.
metadata
[
'doc_hash'
]
,
position
=
max_position
,
content
=
doc
.
get_text
()
,
word_count
=
len
(
doc
.
get_text
()
),
content
=
doc
.
page_content
,
word_count
=
len
(
doc
.
page_content
),
tokens
=
tokens
,
created_by
=
self
.
_user_id
,
)
db
.
session
.
add
(
segment_document
)
else
:
segment_document
.
content
=
doc
.
get_text
()
segment_document
.
index_node_hash
=
doc
.
get_doc_hash
()
segment_document
.
word_count
=
len
(
doc
.
get_text
()
)
segment_document
.
content
=
doc
.
page_content
segment_document
.
index_node_hash
=
doc
.
metadata
[
'doc_hash'
]
segment_document
.
word_count
=
len
(
doc
.
page_content
)
segment_document
.
tokens
=
tokens
db
.
session
.
commit
()
...
...
@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
def
get_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
Optional
[
Base
Document
]:
)
->
Optional
[
Document
]:
document_segment
=
self
.
get_document_segment
(
doc_id
)
if
document_segment
is
None
:
...
...
@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
else
:
return
None
result
=
self
.
segment_to_dict
(
document_segment
)
return
json_to_doc
(
result
)
return
Document
(
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
def
delete_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
None
:
document_segment
=
self
.
get_document_segment
(
doc_id
)
...
...
@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
return
document_segment
.
index_node_hash
def
update_docstore
(
self
,
other
:
"BaseDocumentStore"
)
->
None
:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self
.
add_documents
(
list
(
other
.
docs
.
values
()))
def
get_document_segment
(
self
,
doc_id
:
str
)
->
DocumentSegment
:
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
,
...
...
@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
)
.
first
()
return
document_segment
def
segment_to_dict
(
self
,
segment
:
DocumentSegment
)
->
Dict
[
str
,
Any
]:
return
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"text"
:
segment
.
content
,
"__type__"
:
Node
.
get_type
()
}
api/core/embedding/cached_embedding.py
0 → 100644
View file @
9e9d15ec
import
logging
from
typing
import
List
from
langchain.embeddings.base
import
Embeddings
from
sqlalchemy.exc
import
IntegrityError
from
extensions.ext_database
import
db
from
libs
import
helper
from
models.dataset
import
Embedding
class
CacheEmbedding
(
Embeddings
):
def
__init__
(
self
,
embeddings
:
Embeddings
):
self
.
_embeddings
=
embeddings
def
embed_documents
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Embed search docs."""
# use doc embedding cache or store if not exists
text_embeddings
=
[]
embedding_queue_texts
=
[]
for
text
in
texts
:
hash
=
helper
.
generate_text_hash
(
text
)
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
hash
)
.
first
()
if
embedding
:
text_embeddings
.
append
(
embedding
.
embedding
)
else
:
embedding_queue_texts
.
append
(
text
)
embedding_results
=
self
.
_embeddings
.
embed_documents
(
embedding_queue_texts
)
i
=
0
for
text
in
embedding_queue_texts
:
hash
=
helper
.
generate_text_hash
(
text
)
try
:
embedding
=
Embedding
(
hash
=
hash
)
embedding
.
set_embedding
(
embedding_results
[
i
])
db
.
session
.
add
(
embedding
)
db
.
session
.
commit
()
except
IntegrityError
:
db
.
session
.
rollback
()
continue
except
:
logging
.
exception
(
'Failed to add embedding to db'
)
continue
i
+=
1
embedding_queue_texts
.
extend
(
embedding_results
)
return
embedding_queue_texts
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash
=
helper
.
generate_text_hash
(
text
)
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
hash
)
.
first
()
if
embedding
:
return
embedding
.
embedding
embedding_results
=
self
.
_embeddings
.
embed_query
(
text
)
try
:
embedding
=
Embedding
(
hash
=
hash
)
embedding
.
set_embedding
(
embedding_results
)
db
.
session
.
add
(
embedding
)
db
.
session
.
commit
()
except
IntegrityError
:
db
.
session
.
rollback
()
except
:
logging
.
exception
(
'Failed to add embedding to db'
)
return
embedding_results
api/core/embedding/openai_embedding.py
deleted
100644 → 0
View file @
23ef2262
from
typing
import
Optional
,
Any
,
List
import
openai
from
llama_index.embeddings.base
import
BaseEmbedding
from
llama_index.embeddings.openai
import
OpenAIEmbeddingMode
,
OpenAIEmbeddingModelType
,
_QUERY_MODE_MODEL_DICT
,
\
_TEXT_MODE_MODEL_DICT
from
tenacity
import
wait_random_exponential
,
retry
,
stop_after_attempt
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
def
get_embedding
(
text
:
str
,
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
float
]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text
=
text
.
replace
(
"
\n
"
,
" "
)
return
openai
.
Embedding
.
create
(
input
=
[
text
],
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
)[
"data"
][
0
][
"embedding"
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
async
def
aget_embedding
(
text
:
str
,
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
float
]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text
=
text
.
replace
(
"
\n
"
,
" "
)
return
(
await
openai
.
Embedding
.
acreate
(
input
=
[
text
],
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
))[
"data"
][
0
][
"embedding"
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
def
get_embeddings
(
list_of_text
:
List
[
str
],
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
List
[
float
]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert
len
(
list_of_text
)
<=
2048
,
"The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
data
=
openai
.
Embedding
.
create
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
)
.
data
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
return
[
d
[
"embedding"
]
for
d
in
data
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
async
def
aget_embeddings
(
list_of_text
:
List
[
str
],
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
List
[
float
]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert
len
(
list_of_text
)
<=
2048
,
"The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
data
=
(
await
openai
.
Embedding
.
acreate
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
))
.
data
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
return
[
d
[
"embedding"
]
for
d
in
data
]
class
OpenAIEmbedding
(
BaseEmbedding
):
def
__init__
(
self
,
mode
:
str
=
OpenAIEmbeddingMode
.
TEXT_SEARCH_MODE
,
model
:
str
=
OpenAIEmbeddingModelType
.
TEXT_EMBED_ADA_002
,
deployment_name
:
Optional
[
str
]
=
None
,
openai_api_key
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
)
->
None
:
"""Init params."""
new_kwargs
=
{}
if
'embed_batch_size'
in
kwargs
:
new_kwargs
[
'embed_batch_size'
]
=
kwargs
[
'embed_batch_size'
]
if
'tokenizer'
in
kwargs
:
new_kwargs
[
'tokenizer'
]
=
kwargs
[
'tokenizer'
]
super
()
.
__init__
(
**
new_kwargs
)
self
.
mode
=
OpenAIEmbeddingMode
(
mode
)
self
.
model
=
OpenAIEmbeddingModelType
(
model
)
self
.
deployment_name
=
deployment_name
self
.
openai_api_key
=
openai_api_key
self
.
openai_api_type
=
kwargs
.
get
(
'openai_api_type'
)
self
.
openai_api_version
=
kwargs
.
get
(
'openai_api_version'
)
self
.
openai_api_base
=
kwargs
.
get
(
'openai_api_base'
)
@
handle_llm_exceptions
def
_get_query_embedding
(
self
,
query
:
str
)
->
List
[
float
]:
"""Get query embedding."""
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_QUERY_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_QUERY_MODE_MODEL_DICT
[
key
]
return
get_embedding
(
query
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
def
_get_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
"""Get text embedding."""
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
return
get_embedding
(
text
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
async
def
_aget_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
"""Asynchronously get text embedding."""
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
return
await
aget_embedding
(
text
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
def
_get_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Can be overriden for batch queries.
"""
if
self
.
openai_api_type
and
self
.
openai_api_type
==
'azure'
:
embeddings
=
[]
for
text
in
texts
:
embeddings
.
append
(
self
.
_get_text_embedding
(
text
))
return
embeddings
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
embeddings
=
get_embeddings
(
texts
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
return
embeddings
async
def
_aget_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Asynchronously get text embeddings."""
if
self
.
openai_api_type
and
self
.
openai_api_type
==
'azure'
:
embeddings
=
[]
for
text
in
texts
:
embeddings
.
append
(
await
self
.
_aget_text_embedding
(
text
))
return
embeddings
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
embeddings
=
await
aget_embeddings
(
texts
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
return
embeddings
api/core/index/base.py
0 → 100644
View file @
9e9d15ec
from
__future__
import
annotations
from
abc
import
abstractmethod
,
ABC
from
typing
import
List
,
Any
from
langchain.schema
import
Document
,
BaseRetriever
class
BaseIndex
(
ABC
):
@
abstractmethod
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
raise
NotImplementedError
@
abstractmethod
def
add_texts
(
self
,
texts
:
list
[
Document
]):
raise
NotImplementedError
@
abstractmethod
def
text_exists
(
self
,
id
:
str
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
raise
NotImplementedError
@
abstractmethod
def
delete_by_document_id
(
self
,
document_id
:
str
):
raise
NotImplementedError
@
abstractmethod
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
raise
NotImplementedError
@
abstractmethod
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
raise
NotImplementedError
def
_filter_duplicate_texts
(
self
,
texts
:
list
[
Document
])
->
list
[
Document
]:
for
text
in
texts
:
doc_id
=
text
.
metadata
[
'doc_id'
]
exists_duplicate_node
=
self
.
text_exists
(
doc_id
)
if
exists_duplicate_node
:
texts
.
remove
(
text
)
return
texts
def
_get_uuids
(
self
,
texts
:
list
[
Document
])
->
list
[
str
]:
return
[
text
.
metadata
[
'doc_id'
]
for
text
in
texts
]
api/core/index/index_builder.py
deleted
100644 → 0
View file @
23ef2262
from
langchain.callbacks
import
CallbackManager
from
llama_index
import
ServiceContext
,
PromptHelper
,
LLMPredictor
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.embedding.openai_embedding
import
OpenAIEmbedding
from
core.llm.llm_builder
import
LLMBuilder
class
IndexBuilder
:
@
classmethod
def
get_default_service_context
(
cls
,
tenant_id
:
str
)
->
ServiceContext
:
# set number of output tokens
num_output
=
512
# only for verbose
callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'text-davinci-003'
,
temperature
=
0
,
max_tokens
=
num_output
,
callback_manager
=
callback_manager
,
)
llm_predictor
=
LLMPredictor
(
llm
=
llm
)
# These parameters here will affect the logic of segmenting the final synthesized response.
# The number of refinement iterations in the synthesis process depends
# on whether the length of the segmented output exceeds the max_input_size.
prompt_helper
=
PromptHelper
(
max_input_size
=
3500
,
num_output
=
num_output
,
max_chunk_overlap
=
20
)
provider
=
LLMBuilder
.
get_default_provider
(
tenant_id
)
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
tenant_id
,
model_provider
=
provider
,
model_name
=
'text-embedding-ada-002'
)
return
ServiceContext
.
from_defaults
(
llm_predictor
=
llm_predictor
,
prompt_helper
=
prompt_helper
,
embed_model
=
OpenAIEmbedding
(
**
model_credentials
),
)
@
classmethod
def
get_fake_llm_service_context
(
cls
,
tenant_id
:
str
)
->
ServiceContext
:
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'fake'
)
return
ServiceContext
.
from_defaults
(
llm_predictor
=
LLMPredictor
(
llm
=
llm
),
embed_model
=
OpenAIEmbedding
()
)
api/core/index/keyword_table/jieba_keyword_table.py
deleted
100644 → 0
View file @
23ef2262
import
re
from
typing
import
(
Any
,
Dict
,
List
,
Set
,
Optional
)
import
jieba.analyse
from
core.index.keyword_table.stopwords
import
STOPWORDS
from
llama_index.indices.query.base
import
IS
from
llama_index
import
QueryMode
from
llama_index.indices.base
import
QueryMap
from
llama_index.indices.keyword_table.base
import
BaseGPTKeywordTableIndex
from
llama_index.indices.keyword_table.query
import
BaseGPTKeywordTableQuery
from
llama_index.docstore
import
BaseDocumentStore
from
llama_index.indices.postprocessor.node
import
(
BaseNodePostprocessor
,
)
from
llama_index.indices.response.response_builder
import
ResponseMode
from
llama_index.indices.service_context
import
ServiceContext
from
llama_index.optimization.optimizer
import
BaseTokenUsageOptimizer
from
llama_index.prompts.prompts
import
(
QuestionAnswerPrompt
,
RefinePrompt
,
SimpleInputPrompt
,
)
from
core.index.query.synthesizer
import
EnhanceResponseSynthesizer
def
jieba_extract_keywords
(
text_chunk
:
str
,
max_keywords
:
Optional
[
int
]
=
None
,
expand_with_subtokens
:
bool
=
True
,
)
->
Set
[
str
]:
"""Extract keywords with JIEBA tfidf."""
keywords
=
jieba
.
analyse
.
extract_tags
(
sentence
=
text_chunk
,
topK
=
max_keywords
,
)
if
expand_with_subtokens
:
return
set
(
expand_tokens_with_subtokens
(
keywords
))
else
:
return
set
(
keywords
)
def
expand_tokens_with_subtokens
(
tokens
:
Set
[
str
])
->
Set
[
str
]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results
=
set
()
for
token
in
tokens
:
results
.
add
(
token
)
sub_tokens
=
re
.
findall
(
r"\w+"
,
token
)
if
len
(
sub_tokens
)
>
1
:
results
.
update
({
w
for
w
in
sub_tokens
if
w
not
in
list
(
STOPWORDS
)})
return
results
class
GPTJIEBAKeywordTableIndex
(
BaseGPTKeywordTableIndex
):
"""GPT JIEBA Keyword Table Index.
This index uses a JIEBA keyword extractor to extract keywords from the text.
"""
def
_extract_keywords
(
self
,
text
:
str
)
->
Set
[
str
]:
"""Extract keywords from text."""
return
jieba_extract_keywords
(
text
,
max_keywords
=
self
.
max_keywords_per_chunk
)
@
classmethod
def
get_query_map
(
self
)
->
QueryMap
:
"""Get query map."""
super_map
=
super
()
.
get_query_map
()
super_map
[
QueryMode
.
DEFAULT
]
=
GPTKeywordTableJIEBAQuery
return
super_map
def
_delete
(
self
,
doc_id
:
str
,
**
delete_kwargs
:
Any
)
->
None
:
"""Delete a document."""
# get set of ids that correspond to node
node_idxs_to_delete
=
{
doc_id
}
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete
=
set
()
for
keyword
,
node_idxs
in
self
.
_index_struct
.
table
.
items
():
if
node_idxs_to_delete
.
intersection
(
node_idxs
):
self
.
_index_struct
.
table
[
keyword
]
=
node_idxs
.
difference
(
node_idxs_to_delete
)
if
not
self
.
_index_struct
.
table
[
keyword
]:
keywords_to_delete
.
add
(
keyword
)
for
keyword
in
keywords_to_delete
:
del
self
.
_index_struct
.
table
[
keyword
]
class
GPTKeywordTableJIEBAQuery
(
BaseGPTKeywordTableQuery
):
"""GPT Keyword Table Index JIEBA Query.
Extracts keywords using JIEBA keyword extractor.
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
.. code-block:: python
response = index.query("<query_str>", mode="jieba")
See BaseGPTKeywordTableQuery for arguments.
"""
@
classmethod
def
from_args
(
cls
,
index_struct
:
IS
,
service_context
:
ServiceContext
,
docstore
:
Optional
[
BaseDocumentStore
]
=
None
,
node_postprocessors
:
Optional
[
List
[
BaseNodePostprocessor
]]
=
None
,
verbose
:
bool
=
False
,
# response synthesizer args
response_mode
:
ResponseMode
=
ResponseMode
.
DEFAULT
,
text_qa_template
:
Optional
[
QuestionAnswerPrompt
]
=
None
,
refine_template
:
Optional
[
RefinePrompt
]
=
None
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
response_kwargs
:
Optional
[
Dict
]
=
None
,
use_async
:
bool
=
False
,
streaming
:
bool
=
False
,
optimizer
:
Optional
[
BaseTokenUsageOptimizer
]
=
None
,
# class-specific args
**
kwargs
:
Any
,
)
->
"BaseGPTIndexQuery"
:
response_synthesizer
=
EnhanceResponseSynthesizer
.
from_args
(
service_context
=
service_context
,
text_qa_template
=
text_qa_template
,
refine_template
=
refine_template
,
simple_template
=
simple_template
,
response_mode
=
response_mode
,
response_kwargs
=
response_kwargs
,
use_async
=
use_async
,
streaming
=
streaming
,
optimizer
=
optimizer
,
)
return
cls
(
index_struct
=
index_struct
,
service_context
=
service_context
,
response_synthesizer
=
response_synthesizer
,
docstore
=
docstore
,
node_postprocessors
=
node_postprocessors
,
verbose
=
verbose
,
**
kwargs
,
)
def
_get_keywords
(
self
,
query_str
:
str
)
->
List
[
str
]:
"""Extract keywords."""
return
list
(
jieba_extract_keywords
(
query_str
,
max_keywords
=
self
.
max_keywords_per_query
)
)
api/core/index/keyword_table_index.py
deleted
100644 → 0
View file @
23ef2262
import
json
from
typing
import
List
,
Optional
from
llama_index
import
ServiceContext
,
LLMPredictor
,
OpenAIEmbedding
from
llama_index.data_structs
import
KeywordTable
,
Node
from
llama_index.indices.keyword_table.base
import
BaseGPTKeywordTableIndex
from
llama_index.indices.registry
import
load_index_struct_from_dict
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.docstore.empty_docstore
import
EmptyDocumentStore
from
core.index.index_builder
import
IndexBuilder
from
core.index.keyword_table.jieba_keyword_table
import
GPTJIEBAKeywordTableIndex
from
core.llm.llm_builder
import
LLMBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DatasetKeywordTable
,
DocumentSegment
class
KeywordTableIndex
:
def
__init__
(
self
,
dataset
:
Dataset
):
self
.
_dataset
=
dataset
def
add_nodes
(
self
,
nodes
:
List
[
Node
]):
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
_dataset
.
tenant_id
,
model_name
=
'fake'
)
service_context
=
ServiceContext
.
from_defaults
(
llm_predictor
=
LLMPredictor
(
llm
=
llm
),
embed_model
=
OpenAIEmbedding
()
)
dataset_keyword_table
=
self
.
get_keyword_table
()
if
not
dataset_keyword_table
or
not
dataset_keyword_table
.
keyword_table_dict
:
index_struct
=
KeywordTable
()
else
:
index_struct_dict
=
dataset_keyword_table
.
keyword_table_dict
index_struct
:
KeywordTable
=
load_index_struct_from_dict
(
index_struct_dict
)
# create index
index
=
GPTJIEBAKeywordTableIndex
(
index_struct
=
index_struct
,
docstore
=
EmptyDocumentStore
(),
service_context
=
service_context
)
for
node
in
nodes
:
keywords
=
index
.
_extract_keywords
(
node
.
get_text
())
self
.
update_segment_keywords
(
node
.
doc_id
,
list
(
keywords
))
index
.
_index_struct
.
add_node
(
list
(
keywords
),
node
)
index_struct_dict
=
index
.
index_struct
.
to_dict
()
if
not
dataset_keyword_table
:
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
_dataset
.
id
,
keyword_table
=
json
.
dumps
(
index_struct_dict
)
)
db
.
session
.
add
(
dataset_keyword_table
)
else
:
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
index_struct_dict
)
db
.
session
.
commit
()
def
del_nodes
(
self
,
node_ids
:
List
[
str
]):
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
_dataset
.
tenant_id
,
model_name
=
'fake'
)
service_context
=
ServiceContext
.
from_defaults
(
llm_predictor
=
LLMPredictor
(
llm
=
llm
),
embed_model
=
OpenAIEmbedding
()
)
dataset_keyword_table
=
self
.
get_keyword_table
()
if
not
dataset_keyword_table
or
not
dataset_keyword_table
.
keyword_table_dict
:
return
else
:
index_struct_dict
=
dataset_keyword_table
.
keyword_table_dict
index_struct
:
KeywordTable
=
load_index_struct_from_dict
(
index_struct_dict
)
# create index
index
=
GPTJIEBAKeywordTableIndex
(
index_struct
=
index_struct
,
docstore
=
EmptyDocumentStore
(),
service_context
=
service_context
)
for
node_id
in
node_ids
:
index
.
delete
(
node_id
)
index_struct_dict
=
index
.
index_struct
.
to_dict
()
if
not
dataset_keyword_table
:
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
_dataset
.
id
,
keyword_table
=
json
.
dumps
(
index_struct_dict
)
)
db
.
session
.
add
(
dataset_keyword_table
)
else
:
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
index_struct_dict
)
db
.
session
.
commit
()
@
property
def
query_index
(
self
)
->
Optional
[
BaseGPTKeywordTableIndex
]:
docstore
=
DatesetDocumentStore
(
dataset
=
self
.
_dataset
,
user_id
=
self
.
_dataset
.
created_by
,
embedding_model_name
=
"text-embedding-ada-002"
)
service_context
=
IndexBuilder
.
get_default_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
dataset_keyword_table
=
self
.
get_keyword_table
()
if
not
dataset_keyword_table
or
not
dataset_keyword_table
.
keyword_table_dict
:
return
None
index_struct
:
KeywordTable
=
load_index_struct_from_dict
(
dataset_keyword_table
.
keyword_table_dict
)
return
GPTJIEBAKeywordTableIndex
(
index_struct
=
index_struct
,
docstore
=
docstore
,
service_context
=
service_context
)
def
get_keyword_table
(
self
):
dataset_keyword_table
=
self
.
_dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
return
dataset_keyword_table
return
None
def
update_segment_keywords
(
self
,
node_id
:
str
,
keywords
:
List
[
str
]):
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
index_node_id
==
node_id
)
.
first
()
if
document_segment
:
document_segment
.
keywords
=
keywords
db
.
session
.
commit
()
api/core/index/keyword_table_index/jieba_keyword_table_handler.py
0 → 100644
View file @
9e9d15ec
import
re
from
typing
import
Set
import
jieba
from
jieba.analyse
import
default_tfidf
from
core.index.keyword_table_index.stopwords
import
STOPWORDS
class
JiebaKeywordTableHandler
:
def
__init__
(
self
):
default_tfidf
.
stop_words
=
STOPWORDS
def
extract_keywords
(
self
,
text
:
str
,
max_keywords_per_chunk
:
int
=
10
)
->
Set
[
str
]:
"""Extract keywords with JIEBA tfidf."""
keywords
=
jieba
.
analyse
.
extract_tags
(
sentence
=
text
,
topK
=
max_keywords_per_chunk
,
)
return
set
(
self
.
_expand_tokens_with_subtokens
(
keywords
))
def
_expand_tokens_with_subtokens
(
self
,
tokens
:
Set
[
str
])
->
Set
[
str
]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results
=
set
()
for
token
in
tokens
:
results
.
add
(
token
)
sub_tokens
=
re
.
findall
(
r"\w+"
,
token
)
if
len
(
sub_tokens
)
>
1
:
results
.
update
({
w
for
w
in
sub_tokens
if
w
not
in
list
(
STOPWORDS
)})
return
results
\ No newline at end of file
api/core/index/keyword_table_index/keyword_table_index.py
0 → 100644
View file @
9e9d15ec
import
json
from
collections
import
defaultdict
from
typing
import
Any
,
List
,
Optional
,
Dict
from
langchain.schema
import
Document
,
BaseRetriever
from
pydantic
import
BaseModel
,
Field
from
core.index.base
import
BaseIndex
from
core.index.keyword_table_index.jieba_keyword_table_handler
import
JiebaKeywordTableHandler
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetKeywordTable
class
KeywordTableConfig
(
BaseModel
):
max_keywords_per_chunk
:
int
=
10
class
KeywordTableIndex
(
BaseIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
KeywordTableConfig
=
KeywordTableConfig
()):
self
.
_dataset
=
dataset
self
.
_config
=
config
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
{}
for
text
in
texts
:
keywords
=
keyword_table_handler
.
extract_keywords
(
text
.
page_content
,
self
.
_config
.
max_keywords_per_chunk
)
self
.
_update_segment_keywords
(
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
keyword_table
=
self
.
_add_text_to_keyword_table
(
keyword_table
,
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
_dataset
.
id
,
keyword_table
=
json
.
dumps
(
keyword_table
)
)
db
.
session
.
add
(
dataset_keyword_table
)
db
.
session
.
commit
()
return
self
def
add_texts
(
self
,
texts
:
list
[
Document
]):
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
self
.
_get_dataset_keyword_table
()
for
text
in
texts
:
keywords
=
keyword_table_handler
.
extract_keywords
(
text
.
page_content
,
self
.
_config
.
max_keywords_per_chunk
)
self
.
_update_segment_keywords
(
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
keyword_table
=
self
.
_add_text_to_keyword_table
(
keyword_table
,
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
self
.
_dataset
.
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
keyword_table
)
db
.
session
.
commit
()
def
text_exists
(
self
,
id
:
str
)
->
bool
:
keyword_table
=
self
.
_get_dataset_keyword_table
()
return
id
in
set
.
union
(
*
keyword_table
.
values
())
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
keyword_table
=
self
.
_get_dataset_keyword_table
()
keyword_table
=
self
.
_delete_ids_from_keyword_table
(
keyword_table
,
ids
)
self
.
_dataset
.
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
keyword_table
)
db
.
session
.
commit
()
def
delete_by_document_id
(
self
,
document_id
:
str
):
# get segment ids by document_id
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
,
DocumentSegment
.
document_id
==
document_id
)
.
all
()
ids
=
[
segment
.
id
for
segment
in
segments
]
keyword_table
=
self
.
_get_dataset_keyword_table
()
keyword_table
=
self
.
_delete_ids_from_keyword_table
(
keyword_table
,
ids
)
self
.
_dataset
.
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
keyword_table
)
db
.
session
.
commit
()
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
return
KeywordTableRetriever
(
index
=
self
,
**
kwargs
)
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
keyword_table
=
self
.
_get_dataset_keyword_table
()
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
k
=
search_kwargs
.
get
(
'k'
)
if
search_kwargs
.
get
(
'k'
)
else
4
sorted_chunk_indices
=
self
.
_retrieve_ids_by_query
(
keyword_table
,
query
,
k
)
documents
=
[]
for
chunk_index
in
sorted_chunk_indices
:
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
,
DocumentSegment
.
index_node_id
==
chunk_index
)
.
first
()
if
segment
:
documents
.
append
(
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
chunk_index
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
))
return
documents
def
_get_dataset_keyword_table
(
self
)
->
Optional
[
dict
]:
keyword_table_dict
=
self
.
_dataset
.
dataset_keyword_table
.
keyword_table_dict
if
keyword_table_dict
:
return
keyword_table_dict
return
{}
def
_add_text_to_keyword_table
(
self
,
keyword_table
:
dict
,
id
:
str
,
keywords
:
list
[
str
])
->
dict
:
for
keyword
in
keywords
:
if
keyword
not
in
keyword_table
:
keyword_table
[
keyword
]
=
set
()
keyword_table
[
keyword
]
.
add
(
id
)
return
keyword_table
def
_delete_ids_from_keyword_table
(
self
,
keyword_table
:
dict
,
ids
:
list
[
str
])
->
dict
:
# get set of ids that correspond to node
node_idxs_to_delete
=
set
(
ids
)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete
=
set
()
for
keyword
,
node_idxs
in
keyword_table
.
items
():
if
node_idxs_to_delete
.
intersection
(
node_idxs
):
keyword_table
[
keyword
]
=
node_idxs
.
difference
(
node_idxs_to_delete
)
if
not
keyword_table
[
keyword
]:
keywords_to_delete
.
add
(
keyword
)
for
keyword
in
keywords_to_delete
:
del
keyword_table
[
keyword
]
return
keyword_table
def
_retrieve_ids_by_query
(
self
,
keyword_table
:
dict
,
query
:
str
,
k
:
int
=
4
):
keyword_table_handler
=
JiebaKeywordTableHandler
()
keywords
=
keyword_table_handler
.
extract_keywords
(
query
)
# go through text chunks in order of most matching keywords
chunk_indices_count
:
Dict
[
str
,
int
]
=
defaultdict
(
int
)
keywords
=
[
k
for
k
in
keywords
if
k
in
set
(
keyword_table
.
keys
())]
for
k
in
keywords
:
for
node_id
in
keyword_table
[
k
]:
chunk_indices_count
[
node_id
]
+=
1
sorted_chunk_indices
=
sorted
(
list
(
chunk_indices_count
.
keys
()),
key
=
lambda
x
:
chunk_indices_count
[
x
],
reverse
=
True
,
)
return
sorted_chunk_indices
[:
k
]
def
_update_segment_keywords
(
self
,
node_id
:
str
,
keywords
:
List
[
str
]):
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
index_node_id
==
node_id
)
.
first
()
if
document_segment
:
document_segment
.
keywords
=
keywords
db
.
session
.
commit
()
class
KeywordTableRetriever
(
BaseRetriever
,
BaseModel
):
index
:
KeywordTableIndex
search_kwargs
:
dict
=
Field
(
default_factory
=
dict
)
def
get_relevant_documents
(
self
,
query
:
str
)
->
List
[
Document
]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return
self
.
index
.
search
(
query
,
**
self
.
search_kwargs
)
async
def
aget_relevant_documents
(
self
,
query
:
str
)
->
List
[
Document
]:
raise
NotImplementedError
(
"KeywordTableRetriever does not support async"
)
api/core/index/keyword_table/stopwords.py
→
api/core/index/keyword_table
_index
/stopwords.py
View file @
9e9d15ec
File moved
api/core/index/query/synthesizer.py
deleted
100644 → 0
View file @
23ef2262
from
typing
import
(
Any
,
Dict
,
Optional
,
Sequence
,
)
from
llama_index.indices.response.response_synthesis
import
ResponseSynthesizer
from
llama_index.indices.response.response_builder
import
ResponseMode
,
BaseResponseBuilder
,
get_response_builder
from
llama_index.indices.service_context
import
ServiceContext
from
llama_index.optimization.optimizer
import
BaseTokenUsageOptimizer
from
llama_index.prompts.prompts
import
(
QuestionAnswerPrompt
,
RefinePrompt
,
SimpleInputPrompt
,
)
from
llama_index.types
import
RESPONSE_TEXT_TYPE
class
EnhanceResponseSynthesizer
(
ResponseSynthesizer
):
@
classmethod
def
from_args
(
cls
,
service_context
:
ServiceContext
,
streaming
:
bool
=
False
,
use_async
:
bool
=
False
,
text_qa_template
:
Optional
[
QuestionAnswerPrompt
]
=
None
,
refine_template
:
Optional
[
RefinePrompt
]
=
None
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
response_mode
:
ResponseMode
=
ResponseMode
.
DEFAULT
,
response_kwargs
:
Optional
[
Dict
]
=
None
,
optimizer
:
Optional
[
BaseTokenUsageOptimizer
]
=
None
,
)
->
"ResponseSynthesizer"
:
response_builder
:
Optional
[
BaseResponseBuilder
]
=
None
if
response_mode
!=
ResponseMode
.
NO_TEXT
:
if
response_mode
==
'no_synthesizer'
:
response_builder
=
NoSynthesizer
(
service_context
=
service_context
,
simple_template
=
simple_template
,
streaming
=
streaming
,
)
else
:
response_builder
=
get_response_builder
(
service_context
,
text_qa_template
,
refine_template
,
simple_template
,
response_mode
,
use_async
=
use_async
,
streaming
=
streaming
,
)
return
cls
(
response_builder
,
response_mode
,
response_kwargs
,
optimizer
)
class
NoSynthesizer
(
BaseResponseBuilder
):
def
__init__
(
self
,
service_context
:
ServiceContext
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
streaming
:
bool
=
False
,
)
->
None
:
super
()
.
__init__
(
service_context
,
streaming
)
async
def
aget_response
(
self
,
query_str
:
str
,
text_chunks
:
Sequence
[
str
],
prev_response
:
Optional
[
str
]
=
None
,
**
response_kwargs
:
Any
,
)
->
RESPONSE_TEXT_TYPE
:
return
"
\n
"
.
join
(
text_chunks
)
def
get_response
(
self
,
query_str
:
str
,
text_chunks
:
Sequence
[
str
],
prev_response
:
Optional
[
str
]
=
None
,
**
response_kwargs
:
Any
,
)
->
RESPONSE_TEXT_TYPE
:
return
"
\n
"
.
join
(
text_chunks
)
\ No newline at end of file
api/core/index/readers/html_parser.py
deleted
100644 → 0
View file @
23ef2262
from
pathlib
import
Path
from
typing
import
Dict
from
bs4
import
BeautifulSoup
from
llama_index.readers.file.base_parser
import
BaseParser
class
HTMLParser
(
BaseParser
):
"""HTML parser."""
def
_init_parser
(
self
)
->
Dict
:
"""Init parser."""
return
{}
def
parse_file
(
self
,
file
:
Path
,
errors
:
str
=
"ignore"
)
->
str
:
"""Parse file."""
with
open
(
file
,
"rb"
)
as
fp
:
soup
=
BeautifulSoup
(
fp
,
'html.parser'
)
text
=
soup
.
get_text
()
text
=
text
.
strip
()
if
text
else
''
return
text
api/core/index/readers/xlsx_parser.py
deleted
100644 → 0
View file @
23ef2262
from
pathlib
import
Path
import
json
from
typing
import
Dict
from
openpyxl
import
load_workbook
from
llama_index.readers.file.base_parser
import
BaseParser
from
flask
import
current_app
class
XLSXParser
(
BaseParser
):
"""XLSX parser."""
def
_init_parser
(
self
)
->
Dict
:
"""Init parser"""
return
{}
def
parse_file
(
self
,
file
:
Path
,
errors
:
str
=
"ignore"
)
->
str
:
data
=
[]
keys
=
[]
with
open
(
file
,
"r"
)
as
fp
:
wb
=
load_workbook
(
filename
=
file
,
read_only
=
True
)
# loop over all sheets
for
sheet
in
wb
:
for
row
in
sheet
.
iter_rows
(
values_only
=
True
):
if
all
(
v
is
None
for
v
in
row
):
continue
if
keys
==
[]:
keys
=
list
(
map
(
str
,
row
))
else
:
row_dict
=
dict
(
zip
(
keys
,
row
))
row_dict
=
{
k
:
v
for
k
,
v
in
row_dict
.
items
()
if
v
}
data
.
append
(
json
.
dumps
(
row_dict
,
ensure_ascii
=
False
))
return
'
\n\n
'
.
join
(
data
)
api/core/index/vector_index.py
deleted
100644 → 0
View file @
23ef2262
import
json
import
logging
from
typing
import
List
,
Optional
from
llama_index.data_structs
import
Node
from
requests
import
ReadTimeout
from
sqlalchemy.exc
import
IntegrityError
from
tenacity
import
retry
,
stop_after_attempt
,
retry_if_exception_type
from
core.index.index_builder
import
IndexBuilder
from
core.vector_store.base
import
BaseGPTVectorStoreIndex
from
extensions.ext_vector_store
import
vector_store
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
Embedding
class
VectorIndex
:
def
__init__
(
self
,
dataset
:
Dataset
):
self
.
_dataset
=
dataset
def
add_nodes
(
self
,
nodes
:
List
[
Node
],
duplicate_check
:
bool
=
False
):
if
not
self
.
_dataset
.
index_struct_dict
:
index_id
=
"Vector_index_"
+
self
.
_dataset
.
id
.
replace
(
"-"
,
"_"
)
self
.
_dataset
.
index_struct
=
json
.
dumps
(
vector_store
.
to_index_struct
(
index_id
))
db
.
session
.
commit
()
service_context
=
IndexBuilder
.
get_default_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
index
=
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
if
duplicate_check
:
nodes
=
self
.
_filter_duplicate_nodes
(
index
,
nodes
)
embedding_queue_nodes
=
[]
embedded_nodes
=
[]
for
node
in
nodes
:
node_hash
=
node
.
doc_hash
# if node hash in cached embedding tables, use cached embedding
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
node_hash
)
.
first
()
if
embedding
:
node
.
embedding
=
embedding
.
get_embedding
()
embedded_nodes
.
append
(
node
)
else
:
embedding_queue_nodes
.
append
(
node
)
if
embedding_queue_nodes
:
embedding_results
=
index
.
_get_node_embedding_results
(
embedding_queue_nodes
,
set
(),
)
# pre embed nodes for cached embedding
for
embedding_result
in
embedding_results
:
node
=
embedding_result
.
node
node
.
embedding
=
embedding_result
.
embedding
try
:
embedding
=
Embedding
(
hash
=
node
.
doc_hash
)
embedding
.
set_embedding
(
node
.
embedding
)
db
.
session
.
add
(
embedding
)
db
.
session
.
commit
()
except
IntegrityError
:
db
.
session
.
rollback
()
continue
except
:
logging
.
exception
(
'Failed to add embedding to db'
)
continue
embedded_nodes
.
append
(
node
)
self
.
index_insert_nodes
(
index
,
embedded_nodes
)
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
index_insert_nodes
(
self
,
index
:
BaseGPTVectorStoreIndex
,
nodes
:
List
[
Node
]):
index
.
insert_nodes
(
nodes
)
def
del_nodes
(
self
,
node_ids
:
List
[
str
]):
if
not
self
.
_dataset
.
index_struct_dict
:
return
service_context
=
IndexBuilder
.
get_fake_llm_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
index
=
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
for
node_id
in
node_ids
:
self
.
index_delete_node
(
index
,
node_id
)
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
index_delete_node
(
self
,
index
:
BaseGPTVectorStoreIndex
,
node_id
:
str
):
index
.
delete_node
(
node_id
)
def
del_doc
(
self
,
doc_id
:
str
):
if
not
self
.
_dataset
.
index_struct_dict
:
return
service_context
=
IndexBuilder
.
get_fake_llm_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
index
=
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
self
.
index_delete_doc
(
index
,
doc_id
)
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
index_delete_doc
(
self
,
index
:
BaseGPTVectorStoreIndex
,
doc_id
:
str
):
index
.
delete
(
doc_id
)
@
property
def
query_index
(
self
)
->
Optional
[
BaseGPTVectorStoreIndex
]:
if
not
self
.
_dataset
.
index_struct_dict
:
return
None
service_context
=
IndexBuilder
.
get_default_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
return
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
def
_filter_duplicate_nodes
(
self
,
index
:
BaseGPTVectorStoreIndex
,
nodes
:
List
[
Node
])
->
List
[
Node
]:
for
node
in
nodes
:
node_id
=
node
.
doc_id
exists_duplicate_node
=
index
.
exists_by_node_id
(
node_id
)
if
exists_duplicate_node
:
nodes
.
remove
(
node
)
return
nodes
api/core/index/vector_index/base.py
0 → 100644
View file @
9e9d15ec
from
abc
import
abstractmethod
from
typing
import
List
,
Any
,
Tuple
from
langchain.schema
import
Document
from
langchain.vectorstores
import
VectorStore
from
index.base
import
BaseIndex
class
BaseVectorIndex
(
BaseIndex
):
def
get_type
(
self
)
->
str
:
raise
NotImplementedError
@
abstractmethod
def
get_index_name
(
self
,
dataset_id
:
str
)
->
str
:
raise
NotImplementedError
@
abstractmethod
def
to_index_struct
(
self
)
->
dict
:
raise
NotImplementedError
@
abstractmethod
def
_get_vector_store
(
self
)
->
VectorStore
:
raise
NotImplementedError
api/core/index/vector_index/qdrant_vector_index.py
0 → 100644
View file @
9e9d15ec
import
os
from
typing
import
Optional
,
Any
,
List
,
cast
import
qdrant_client
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.qdrant_vector_store
import
QdrantVectorStore
from
models.dataset
import
Dataset
class
QdrantConfig
(
BaseModel
):
endpoint
:
str
api_key
:
Optional
[
str
]
root_path
:
Optional
[
str
]
def
to_qdrant_params
(
self
):
if
self
.
endpoint
and
self
.
endpoint
.
startswith
(
'path:'
):
path
=
self
.
endpoint
.
replace
(
'path:'
,
''
)
if
not
os
.
path
.
isabs
(
path
):
path
=
os
.
path
.
join
(
self
.
root_path
,
path
)
return
{
'path'
:
path
}
else
:
return
{
'url'
:
self
.
endpoint
,
'api_key'
:
self
.
api_key
,
}
class
QdrantVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
QdrantConfig
,
embeddings
:
Embeddings
):
self
.
_dataset
=
dataset
self
.
_client_config
=
config
self
.
_embeddings
=
embeddings
self
.
_vector_store
=
None
def
get_type
(
self
)
->
str
:
return
'qdrant'
def
get_index_name
(
self
,
dataset_id
:
str
)
->
str
:
return
"Vector_index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
+
'_Node'
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
())}
}
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
QdrantVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
collection_name
=
self
.
get_index_name
(
self
.
_dataset
.
get_id
()),
ids
=
uuids
,
**
self
.
_client_config
.
to_qdrant_params
()
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
return
self
.
_vector_store
client
=
qdrant_client
.
QdrantClient
(
**
self
.
_client_config
.
to_qdrant_params
()
)
return
QdrantVectorStore
(
client
=
client
,
collection_name
=
self
.
get_index_name
(
self
.
_dataset
.
get_id
()),
embeddings
=
self
.
_embeddings
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
def
delete_by_document_id
(
self
,
document_id
:
str
):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
from
qdrant_client.http
import
models
vector_store
.
del_texts
(
models
.
Filter
(
must
=
[
models
.
FieldCondition
(
key
=
"metadata.document_id"
,
match
=
models
.
MatchValue
(
value
=
document_id
),
),
],
))
api/core/index/vector_index/vector_index.py
0 → 100644
View file @
9e9d15ec
import
json
from
flask
import
current_app
from
langchain.embeddings.base
import
Embeddings
from
core.index.vector_index.base
import
BaseVectorIndex
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
Document
class
VectorIndex
:
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
dict
,
embeddings
:
Embeddings
):
self
.
_dataset
=
dataset
self
.
_embeddings
=
embeddings
self
.
_vector_index
=
self
.
_init_vector_index
(
dataset
,
config
,
embeddings
)
def
_init_vector_index
(
self
,
dataset
:
Dataset
,
config
:
dict
,
embeddings
:
Embeddings
)
->
BaseVectorIndex
:
vector_type
=
config
.
get
(
'VECTOR_STORE'
)
if
self
.
_dataset
.
index_struct_dict
:
vector_type
=
self
.
_dataset
.
index_struct_dict
[
'type'
]
if
not
vector_type
:
raise
ValueError
(
f
"Vector store must be specified."
)
if
vector_type
==
"weaviate"
:
from
core.index.vector_index.weaviate_vector_index
import
WeaviateVectorIndex
,
WeaviateConfig
return
WeaviateVectorIndex
(
dataset
=
dataset
,
config
=
WeaviateConfig
(
endpoint
=
config
.
get
(
'WEAVIATE_URL'
),
api_key
=
config
.
get
(
'WEAVIATE_API_KEY'
),
batch_size
=
int
(
config
.
get
(
'WEAVIATE_BATCH_SIZE'
))
),
embeddings
=
embeddings
,
attributes
=
[
'doc_id'
,
'dataset_id'
,
'document_id'
,
'source'
],
)
elif
vector_type
==
"qdrant"
:
from
core.index.vector_index.qdrant_vector_index
import
QdrantVectorIndex
,
QdrantConfig
return
QdrantVectorIndex
(
dataset
=
dataset
,
config
=
QdrantConfig
(
endpoint
=
config
.
get
(
'QDRANT_URL'
),
api_key
=
config
.
get
(
'QDRANT_API_KEY'
),
root_path
=
current_app
.
root_path
),
embeddings
=
embeddings
)
else
:
raise
ValueError
(
f
"Vector store {config.get('VECTOR_STORE')} is not supported."
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
if
not
self
.
_dataset
.
index_struct_dict
:
self
.
_vector_index
.
create
(
texts
)
self
.
_dataset
.
index_struct
=
json
.
dumps
(
self
.
_vector_index
.
to_index_struct
())
db
.
session
.
commit
()
return
self
.
_vector_index
.
add_texts
(
texts
)
def
__getattr__
(
self
,
name
):
if
self
.
_vector_index
is
not
None
:
method
=
getattr
(
self
.
_vector_index
,
name
)
if
callable
(
method
):
return
method
raise
AttributeError
(
f
"'VectorIndex' object has no attribute '{name}'"
)
api/core/index/vector_index/weaviate_vector_index.py
0 → 100644
View file @
9e9d15ec
from
typing
import
Optional
,
Any
,
List
,
cast
import
weaviate
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.weaviate_vector_store
import
WeaviateVectorStore
from
models.dataset
import
Dataset
class
WeaviateConfig
(
BaseModel
):
endpoint
:
str
api_key
:
Optional
[
str
]
batch_size
:
int
=
100
class
WeaviateVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
WeaviateConfig
,
embeddings
:
Embeddings
,
attributes
:
list
[
str
]):
self
.
_dataset
=
dataset
self
.
_client
=
self
.
_init_client
(
config
)
self
.
_embeddings
=
embeddings
self
.
_attributes
=
attributes
self
.
_vector_store
=
None
def
_init_client
(
self
,
config
:
WeaviateConfig
)
->
weaviate
.
Client
:
auth_config
=
weaviate
.
auth
.
AuthApiKey
(
api_key
=
config
.
api_key
)
weaviate
.
connect
.
connection
.
has_grpc
=
False
client
=
weaviate
.
Client
(
url
=
config
.
endpoint
,
auth_client_secret
=
auth_config
,
timeout_config
=
(
5
,
60
),
startup_period
=
None
)
client
.
batch
.
configure
(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size
=
config
.
batch_size
,
# dynamically update the `batch_size` based on import speed
dynamic
=
True
,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries
=
3
,
)
return
client
def
get_type
(
self
)
->
str
:
return
'weaviate'
def
get_index_name
(
self
,
dataset_id
:
str
)
->
str
:
return
"Vector_index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
+
'_Node'
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
())}
}
def
create
(
self
,
texts
:
list
[
Document
])
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
_dataset
.
get_id
()),
uuids
=
uuids
,
by_text
=
False
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
return
self
.
_vector_store
return
WeaviateVectorStore
(
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
_dataset
.
get_id
()),
text_key
=
'text'
,
embedding
=
self
.
_embeddings
,
attributes
=
self
.
_attributes
,
by_text
=
False
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
def
delete_by_document_id
(
self
,
document_id
:
str
):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
vector_store
.
del_texts
({
"operator"
:
"Equal"
,
"path"
:
[
"document_id"
],
"valueText"
:
document_id
})
api/core/indexing_runner.py
View file @
9e9d15ec
import
datetime
import
json
import
re
import
tempfile
import
time
from
pathlib
import
Path
from
typing
import
Optional
,
List
import
uuid
from
typing
import
Optional
,
List
,
cast
from
flask
import
current_app
from
flask_login
import
current_user
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.schema
import
Document
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
,
TextSplitter
from
llama_index
import
SimpleDirectoryReader
from
llama_index.data_structs
import
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
llama_index.node_parser
import
SimpleNodeParser
,
NodeParser
from
llama_index.readers.file.base
import
DEFAULT_FILE_EXTRACTOR
from
llama_index.readers.file.markdown_parser
import
MarkdownParser
from
core.data_source.notion
import
NotionPageReader
from
core.index.readers.xlsx_parser
import
XLSXParser
from
core.data_loader.file_extractor
import
FileExtractor
from
core.data_loader.loader.notion
import
NotionLoader
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.readers.html_parser
import
HTMLParser
from
core.index.readers.markdown_parser
import
MarkdownParser
from
core.index.readers.pdf_parser
import
PDFParser
from
core.index.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.index.vector_index
import
VectorIndex
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
core.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.llm.token_calculator
import
TokenCalculator
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_storage
import
storage
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
,
DatasetProcessRule
from
libs
import
helper
from
models.dataset
import
Document
as
DatasetDocument
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetProcessRule
from
models.model
import
UploadFile
from
models.source
import
DataSourceBinding
...
...
@@ -40,49 +36,49 @@ class IndexingRunner:
self
.
storage
=
storage
self
.
embedding_model_name
=
embedding_model_name
def
run
(
self
,
d
ocuments
:
List
[
Document
]):
def
run
(
self
,
d
ataset_documents
:
List
[
Dataset
Document
]):
"""Run the indexing process."""
for
d
ocument
in
documents
:
for
d
ataset_document
in
dataset_
documents
:
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
document
.
dataset_id
id
=
d
ataset_d
ocument
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# load file
text_docs
=
self
.
_load_data
(
document
)
text_docs
=
self
.
_load_data
(
d
ataset_d
ocument
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
document
.
dataset_process_rule_id
)
.
\
filter
(
DatasetProcessRule
.
id
==
d
ataset_d
ocument
.
dataset_process_rule_id
)
.
\
first
()
# get
node parser for splitting
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
# get
splitter
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
node
s
=
self
.
_step_split
(
# split to
document
s
document
s
=
self
.
_step_split
(
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
dataset
=
dataset
,
d
ocument
=
document
,
d
ataset_document
=
dataset_
document
,
processing_rule
=
processing_rule
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
d
ocument
=
document
,
nodes
=
node
s
d
ataset_document
=
dataset_
document
,
documents
=
document
s
)
def
run_in_splitting_status
(
self
,
d
ocument
:
Document
):
def
run_in_splitting_status
(
self
,
d
ataset_document
:
Dataset
Document
):
"""Run the indexing process when the index_status is splitting."""
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
document
.
dataset_id
id
=
d
ataset_d
ocument
.
dataset_id
)
.
first
()
if
not
dataset
:
...
...
@@ -91,42 +87,44 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
document
.
id
document_id
=
d
ataset_d
ocument
.
id
)
.
all
()
db
.
session
.
delete
(
document_segments
)
db
.
session
.
commit
()
# load file
text_docs
=
self
.
_load_data
(
document
)
text_docs
=
self
.
_load_data
(
d
ataset_d
ocument
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
document
.
dataset_process_rule_id
)
.
\
filter
(
DatasetProcessRule
.
id
==
d
ataset_d
ocument
.
dataset_process_rule_id
)
.
\
first
()
# get
node parser for splitting
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
# get
splitter
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
node
s
=
self
.
_step_split
(
# split to
document
s
document
s
=
self
.
_step_split
(
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
dataset
=
dataset
,
d
ocument
=
document
,
d
ataset_document
=
dataset_
document
,
processing_rule
=
processing_rule
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
d
ocument
=
document
,
nodes
=
node
s
d
ataset_document
=
dataset_
document
,
documents
=
document
s
)
def
run_in_indexing_status
(
self
,
d
ocument
:
Document
):
def
run_in_indexing_status
(
self
,
d
ataset_document
:
Dataset
Document
):
"""Run the indexing process when the index_status is indexing."""
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
document
.
dataset_id
id
=
d
ataset_d
ocument
.
dataset_id
)
.
first
()
if
not
dataset
:
...
...
@@ -135,39 +133,31 @@ class IndexingRunner:
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
document
.
id
document_id
=
d
ataset_d
ocument
.
id
)
.
all
()
nodes
=
[]
documents
=
[]
if
document_segments
:
for
document_segment
in
document_segments
:
# transform segment to node
if
document_segment
.
status
!=
"completed"
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
document_segment
.
document_id
,
}
previous_segment
=
document_segment
.
previous_segment
if
previous_segment
:
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_segment
.
index_node_id
next_segment
=
document_segment
.
next_segment
if
next_segment
:
relationships
[
DocumentRelationship
.
NEXT
]
=
next_segment
.
index_node_id
node
=
Node
(
doc_id
=
document_segment
.
index_node_id
,
doc_hash
=
document_segment
.
index_node_hash
,
text
=
document_segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
document
=
Document
(
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
nodes
.
append
(
node
)
documents
.
append
(
document
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
d
ocument
=
document
,
nodes
=
node
s
d
ataset_document
=
dataset_
document
,
documents
=
document
s
)
def
file_indexing_estimate
(
self
,
file_details
:
List
[
UploadFile
],
tmp_processing_rule
:
dict
)
->
dict
:
...
...
@@ -179,28 +169,28 @@ class IndexingRunner:
total_segments
=
0
for
file_detail
in
file_details
:
# load data from file
text_docs
=
self
.
_load_data_from_file
(
file_detail
)
text_docs
=
FileExtractor
.
load
(
file_detail
)
processing_rule
=
DatasetProcessRule
(
mode
=
tmp_processing_rule
[
"mode"
],
rules
=
json
.
dumps
(
tmp_processing_rule
[
"rules"
])
)
# get
node parser for splitting
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
# get
splitter
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
nodes
=
self
.
_split_to_node
s
(
# split to
document
s
documents
=
self
.
_split_to_document
s
(
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
processing_rule
=
processing_rule
)
total_segments
+=
len
(
node
s
)
for
node
in
node
s
:
total_segments
+=
len
(
document
s
)
for
document
in
document
s
:
if
len
(
preview_texts
)
<
5
:
preview_texts
.
append
(
node
.
get_text
()
)
preview_texts
.
append
(
document
.
page_content
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
node
.
get_text
()
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
document
.
page_content
)
return
{
"total_segments"
:
total_segments
,
...
...
@@ -230,35 +220,36 @@ class IndexingRunner:
)
.
first
()
if
not
data_source_binding
:
raise
ValueError
(
'Data source binding not found.'
)
reader
=
NotionPageReader
(
integration_token
=
data_source_binding
.
access_token
)
for
page
in
notion_info
[
'pages'
]:
if
page
[
'type'
]
==
'page'
:
page_ids
=
[
page
[
'page_id'
]]
documents
=
reader
.
load_data_as_documents
(
page_ids
=
page_ids
)
elif
page
[
'type'
]
==
'database'
:
documents
=
reader
.
load_data_as_documents
(
database_id
=
page
[
'page_id'
])
else
:
documents
=
[]
loader
=
NotionLoader
(
notion_access_token
=
data_source_binding
.
access_token
,
notion_workspace_id
=
workspace_id
,
notion_obj_id
=
page
[
'page_id'
],
notion_page_type
=
page
[
'type'
]
)
documents
=
loader
.
load
()
processing_rule
=
DatasetProcessRule
(
mode
=
tmp_processing_rule
[
"mode"
],
rules
=
json
.
dumps
(
tmp_processing_rule
[
"rules"
])
)
# get
node parser for splitting
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
# get
splitter
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
nodes
=
self
.
_split_to_node
s
(
# split to
document
s
documents
=
self
.
_split_to_document
s
(
text_docs
=
documents
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
processing_rule
=
processing_rule
)
total_segments
+=
len
(
node
s
)
for
node
in
node
s
:
total_segments
+=
len
(
document
s
)
for
document
in
document
s
:
if
len
(
preview_texts
)
<
5
:
preview_texts
.
append
(
node
.
get_text
()
)
preview_texts
.
append
(
document
.
page_content
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
node
.
get_text
()
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
document
.
page_content
)
return
{
"total_segments"
:
total_segments
,
...
...
@@ -268,14 +259,14 @@ class IndexingRunner:
"preview"
:
preview_texts
}
def
_load_data
(
self
,
d
ocument
:
Document
)
->
List
[
Document
]:
def
_load_data
(
self
,
d
ataset_document
:
Dataset
Document
)
->
List
[
Document
]:
# load file
if
document
.
data_source_type
not
in
[
"upload_file"
,
"notion_import"
]:
if
d
ataset_d
ocument
.
data_source_type
not
in
[
"upload_file"
,
"notion_import"
]:
return
[]
data_source_info
=
document
.
data_source_info_dict
data_source_info
=
d
ataset_d
ocument
.
data_source_info_dict
text_docs
=
[]
if
document
.
data_source_type
==
'upload_file'
:
if
d
ataset_d
ocument
.
data_source_type
==
'upload_file'
:
if
not
data_source_info
or
'upload_file_id'
not
in
data_source_info
:
raise
ValueError
(
"no upload file found"
)
...
...
@@ -283,47 +274,28 @@ class IndexingRunner:
filter
(
UploadFile
.
id
==
data_source_info
[
'upload_file_id'
])
.
\
one_or_none
()
text_docs
=
self
.
_load_data_from_file
(
file_detail
)
elif
document
.
data_source_type
==
'notion_import'
:
if
not
data_source_info
or
'notion_page_id'
not
in
data_source_info
\
or
'notion_workspace_id'
not
in
data_source_info
:
raise
ValueError
(
"no notion page found"
)
workspace_id
=
data_source_info
[
'notion_workspace_id'
]
page_id
=
data_source_info
[
'notion_page_id'
]
page_type
=
data_source_info
[
'type'
]
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
db
.
and_
(
DataSourceBinding
.
tenant_id
==
document
.
tenant_id
,
DataSourceBinding
.
provider
==
'notion'
,
DataSourceBinding
.
disabled
==
False
,
DataSourceBinding
.
source_info
[
'workspace_id'
]
==
f
'"{workspace_id}"'
)
)
.
first
()
if
not
data_source_binding
:
raise
ValueError
(
'Data source binding not found.'
)
if
page_type
==
'page'
:
# add page last_edited_time to data_source_info
self
.
_get_notion_page_last_edited_time
(
page_id
,
data_source_binding
.
access_token
,
document
)
text_docs
=
self
.
_load_page_data_from_notion
(
page_id
,
data_source_binding
.
access_token
)
elif
page_type
==
'database'
:
# add page last_edited_time to data_source_info
self
.
_get_notion_database_last_edited_time
(
page_id
,
data_source_binding
.
access_token
,
document
)
text_docs
=
self
.
_load_database_data_from_notion
(
page_id
,
data_source_binding
.
access_token
)
text_docs
=
FileExtractor
.
load
(
file_detail
)
elif
dataset_document
.
data_source_type
==
'notion_import'
:
loader
=
NotionLoader
.
from_document
(
dataset_document
)
text_docs
=
loader
.
load
()
# update document status to splitting
self
.
_update_document_index_status
(
document_id
=
document
.
id
,
document_id
=
d
ataset_d
ocument
.
id
,
after_indexing_status
=
"splitting"
,
extra_update_params
=
{
Document
.
word_count
:
sum
([
len
(
text_doc
.
text
)
for
text_doc
in
text_docs
]),
Document
.
parsing_completed_at
:
datetime
.
datetime
.
utcnow
()
D
atasetD
ocument
.
word_count
:
sum
([
len
(
text_doc
.
text
)
for
text_doc
in
text_docs
]),
D
atasetD
ocument
.
parsing_completed_at
:
datetime
.
datetime
.
utcnow
()
}
)
# replace doc id to document model id
text_docs
=
cast
(
List
[
Document
],
text_docs
)
for
text_doc
in
text_docs
:
# remove invalid symbol
text_doc
.
text
=
self
.
filter_string
(
text_doc
.
get_text
())
text_doc
.
doc_id
=
document
.
id
text_doc
.
page_content
=
self
.
filter_string
(
text_doc
.
page_content
)
text_doc
.
metadata
[
'document_id'
]
=
dataset_document
.
id
text_doc
.
metadata
[
'dataset_id'
]
=
dataset_document
.
dataset_id
return
text_docs
...
...
@@ -331,61 +303,7 @@ class IndexingRunner:
pattern
=
re
.
compile
(
'[
\x00
-
\x08\x0B\x0C\x0E
-
\x1F\x7F\x80
-
\xFF
]'
)
return
pattern
.
sub
(
''
,
text
)
def
_load_data_from_file
(
self
,
upload_file
:
UploadFile
)
->
List
[
Document
]:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
upload_file
.
key
)
.
suffix
filepath
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
self
.
storage
.
download
(
upload_file
.
key
,
filepath
)
file_extractor
=
DEFAULT_FILE_EXTRACTOR
.
copy
()
file_extractor
[
".markdown"
]
=
MarkdownParser
()
file_extractor
[
".md"
]
=
MarkdownParser
()
file_extractor
[
".html"
]
=
HTMLParser
()
file_extractor
[
".htm"
]
=
HTMLParser
()
file_extractor
[
".pdf"
]
=
PDFParser
({
'upload_file'
:
upload_file
})
file_extractor
[
".xlsx"
]
=
XLSXParser
()
loader
=
SimpleDirectoryReader
(
input_files
=
[
filepath
],
file_extractor
=
file_extractor
)
text_docs
=
loader
.
load_data
()
return
text_docs
def
_load_page_data_from_notion
(
self
,
page_id
:
str
,
access_token
:
str
)
->
List
[
Document
]:
page_ids
=
[
page_id
]
reader
=
NotionPageReader
(
integration_token
=
access_token
)
text_docs
=
reader
.
load_data_as_documents
(
page_ids
=
page_ids
)
return
text_docs
def
_load_database_data_from_notion
(
self
,
database_id
:
str
,
access_token
:
str
)
->
List
[
Document
]:
reader
=
NotionPageReader
(
integration_token
=
access_token
)
text_docs
=
reader
.
load_data_as_documents
(
database_id
=
database_id
)
return
text_docs
def
_get_notion_page_last_edited_time
(
self
,
page_id
:
str
,
access_token
:
str
,
document
:
Document
):
reader
=
NotionPageReader
(
integration_token
=
access_token
)
last_edited_time
=
reader
.
get_page_last_edited_time
(
page_id
)
data_source_info
=
document
.
data_source_info_dict
data_source_info
[
'last_edited_time'
]
=
last_edited_time
update_params
=
{
Document
.
data_source_info
:
json
.
dumps
(
data_source_info
)
}
Document
.
query
.
filter_by
(
id
=
document
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
def
_get_notion_database_last_edited_time
(
self
,
page_id
:
str
,
access_token
:
str
,
document
:
Document
):
reader
=
NotionPageReader
(
integration_token
=
access_token
)
last_edited_time
=
reader
.
get_database_last_edited_time
(
page_id
)
data_source_info
=
document
.
data_source_info_dict
data_source_info
[
'last_edited_time'
]
=
last_edited_time
update_params
=
{
Document
.
data_source_info
:
json
.
dumps
(
data_source_info
)
}
Document
.
query
.
filter_by
(
id
=
document
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
def
_get_node_parser
(
self
,
processing_rule
:
DatasetProcessRule
)
->
NodeParser
:
def
_get_splitter
(
self
,
processing_rule
:
DatasetProcessRule
)
->
TextSplitter
:
"""
Get the NodeParser object according to the processing rule.
"""
...
...
@@ -414,68 +332,83 @@ class IndexingRunner:
separators
=
[
"
\n\n
"
,
"。"
,
"."
,
" "
,
""
]
)
return
SimpleNodeParser
(
text_splitter
=
character_splitter
,
include_extra_info
=
True
)
return
character_splitter
def
_step_split
(
self
,
text_docs
:
List
[
Document
],
node_parser
:
NodeParser
,
dataset
:
Dataset
,
document
:
Document
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Node
]:
def
_step_split
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitter
,
dataset
:
Dataset
,
dataset_document
:
DatasetDocument
,
processing_rule
:
DatasetProcessRule
)
\
->
List
[
Document
]:
"""
Split the text documents into
node
s and save them to the document segment.
Split the text documents into
document
s and save them to the document segment.
"""
nodes
=
self
.
_split_to_node
s
(
documents
=
self
.
_split_to_document
s
(
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
processing_rule
=
processing_rule
)
# save node to document segment
doc_store
=
DatesetDocumentStore
(
dataset
=
dataset
,
user_id
=
document
.
created_by
,
user_id
=
d
ataset_d
ocument
.
created_by
,
embedding_model_name
=
self
.
embedding_model_name
,
document_id
=
document
.
id
document_id
=
d
ataset_d
ocument
.
id
)
# add document segments
doc_store
.
add_documents
(
node
s
)
doc_store
.
add_documents
(
document
s
)
# update document status to indexing
cur_time
=
datetime
.
datetime
.
utcnow
()
self
.
_update_document_index_status
(
document_id
=
document
.
id
,
document_id
=
d
ataset_d
ocument
.
id
,
after_indexing_status
=
"indexing"
,
extra_update_params
=
{
Document
.
cleaning_completed_at
:
cur_time
,
Document
.
splitting_completed_at
:
cur_time
,
D
atasetD
ocument
.
cleaning_completed_at
:
cur_time
,
D
atasetD
ocument
.
splitting_completed_at
:
cur_time
,
}
)
# update segment status to indexing
self
.
_update_segments_by_document
(
d
ocument_id
=
document
.
id
,
d
ataset_document_id
=
dataset_
document
.
id
,
update_params
=
{
DocumentSegment
.
status
:
"indexing"
,
DocumentSegment
.
indexing_at
:
datetime
.
datetime
.
utcnow
()
}
)
return
node
s
return
document
s
def
_split_to_
nodes
(
self
,
text_docs
:
List
[
Document
],
node_parser
:
NodePars
er
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Node
]:
def
_split_to_
documents
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitt
er
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Document
]:
"""
Split the text documents into nodes.
"""
all_
node
s
=
[]
all_
document
s
=
[]
for
text_doc
in
text_docs
:
# document clean
document_text
=
self
.
_document_clean
(
text_doc
.
get_text
()
,
processing_rule
)
text_doc
.
tex
t
=
document_text
document_text
=
self
.
_document_clean
(
text_doc
.
page_content
,
processing_rule
)
text_doc
.
page_conten
t
=
document_text
# parse document to nodes
nodes
=
node_parser
.
get_nodes_from_documents
([
text_doc
])
nodes
=
[
node
for
node
in
nodes
if
node
.
text
is
not
None
and
node
.
text
.
strip
()]
all_nodes
.
extend
(
nodes
)
documents
=
splitter
.
split_documents
([
text_doc
])
split_documents
=
[]
for
document
in
documents
:
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
continue
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
document
.
metadata
[
'doc_id'
]
=
doc_id
document
.
metadata
[
'doc_hash'
]
=
hash
return
all_nodes
split_documents
.
append
(
document
)
all_documents
.
extend
(
split_documents
)
return
all_documents
def
_document_clean
(
self
,
text
:
str
,
processing_rule
:
DatasetProcessRule
)
->
str
:
"""
...
...
@@ -506,37 +439,58 @@ class IndexingRunner:
return
text
def
_build_index
(
self
,
dataset
:
Dataset
,
d
ocument
:
Document
,
nodes
:
List
[
Node
])
->
None
:
def
_build_index
(
self
,
dataset
:
Dataset
,
d
ataset_document
:
DatasetDocument
,
documents
:
List
[
Document
])
->
None
:
"""
Build the index for the document.
"""
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
10
)
)
# chunk nodes by chunk size
indexing_start_at
=
time
.
perf_counter
()
tokens
=
0
chunk_size
=
100
for
i
in
range
(
0
,
len
(
node
s
),
chunk_size
):
for
i
in
range
(
0
,
len
(
document
s
),
chunk_size
):
# check document is paused
self
.
_check_document_paused_status
(
document
.
id
)
chunk_
nodes
=
node
s
[
i
:
i
+
chunk_size
]
self
.
_check_document_paused_status
(
d
ataset_d
ocument
.
id
)
chunk_
documents
=
document
s
[
i
:
i
+
chunk_size
]
tokens
+=
sum
(
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
node
.
get_text
())
for
node
in
chunk_nodes
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
document
.
page_content
)
for
document
in
chunk_documents
)
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
vector_index
.
add_
nodes
(
chunk_node
s
)
vector_index
.
add_
texts
(
chunk_document
s
)
# save keyword index
keyword_table_index
.
add_
nodes
(
chunk_node
s
)
keyword_table_index
.
add_
texts
(
chunk_document
s
)
node_ids
=
[
node
.
doc_id
for
node
in
chunk_node
s
]
document_ids
=
[
document
.
metadata
[
'doc_id'
]
for
document
in
chunk_document
s
]
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
index_node_id
.
in_
(
node
_ids
),
DocumentSegment
.
document_id
==
d
ataset_d
ocument
.
id
,
DocumentSegment
.
index_node_id
.
in_
(
document
_ids
),
DocumentSegment
.
status
==
"indexing"
)
.
update
({
DocumentSegment
.
status
:
"completed"
,
...
...
@@ -549,12 +503,12 @@ class IndexingRunner:
# update document status to completed
self
.
_update_document_index_status
(
document_id
=
document
.
id
,
document_id
=
d
ataset_d
ocument
.
id
,
after_indexing_status
=
"completed"
,
extra_update_params
=
{
Document
.
tokens
:
tokens
,
Document
.
completed_at
:
datetime
.
datetime
.
utcnow
(),
Document
.
indexing_latency
:
indexing_end_at
-
indexing_start_at
,
D
atasetD
ocument
.
tokens
:
tokens
,
D
atasetD
ocument
.
completed_at
:
datetime
.
datetime
.
utcnow
(),
D
atasetD
ocument
.
indexing_latency
:
indexing_end_at
-
indexing_start_at
,
}
)
...
...
@@ -569,25 +523,25 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count
=
Document
.
query
.
filter_by
(
id
=
document_id
,
is_paused
=
True
)
.
count
()
count
=
D
atasetD
ocument
.
query
.
filter_by
(
id
=
document_id
,
is_paused
=
True
)
.
count
()
if
count
>
0
:
raise
DocumentIsPausedException
()
update_params
=
{
Document
.
indexing_status
:
after_indexing_status
D
atasetD
ocument
.
indexing_status
:
after_indexing_status
}
if
extra_update_params
:
update_params
.
update
(
extra_update_params
)
Document
.
query
.
filter_by
(
id
=
document_id
)
.
update
(
update_params
)
D
atasetD
ocument
.
query
.
filter_by
(
id
=
document_id
)
.
update
(
update_params
)
db
.
session
.
commit
()
def
_update_segments_by_document
(
self
,
document_id
:
str
,
update_params
:
dict
)
->
None
:
def
_update_segments_by_document
(
self
,
d
ataset_d
ocument_id
:
str
,
update_params
:
dict
)
->
None
:
"""
Update the document segment by document id.
"""
DocumentSegment
.
query
.
filter_by
(
document_id
=
document_id
)
.
update
(
update_params
)
DocumentSegment
.
query
.
filter_by
(
document_id
=
d
ataset_d
ocument_id
)
.
update
(
update_params
)
db
.
session
.
commit
()
...
...
api/core/llm/provider/azure_provider.py
View file @
9e9d15ec
...
...
@@ -42,7 +42,7 @@ class AzureProvider(BaseProvider):
"""
config
=
self
.
get_provider_api_key
(
model_id
=
model_id
)
config
[
'openai_api_type'
]
=
'azure'
config
[
'deployment_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
config
[
'deployment
'
]
=
config
[
'deployment
_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
return
config
def
get_provider_name
(
self
):
...
...
api/core/
index/
spiltter/fixed_text_splitter.py
→
api/core/spiltter/fixed_text_splitter.py
View file @
9e9d15ec
File moved
api/core/vector_store/base.py
deleted
100644 → 0
View file @
23ef2262
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
llama_index
import
ServiceContext
,
GPTVectorStoreIndex
from
llama_index.data_structs
import
Node
from
llama_index.vector_stores.types
import
VectorStore
class
BaseVectorStoreClient
(
ABC
):
@
abstractmethod
def
get_index
(
self
,
service_context
:
ServiceContext
,
config
:
dict
)
->
GPTVectorStoreIndex
:
raise
NotImplementedError
@
abstractmethod
def
to_index_config
(
self
,
index_id
:
str
)
->
dict
:
raise
NotImplementedError
class
BaseGPTVectorStoreIndex
(
GPTVectorStoreIndex
):
def
delete_node
(
self
,
node_id
:
str
):
self
.
_vector_store
.
delete_node
(
node_id
)
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
return
self
.
_vector_store
.
exists_by_node_id
(
node_id
)
class
EnhanceVectorStore
(
ABC
):
@
abstractmethod
def
delete_node
(
self
,
node_id
:
str
):
pass
@
abstractmethod
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
pass
api/core/vector_store/qdrant_vector_store.py
0 → 100644
View file @
9e9d15ec
from
typing
import
cast
from
langchain.vectorstores
import
Qdrant
from
qdrant_client.http.models
import
Filter
,
PointIdsList
,
FilterSelector
from
qdrant_client.local.qdrant_local
import
QdrantLocal
class
QdrantVectorStore
(
Qdrant
):
def
del_texts
(
self
,
filter
:
Filter
):
if
not
filter
:
raise
ValueError
(
'filter must not be empty'
)
self
.
_reload_if_needed
()
self
.
client
.
delete
(
collection_name
=
self
.
collection_name
,
points_selector
=
FilterSelector
(
filter
=
filter
),
)
def
del_text
(
self
,
uuid
:
str
)
->
None
:
self
.
_reload_if_needed
()
self
.
client
.
delete
(
collection_name
=
self
.
collection_name
,
points_selector
=
PointIdsList
(
points
=
[
uuid
],
),
)
def
text_exists
(
self
,
uuid
:
str
)
->
bool
:
self
.
_reload_if_needed
()
response
=
self
.
client
.
retrieve
(
collection_name
=
self
.
collection_name
,
ids
=
[
uuid
]
)
return
len
(
response
)
>
0
def
_reload_if_needed
(
self
):
if
isinstance
(
self
.
client
,
QdrantLocal
):
self
.
client
=
cast
(
QdrantLocal
,
self
.
client
)
self
.
client
.
_load
()
api/core/vector_store/qdrant_vector_store_client.py
deleted
100644 → 0
View file @
23ef2262
import
os
from
typing
import
cast
,
List
from
llama_index.data_structs
import
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
llama_index.vector_stores.types
import
VectorStoreQuery
,
VectorStoreQueryResult
from
qdrant_client.http.models
import
Payload
,
Filter
import
qdrant_client
from
llama_index
import
ServiceContext
,
GPTVectorStoreIndex
,
GPTQdrantIndex
from
llama_index.data_structs.data_structs_v2
import
QdrantIndexDict
from
llama_index.vector_stores
import
QdrantVectorStore
from
qdrant_client.local.qdrant_local
import
QdrantLocal
from
core.vector_store.base
import
BaseVectorStoreClient
,
BaseGPTVectorStoreIndex
,
EnhanceVectorStore
class
QdrantVectorStoreClient
(
BaseVectorStoreClient
):
def
__init__
(
self
,
url
:
str
,
api_key
:
str
,
root_path
:
str
):
self
.
_client
=
self
.
init_from_config
(
url
,
api_key
,
root_path
)
@
classmethod
def
init_from_config
(
cls
,
url
:
str
,
api_key
:
str
,
root_path
:
str
):
if
url
and
url
.
startswith
(
'path:'
):
path
=
url
.
replace
(
'path:'
,
''
)
if
not
os
.
path
.
isabs
(
path
):
path
=
os
.
path
.
join
(
root_path
,
path
)
return
qdrant_client
.
QdrantClient
(
path
=
path
)
else
:
return
qdrant_client
.
QdrantClient
(
url
=
url
,
api_key
=
api_key
,
)
def
get_index
(
self
,
service_context
:
ServiceContext
,
config
:
dict
)
->
GPTVectorStoreIndex
:
index_struct
=
QdrantIndexDict
()
if
self
.
_client
is
None
:
raise
Exception
(
"Vector client is not initialized."
)
# {"collection_name": "Gpt_index_xxx"}
collection_name
=
config
.
get
(
'collection_name'
)
if
not
collection_name
:
raise
Exception
(
"collection_name cannot be None."
)
return
GPTQdrantEnhanceIndex
(
service_context
=
service_context
,
index_struct
=
index_struct
,
vector_store
=
QdrantEnhanceVectorStore
(
client
=
self
.
_client
,
collection_name
=
collection_name
)
)
def
to_index_config
(
self
,
index_id
:
str
)
->
dict
:
return
{
"collection_name"
:
index_id
}
class
GPTQdrantEnhanceIndex
(
GPTQdrantIndex
,
BaseGPTVectorStoreIndex
):
pass
class
QdrantEnhanceVectorStore
(
QdrantVectorStore
,
EnhanceVectorStore
):
def
delete_node
(
self
,
node_id
:
str
):
"""
Delete node from the index.
:param node_id: node id
"""
from
qdrant_client.http
import
models
as
rest
self
.
_reload_if_needed
()
self
.
_client
.
delete
(
collection_name
=
self
.
_collection_name
,
points_selector
=
rest
.
Filter
(
must
=
[
rest
.
FieldCondition
(
key
=
"id"
,
match
=
rest
.
MatchValue
(
value
=
node_id
)
)
]
),
)
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
"""
Get node from the index by node id.
:param node_id: node id
"""
self
.
_reload_if_needed
()
response
=
self
.
_client
.
retrieve
(
collection_name
=
self
.
_collection_name
,
ids
=
[
node_id
]
)
return
len
(
response
)
>
0
def
query
(
self
,
query
:
VectorStoreQuery
,
)
->
VectorStoreQueryResult
:
"""Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): query
"""
query_embedding
=
cast
(
List
[
float
],
query
.
query_embedding
)
self
.
_reload_if_needed
()
response
=
self
.
_client
.
search
(
collection_name
=
self
.
_collection_name
,
query_vector
=
query_embedding
,
limit
=
cast
(
int
,
query
.
similarity_top_k
),
query_filter
=
cast
(
Filter
,
self
.
_build_query_filter
(
query
)),
with_vectors
=
True
)
nodes
=
[]
similarities
=
[]
ids
=
[]
for
point
in
response
:
payload
=
cast
(
Payload
,
point
.
payload
)
node
=
Node
(
doc_id
=
str
(
point
.
id
),
text
=
payload
.
get
(
"text"
),
embedding
=
point
.
vector
,
extra_info
=
payload
.
get
(
"extra_info"
),
relationships
=
{
DocumentRelationship
.
SOURCE
:
payload
.
get
(
"doc_id"
,
"None"
),
},
)
nodes
.
append
(
node
)
similarities
.
append
(
point
.
score
)
ids
.
append
(
str
(
point
.
id
))
return
VectorStoreQueryResult
(
nodes
=
nodes
,
similarities
=
similarities
,
ids
=
ids
)
def
_reload_if_needed
(
self
):
if
isinstance
(
self
.
_client
.
_client
,
QdrantLocal
):
self
.
_client
.
_client
.
_load
()
api/core/vector_store/vector_store.py
deleted
100644 → 0
View file @
23ef2262
from
flask
import
Flask
from
llama_index
import
ServiceContext
,
GPTVectorStoreIndex
from
requests
import
ReadTimeout
from
tenacity
import
retry
,
retry_if_exception_type
,
stop_after_attempt
from
core.vector_store.qdrant_vector_store_client
import
QdrantVectorStoreClient
from
core.vector_store.weaviate_vector_store_client
import
WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES
=
[
'weaviate'
,
'qdrant'
]
class
VectorStore
:
def
__init__
(
self
):
self
.
_vector_store
=
None
self
.
_client
=
None
def
init_app
(
self
,
app
:
Flask
):
if
not
app
.
config
[
'VECTOR_STORE'
]:
return
self
.
_vector_store
=
app
.
config
[
'VECTOR_STORE'
]
if
self
.
_vector_store
not
in
SUPPORTED_VECTOR_STORES
:
raise
ValueError
(
f
"Vector store {self._vector_store} is not supported."
)
if
self
.
_vector_store
==
'weaviate'
:
self
.
_client
=
WeaviateVectorStoreClient
(
endpoint
=
app
.
config
[
'WEAVIATE_ENDPOINT'
],
api_key
=
app
.
config
[
'WEAVIATE_API_KEY'
],
grpc_enabled
=
app
.
config
[
'WEAVIATE_GRPC_ENABLED'
],
batch_size
=
app
.
config
[
'WEAVIATE_BATCH_SIZE'
]
)
elif
self
.
_vector_store
==
'qdrant'
:
self
.
_client
=
QdrantVectorStoreClient
(
url
=
app
.
config
[
'QDRANT_URL'
],
api_key
=
app
.
config
[
'QDRANT_API_KEY'
],
root_path
=
app
.
root_path
)
app
.
extensions
[
'vector_store'
]
=
self
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
get_index
(
self
,
service_context
:
ServiceContext
,
index_struct
:
dict
)
->
GPTVectorStoreIndex
:
vector_store_config
:
dict
=
index_struct
.
get
(
'vector_store'
)
index
=
self
.
get_client
()
.
get_index
(
service_context
=
service_context
,
config
=
vector_store_config
)
return
index
def
to_index_struct
(
self
,
index_id
:
str
)
->
dict
:
return
{
"type"
:
self
.
_vector_store
,
"vector_store"
:
self
.
get_client
()
.
to_index_config
(
index_id
)
}
def
get_client
(
self
):
if
not
self
.
_client
:
raise
Exception
(
"Vector store client is not initialized."
)
return
self
.
_client
api/core/vector_store/vector_store_index_query.py
deleted
100644 → 0
View file @
23ef2262
from
llama_index.indices.query.base
import
IS
from
typing
import
(
Any
,
Dict
,
List
,
Optional
)
from
llama_index.docstore
import
BaseDocumentStore
from
llama_index.indices.postprocessor.node
import
(
BaseNodePostprocessor
,
)
from
llama_index.indices.vector_store
import
GPTVectorStoreIndexQuery
from
llama_index.indices.response.response_builder
import
ResponseMode
from
llama_index.indices.service_context
import
ServiceContext
from
llama_index.optimization.optimizer
import
BaseTokenUsageOptimizer
from
llama_index.prompts.prompts
import
(
QuestionAnswerPrompt
,
RefinePrompt
,
SimpleInputPrompt
,
)
from
core.index.query.synthesizer
import
EnhanceResponseSynthesizer
class
EnhanceGPTVectorStoreIndexQuery
(
GPTVectorStoreIndexQuery
):
@
classmethod
def
from_args
(
cls
,
index_struct
:
IS
,
service_context
:
ServiceContext
,
docstore
:
Optional
[
BaseDocumentStore
]
=
None
,
node_postprocessors
:
Optional
[
List
[
BaseNodePostprocessor
]]
=
None
,
verbose
:
bool
=
False
,
# response synthesizer args
response_mode
:
ResponseMode
=
ResponseMode
.
DEFAULT
,
text_qa_template
:
Optional
[
QuestionAnswerPrompt
]
=
None
,
refine_template
:
Optional
[
RefinePrompt
]
=
None
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
response_kwargs
:
Optional
[
Dict
]
=
None
,
use_async
:
bool
=
False
,
streaming
:
bool
=
False
,
optimizer
:
Optional
[
BaseTokenUsageOptimizer
]
=
None
,
# class-specific args
**
kwargs
:
Any
,
)
->
"BaseGPTIndexQuery"
:
response_synthesizer
=
EnhanceResponseSynthesizer
.
from_args
(
service_context
=
service_context
,
text_qa_template
=
text_qa_template
,
refine_template
=
refine_template
,
simple_template
=
simple_template
,
response_mode
=
response_mode
,
response_kwargs
=
response_kwargs
,
use_async
=
use_async
,
streaming
=
streaming
,
optimizer
=
optimizer
,
)
return
cls
(
index_struct
=
index_struct
,
service_context
=
service_context
,
response_synthesizer
=
response_synthesizer
,
docstore
=
docstore
,
node_postprocessors
=
node_postprocessors
,
verbose
=
verbose
,
**
kwargs
,
)
api/core/vector_store/weaviate_vector_store.py
0 → 100644
View file @
9e9d15ec
from
langchain.vectorstores
import
Weaviate
class
WeaviateVectorStore
(
Weaviate
):
def
del_texts
(
self
,
where_filter
:
dict
):
if
not
where_filter
:
raise
ValueError
(
'where_filter must not be empty'
)
self
.
_client
.
batch
.
delete_objects
(
class_name
=
self
.
_index_name
,
where
=
where_filter
,
output
=
'minimal'
)
def
del_text
(
self
,
uuid
:
str
)
->
None
:
self
.
_client
.
data_object
.
delete
(
uuid
,
class_name
=
self
.
_index_name
)
def
text_exists
(
self
,
uuid
:
str
)
->
bool
:
result
=
self
.
_client
.
query
.
get
(
self
.
_index_name
)
.
with_additional
([
"id"
])
.
with_where
({
"path"
:
[
"doc_id"
],
"operator"
:
"Equal"
,
"valueText"
:
uuid
,
})
.
with_limit
(
1
)
.
do
()
if
"errors"
in
result
:
raise
ValueError
(
f
"Error during query: {result['errors']}"
)
entries
=
result
[
"data"
][
"Get"
][
self
.
_index_name
]
if
len
(
entries
)
==
0
:
return
False
return
True
api/core/vector_store/weaviate_vector_store_client.py
deleted
100644 → 0
View file @
23ef2262
import
json
import
weaviate
from
dataclasses
import
field
from
typing
import
List
,
Any
,
Dict
,
Optional
from
core.vector_store.base
import
BaseVectorStoreClient
,
BaseGPTVectorStoreIndex
,
EnhanceVectorStore
from
llama_index
import
ServiceContext
,
GPTWeaviateIndex
,
GPTVectorStoreIndex
from
llama_index.data_structs.data_structs_v2
import
WeaviateIndexDict
,
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
llama_index.readers.weaviate.client
import
_class_name
,
NODE_SCHEMA
,
_logger
from
llama_index.vector_stores
import
WeaviateVectorStore
from
llama_index.vector_stores.types
import
VectorStoreQuery
,
VectorStoreQueryResult
,
VectorStoreQueryMode
from
llama_index.readers.weaviate.utils
import
(
parse_get_response
,
validate_client
,
)
class
WeaviateVectorStoreClient
(
BaseVectorStoreClient
):
def
__init__
(
self
,
endpoint
:
str
,
api_key
:
str
,
grpc_enabled
:
bool
,
batch_size
:
int
):
self
.
_client
=
self
.
init_from_config
(
endpoint
,
api_key
,
grpc_enabled
,
batch_size
)
def
init_from_config
(
self
,
endpoint
:
str
,
api_key
:
str
,
grpc_enabled
:
bool
,
batch_size
:
int
):
auth_config
=
weaviate
.
auth
.
AuthApiKey
(
api_key
=
api_key
)
weaviate
.
connect
.
connection
.
has_grpc
=
grpc_enabled
client
=
weaviate
.
Client
(
url
=
endpoint
,
auth_client_secret
=
auth_config
,
timeout_config
=
(
5
,
60
),
startup_period
=
None
)
client
.
batch
.
configure
(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size
=
batch_size
,
# dynamically update the `batch_size` based on import speed
dynamic
=
True
,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries
=
3
,
)
return
client
def
get_index
(
self
,
service_context
:
ServiceContext
,
config
:
dict
)
->
GPTVectorStoreIndex
:
index_struct
=
WeaviateIndexDict
()
if
self
.
_client
is
None
:
raise
Exception
(
"Vector client is not initialized."
)
# {"class_prefix": "Gpt_index_xxx"}
class_prefix
=
config
.
get
(
'class_prefix'
)
if
not
class_prefix
:
raise
Exception
(
"class_prefix cannot be None."
)
return
GPTWeaviateEnhanceIndex
(
service_context
=
service_context
,
index_struct
=
index_struct
,
vector_store
=
WeaviateWithSimilaritiesVectorStore
(
weaviate_client
=
self
.
_client
,
class_prefix
=
class_prefix
)
)
def
to_index_config
(
self
,
index_id
:
str
)
->
dict
:
return
{
"class_prefix"
:
index_id
}
class
WeaviateWithSimilaritiesVectorStore
(
WeaviateVectorStore
,
EnhanceVectorStore
):
def
query
(
self
,
query
:
VectorStoreQuery
)
->
VectorStoreQueryResult
:
"""Query index for top k most similar nodes."""
nodes
=
self
.
weaviate_query
(
self
.
_client
,
self
.
_class_prefix
,
query
,
)
nodes
=
nodes
[:
query
.
similarity_top_k
]
node_idxs
=
[
str
(
i
)
for
i
in
range
(
len
(
nodes
))]
similarities
=
[]
for
node
in
nodes
:
similarities
.
append
(
node
.
extra_info
[
'similarity'
])
del
node
.
extra_info
[
'similarity'
]
return
VectorStoreQueryResult
(
nodes
=
nodes
,
ids
=
node_idxs
,
similarities
=
similarities
)
def
weaviate_query
(
self
,
client
:
Any
,
class_prefix
:
str
,
query_spec
:
VectorStoreQuery
,
)
->
List
[
Node
]:
"""Convert to LlamaIndex list."""
validate_client
(
client
)
class_name
=
_class_name
(
class_prefix
)
prop_names
=
[
p
[
"name"
]
for
p
in
NODE_SCHEMA
]
vector
=
query_spec
.
query_embedding
# build query
query
=
client
.
query
.
get
(
class_name
,
prop_names
)
.
with_additional
([
"id"
,
"vector"
,
"certainty"
])
if
query_spec
.
mode
==
VectorStoreQueryMode
.
DEFAULT
:
_logger
.
debug
(
"Using vector search"
)
if
vector
is
not
None
:
query
=
query
.
with_near_vector
(
{
"vector"
:
vector
,
}
)
elif
query_spec
.
mode
==
VectorStoreQueryMode
.
HYBRID
:
_logger
.
debug
(
f
"Using hybrid search with alpha {query_spec.alpha}"
)
query
=
query
.
with_hybrid
(
query
=
query_spec
.
query_str
,
alpha
=
query_spec
.
alpha
,
vector
=
vector
,
)
query
=
query
.
with_limit
(
query_spec
.
similarity_top_k
)
_logger
.
debug
(
f
"Using limit of {query_spec.similarity_top_k}"
)
# execute query
query_result
=
query
.
do
()
# parse results
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
results
=
[
self
.
_to_node
(
entry
)
for
entry
in
entries
]
return
results
def
_to_node
(
self
,
entry
:
Dict
)
->
Node
:
"""Convert to Node."""
extra_info_str
=
entry
[
"extra_info"
]
if
extra_info_str
==
""
:
extra_info
=
None
else
:
extra_info
=
json
.
loads
(
extra_info_str
)
if
'certainty'
in
entry
[
'_additional'
]:
if
extra_info
:
extra_info
[
'similarity'
]
=
entry
[
'_additional'
][
'certainty'
]
else
:
extra_info
=
{
'similarity'
:
entry
[
'_additional'
][
'certainty'
]}
node_info_str
=
entry
[
"node_info"
]
if
node_info_str
==
""
:
node_info
=
None
else
:
node_info
=
json
.
loads
(
node_info_str
)
relationships_str
=
entry
[
"relationships"
]
relationships
:
Dict
[
DocumentRelationship
,
str
]
if
relationships_str
==
""
:
relationships
=
field
(
default_factory
=
dict
)
else
:
relationships
=
{
DocumentRelationship
(
k
):
v
for
k
,
v
in
json
.
loads
(
relationships_str
)
.
items
()
}
return
Node
(
text
=
entry
[
"text"
],
doc_id
=
entry
[
"doc_id"
],
embedding
=
entry
[
"_additional"
][
"vector"
],
extra_info
=
extra_info
,
node_info
=
node_info
,
relationships
=
relationships
,
)
def
delete
(
self
,
doc_id
:
str
,
**
delete_kwargs
:
Any
)
->
None
:
"""Delete a document.
Args:
doc_id (str): document id
"""
delete_document
(
self
.
_client
,
doc_id
,
self
.
_class_prefix
)
def
delete_node
(
self
,
node_id
:
str
):
"""
Delete node from the index.
:param node_id: node id
"""
delete_node
(
self
.
_client
,
node_id
,
self
.
_class_prefix
)
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
"""
Get node from the index by node id.
:param node_id: node id
"""
entry
=
get_by_node_id
(
self
.
_client
,
node_id
,
self
.
_class_prefix
)
return
True
if
entry
else
False
class
GPTWeaviateEnhanceIndex
(
GPTWeaviateIndex
,
BaseGPTVectorStoreIndex
):
pass
def
delete_document
(
client
:
Any
,
ref_doc_id
:
str
,
class_prefix
:
str
)
->
None
:
"""Delete entry."""
validate_client
(
client
)
# make sure that each entry
class_name
=
_class_name
(
class_prefix
)
where_filter
=
{
"path"
:
[
"ref_doc_id"
],
"operator"
:
"Equal"
,
"valueString"
:
ref_doc_id
,
}
query
=
(
client
.
query
.
get
(
class_name
)
.
with_additional
([
"id"
])
.
with_where
(
where_filter
)
)
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
for
entry
in
entries
:
client
.
data_object
.
delete
(
entry
[
"_additional"
][
"id"
],
class_name
)
while
len
(
entries
)
>
0
:
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
for
entry
in
entries
:
client
.
data_object
.
delete
(
entry
[
"_additional"
][
"id"
],
class_name
)
def
delete_node
(
client
:
Any
,
node_id
:
str
,
class_prefix
:
str
)
->
None
:
"""Delete entry."""
validate_client
(
client
)
# make sure that each entry
class_name
=
_class_name
(
class_prefix
)
where_filter
=
{
"path"
:
[
"doc_id"
],
"operator"
:
"Equal"
,
"valueString"
:
node_id
,
}
query
=
(
client
.
query
.
get
(
class_name
)
.
with_additional
([
"id"
])
.
with_where
(
where_filter
)
)
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
for
entry
in
entries
:
client
.
data_object
.
delete
(
entry
[
"_additional"
][
"id"
],
class_name
)
def
get_by_node_id
(
client
:
Any
,
node_id
:
str
,
class_prefix
:
str
)
->
Optional
[
Dict
]:
"""Delete entry."""
validate_client
(
client
)
# make sure that each entry
class_name
=
_class_name
(
class_prefix
)
where_filter
=
{
"path"
:
[
"doc_id"
],
"operator"
:
"Equal"
,
"valueString"
:
node_id
,
}
query
=
(
client
.
query
.
get
(
class_name
)
.
with_additional
([
"id"
])
.
with_where
(
where_filter
)
)
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
if
len
(
entries
)
==
0
:
return
None
return
entries
[
0
]
api/libs/helper.py
View file @
9e9d15ec
...
...
@@ -3,6 +3,7 @@ import re
import
subprocess
import
uuid
from
datetime
import
datetime
from
hashlib
import
sha256
from
zoneinfo
import
available_timezones
import
random
import
string
...
...
@@ -147,3 +148,8 @@ def get_remote_ip(request):
return
request
.
headers
.
getlist
(
"X-Forwarded-For"
)[
0
]
else
:
return
request
.
remote_addr
def
generate_text_hash
(
text
:
str
)
->
str
:
hash_text
=
str
(
text
)
+
'None'
return
sha256
(
hash_text
.
encode
())
.
hexdigest
()
api/requirements.txt
View file @
9e9d15ec
...
...
@@ -9,8 +9,7 @@ flask-session2==1.3.1
flask-cors==3.0.10
gunicorn~=20.1.0
gevent~=22.10.2
langchain==0.0.142
llama-index==0.5.27
langchain==0.0.201
openai~=0.27.5
psycopg2-binary~=2.9.6
pycryptodome==3.17
...
...
@@ -31,4 +30,5 @@ celery==5.2.7
redis~=4.5.4
pypdf==3.8.1
openpyxl==3.1.2
chardet~=5.1.0
\ No newline at end of file
chardet~=5.1.0
docx2txt==0.8
\ No newline at end of file
api/services/dataset_service.py
View file @
9e9d15ec
...
...
@@ -7,7 +7,6 @@ from typing import Optional, List
from
extensions.ext_redis
import
redis_client
from
flask_login
import
current_user
from
core.index.index_builder
import
IndexBuilder
from
events.dataset_event
import
dataset_was_deleted
from
events.document_event
import
document_was_deleted
from
extensions.ext_database
import
db
...
...
@@ -386,8 +385,6 @@ class DocumentService:
dataset
.
indexing_technique
=
document_data
[
"indexing_technique"
]
if
dataset
.
indexing_technique
==
'high_quality'
:
IndexBuilder
.
get_default_service_context
(
dataset
.
tenant_id
)
documents
=
[]
batch
=
time
.
strftime
(
'
%
Y
%
m
%
d
%
H
%
M
%
S'
)
+
str
(
random
.
randint
(
100000
,
999999
))
if
'original_document_id'
in
document_data
and
document_data
[
"original_document_id"
]:
...
...
api/tasks/document_indexing_sync_task.py
View file @
9e9d15ec
...
...
@@ -6,7 +6,7 @@ import click
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
core.data_
source.notion
import
NotionPageRe
ader
from
core.data_
loader.loader.notion
import
NotionLo
ader
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
...
...
@@ -43,6 +43,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
raise
ValueError
(
"no notion page found"
)
workspace_id
=
data_source_info
[
'notion_workspace_id'
]
page_id
=
data_source_info
[
'notion_page_id'
]
page_type
=
data_source_info
[
'type'
]
page_edited_time
=
data_source_info
[
'last_edited_time'
]
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
db
.
and_
(
...
...
@@ -54,8 +55,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
)
.
first
()
if
not
data_source_binding
:
raise
ValueError
(
'Data source binding not found.'
)
reader
=
NotionPageReader
(
integration_token
=
data_source_binding
.
access_token
)
last_edited_time
=
reader
.
get_page_last_edited_time
(
page_id
)
loader
=
NotionLoader
(
notion_access_token
=
data_source_binding
.
access_token
,
notion_workspace_id
=
workspace_id
,
notion_obj_id
=
page_id
,
notion_page_type
=
page_type
)
last_edited_time
=
loader
.
get_notion_last_edited_time
()
# check the page is updated
if
last_edited_time
!=
page_edited_time
:
document
.
indexing_status
=
'parsing'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment