Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
eea011bd
Commit
eea011bd
authored
Jun 25, 2023
by
StyleZhang
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'main' into feat/add-icons
parents
3eb8e66b
951afcaa
Changes
92
Hide whitespace changes
Inline
Side-by-side
Showing
92 changed files
with
2696 additions
and
3031 deletions
+2696
-3031
app.py
api/app.py
+1
-2
commands.py
api/commands.py
+35
-0
config.py
api/config.py
+2
-0
data_source.py
api/controllers/console/datasets/data_source.py
+11
-10
file.py
api/controllers/console/datasets/file.py
+2
-28
version.py
api/controllers/console/version.py
+7
-2
__init__.py
api/core/__init__.py
+0
-20
agent_builder.py
api/core/agent/agent_builder.py
+8
-11
agent_loop_gather_callback_handler.py
...re/callback_handler/agent_loop_gather_callback_handler.py
+1
-29
dataset_tool_callback_handler.py
api/core/callback_handler/dataset_tool_callback_handler.py
+1
-50
index_tool_callback_handler.py
api/core/callback_handler/index_tool_callback_handler.py
+8
-21
llm_callback_handler.py
api/core/callback_handler/llm_callback_handler.py
+31
-85
main_chain_gather_callback_handler.py
...re/callback_handler/main_chain_gather_callback_handler.py
+12
-70
std_out_callback_handler.py
api/core/callback_handler/std_out_callback_handler.py
+36
-9
chain_builder.py
api/core/chain/chain_builder.py
+2
-4
llm_router_chain.py
api/core/chain/llm_router_chain.py
+6
-4
main_chain_builder.py
api/core/chain/main_chain_builder.py
+10
-8
multi_dataset_router_chain.py
api/core/chain/multi_dataset_router_chain.py
+62
-16
sensitive_word_avoidance_chain.py
api/core/chain/sensitive_word_avoidance_chain.py
+7
-2
tool_chain.py
api/core/chain/tool_chain.py
+12
-3
completion.py
api/core/completion.py
+42
-17
conversation_message_task.py
api/core/conversation_message_task.py
+4
-4
file_extractor.py
api/core/data_loader/file_extractor.py
+43
-0
csv.py
api/core/data_loader/loader/csv.py
+67
-0
excel.py
api/core/data_loader/loader/excel.py
+43
-0
html.py
api/core/data_loader/loader/html.py
+35
-0
markdown.py
api/core/data_loader/loader/markdown.py
+134
-0
notion.py
api/core/data_loader/loader/notion.py
+236
-236
pdf.py
api/core/data_loader/loader/pdf.py
+55
-0
dataset_docstore.py
api/core/docstore/dataset_docstore.py
+35
-45
empty_docstore.py
api/core/docstore/empty_docstore.py
+0
-51
cached_embedding.py
api/core/embedding/cached_embedding.py
+72
-0
openai_embedding.py
api/core/embedding/openai_embedding.py
+0
-214
base.py
api/core/index/base.py
+59
-0
index.py
api/core/index/index.py
+41
-0
index_builder.py
api/core/index/index_builder.py
+0
-60
jieba_keyword_table.py
api/core/index/keyword_table/jieba_keyword_table.py
+0
-159
keyword_table_index.py
api/core/index/keyword_table_index.py
+0
-135
jieba_keyword_table_handler.py
.../index/keyword_table_index/jieba_keyword_table_handler.py
+33
-0
keyword_table_index.py
api/core/index/keyword_table_index/keyword_table_index.py
+238
-0
stopwords.py
api/core/index/keyword_table_index/stopwords.py
+0
-0
synthesizer.py
api/core/index/query/synthesizer.py
+0
-79
html_parser.py
api/core/index/readers/html_parser.py
+0
-22
pdf_parser.py
api/core/index/readers/pdf_parser.py
+0
-56
xlsx_parser.py
api/core/index/readers/xlsx_parser.py
+0
-33
vector_index.py
api/core/index/vector_index.py
+0
-136
base.py
api/core/index/vector_index/base.py
+175
-0
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+116
-0
vector_index.py
api/core/index/vector_index/vector_index.py
+69
-0
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+136
-0
indexing_runner.py
api/core/indexing_runner.py
+262
-283
llm_builder.py
api/core/llm/llm_builder.py
+17
-14
azure_provider.py
api/core/llm/provider/azure_provider.py
+4
-1
streamable_azure_chat_open_ai.py
api/core/llm/streamable_azure_chat_open_ai.py
+13
-50
streamable_azure_open_ai.py
api/core/llm/streamable_azure_open_ai.py
+13
-6
streamable_chat_open_ai.py
api/core/llm/streamable_chat_open_ai.py
+14
-48
streamable_open_ai.py
api/core/llm/streamable_open_ai.py
+13
-5
read_only_conversation_token_db_string_buffer_shared_memory.py
...only_conversation_token_db_string_buffer_shared_memory.py
+1
-1
prompts.py
api/core/prompt/prompts.py
+0
-19
fixed_text_splitter.py
api/core/spiltter/fixed_text_splitter.py
+0
-0
dataset_index_tool.py
api/core/tool/dataset_index_tool.py
+87
-0
dataset_tool_builder.py
api/core/tool/dataset_tool_builder.py
+0
-73
llama_index_tool.py
api/core/tool/llama_index_tool.py
+0
-43
base.py
api/core/vector_store/base.py
+0
-34
qdrant_vector_store.py
api/core/vector_store/qdrant_vector_store.py
+69
-0
qdrant_vector_store_client.py
api/core/vector_store/qdrant_vector_store_client.py
+0
-147
vector_store.py
api/core/vector_store/vector_store.py
+0
-62
vector_store_index_query.py
api/core/vector_store/vector_store_index_query.py
+0
-66
weaviate_vector_store.py
api/core/vector_store/weaviate_vector_store.py
+38
-0
weaviate_vector_store_client.py
api/core/vector_store/weaviate_vector_store_client.py
+0
-270
ext_vector_store.py
api/extensions/ext_vector_store.py
+0
-7
helper.py
api/libs/helper.py
+6
-0
account.py
api/models/account.py
+0
-2
dataset.py
api/models/dataset.py
+30
-2
requirements.txt
api/requirements.txt
+5
-4
app_model_config_service.py
api/services/app_model_config_service.py
+0
-1
dataset_service.py
api/services/dataset_service.py
+0
-3
hit_testing_service.py
api/services/hit_testing_service.py
+44
-36
add_document_to_index_task.py
api/tasks/add_document_to_index_task.py
+33
-48
add_segment_to_index_task.py
api/tasks/add_segment_to_index_task.py
+27
-32
clean_dataset_task.py
api/tasks/clean_dataset_task.py
+13
-20
clean_document_task.py
api/tasks/clean_document_task.py
+7
-6
clean_notion_document_task.py
api/tasks/clean_notion_document_task.py
+7
-6
deal_dataset_vector_index_task.py
api/tasks/deal_dataset_vector_index_task.py
+36
-36
document_indexing_sync_task.py
api/tasks/document_indexing_sync_task.py
+23
-23
document_indexing_task.py
api/tasks/document_indexing_task.py
+6
-16
document_indexing_update_task.py
api/tasks/document_indexing_update_task.py
+11
-20
recover_document_indexing_task.py
api/tasks/recover_document_indexing_task.py
+4
-9
remove_document_from_index_task.py
api/tasks/remove_document_from_index_task.py
+5
-6
remove_segment_from_index_task.py
api/tasks/remove_segment_from_index_task.py
+18
-8
client.py
sdks/python-client/dify_client/client.py
+2
-2
setup.py
sdks/python-client/setup.py
+1
-1
No files found.
api/app.py
View file @
eea011bd
...
@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
...
@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
import
flask_login
import
flask_login
from
flask_cors
import
CORS
from
flask_cors
import
CORS
from
extensions
import
ext_session
,
ext_celery
,
ext_sentry
,
ext_redis
,
ext_login
,
ext_
vector_store
,
ext_
migrate
,
\
from
extensions
import
ext_session
,
ext_celery
,
ext_sentry
,
ext_redis
,
ext_login
,
ext_migrate
,
\
ext_database
,
ext_storage
ext_database
,
ext_storage
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_login
import
login_manager
from
extensions.ext_login
import
login_manager
...
@@ -79,7 +79,6 @@ def initialize_extensions(app):
...
@@ -79,7 +79,6 @@ def initialize_extensions(app):
ext_database
.
init_app
(
app
)
ext_database
.
init_app
(
app
)
ext_migrate
.
init
(
app
,
db
)
ext_migrate
.
init
(
app
,
db
)
ext_redis
.
init_app
(
app
)
ext_redis
.
init_app
(
app
)
ext_vector_store
.
init_app
(
app
)
ext_storage
.
init_app
(
app
)
ext_storage
.
init_app
(
app
)
ext_celery
.
init_app
(
app
)
ext_celery
.
init_app
(
app
)
ext_session
.
init_app
(
app
)
ext_session
.
init_app
(
app
)
...
...
api/commands.py
View file @
eea011bd
import
datetime
import
datetime
import
logging
import
random
import
random
import
string
import
string
import
click
import
click
from
flask
import
current_app
from
flask
import
current_app
from
werkzeug.exceptions
import
NotFound
from
core.index.index
import
IndexBuilder
from
libs.password
import
password_pattern
,
valid_password
,
hash_password
from
libs.password
import
password_pattern
,
valid_password
,
hash_password
from
libs.helper
import
email
as
email_validate
from
libs.helper
import
email
as
email_validate
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
libs.rsa
import
generate_key_pair
from
libs.rsa
import
generate_key_pair
from
models.account
import
InvitationCode
,
Tenant
from
models.account
import
InvitationCode
,
Tenant
from
models.dataset
import
Dataset
from
models.model
import
Account
from
models.model
import
Account
import
secrets
import
secrets
import
base64
import
base64
...
@@ -159,8 +163,39 @@ def generate_upper_string():
...
@@ -159,8 +163,39 @@ def generate_upper_string():
return
result
return
result
@
click
.
command
(
'recreate-all-dataset-indexes'
,
help
=
'Recreate all dataset indexes.'
)
def
recreate_all_dataset_indexes
():
click
.
echo
(
click
.
style
(
'Start recreate all dataset indexes.'
,
fg
=
'green'
))
recreate_count
=
0
page
=
1
while
True
:
try
:
datasets
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
indexing_technique
==
'high_quality'
)
\
.
order_by
(
Dataset
.
created_at
.
desc
())
.
paginate
(
page
=
page
,
per_page
=
50
)
except
NotFound
:
break
page
+=
1
for
dataset
in
datasets
:
try
:
click
.
echo
(
'Recreating dataset index: {}'
.
format
(
dataset
.
id
))
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
and
index
.
_is_origin
():
index
.
recreate_dataset
(
dataset
)
recreate_count
+=
1
else
:
click
.
echo
(
'passed.'
)
except
Exception
as
e
:
click
.
echo
(
click
.
style
(
'Recreate dataset index error: {} {}'
.
format
(
e
.
__class__
.
__name__
,
str
(
e
)),
fg
=
'red'
))
continue
click
.
echo
(
click
.
style
(
'Congratulations! Recreate {} dataset indexes.'
.
format
(
recreate_count
),
fg
=
'green'
))
def
register_commands
(
app
):
def
register_commands
(
app
):
app
.
cli
.
add_command
(
reset_password
)
app
.
cli
.
add_command
(
reset_password
)
app
.
cli
.
add_command
(
reset_email
)
app
.
cli
.
add_command
(
reset_email
)
app
.
cli
.
add_command
(
generate_invitation_codes
)
app
.
cli
.
add_command
(
generate_invitation_codes
)
app
.
cli
.
add_command
(
reset_encrypt_key_pair
)
app
.
cli
.
add_command
(
reset_encrypt_key_pair
)
app
.
cli
.
add_command
(
recreate_all_dataset_indexes
)
api/config.py
View file @
eea011bd
...
@@ -187,11 +187,13 @@ class Config:
...
@@ -187,11 +187,13 @@ class Config:
# For temp use only
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
# set default LLM provider, default is 'openai', support `azure_openai`
self
.
DEFAULT_LLM_PROVIDER
=
get_env
(
'DEFAULT_LLM_PROVIDER'
)
self
.
DEFAULT_LLM_PROVIDER
=
get_env
(
'DEFAULT_LLM_PROVIDER'
)
# notion import setting
# notion import setting
self
.
NOTION_CLIENT_ID
=
get_env
(
'NOTION_CLIENT_ID'
)
self
.
NOTION_CLIENT_ID
=
get_env
(
'NOTION_CLIENT_ID'
)
self
.
NOTION_CLIENT_SECRET
=
get_env
(
'NOTION_CLIENT_SECRET'
)
self
.
NOTION_CLIENT_SECRET
=
get_env
(
'NOTION_CLIENT_SECRET'
)
self
.
NOTION_INTEGRATION_TYPE
=
get_env
(
'NOTION_INTEGRATION_TYPE'
)
self
.
NOTION_INTEGRATION_TYPE
=
get_env
(
'NOTION_INTEGRATION_TYPE'
)
self
.
NOTION_INTERNAL_SECRET
=
get_env
(
'NOTION_INTERNAL_SECRET'
)
self
.
NOTION_INTERNAL_SECRET
=
get_env
(
'NOTION_INTERNAL_SECRET'
)
self
.
NOTION_INTEGRATION_TOKEN
=
get_env
(
'NOTION_INTEGRATION_TOKEN'
)
class
CloudEditionConfig
(
Config
):
class
CloudEditionConfig
(
Config
):
...
...
api/controllers/console/datasets/data_source.py
View file @
eea011bd
...
@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
...
@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
from
controllers.console
import
api
from
controllers.console
import
api
from
controllers.console.setup
import
setup_required
from
controllers.console.setup
import
setup_required
from
controllers.console.wraps
import
account_initialization_required
from
controllers.console.wraps
import
account_initialization_required
from
core.data_
source.notion
import
NotionPageRe
ader
from
core.data_
loader.loader.notion
import
NotionLo
ader
from
core.indexing_runner
import
IndexingRunner
from
core.indexing_runner
import
IndexingRunner
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
libs.helper
import
TimestampField
from
libs.helper
import
TimestampField
from
libs.oauth_data_source
import
NotionOAuth
from
models.dataset
import
Document
from
models.dataset
import
Document
from
models.source
import
DataSourceBinding
from
models.source
import
DataSourceBinding
from
services.dataset_service
import
DatasetService
,
DocumentService
from
services.dataset_service
import
DatasetService
,
DocumentService
...
@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
...
@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
)
.
first
()
)
.
first
()
if
not
data_source_binding
:
if
not
data_source_binding
:
raise
NotFound
(
'Data source binding not found.'
)
raise
NotFound
(
'Data source binding not found.'
)
reader
=
NotionPageReader
(
integration_token
=
data_source_binding
.
access_token
)
if
page_type
==
'page'
:
loader
=
NotionLoader
(
page_content
=
reader
.
read_page
(
page_id
)
notion_access_token
=
data_source_binding
.
access_token
,
elif
page_type
==
'database'
:
notion_workspace_id
=
workspace_id
,
page_content
=
reader
.
query_database_data
(
page_id
)
notion_obj_id
=
page_id
,
else
:
notion_page_type
=
page_type
page_content
=
""
)
text_docs
=
loader
.
load
()
return
{
return
{
'content'
:
page_content
'content'
:
"
\n
"
.
join
([
doc
.
page_content
for
doc
in
text_docs
])
},
200
},
200
@
setup_required
@
setup_required
...
...
api/controllers/console/datasets/file.py
View file @
eea011bd
...
@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
...
@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
UnsupportedFileTypeError
UnsupportedFileTypeError
from
controllers.console.setup
import
setup_required
from
controllers.console.setup
import
setup_required
from
controllers.console.wraps
import
account_initialization_required
from
controllers.console.wraps
import
account_initialization_required
from
core.index.readers.html_parser
import
HTMLParser
from
core.data_loader.file_extractor
import
FileExtractor
from
core.index.readers.pdf_parser
import
PDFParser
from
core.index.readers.xlsx_parser
import
XLSXParser
from
extensions.ext_storage
import
storage
from
extensions.ext_storage
import
storage
from
libs.helper
import
TimestampField
from
libs.helper
import
TimestampField
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
...
@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
...
@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
if
extension
not
in
ALLOWED_EXTENSIONS
:
if
extension
not
in
ALLOWED_EXTENSIONS
:
raise
UnsupportedFileTypeError
()
raise
UnsupportedFileTypeError
()
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
text
=
FileExtractor
.
load
(
upload_file
,
return_text
=
True
)
suffix
=
Path
(
upload_file
.
key
)
.
suffix
filepath
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage
.
download
(
upload_file
.
key
,
filepath
)
if
extension
==
'pdf'
:
parser
=
PDFParser
({
'upload_file'
:
upload_file
})
text
=
parser
.
parse_file
(
Path
(
filepath
))
elif
extension
in
[
'html'
,
'htm'
]:
# Use BeautifulSoup to extract text
parser
=
HTMLParser
()
text
=
parser
.
parse_file
(
Path
(
filepath
))
elif
extension
==
'xlsx'
:
parser
=
XLSXParser
()
text
=
parser
.
parse_file
(
filepath
)
else
:
# ['txt', 'markdown', 'md']
with
open
(
filepath
,
"rb"
)
as
fp
:
data
=
fp
.
read
()
encoding
=
chardet
.
detect
(
data
)[
'encoding'
]
if
encoding
:
text
=
data
.
decode
(
encoding
=
encoding
)
.
strip
()
if
data
else
''
else
:
text
=
data
.
decode
(
encoding
=
'utf-8'
)
.
strip
()
if
data
else
''
text
=
text
[
0
:
PREVIEW_WORDS_LIMIT
]
if
text
else
''
text
=
text
[
0
:
PREVIEW_WORDS_LIMIT
]
if
text
else
''
return
{
'content'
:
text
}
return
{
'content'
:
text
}
...
...
api/controllers/console/version.py
View file @
eea011bd
...
@@ -32,8 +32,13 @@ class VersionApi(Resource):
...
@@ -32,8 +32,13 @@ class VersionApi(Resource):
'current_version'
:
args
.
get
(
'current_version'
)
'current_version'
:
args
.
get
(
'current_version'
)
})
})
except
Exception
as
error
:
except
Exception
as
error
:
logging
.
exception
(
"Check update error."
)
logging
.
warning
(
"Check update version error: {}."
.
format
(
str
(
error
)))
raise
InternalServerError
()
return
{
'version'
:
args
.
get
(
'current_version'
),
'release_date'
:
''
,
'release_notes'
:
''
,
'can_auto_update'
:
False
}
content
=
json
.
loads
(
response
.
content
)
content
=
json
.
loads
(
response
.
content
)
return
{
return
{
...
...
api/core/__init__.py
View file @
eea011bd
...
@@ -3,19 +3,11 @@ from typing import Optional
...
@@ -3,19 +3,11 @@ from typing import Optional
import
langchain
import
langchain
from
flask
import
Flask
from
flask
import
Flask
from
jieba.analyse
import
default_tfidf
from
langchain
import
set_handler
from
langchain.prompts.base
import
DEFAULT_FORMATTER_MAPPING
from
langchain.prompts.base
import
DEFAULT_FORMATTER_MAPPING
from
llama_index
import
IndexStructType
,
QueryMode
from
llama_index.indices.registry
import
INDEX_STRUT_TYPE_TO_QUERY_MAP
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.index.keyword_table.jieba_keyword_table
import
GPTJIEBAKeywordTableIndex
from
core.index.keyword_table.stopwords
import
STOPWORDS
from
core.prompt.prompt_template
import
OneLineFormatter
from
core.prompt.prompt_template
import
OneLineFormatter
from
core.vector_store.vector_store
import
VectorStore
from
core.vector_store.vector_store_index_query
import
EnhanceGPTVectorStoreIndexQuery
class
HostedOpenAICredential
(
BaseModel
):
class
HostedOpenAICredential
(
BaseModel
):
...
@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials()
...
@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials()
def
init_app
(
app
:
Flask
):
def
init_app
(
app
:
Flask
):
formatter
=
OneLineFormatter
()
formatter
=
OneLineFormatter
()
DEFAULT_FORMATTER_MAPPING
[
'f-string'
]
=
formatter
.
format
DEFAULT_FORMATTER_MAPPING
[
'f-string'
]
=
formatter
.
format
INDEX_STRUT_TYPE_TO_QUERY_MAP
[
IndexStructType
.
KEYWORD_TABLE
]
=
GPTJIEBAKeywordTableIndex
.
get_query_map
()
INDEX_STRUT_TYPE_TO_QUERY_MAP
[
IndexStructType
.
WEAVIATE
]
=
{
QueryMode
.
DEFAULT
:
EnhanceGPTVectorStoreIndexQuery
,
QueryMode
.
EMBEDDING
:
EnhanceGPTVectorStoreIndexQuery
,
}
INDEX_STRUT_TYPE_TO_QUERY_MAP
[
IndexStructType
.
QDRANT
]
=
{
QueryMode
.
DEFAULT
:
EnhanceGPTVectorStoreIndexQuery
,
QueryMode
.
EMBEDDING
:
EnhanceGPTVectorStoreIndexQuery
,
}
default_tfidf
.
stop_words
=
STOPWORDS
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
langchain
.
verbose
=
True
langchain
.
verbose
=
True
set_handler
(
DifyStdOutCallbackHandler
())
if
app
.
config
.
get
(
"OPENAI_API_KEY"
):
if
app
.
config
.
get
(
"OPENAI_API_KEY"
):
hosted_llm_credentials
.
openai
=
HostedOpenAICredential
(
api_key
=
app
.
config
.
get
(
"OPENAI_API_KEY"
))
hosted_llm_credentials
.
openai
=
HostedOpenAICredential
(
api_key
=
app
.
config
.
get
(
"OPENAI_API_KEY"
))
api/core/agent/agent_builder.py
View file @
eea011bd
...
@@ -2,7 +2,7 @@ from typing import Optional
...
@@ -2,7 +2,7 @@ from typing import Optional
from
langchain
import
LLMChain
from
langchain
import
LLMChain
from
langchain.agents
import
ZeroShotAgent
,
AgentExecutor
,
ConversationalAgent
from
langchain.agents
import
ZeroShotAgent
,
AgentExecutor
,
ConversationalAgent
from
langchain.callbacks
import
CallbackManager
from
langchain.callbacks
.manager
import
CallbackManager
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
...
@@ -16,23 +16,20 @@ class AgentBuilder:
...
@@ -16,23 +16,20 @@ class AgentBuilder:
def
to_agent_chain
(
cls
,
tenant_id
:
str
,
tools
,
memory
:
Optional
[
BaseChatMemory
],
def
to_agent_chain
(
cls
,
tenant_id
:
str
,
tools
,
memory
:
Optional
[
BaseChatMemory
],
dataset_tool_callback_handler
:
DatasetToolCallbackHandler
,
dataset_tool_callback_handler
:
DatasetToolCallbackHandler
,
agent_loop_gather_callback_handler
:
AgentLoopGatherCallbackHandler
):
agent_loop_gather_callback_handler
:
AgentLoopGatherCallbackHandler
):
llm_callback_manager
=
CallbackManager
([
agent_loop_gather_callback_handler
,
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
model_name
=
agent_loop_gather_callback_handler
.
model_name
,
model_name
=
agent_loop_gather_callback_handler
.
model_name
,
temperature
=
0
,
temperature
=
0
,
max_tokens
=
1024
,
max_tokens
=
1024
,
callback
_manager
=
llm_callback_manager
callback
s
=
[
agent_loop_gather_callback_handler
,
DifyStdOutCallbackHandler
()]
)
)
tool_callback_manager
=
CallbackManager
([
agent_loop_gather_callback_handler
,
dataset_tool_callback_handler
,
DifyStdOutCallbackHandler
()
])
for
tool
in
tools
:
for
tool
in
tools
:
tool
.
callback_manager
=
tool_callback_manager
tool
.
callbacks
=
[
agent_loop_gather_callback_handler
,
dataset_tool_callback_handler
,
DifyStdOutCallbackHandler
()
]
prompt
=
cls
.
build_agent_prompt_template
(
prompt
=
cls
.
build_agent_prompt_template
(
tools
=
tools
,
tools
=
tools
,
...
@@ -54,7 +51,7 @@ class AgentBuilder:
...
@@ -54,7 +51,7 @@ class AgentBuilder:
tools
=
tools
,
tools
=
tools
,
agent
=
agent
,
agent
=
agent
,
memory
=
memory
,
memory
=
memory
,
callback
_manager
=
agent_callback_manager
,
callback
s
=
agent_callback_manager
,
max_iterations
=
6
,
max_iterations
=
6
,
early_stopping_method
=
"generate"
,
early_stopping_method
=
"generate"
,
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
...
...
api/core/callback_handler/agent_loop_gather_callback_handler.py
View file @
eea011bd
...
@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
...
@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
class
AgentLoopGatherCallbackHandler
(
BaseCallbackHandler
):
class
AgentLoopGatherCallbackHandler
(
BaseCallbackHandler
):
"""Callback Handler that prints to std out."""
"""Callback Handler that prints to std out."""
raise_error
:
bool
=
True
def
__init__
(
self
,
model_name
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
def
__init__
(
self
,
model_name
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
"""Initialize callback handler."""
"""Initialize callback handler."""
...
@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_current_loop
.
completion
=
response
.
generations
[
0
][
0
]
.
text
self
.
_current_loop
.
completion
=
response
.
generations
[
0
][
0
]
.
text
self
.
_current_loop
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
self
.
_current_loop
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
def
on_llm_new_token
(
self
,
token
:
str
,
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
pass
def
on_llm_error
(
def
on_llm_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
)
->
None
:
...
@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_agent_loops
=
[]
self
.
_agent_loops
=
[]
self
.
_current_loop
=
None
self
.
_current_loop
=
None
def
on_chain_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we are entering a chain."""
pass
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we finished a chain."""
pass
def
on_chain_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
logging
.
error
(
error
)
def
on_tool_start
(
def
on_tool_start
(
self
,
self
,
serialized
:
Dict
[
str
,
Any
],
serialized
:
Dict
[
str
,
Any
],
...
@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_agent_loops
=
[]
self
.
_agent_loops
=
[]
self
.
_current_loop
=
None
self
.
_current_loop
=
None
def
on_text
(
self
,
text
:
str
,
color
:
Optional
[
str
]
=
None
,
end
:
str
=
""
,
**
kwargs
:
Optional
[
str
],
)
->
None
:
"""Run on additional input from chains and agents."""
pass
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
**
kwargs
:
Any
)
->
Any
:
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
**
kwargs
:
Any
)
->
Any
:
"""Run on agent end."""
"""Run on agent end."""
# Final Answer
# Final Answer
...
...
api/core/callback_handler/dataset_tool_callback_handler.py
View file @
eea011bd
...
@@ -3,7 +3,6 @@ import logging
...
@@ -3,7 +3,6 @@ import logging
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
from
core.callback_handler.entity.dataset_query
import
DatasetQueryObj
from
core.callback_handler.entity.dataset_query
import
DatasetQueryObj
from
core.conversation_message_task
import
ConversationMessageTask
from
core.conversation_message_task
import
ConversationMessageTask
...
@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
...
@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
class
DatasetToolCallbackHandler
(
BaseCallbackHandler
):
class
DatasetToolCallbackHandler
(
BaseCallbackHandler
):
"""Callback Handler that prints to std out."""
"""Callback Handler that prints to std out."""
raise_error
:
bool
=
True
def
__init__
(
self
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
def
__init__
(
self
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
"""Initialize callback handler."""
"""Initialize callback handler."""
...
@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
...
@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
)
->
None
:
)
->
None
:
"""Do nothing."""
"""Do nothing."""
logging
.
error
(
error
)
logging
.
error
(
error
)
def
on_chain_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_chain_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_new_token
(
self
,
token
:
str
,
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
pass
def
on_llm_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
logging
.
error
(
error
)
def
on_agent_action
(
self
,
action
:
AgentAction
,
color
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
)
->
Any
:
pass
def
on_text
(
self
,
text
:
str
,
color
:
Optional
[
str
]
=
None
,
end
:
str
=
""
,
**
kwargs
:
Optional
[
str
],
)
->
None
:
"""Run on additional input from chains and agents."""
pass
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
**
kwargs
:
Any
)
->
Any
:
"""Run on agent end."""
pass
api/core/callback_handler/index_tool_callback_handler.py
View file @
eea011bd
from
llama_index
import
Response
from
typing
import
List
from
langchain.schema
import
Document
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
from
models.dataset
import
DocumentSegment
class
IndexToolCallbackHandler
:
class
DatasetIndexToolCallbackHandler
:
def
__init__
(
self
)
->
None
:
self
.
_response
=
None
@
property
def
response
(
self
)
->
Response
:
return
self
.
_response
def
on_tool_end
(
self
,
response
:
Response
)
->
None
:
"""Handle tool end."""
self
.
_response
=
response
class
DatasetIndexToolCallbackHandler
(
IndexToolCallbackHandler
):
"""Callback handler for dataset tool."""
"""Callback handler for dataset tool."""
def
__init__
(
self
,
dataset_id
:
str
)
->
None
:
def
__init__
(
self
,
dataset_id
:
str
)
->
None
:
super
()
.
__init__
()
self
.
dataset_id
=
dataset_id
self
.
dataset_id
=
dataset_id
def
on_tool_end
(
self
,
response
:
Response
)
->
None
:
def
on_tool_end
(
self
,
documents
:
List
[
Document
]
)
->
None
:
"""Handle tool end."""
"""Handle tool end."""
for
node
in
response
.
source_node
s
:
for
document
in
document
s
:
index_node_id
=
node
.
node
.
doc_id
doc_id
=
document
.
metadata
[
'doc_id'
]
# add hit count to document segment
# add hit count to document segment
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
dataset_id
,
DocumentSegment
.
dataset_id
==
self
.
dataset_id
,
DocumentSegment
.
index_node_id
==
index_node
_id
DocumentSegment
.
index_node_id
==
doc
_id
)
.
update
(
)
.
update
(
{
DocumentSegment
.
hit_count
:
DocumentSegment
.
hit_count
+
1
},
{
DocumentSegment
.
hit_count
:
DocumentSegment
.
hit_count
+
1
},
synchronize_session
=
False
synchronize_session
=
False
...
...
api/core/callback_handler/llm_callback_handler.py
View file @
eea011bd
...
@@ -3,7 +3,7 @@ import time
...
@@ -3,7 +3,7 @@ import time
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
HumanMessage
,
AIMessage
,
SystemMessage
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
HumanMessage
,
AIMessage
,
SystemMessage
,
BaseMessage
from
core.callback_handler.entity.llm_message
import
LLMMessage
from
core.callback_handler.entity.llm_message
import
LLMMessage
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
...
@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
...
@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
class
LLMCallbackHandler
(
BaseCallbackHandler
):
class
LLMCallbackHandler
(
BaseCallbackHandler
):
raise_error
:
bool
=
True
def
__init__
(
self
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
def
__init__
(
self
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
conversation_message_task
:
ConversationMessageTask
):
conversation_message_task
:
ConversationMessageTask
):
...
@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
...
@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Whether to call verbose callbacks even if verbose is False."""
"""Whether to call verbose callbacks even if verbose is False."""
return
True
return
True
def
on_chat_model_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
messages
:
List
[
List
[
BaseMessage
]],
**
kwargs
:
Any
)
->
Any
:
self
.
start_at
=
time
.
perf_counter
()
real_prompts
=
[]
for
message
in
messages
[
0
]:
if
message
.
type
==
'human'
:
role
=
'user'
elif
message
.
type
==
'ai'
:
role
=
'assistant'
else
:
role
=
'system'
real_prompts
.
append
({
"role"
:
role
,
"text"
:
message
.
content
})
self
.
llm_message
.
prompt
=
real_prompts
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_messages_tokens
(
messages
[
0
])
def
on_llm_start
(
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
)
->
None
:
self
.
start_at
=
time
.
perf_counter
()
self
.
start_at
=
time
.
perf_counter
()
if
'Chat'
in
serialized
[
'name'
]:
self
.
llm_message
.
prompt
=
[{
real_prompts
=
[]
"role"
:
'user'
,
messages
=
[]
"text"
:
prompts
[
0
]
for
prompt
in
prompts
:
}]
role
,
content
=
prompt
.
split
(
': '
,
maxsplit
=
1
)
if
role
==
'human'
:
role
=
'user'
message
=
HumanMessage
(
content
=
content
)
elif
role
==
'ai'
:
role
=
'assistant'
message
=
AIMessage
(
content
=
content
)
else
:
message
=
SystemMessage
(
content
=
content
)
real_prompt
=
{
"role"
:
role
,
"text"
:
content
}
real_prompts
.
append
(
real_prompt
)
messages
.
append
(
message
)
self
.
llm_message
.
prompt
=
real_prompts
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_messages_tokens
(
messages
)
else
:
self
.
llm_message
.
prompt
=
[{
"role"
:
'user'
,
"text"
:
prompts
[
0
]
}]
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_num_tokens
(
prompts
[
0
])
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_num_tokens
(
prompts
[
0
])
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
...
@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
...
@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
self
.
conversation_message_task
.
save_message
(
llm_message
=
self
.
llm_message
,
by_stopped
=
True
)
self
.
conversation_message_task
.
save_message
(
llm_message
=
self
.
llm_message
,
by_stopped
=
True
)
else
:
else
:
logging
.
error
(
error
)
logging
.
error
(
error
)
def
on_chain_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_chain_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_tool_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
input_str
:
str
,
**
kwargs
:
Any
,
)
->
None
:
pass
def
on_agent_action
(
self
,
action
:
AgentAction
,
color
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
)
->
Any
:
pass
def
on_tool_end
(
self
,
output
:
str
,
color
:
Optional
[
str
]
=
None
,
observation_prefix
:
Optional
[
str
]
=
None
,
llm_prefix
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
)
->
None
:
pass
def
on_tool_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_text
(
self
,
text
:
str
,
color
:
Optional
[
str
]
=
None
,
end
:
str
=
""
,
**
kwargs
:
Optional
[
str
],
)
->
None
:
pass
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
color
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
)
->
None
:
pass
api/core/callback_handler/main_chain_gather_callback_handler.py
View file @
eea011bd
import
logging
import
logging
import
time
import
time
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
typing
import
Any
,
Dict
,
Union
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.entity.chain_result
import
ChainResult
from
core.callback_handler.entity.chain_result
import
ChainResult
...
@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
...
@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
class
MainChainGatherCallbackHandler
(
BaseCallbackHandler
):
class
MainChainGatherCallbackHandler
(
BaseCallbackHandler
):
"""Callback Handler that prints to std out."""
"""Callback Handler that prints to std out."""
raise_error
:
bool
=
True
def
__init__
(
self
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
def
__init__
(
self
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
"""Initialize callback handler."""
"""Initialize callback handler."""
...
@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
...
@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
)
->
None
:
)
->
None
:
"""Print out that we are entering a chain."""
"""Print out that we are entering a chain."""
if
not
self
.
_current_chain_result
:
if
not
self
.
_current_chain_result
:
self
.
_current_chain_result
=
ChainResult
(
chain_type
=
serialized
[
'id'
][
-
1
]
type
=
serialized
[
'name'
],
if
chain_type
:
prompt
=
inputs
,
self
.
_current_chain_result
=
ChainResult
(
started_at
=
time
.
perf_counter
()
type
=
chain_type
,
)
prompt
=
inputs
,
self
.
_current_chain_message
=
self
.
conversation_message_task
.
init_chain
(
self
.
_current_chain_result
)
started_at
=
time
.
perf_counter
()
self
.
agent_loop_gather_callback_handler
.
current_chain
=
self
.
_current_chain_message
)
self
.
_current_chain_message
=
self
.
conversation_message_task
.
init_chain
(
self
.
_current_chain_result
)
self
.
agent_loop_gather_callback_handler
.
current_chain
=
self
.
_current_chain_message
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we finished a chain."""
"""Print out that we finished a chain."""
...
@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
...
@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
)
->
None
:
logging
.
error
(
error
)
logging
.
error
(
error
)
self
.
clear_chain_results
()
self
.
clear_chain_results
()
\ No newline at end of file
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_new_token
(
self
,
token
:
str
,
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
pass
def
on_llm_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
logging
.
error
(
error
)
def
on_tool_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
input_str
:
str
,
**
kwargs
:
Any
,
)
->
None
:
pass
def
on_agent_action
(
self
,
action
:
AgentAction
,
color
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
)
->
Any
:
pass
def
on_tool_end
(
self
,
output
:
str
,
color
:
Optional
[
str
]
=
None
,
observation_prefix
:
Optional
[
str
]
=
None
,
llm_prefix
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
)
->
None
:
pass
def
on_tool_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
logging
.
error
(
error
)
def
on_text
(
self
,
text
:
str
,
color
:
Optional
[
str
]
=
None
,
end
:
str
=
""
,
**
kwargs
:
Optional
[
str
],
)
->
None
:
"""Run on additional input from chains and agents."""
pass
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
**
kwargs
:
Any
)
->
Any
:
"""Run on agent end."""
pass
api/core/callback_handler/std_out_callback_handler.py
View file @
eea011bd
import
os
import
sys
import
sys
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.input
import
print_text
from
langchain.input
import
print_text
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
BaseMessage
class
DifyStdOutCallbackHandler
(
BaseCallbackHandler
):
class
DifyStdOutCallbackHandler
(
BaseCallbackHandler
):
...
@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
...
@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Initialize callback handler."""
"""Initialize callback handler."""
self
.
color
=
color
self
.
color
=
color
def
on_chat_model_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
messages
:
List
[
List
[
BaseMessage
]],
**
kwargs
:
Any
)
->
Any
:
print_text
(
"
\n
[on_chat_model_start]
\n
"
,
color
=
'blue'
)
for
sub_messages
in
messages
:
for
sub_message
in
sub_messages
:
print_text
(
str
(
sub_message
)
+
"
\n
"
,
color
=
'blue'
)
def
on_llm_start
(
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
)
->
None
:
"""Print out the prompts."""
"""Print out the prompts."""
print_text
(
"
\n
[on_llm_start]
\n
"
,
color
=
'blue'
)
print_text
(
"
\n
[on_llm_start]
\n
"
,
color
=
'blue'
)
print_text
(
prompts
[
0
]
+
"
\n
"
,
color
=
'blue'
)
if
'Chat'
in
serialized
[
'name'
]:
for
prompt
in
prompts
:
print_text
(
prompt
+
"
\n
"
,
color
=
'blue'
)
else
:
print_text
(
prompts
[
0
]
+
"
\n
"
,
color
=
'blue'
)
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
"""Do nothing."""
...
@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
...
@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
)
->
None
:
"""Print out that we are entering a chain."""
"""Print out that we are entering a chain."""
c
lass_name
=
serialized
[
"name"
]
c
hain_type
=
serialized
[
'id'
][
-
1
]
print_text
(
"
\n
[on_chain_start]
\n
Chain: "
+
c
lass_nam
e
+
"
\n
Inputs: "
+
str
(
inputs
)
+
"
\n
"
,
color
=
'pink'
)
print_text
(
"
\n
[on_chain_start]
\n
Chain: "
+
c
hain_typ
e
+
"
\n
Inputs: "
+
str
(
inputs
)
+
"
\n
"
,
color
=
'pink'
)
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we finished a chain."""
"""Print out that we finished a chain."""
...
@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
...
@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent end."""
"""Run on agent end."""
print_text
(
"[on_agent_finish] "
+
finish
.
return_values
[
'output'
]
+
"
\n
"
,
color
=
'green'
,
end
=
"
\n
"
)
print_text
(
"[on_agent_finish] "
+
finish
.
return_values
[
'output'
]
+
"
\n
"
,
color
=
'green'
,
end
=
"
\n
"
)
@
property
def
ignore_llm
(
self
)
->
bool
:
"""Whether to ignore LLM callbacks."""
return
not
os
.
environ
.
get
(
"DEBUG"
)
or
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
!=
'true'
@
property
def
ignore_chain
(
self
)
->
bool
:
"""Whether to ignore chain callbacks."""
return
not
os
.
environ
.
get
(
"DEBUG"
)
or
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
!=
'true'
@
property
def
ignore_agent
(
self
)
->
bool
:
"""Whether to ignore agent callbacks."""
return
not
os
.
environ
.
get
(
"DEBUG"
)
or
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
!=
'true'
@
property
def
ignore_chat_model
(
self
)
->
bool
:
"""Whether to ignore chat model callbacks."""
return
not
os
.
environ
.
get
(
"DEBUG"
)
or
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
!=
'true'
class
DifyStreamingStdOutCallbackHandler
(
DifyStdOutCallbackHandler
):
class
DifyStreamingStdOutCallbackHandler
(
DifyStdOutCallbackHandler
):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
"""Callback handler for streaming. Only works with LLMs that support streaming."""
...
...
api/core/chain/chain_builder.py
View file @
eea011bd
from
typing
import
Optional
from
typing
import
Optional
from
langchain.callbacks
import
CallbackManager
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.tool_chain
import
ToolChain
from
core.chain.tool_chain
import
ToolChain
...
@@ -14,7 +12,7 @@ class ChainBuilder:
...
@@ -14,7 +12,7 @@ class ChainBuilder:
tool
=
tool
,
tool
=
tool
,
input_key
=
kwargs
.
get
(
'input_key'
,
'input'
),
input_key
=
kwargs
.
get
(
'input_key'
,
'input'
),
output_key
=
kwargs
.
get
(
'output_key'
,
'tool_output'
),
output_key
=
kwargs
.
get
(
'output_key'
,
'tool_output'
),
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
)
@
classmethod
@
classmethod
...
@@ -27,7 +25,7 @@ class ChainBuilder:
...
@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words
=
sensitive_words
.
split
(
","
),
sensitive_words
=
sensitive_words
.
split
(
","
),
canned_response
=
tool_config
.
get
(
"canned_response"
,
''
),
canned_response
=
tool_config
.
get
(
"canned_response"
,
''
),
output_key
=
"sensitive_word_avoidance_output"
,
output_key
=
"sensitive_word_avoidance_output"
,
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
,
callback
s
=
[
DifyStdOutCallbackHandler
()]
,
**
kwargs
**
kwargs
)
)
...
...
api/core/chain/llm_router_chain.py
View file @
eea011bd
"""Base classes for LLM-powered router chains."""
"""Base classes for LLM-powered router chains."""
from
__future__
import
annotations
from
__future__
import
annotations
import
json
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
pydantic
import
root_validator
from
pydantic
import
root_validator
from
langchain.chains
import
LLMChain
from
langchain.chains
import
LLMChain
from
langchain.prompts
import
BasePromptTemplate
from
langchain.prompts
import
BasePromptTemplate
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
,
BaseLanguageModel
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
from
libs.json_in_md_parser
import
parse_and_check_json_markdown
from
libs.json_in_md_parser
import
parse_and_check_json_markdown
...
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
...
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
raise
ValueError
raise
ValueError
def
_call
(
def
_call
(
self
,
self
,
inputs
:
Dict
[
str
,
Any
]
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
output
=
cast
(
output
=
cast
(
Dict
[
str
,
Any
],
Dict
[
str
,
Any
],
...
...
api/core/chain/main_chain_builder.py
View file @
eea011bd
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
cast
from
langchain.callbacks
import
SharedCallbackManager
,
CallbackManager
from
langchain.chains
import
SequentialChain
from
langchain.chains
import
SequentialChain
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.chain_builder
import
ChainBuilder
from
core.chain.chain_builder
import
ChainBuilder
...
@@ -18,6 +16,7 @@ from models.dataset import Dataset
...
@@ -18,6 +16,7 @@ from models.dataset import Dataset
class
MainChainBuilder
:
class
MainChainBuilder
:
@
classmethod
@
classmethod
def
to_langchain_components
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
def
to_langchain_components
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
conversation_message_task
:
ConversationMessageTask
):
conversation_message_task
:
ConversationMessageTask
):
first_input_key
=
"input"
first_input_key
=
"input"
final_output_key
=
"output"
final_output_key
=
"output"
...
@@ -30,6 +29,7 @@ class MainChainBuilder:
...
@@ -30,6 +29,7 @@ class MainChainBuilder:
tool_chains
,
chains_output_key
=
cls
.
get_agent_chains
(
tool_chains
,
chains_output_key
=
cls
.
get_agent_chains
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
agent_mode
=
agent_mode
,
agent_mode
=
agent_mode
,
rest_tokens
=
rest_tokens
,
memory
=
memory
,
memory
=
memory
,
conversation_message_task
=
conversation_message_task
conversation_message_task
=
conversation_message_task
)
)
...
@@ -42,9 +42,8 @@ class MainChainBuilder:
...
@@ -42,9 +42,8 @@ class MainChainBuilder:
return
None
return
None
for
chain
in
chains
:
for
chain
in
chains
:
# do not add handler into singleton callback manager
chain
=
cast
(
Chain
,
chain
)
if
not
isinstance
(
chain
.
callback_manager
,
SharedCallbackManager
):
chain
.
callbacks
.
append
(
chain_callback_handler
)
chain
.
callback_manager
.
add_handler
(
chain_callback_handler
)
# build main chain
# build main chain
overall_chain
=
SequentialChain
(
overall_chain
=
SequentialChain
(
...
@@ -57,7 +56,9 @@ class MainChainBuilder:
...
@@ -57,7 +56,9 @@ class MainChainBuilder:
return
overall_chain
return
overall_chain
@
classmethod
@
classmethod
def
get_agent_chains
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
def
get_agent_chains
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
rest_tokens
:
int
,
memory
:
Optional
[
BaseChatMemory
],
conversation_message_task
:
ConversationMessageTask
):
conversation_message_task
:
ConversationMessageTask
):
# agent mode
# agent mode
chains
=
[]
chains
=
[]
...
@@ -93,7 +94,8 @@ class MainChainBuilder:
...
@@ -93,7 +94,8 @@ class MainChainBuilder:
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
datasets
=
datasets
,
datasets
=
datasets
,
conversation_message_task
=
conversation_message_task
,
conversation_message_task
=
conversation_message_task
,
callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
rest_tokens
=
rest_tokens
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
)
chains
.
append
(
multi_dataset_router_chain
)
chains
.
append
(
multi_dataset_router_chain
)
...
...
api/core/chain/multi_dataset_router_chain.py
View file @
eea011bd
import
math
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
langchain
import
LLMChain
,
PromptTemplate
,
ConversationChain
from
langchain
import
PromptTemplate
from
langchain.callbacks
import
CallbackManager
from
langchain.callbacks
.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
langchain.schema
import
BaseLanguageModel
from
pydantic
import
Extra
from
pydantic
import
Extra
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
...
@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
...
@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from
core.chain.llm_router_chain
import
LLMRouterChain
,
RouterOutputParser
from
core.chain.llm_router_chain
import
LLMRouterChain
,
RouterOutputParser
from
core.conversation_message_task
import
ConversationMessageTask
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.llm_builder
import
LLMBuilder
from
core.tool.dataset_tool_builder
import
DatasetToolBuilder
from
core.tool.dataset_index_tool
import
DatasetTool
from
core.tool.llama_index_tool
import
EnhanceLlamaIndexTool
from
models.dataset
import
Dataset
,
DatasetProcessRule
from
models.dataset
import
Dataset
DEFAULT_K
=
2
CONTEXT_TOKENS_PERCENT
=
0.3
MULTI_PROMPT_ROUTER_TEMPLATE
=
"""
MULTI_PROMPT_ROUTER_TEMPLATE
=
"""
Given a raw text input to a language model select the model prompt best suited for
\
Given a raw text input to a language model select the model prompt best suited for
\
the input. You will be given the names of the available prompts and a description of
\
the input. You will be given the names of the available prompts and a description of
\
...
@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
...
@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
router_chain
:
LLMRouterChain
router_chain
:
LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
"""Chain for deciding a destination chain and the input to it."""
dataset_tools
:
Mapping
[
str
,
EnhanceLlamaIndex
Tool
]
dataset_tools
:
Mapping
[
str
,
Dataset
Tool
]
"""Map of name to candidate chains that inputs can be routed to."""
"""Map of name to candidate chains that inputs can be routed to."""
class
Config
:
class
Config
:
...
@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
...
@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
tenant_id
:
str
,
tenant_id
:
str
,
datasets
:
List
[
Dataset
],
datasets
:
List
[
Dataset
],
conversation_message_task
:
ConversationMessageTask
,
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
):
):
"""Convenience constructor for instantiating from destination prompts."""
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
model_name
=
'gpt-3.5-turbo'
,
temperature
=
0
,
temperature
=
0
,
max_tokens
=
1024
,
max_tokens
=
1024
,
callback
_manager
=
llm_callback_manager
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
)
destinations
=
[
"
{}
: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
destinations
=
[
"
[[{}]]
: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
else
(
'useful for when you want to answer queries about the '
+
d
.
name
))
else
(
'useful for when you want to answer queries about the '
+
d
.
name
))
for
d
in
datasets
]
for
d
in
datasets
]
destinations_str
=
"
\n
"
.
join
(
destinations
)
destinations_str
=
"
\n
"
.
join
(
destinations
)
router_template
=
MULTI_PROMPT_ROUTER_TEMPLATE
.
format
(
router_template
=
MULTI_PROMPT_ROUTER_TEMPLATE
.
format
(
destinations
=
destinations_str
destinations
=
destinations_str
)
)
router_prompt
=
PromptTemplate
(
router_prompt
=
PromptTemplate
(
template
=
router_template
,
template
=
router_template
,
input_variables
=
[
"input"
],
input_variables
=
[
"input"
],
output_parser
=
RouterOutputParser
(),
output_parser
=
RouterOutputParser
(),
)
)
router_chain
=
LLMRouterChain
.
from_llm
(
llm
,
router_prompt
)
router_chain
=
LLMRouterChain
.
from_llm
(
llm
,
router_prompt
)
dataset_tools
=
{}
dataset_tools
=
{}
for
dataset
in
datasets
:
for
dataset
in
datasets
:
dataset_tool
=
DatasetToolBuilder
.
build_dataset_tool
(
# fulfill description when it is empty
if
dataset
.
available_document_count
==
0
or
dataset
.
available_document_count
==
0
:
continue
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
k
=
cls
.
_dynamic_calc_retrieve_k
(
dataset
,
rest_tokens
)
if
k
==
0
:
continue
dataset_tool
=
DatasetTool
(
name
=
f
"dataset-{dataset.id}"
,
description
=
description
,
k
=
k
,
dataset
=
dataset
,
dataset
=
dataset
,
response_mode
=
'no_synthesizer'
,
# "compact"
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
callback_handler
=
DatasetToolCallbackHandler
(
conversation_message_task
)
)
)
if
dataset_tool
:
dataset_tools
[
str
(
dataset
.
id
)]
=
dataset_tool
dataset_tools
[
dataset
.
id
]
=
dataset_tool
return
cls
(
return
cls
(
router_chain
=
router_chain
,
router_chain
=
router_chain
,
...
@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
...
@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
**
kwargs
,
**
kwargs
,
)
)
@
classmethod
def
_dynamic_calc_retrieve_k
(
cls
,
dataset
:
Dataset
,
rest_tokens
:
int
)
->
int
:
processing_rule
=
dataset
.
latest_process_rule
if
not
processing_rule
:
return
DEFAULT_K
if
processing_rule
.
mode
==
"custom"
:
rules
=
processing_rule
.
rules_dict
if
not
rules
:
return
DEFAULT_K
segmentation
=
rules
[
"segmentation"
]
segment_max_tokens
=
segmentation
[
"max_tokens"
]
else
:
segment_max_tokens
=
DatasetProcessRule
.
AUTOMATIC_RULES
[
'segmentation'
][
'max_tokens'
]
# when rest_tokens is less than default context tokens
if
rest_tokens
<
segment_max_tokens
*
DEFAULT_K
:
return
rest_tokens
//
segment_max_tokens
context_limit_tokens
=
math
.
floor
(
rest_tokens
*
CONTEXT_TOKENS_PERCENT
)
# when context_limit_tokens is less than default context tokens, use default_k
if
context_limit_tokens
<=
segment_max_tokens
*
DEFAULT_K
:
return
DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return
context_limit_tokens
//
segment_max_tokens
def
_call
(
def
_call
(
self
,
self
,
inputs
:
Dict
[
str
,
Any
]
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
if
len
(
self
.
dataset_tools
)
==
0
:
if
len
(
self
.
dataset_tools
)
==
0
:
return
{
"text"
:
''
}
return
{
"text"
:
''
}
...
...
api/core/chain/sensitive_word_avoidance_chain.py
View file @
eea011bd
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
...
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
...
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return
self
.
canned_response
return
self
.
canned_response
return
text
return
text
def
_call
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
text
=
inputs
[
self
.
input_key
]
text
=
inputs
[
self
.
input_key
]
output
=
self
.
_check_sensitive_word
(
text
)
output
=
self
.
_check_sensitive_word
(
text
)
return
{
self
.
output_key
:
output
}
return
{
self
.
output_key
:
output
}
api/core/chain/tool_chain.py
View file @
eea011bd
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
,
AsyncCallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
langchain.tools
import
BaseTool
from
langchain.tools
import
BaseTool
...
@@ -30,12 +31,20 @@ class ToolChain(Chain):
...
@@ -30,12 +31,20 @@ class ToolChain(Chain):
"""
"""
return
[
self
.
output_key
]
return
[
self
.
output_key
]
def
_call
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
input
=
inputs
[
self
.
input_key
]
input
=
inputs
[
self
.
input_key
]
output
=
self
.
tool
.
run
(
input
,
self
.
verbose
)
output
=
self
.
tool
.
run
(
input
,
self
.
verbose
)
return
{
self
.
output_key
:
output
}
return
{
self
.
output_key
:
output
}
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
AsyncCallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Run the logic of this chain and return the output."""
"""Run the logic of this chain and return the output."""
input
=
inputs
[
self
.
input_key
]
input
=
inputs
[
self
.
input_key
]
output
=
await
self
.
tool
.
arun
(
input
,
self
.
verbose
)
output
=
await
self
.
tool
.
arun
(
input
,
self
.
verbose
)
...
...
api/core/completion.py
View file @
eea011bd
import
logging
import
logging
from
typing
import
Optional
,
List
,
Union
,
Tuple
from
typing
import
Optional
,
List
,
Union
,
Tuple
from
langchain.callbacks
import
CallbackManager
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.llms
import
BaseLLM
from
langchain.llms
import
BaseLLM
from
langchain.schema
import
BaseMessage
,
BaseLanguageModel
,
HumanMessage
from
langchain.schema
import
BaseMessage
,
HumanMessage
from
requests.exceptions
import
ChunkedEncodingError
from
requests.exceptions
import
ChunkedEncodingError
from
core.constant
import
llm_constant
from
core.constant
import
llm_constant
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
DifyStdOutCallbackHandler
DifyStdOutCallbackHandler
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
,
PubHandler
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.llm.error
import
LLMBadRequestError
from
core.llm.error
import
LLMBadRequestError
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.llm_builder
import
LLMBuilder
from
core.chain.main_chain_builder
import
MainChainBuilder
from
core.chain.main_chain_builder
import
MainChainBuilder
...
@@ -34,8 +35,6 @@ class Completion:
...
@@ -34,8 +35,6 @@ class Completion:
"""
"""
errors: ProviderTokenNotInitError
errors: ProviderTokenNotInitError
"""
"""
cls
.
validate_query_tokens
(
app
.
tenant_id
,
app_model_config
,
query
)
memory
=
None
memory
=
None
if
conversation
:
if
conversation
:
# get memory of conversation (read-only)
# get memory of conversation (read-only)
...
@@ -48,6 +47,14 @@ class Completion:
...
@@ -48,6 +47,14 @@ class Completion:
inputs
=
conversation
.
inputs
inputs
=
conversation
.
inputs
rest_tokens_for_context_and_memory
=
cls
.
get_validate_rest_tokens
(
mode
=
app
.
mode
,
tenant_id
=
app
.
tenant_id
,
app_model_config
=
app_model_config
,
query
=
query
,
inputs
=
inputs
)
conversation_message_task
=
ConversationMessageTask
(
conversation_message_task
=
ConversationMessageTask
(
task_id
=
task_id
,
task_id
=
task_id
,
app
=
app
,
app
=
app
,
...
@@ -64,6 +71,7 @@ class Completion:
...
@@ -64,6 +71,7 @@ class Completion:
main_chain
=
MainChainBuilder
.
to_langchain_components
(
main_chain
=
MainChainBuilder
.
to_langchain_components
(
tenant_id
=
app
.
tenant_id
,
tenant_id
=
app
.
tenant_id
,
agent_mode
=
app_model_config
.
agent_mode_dict
,
agent_mode
=
app_model_config
.
agent_mode_dict
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
memory
=
ReadOnlyConversationTokenDBStringBufferSharedMemory
(
memory
=
memory
)
if
memory
else
None
,
memory
=
ReadOnlyConversationTokenDBStringBufferSharedMemory
(
memory
=
memory
)
if
memory
else
None
,
conversation_message_task
=
conversation_message_task
conversation_message_task
=
conversation_message_task
)
)
...
@@ -115,7 +123,7 @@ class Completion:
...
@@ -115,7 +123,7 @@ class Completion:
memory
=
memory
memory
=
memory
)
)
final_llm
.
callback
_manager
=
cls
.
get_llm_callback_manager
(
final_llm
,
streaming
,
conversation_message_task
)
final_llm
.
callback
s
=
cls
.
get_llm_callbacks
(
final_llm
,
streaming
,
conversation_message_task
)
cls
.
recale_llm_max_tokens
(
cls
.
recale_llm_max_tokens
(
final_llm
=
final_llm
,
final_llm
=
final_llm
,
...
@@ -247,16 +255,14 @@ And answer according to the language of the user's question.
...
@@ -247,16 +255,14 @@ And answer according to the language of the user's question.
return
messages
,
[
'
\n
Human:'
]
return
messages
,
[
'
\n
Human:'
]
@
classmethod
@
classmethod
def
get_llm_callback
_manager
(
cls
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
def
get_llm_callback
s
(
cls
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
streaming
:
bool
,
streaming
:
bool
,
conversation_message_task
:
ConversationMessageTask
)
->
CallbackManager
:
conversation_message_task
:
ConversationMessageTask
)
->
List
[
BaseCallbackHandler
]
:
llm_callback_handler
=
LLMCallbackHandler
(
llm
,
conversation_message_task
)
llm_callback_handler
=
LLMCallbackHandler
(
llm
,
conversation_message_task
)
if
streaming
:
if
streaming
:
callback_handlers
=
[
llm_callback_handler
,
DifyStreamingStdOutCallbackHandler
()]
return
[
llm_callback_handler
,
DifyStreamingStdOutCallbackHandler
()]
else
:
else
:
callback_handlers
=
[
llm_callback_handler
,
DifyStdOutCallbackHandler
()]
return
[
llm_callback_handler
,
DifyStdOutCallbackHandler
()]
return
CallbackManager
(
callback_handlers
)
@
classmethod
@
classmethod
def
get_history_messages_from_memory
(
cls
,
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
,
def
get_history_messages_from_memory
(
cls
,
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
,
...
@@ -293,7 +299,8 @@ And answer according to the language of the user's question.
...
@@ -293,7 +299,8 @@ And answer according to the language of the user's question.
return
memory
return
memory
@
classmethod
@
classmethod
def
validate_query_tokens
(
cls
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
,
query
:
str
):
def
get_validate_rest_tokens
(
cls
,
mode
:
str
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
,
query
:
str
,
inputs
:
dict
)
->
int
:
llm
=
LLMBuilder
.
to_llm_from_model
(
llm
=
LLMBuilder
.
to_llm_from_model
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
model
=
app_model_config
.
model_dict
model
=
app_model_config
.
model_dict
...
@@ -302,8 +309,26 @@ And answer according to the language of the user's question.
...
@@ -302,8 +309,26 @@ And answer according to the language of the user's question.
model_limited_tokens
=
llm_constant
.
max_context_token_length
[
llm
.
model_name
]
model_limited_tokens
=
llm_constant
.
max_context_token_length
[
llm
.
model_name
]
max_tokens
=
llm
.
max_tokens
max_tokens
=
llm
.
max_tokens
if
model_limited_tokens
-
max_tokens
-
llm
.
get_num_tokens
(
query
)
<
0
:
# get prompt without memory and context
raise
LLMBadRequestError
(
"Query is too long"
)
prompt
,
_
=
cls
.
get_main_llm_prompt
(
mode
=
mode
,
llm
=
llm
,
pre_prompt
=
app_model_config
.
pre_prompt
,
query
=
query
,
inputs
=
inputs
,
chain_output
=
None
,
memory
=
None
)
prompt_tokens
=
llm
.
get_num_tokens
(
prompt
)
if
isinstance
(
prompt
,
str
)
\
else
llm
.
get_num_tokens_from_messages
(
prompt
)
rest_tokens
=
model_limited_tokens
-
max_tokens
-
prompt_tokens
if
rest_tokens
<
0
:
raise
LLMBadRequestError
(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return
rest_tokens
@
classmethod
@
classmethod
def
recale_llm_max_tokens
(
cls
,
final_llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
def
recale_llm_max_tokens
(
cls
,
final_llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
...
@@ -360,7 +385,7 @@ And answer according to the language of the user's question.
...
@@ -360,7 +385,7 @@ And answer according to the language of the user's question.
streaming
=
streaming
streaming
=
streaming
)
)
llm
.
callback
_manager
=
cls
.
get_llm_callback_manager
(
llm
,
streaming
,
conversation_message_task
)
llm
.
callback
s
=
cls
.
get_llm_callbacks
(
llm
,
streaming
,
conversation_message_task
)
cls
.
recale_llm_max_tokens
(
cls
.
recale_llm_max_tokens
(
final_llm
=
llm
,
final_llm
=
llm
,
...
...
api/core/conversation_message_task.py
View file @
eea011bd
...
@@ -293,12 +293,12 @@ class PubHandler:
...
@@ -293,12 +293,12 @@ class PubHandler:
if
not
user
:
if
not
user
:
raise
ValueError
(
"user is required"
)
raise
ValueError
(
"user is required"
)
user_str
=
'account-'
+
user
.
id
if
isinstance
(
user
,
Account
)
else
'end-user-'
+
user
.
id
user_str
=
'account-'
+
str
(
user
.
id
)
if
isinstance
(
user
,
Account
)
else
'end-user-'
+
str
(
user
.
id
)
return
"generate_result:{}-{}"
.
format
(
user_str
,
task_id
)
return
"generate_result:{}-{}"
.
format
(
user_str
,
task_id
)
@
classmethod
@
classmethod
def
generate_stopped_cache_key
(
cls
,
user
:
Union
[
Account
|
EndUser
],
task_id
:
str
):
def
generate_stopped_cache_key
(
cls
,
user
:
Union
[
Account
|
EndUser
],
task_id
:
str
):
user_str
=
'account-'
+
user
.
id
if
isinstance
(
user
,
Account
)
else
'end-user-'
+
user
.
id
user_str
=
'account-'
+
str
(
user
.
id
)
if
isinstance
(
user
,
Account
)
else
'end-user-'
+
str
(
user
.
id
)
return
"generate_result_stopped:{}-{}"
.
format
(
user_str
,
task_id
)
return
"generate_result_stopped:{}-{}"
.
format
(
user_str
,
task_id
)
def
pub_text
(
self
,
text
:
str
):
def
pub_text
(
self
,
text
:
str
):
...
@@ -306,10 +306,10 @@ class PubHandler:
...
@@ -306,10 +306,10 @@ class PubHandler:
'event'
:
'message'
,
'event'
:
'message'
,
'data'
:
{
'data'
:
{
'task_id'
:
self
.
_task_id
,
'task_id'
:
self
.
_task_id
,
'message_id'
:
s
elf
.
_message
.
id
,
'message_id'
:
s
tr
(
self
.
_message
.
id
)
,
'text'
:
text
,
'text'
:
text
,
'mode'
:
self
.
_conversation
.
mode
,
'mode'
:
self
.
_conversation
.
mode
,
'conversation_id'
:
s
elf
.
_conversation
.
id
'conversation_id'
:
s
tr
(
self
.
_conversation
.
id
)
}
}
}
}
...
...
api/core/data_loader/file_extractor.py
0 → 100644
View file @
eea011bd
import
tempfile
from
pathlib
import
Path
from
typing
import
List
,
Union
from
langchain.document_loaders
import
TextLoader
,
Docx2txtLoader
from
langchain.schema
import
Document
from
core.data_loader.loader.csv
import
CSVLoader
from
core.data_loader.loader.excel
import
ExcelLoader
from
core.data_loader.loader.html
import
HTMLLoader
from
core.data_loader.loader.markdown
import
MarkdownLoader
from
core.data_loader.loader.pdf
import
PdfLoader
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
class
FileExtractor
:
@
classmethod
def
load
(
cls
,
upload_file
:
UploadFile
,
return_text
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
upload_file
.
key
)
.
suffix
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage
.
download
(
upload_file
.
key
,
file_path
)
input_file
=
Path
(
file_path
)
delimiter
=
'
\n
'
if
input_file
.
suffix
==
'.xlsx'
:
loader
=
ExcelLoader
(
file_path
)
elif
input_file
.
suffix
==
'.pdf'
:
loader
=
PdfLoader
(
file_path
,
upload_file
=
upload_file
)
elif
input_file
.
suffix
in
[
'.md'
,
'.markdown'
]:
loader
=
MarkdownLoader
(
file_path
,
autodetect_encoding
=
True
)
elif
input_file
.
suffix
in
[
'.htm'
,
'.html'
]:
loader
=
HTMLLoader
(
file_path
)
elif
input_file
.
suffix
==
'.docx'
:
loader
=
Docx2txtLoader
(
file_path
)
elif
input_file
.
suffix
==
'.csv'
:
loader
=
CSVLoader
(
file_path
,
autodetect_encoding
=
True
)
else
:
# txt
loader
=
TextLoader
(
file_path
,
autodetect_encoding
=
True
)
return
delimiter
.
join
([
document
.
page_content
for
document
in
loader
.
load
()])
if
return_text
else
loader
.
load
()
api/core/data_loader/loader/csv.py
0 → 100644
View file @
eea011bd
import
logging
from
typing
import
Optional
,
Dict
,
List
from
langchain.document_loaders
import
CSVLoader
as
LCCSVLoader
from
langchain.document_loaders.helpers
import
detect_file_encodings
from
models.dataset
import
Document
logger
=
logging
.
getLogger
(
__name__
)
class
CSVLoader
(
LCCSVLoader
):
def
__init__
(
self
,
file_path
:
str
,
source_column
:
Optional
[
str
]
=
None
,
csv_args
:
Optional
[
Dict
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
autodetect_encoding
:
bool
=
True
,
):
self
.
file_path
=
file_path
self
.
source_column
=
source_column
self
.
encoding
=
encoding
self
.
csv_args
=
csv_args
or
{}
self
.
autodetect_encoding
=
autodetect_encoding
def
load
(
self
)
->
List
[
Document
]:
"""Load data into document objects."""
try
:
with
open
(
self
.
file_path
,
newline
=
""
,
encoding
=
self
.
encoding
)
as
csvfile
:
docs
=
self
.
_read_from_file
(
csvfile
)
except
UnicodeDecodeError
as
e
:
if
self
.
autodetect_encoding
:
detected_encodings
=
detect_file_encodings
(
self
.
file_path
)
for
encoding
in
detected_encodings
:
logger
.
debug
(
"Trying encoding: "
,
encoding
.
encoding
)
try
:
with
open
(
self
.
file_path
,
newline
=
""
,
encoding
=
encoding
.
encoding
)
as
csvfile
:
docs
=
self
.
_read_from_file
(
csvfile
)
break
except
UnicodeDecodeError
:
continue
else
:
raise
RuntimeError
(
f
"Error loading {self.file_path}"
)
from
e
return
docs
def
_read_from_file
(
self
,
csvfile
):
docs
=
[]
csv_reader
=
csv
.
DictReader
(
csvfile
,
**
self
.
csv_args
)
# type: ignore
for
i
,
row
in
enumerate
(
csv_reader
):
content
=
"
\n
"
.
join
(
f
"{k.strip()}: {v.strip()}"
for
k
,
v
in
row
.
items
())
try
:
source
=
(
row
[
self
.
source_column
]
if
self
.
source_column
is
not
None
else
''
)
except
KeyError
:
raise
ValueError
(
f
"Source column '{self.source_column}' not found in CSV file."
)
metadata
=
{
"source"
:
source
,
"row"
:
i
}
doc
=
Document
(
page_content
=
content
,
metadata
=
metadata
)
docs
.
append
(
doc
)
return
docs
api/core/data_loader/loader/excel.py
0 → 100644
View file @
eea011bd
import
json
import
logging
from
typing
import
List
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
from
openpyxl.reader.excel
import
load_workbook
logger
=
logging
.
getLogger
(
__name__
)
class
ExcelLoader
(
BaseLoader
):
"""Load xlxs files.
Args:
file_path: Path to the file to load.
"""
def
__init__
(
self
,
file_path
:
str
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
def
load
(
self
)
->
List
[
Document
]:
data
=
[]
keys
=
[]
wb
=
load_workbook
(
filename
=
self
.
_file_path
,
read_only
=
True
)
# loop over all sheets
for
sheet
in
wb
:
for
row
in
sheet
.
iter_rows
(
values_only
=
True
):
if
all
(
v
is
None
for
v
in
row
):
continue
if
keys
==
[]:
keys
=
list
(
map
(
str
,
row
))
else
:
row_dict
=
dict
(
zip
(
keys
,
row
))
row_dict
=
{
k
:
v
for
k
,
v
in
row_dict
.
items
()
if
v
}
data
.
append
(
json
.
dumps
(
row_dict
,
ensure_ascii
=
False
))
return
[
Document
(
page_content
=
'
\n\n
'
.
join
(
data
))]
api/core/data_loader/loader/html.py
0 → 100644
View file @
eea011bd
import
logging
from
typing
import
List
from
bs4
import
BeautifulSoup
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
logger
=
logging
.
getLogger
(
__name__
)
class
HTMLLoader
(
BaseLoader
):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def
__init__
(
self
,
file_path
:
str
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
def
load
(
self
)
->
List
[
Document
]:
return
[
Document
(
page_content
=
self
.
_load_as_text
())]
def
_load_as_text
(
self
)
->
str
:
with
open
(
self
.
_file_path
,
"rb"
)
as
fp
:
soup
=
BeautifulSoup
(
fp
,
'html.parser'
)
text
=
soup
.
get_text
()
text
=
text
.
strip
()
if
text
else
''
return
text
api/core/
index/readers/markdown_parser
.py
→
api/core/
data_loader/loader/markdown
.py
View file @
eea011bd
"""Markdown parser.
import
logging
import
re
from
typing
import
Optional
,
List
,
Tuple
,
cast
Contains parser for md files.
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.document_loaders.helpers
import
detect_file_encodings
from
langchain.schema
import
Document
logger
=
logging
.
getLogger
(
__name__
)
class
MarkdownLoader
(
BaseLoader
):
"""Load md files.
"""
import
re
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
cast
from
llama_index.readers.file.base_parser
import
BaseParser
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
class
MarkdownParser
(
BaseParser
):
remove_images: Whether to remove images from the text.
"""Markdown parser.
Extract text from markdown files.
encoding: File encoding to use. If `None`, the file will be loaded
Returns dictionary with keys as headers and values as the text between headers
.
with the default system encoding
.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
*
args
:
Any
,
file_path
:
str
,
remove_hyperlinks
:
bool
=
True
,
remove_hyperlinks
:
bool
=
True
,
remove_images
:
bool
=
True
,
remove_images
:
bool
=
True
,
**
kwargs
:
Any
,
encoding
:
Optional
[
str
]
=
None
,
)
->
None
:
autodetect_encoding
:
bool
=
True
,
"""Init params."""
):
super
()
.
__init__
(
*
args
,
**
kwargs
)
"""Initialize with file path."""
self
.
_file_path
=
file_path
self
.
_remove_hyperlinks
=
remove_hyperlinks
self
.
_remove_hyperlinks
=
remove_hyperlinks
self
.
_remove_images
=
remove_images
self
.
_remove_images
=
remove_images
self
.
_encoding
=
encoding
self
.
_autodetect_encoding
=
autodetect_encoding
def
load
(
self
)
->
List
[
Document
]:
tups
=
self
.
parse_tups
(
self
.
_file_path
)
documents
=
[]
for
header
,
value
in
tups
:
value
=
value
.
strip
()
if
header
is
None
:
documents
.
append
(
Document
(
page_content
=
value
))
else
:
documents
.
append
(
Document
(
page_content
=
f
"
\n\n
{header}
\n
{value}"
))
return
documents
def
markdown_to_tups
(
self
,
markdown_text
:
str
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
def
markdown_to_tups
(
self
,
markdown_text
:
str
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
"""Convert a markdown file to a dictionary.
"""Convert a markdown file to a dictionary.
...
@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser):
...
@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser):
content
=
re
.
sub
(
pattern
,
r"\1"
,
content
)
content
=
re
.
sub
(
pattern
,
r"\1"
,
content
)
return
content
return
content
def
_init_parser
(
self
)
->
Dict
:
def
parse_tups
(
self
,
filepath
:
str
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
"""Initialize the parser with the config."""
return
{}
def
parse_tups
(
self
,
filepath
:
Path
,
errors
:
str
=
"ignore"
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
"""Parse file into tuples."""
"""Parse file into tuples."""
with
open
(
filepath
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
content
=
""
content
=
f
.
read
()
try
:
with
open
(
filepath
,
"r"
,
encoding
=
self
.
_encoding
)
as
f
:
content
=
f
.
read
()
except
UnicodeDecodeError
as
e
:
if
self
.
_autodetect_encoding
:
detected_encodings
=
detect_file_encodings
(
filepath
)
for
encoding
in
detected_encodings
:
logger
.
debug
(
"Trying encoding: "
,
encoding
.
encoding
)
try
:
with
open
(
filepath
,
encoding
=
encoding
.
encoding
)
as
f
:
content
=
f
.
read
()
break
except
UnicodeDecodeError
:
continue
else
:
raise
RuntimeError
(
f
"Error loading {filepath}"
)
from
e
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error loading {filepath}"
)
from
e
if
self
.
_remove_hyperlinks
:
if
self
.
_remove_hyperlinks
:
content
=
self
.
remove_hyperlinks
(
content
)
content
=
self
.
remove_hyperlinks
(
content
)
if
self
.
_remove_images
:
if
self
.
_remove_images
:
content
=
self
.
remove_images
(
content
)
content
=
self
.
remove_images
(
content
)
markdown_tups
=
self
.
markdown_to_tups
(
content
)
return
markdown_tups
def
parse_file
(
return
self
.
markdown_to_tups
(
content
)
self
,
filepath
:
Path
,
errors
:
str
=
"ignore"
)
->
Union
[
str
,
List
[
str
]]:
"""Parse file into string."""
tups
=
self
.
parse_tups
(
filepath
,
errors
=
errors
)
results
=
[]
# TODO: don't include headers right now
for
header
,
value
in
tups
:
if
header
is
None
:
results
.
append
(
value
)
else
:
results
.
append
(
f
"
\n\n
{header}
\n
{value}"
)
return
results
api/core/data_
source
/notion.py
→
api/core/data_
loader/loader
/notion.py
View file @
eea011bd
"""Notion reader."""
import
json
import
json
import
logging
import
logging
import
os
from
typing
import
List
,
Dict
,
Any
,
Optional
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
Optional
import
requests
# type: ignore
import
requests
from
flask
import
current_app
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
from
llama_index.readers.base
import
BaseReader
from
extensions.ext_database
import
db
from
llama_index.readers.schema.base
import
Document
from
models.dataset
import
Document
as
DocumentModel
from
models.source
import
DataSourceBinding
logger
=
logging
.
getLogger
(
__name__
)
INTEGRATION_TOKEN_NAME
=
"NOTION_INTEGRATION_TOKEN"
BLOCK_CHILD_URL_TMPL
=
"https://api.notion.com/v1/blocks/{block_id}/children"
BLOCK_CHILD_URL_TMPL
=
"https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL
=
"https://api.notion.com/v1/databases/{database_id}/query"
DATABASE_URL_TMPL
=
"https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL
=
"https://api.notion.com/v1/search"
SEARCH_URL
=
"https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL
=
"https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_PAGE_URL_TMPL
=
"https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL
=
"https://api.notion.com/v1/databases/{database_id}"
RETRIEVE_DATABASE_URL_TMPL
=
"https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE
=
[
'heading_1'
,
'heading_2'
,
'heading_3'
]
HEADING_TYPE
=
[
'heading_1'
,
'heading_2'
,
'heading_3'
]
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Notion DB reader coming soon!
class
NotionPageReader
(
BaseReader
):
"""Notion Page reader.
Reads a set of Notion pages.
class
NotionLoader
(
BaseLoader
):
def
__init__
(
Args:
self
,
integration_token (str): Notion integration token.
notion_access_token
:
str
,
notion_workspace_id
:
str
,
"""
notion_obj_id
:
str
,
notion_page_type
:
str
,
def
__init__
(
self
,
integration_token
:
Optional
[
str
]
=
None
)
->
None
:
document_model
:
Optional
[
DocumentModel
]
=
None
"""Initialize with parameters."""
):
if
integration_token
is
None
:
self
.
_document_model
=
document_model
integration_token
=
os
.
getenv
(
INTEGRATION_TOKEN_NAME
)
self
.
_notion_workspace_id
=
notion_workspace_id
self
.
_notion_obj_id
=
notion_obj_id
self
.
_notion_page_type
=
notion_page_type
self
.
_notion_access_token
=
notion_access_token
if
not
self
.
_notion_access_token
:
integration_token
=
current_app
.
config
.
get
(
'NOTION_INTEGRATION_TOKEN'
)
if
integration_token
is
None
:
if
integration_token
is
None
:
raise
ValueError
(
raise
ValueError
(
"Must specify `integration_token` or set environment "
"Must specify `integration_token` or set environment "
"variable `NOTION_INTEGRATION_TOKEN`."
"variable `NOTION_INTEGRATION_TOKEN`."
)
)
self
.
token
=
integration_token
self
.
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
}
def
_read_block
(
self
,
block_id
:
str
,
num_tabs
:
int
=
0
)
->
str
:
self
.
_notion_access_token
=
integration_token
"""Read a block."""
done
=
False
@
classmethod
def
from_document
(
cls
,
document_model
:
DocumentModel
):
data_source_info
=
document_model
.
data_source_info_dict
if
not
data_source_info
or
'notion_page_id'
not
in
data_source_info
\
or
'notion_workspace_id'
not
in
data_source_info
:
raise
ValueError
(
"no notion page found"
)
notion_workspace_id
=
data_source_info
[
'notion_workspace_id'
]
notion_obj_id
=
data_source_info
[
'notion_page_id'
]
notion_page_type
=
data_source_info
[
'type'
]
notion_access_token
=
cls
.
_get_access_token
(
document_model
.
tenant_id
,
notion_workspace_id
)
return
cls
(
notion_access_token
=
notion_access_token
,
notion_workspace_id
=
notion_workspace_id
,
notion_obj_id
=
notion_obj_id
,
notion_page_type
=
notion_page_type
,
document_model
=
document_model
)
def
load
(
self
)
->
List
[
Document
]:
self
.
update_last_edited_time
(
self
.
_document_model
)
text_docs
=
self
.
_load_data_as_documents
(
self
.
_notion_obj_id
,
self
.
_notion_page_type
)
return
text_docs
def
_load_data_as_documents
(
self
,
notion_obj_id
:
str
,
notion_page_type
:
str
)
->
List
[
Document
]:
docs
=
[]
if
notion_page_type
==
'database'
:
# get all the pages in the database
page_text
=
self
.
_get_notion_database_data
(
notion_obj_id
)
docs
.
append
(
Document
(
page_content
=
page_text
))
elif
notion_page_type
==
'page'
:
page_text_list
=
self
.
_get_notion_block_data
(
notion_obj_id
)
for
page_text
in
page_text_list
:
docs
.
append
(
Document
(
page_content
=
page_text
))
else
:
raise
ValueError
(
"notion page type not supported"
)
return
docs
def
_get_notion_database_data
(
self
,
database_id
:
str
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
)
->
str
:
"""Get all the pages from a Notion database."""
res
=
requests
.
post
(
DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
),
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
,
)
data
=
res
.
json
()
database_content_list
=
[]
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
return
""
for
result
in
data
[
"results"
]:
properties
=
result
[
'properties'
]
data
=
{}
for
property_name
,
property_value
in
properties
.
items
():
type
=
property_value
[
'type'
]
if
type
==
'multi_select'
:
value
=
[]
multi_select_list
=
property_value
[
type
]
for
multi_select
in
multi_select_list
:
value
.
append
(
multi_select
[
'name'
])
elif
type
==
'rich_text'
or
type
==
'title'
:
if
len
(
property_value
[
type
])
>
0
:
value
=
property_value
[
type
][
0
][
'plain_text'
]
else
:
value
=
''
elif
type
==
'select'
or
type
==
'status'
:
if
property_value
[
type
]:
value
=
property_value
[
type
][
'name'
]
else
:
value
=
''
else
:
value
=
property_value
[
type
]
data
[
property_name
]
=
value
database_content_list
.
append
(
json
.
dumps
(
data
,
ensure_ascii
=
False
))
return
"
\n\n
"
.
join
(
database_content_list
)
def
_get_notion_block_data
(
self
,
page_id
:
str
)
->
List
[
str
]:
result_lines_arr
=
[]
result_lines_arr
=
[]
cur_block_id
=
block
_id
cur_block_id
=
page
_id
while
not
don
e
:
while
Tru
e
:
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
self
.
headers
,
json
=
query_dict
"GET"
,
block_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
)
data
=
res
.
json
()
data
=
res
.
json
()
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
# current block's heading
done
=
True
break
heading
=
''
heading
=
''
for
result
in
data
[
"results"
]:
for
result
in
data
[
"results"
]:
result_type
=
result
[
"type"
]
result_type
=
result
[
"type"
]
...
@@ -71,6 +165,7 @@ class NotionPageReader(BaseReader):
...
@@ -71,6 +165,7 @@ class NotionPageReader(BaseReader):
if
result_type
==
'table'
:
if
result_type
==
'table'
:
result_block_id
=
result
[
"id"
]
result_block_id
=
result
[
"id"
]
text
=
self
.
_read_table_rows
(
result_block_id
)
text
=
self
.
_read_table_rows
(
result_block_id
)
text
+=
"
\n\n
"
result_lines_arr
.
append
(
text
)
result_lines_arr
.
append
(
text
)
else
:
else
:
if
"rich_text"
in
result_obj
:
if
"rich_text"
in
result_obj
:
...
@@ -78,91 +173,53 @@ class NotionPageReader(BaseReader):
...
@@ -78,91 +173,53 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object
# skip if doesn't have text object
if
"text"
in
rich_text
:
if
"text"
in
rich_text
:
text
=
rich_text
[
"text"
][
"content"
]
text
=
rich_text
[
"text"
][
"content"
]
prefix
=
"
\t
"
*
num_tabs
cur_result_text_arr
.
append
(
text
)
cur_result_text_arr
.
append
(
prefix
+
text
)
if
result_type
in
HEADING_TYPE
:
if
result_type
in
HEADING_TYPE
:
heading
=
text
heading
=
text
result_block_id
=
result
[
"id"
]
result_block_id
=
result
[
"id"
]
has_children
=
result
[
"has_children"
]
has_children
=
result
[
"has_children"
]
block_type
=
result
[
"type"
]
block_type
=
result
[
"type"
]
if
has_children
and
block_type
!=
'child_page'
:
if
has_children
and
block_type
!=
'child_page'
:
children_text
=
self
.
_read_block
(
children_text
=
self
.
_read_block
(
result_block_id
,
num_tabs
=
num_tabs
+
1
result_block_id
,
num_tabs
=
1
)
)
cur_result_text_arr
.
append
(
children_text
)
cur_result_text_arr
.
append
(
children_text
)
cur_result_text
=
"
\n
"
.
join
(
cur_result_text_arr
)
cur_result_text
=
"
\n
"
.
join
(
cur_result_text_arr
)
cur_result_text
+=
"
\n\n
"
if
result_type
in
HEADING_TYPE
:
if
result_type
in
HEADING_TYPE
:
result_lines_arr
.
append
(
cur_result_text
)
result_lines_arr
.
append
(
cur_result_text
)
else
:
else
:
result_lines_arr
.
append
(
f
'{heading}
\n
{cur_result_text}'
)
result_lines_arr
.
append
(
f
'{heading}
\n
{cur_result_text}'
)
if
data
[
"next_cursor"
]
is
None
:
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
else
:
cur_block_id
=
data
[
"next_cursor"
]
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
def
_read_table_rows
(
self
,
block_id
:
str
)
->
str
:
"""Read table rows."""
done
=
False
result_lines_arr
=
[]
cur_block_id
=
block_id
while
not
done
:
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
self
.
headers
,
json
=
query_dict
)
data
=
res
.
json
()
# get table headers text
table_header_cell_texts
=
[]
tabel_header_cells
=
data
[
"results"
][
0
][
'table_row'
][
'cells'
]
for
tabel_header_cell
in
tabel_header_cells
:
if
tabel_header_cell
:
for
table_header_cell_text
in
tabel_header_cell
:
text
=
table_header_cell_text
[
"text"
][
"content"
]
table_header_cell_texts
.
append
(
text
)
# get table columns text and format
results
=
data
[
"results"
]
for
i
in
range
(
len
(
results
)
-
1
):
column_texts
=
[]
tabel_column_cells
=
data
[
"results"
][
i
+
1
][
'table_row'
][
'cells'
]
for
j
in
range
(
len
(
tabel_column_cells
)):
if
tabel_column_cells
[
j
]:
for
table_column_cell_text
in
tabel_column_cells
[
j
]:
column_text
=
table_column_cell_text
[
"text"
][
"content"
]
column_texts
.
append
(
f
'{table_header_cell_texts[j]}:{column_text}'
)
cur_result_text
=
"
\n
"
.
join
(
column_texts
)
result_lines_arr
.
append
(
cur_result_text
)
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
break
else
:
else
:
cur_block_id
=
data
[
"next_cursor"
]
cur_block_id
=
data
[
"next_cursor"
]
return
result_lines_arr
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
def
_read_block
(
self
,
block_id
:
str
,
num_tabs
:
int
=
0
)
->
str
:
return
result_lines
def
_read_parent_blocks
(
self
,
block_id
:
str
,
num_tabs
:
int
=
0
)
->
List
[
str
]:
"""Read a block."""
"""Read a block."""
done
=
False
result_lines_arr
=
[]
result_lines_arr
=
[]
cur_block_id
=
block_id
cur_block_id
=
block_id
while
not
don
e
:
while
Tru
e
:
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
self
.
headers
,
json
=
query_dict
"GET"
,
block_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
)
data
=
res
.
json
()
data
=
res
.
json
()
# current block's heading
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
break
heading
=
''
heading
=
''
for
result
in
data
[
"results"
]:
for
result
in
data
[
"results"
]:
result_type
=
result
[
"type"
]
result_type
=
result
[
"type"
]
...
@@ -171,7 +228,6 @@ class NotionPageReader(BaseReader):
...
@@ -171,7 +228,6 @@ class NotionPageReader(BaseReader):
if
result_type
==
'table'
:
if
result_type
==
'table'
:
result_block_id
=
result
[
"id"
]
result_block_id
=
result
[
"id"
]
text
=
self
.
_read_table_rows
(
result_block_id
)
text
=
self
.
_read_table_rows
(
result_block_id
)
text
+=
"
\n\n
"
result_lines_arr
.
append
(
text
)
result_lines_arr
.
append
(
text
)
else
:
else
:
if
"rich_text"
in
result_obj
:
if
"rich_text"
in
result_obj
:
...
@@ -179,10 +235,10 @@ class NotionPageReader(BaseReader):
...
@@ -179,10 +235,10 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object
# skip if doesn't have text object
if
"text"
in
rich_text
:
if
"text"
in
rich_text
:
text
=
rich_text
[
"text"
][
"content"
]
text
=
rich_text
[
"text"
][
"content"
]
cur_result_text_arr
.
append
(
text
)
prefix
=
"
\t
"
*
num_tabs
cur_result_text_arr
.
append
(
prefix
+
text
)
if
result_type
in
HEADING_TYPE
:
if
result_type
in
HEADING_TYPE
:
heading
=
text
heading
=
text
result_block_id
=
result
[
"id"
]
result_block_id
=
result
[
"id"
]
has_children
=
result
[
"has_children"
]
has_children
=
result
[
"has_children"
]
block_type
=
result
[
"type"
]
block_type
=
result
[
"type"
]
...
@@ -193,177 +249,121 @@ class NotionPageReader(BaseReader):
...
@@ -193,177 +249,121 @@ class NotionPageReader(BaseReader):
cur_result_text_arr
.
append
(
children_text
)
cur_result_text_arr
.
append
(
children_text
)
cur_result_text
=
"
\n
"
.
join
(
cur_result_text_arr
)
cur_result_text
=
"
\n
"
.
join
(
cur_result_text_arr
)
cur_result_text
+=
"
\n\n
"
if
result_type
in
HEADING_TYPE
:
if
result_type
in
HEADING_TYPE
:
result_lines_arr
.
append
(
cur_result_text
)
result_lines_arr
.
append
(
cur_result_text
)
else
:
else
:
result_lines_arr
.
append
(
f
'{heading}
\n
{cur_result_text}'
)
result_lines_arr
.
append
(
f
'{heading}
\n
{cur_result_text}'
)
if
data
[
"next_cursor"
]
is
None
:
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
break
else
:
else
:
cur_block_id
=
data
[
"next_cursor"
]
cur_block_id
=
data
[
"next_cursor"
]
return
result_lines_arr
def
read_page
(
self
,
page_id
:
str
)
->
str
:
"""Read a page."""
return
self
.
_read_block
(
page_id
)
def
read_page_as_documents
(
self
,
page_id
:
str
)
->
List
[
str
]:
"""Read a page as documents."""
return
self
.
_read_parent_blocks
(
page_id
)
def
query_database_data
(
self
,
database_id
:
str
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
)
->
str
:
"""Get all the pages from a Notion database."""
res
=
requests
.
post
\
(
DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
),
headers
=
self
.
headers
,
json
=
query_dict
,
)
data
=
res
.
json
()
database_content_list
=
[]
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
return
""
for
result
in
data
[
"results"
]:
properties
=
result
[
'properties'
]
data
=
{}
for
property_name
,
property_value
in
properties
.
items
():
type
=
property_value
[
'type'
]
if
type
==
'multi_select'
:
value
=
[]
multi_select_list
=
property_value
[
type
]
for
multi_select
in
multi_select_list
:
value
.
append
(
multi_select
[
'name'
])
elif
type
==
'rich_text'
or
type
==
'title'
:
if
len
(
property_value
[
type
])
>
0
:
value
=
property_value
[
type
][
0
][
'plain_text'
]
else
:
value
=
''
elif
type
==
'select'
or
type
==
'status'
:
if
property_value
[
type
]:
value
=
property_value
[
type
][
'name'
]
else
:
value
=
''
else
:
value
=
property_value
[
type
]
data
[
property_name
]
=
value
database_content_list
.
append
(
json
.
dumps
(
data
,
ensure_ascii
=
False
))
return
"
\n\n
"
.
join
(
database_content_list
)
def
query_database
(
self
,
database_id
:
str
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
)
->
List
[
str
]:
"""Get all the pages from a Notion database."""
res
=
requests
.
post
\
(
DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
),
headers
=
self
.
headers
,
json
=
query_dict
,
)
data
=
res
.
json
()
page_ids
=
[]
for
result
in
data
[
"results"
]:
page_id
=
result
[
"id"
]
page_ids
.
append
(
page_id
)
return
page_ids
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
def
search
(
self
,
query
:
str
)
->
List
[
str
]
:
def
_read_table_rows
(
self
,
block_id
:
str
)
->
str
:
"""
Search Notion page given a text query
."""
"""
Read table rows
."""
done
=
False
done
=
False
next_cursor
:
Optional
[
str
]
=
None
result_lines_arr
=
[]
page_ids
=
[]
cur_block_id
=
block_id
while
not
done
:
while
not
done
:
query_dict
=
{
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
"query"
:
query
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
}
if
next_cursor
is
not
None
:
res
=
requests
.
request
(
query_dict
[
"start_cursor"
]
=
next_cursor
"GET"
,
res
=
requests
.
post
(
SEARCH_URL
,
headers
=
self
.
headers
,
json
=
query_dict
)
block_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
data
=
res
.
json
()
data
=
res
.
json
()
for
result
in
data
[
"results"
]:
# get table headers text
page_id
=
result
[
"id"
]
table_header_cell_texts
=
[]
page_ids
.
append
(
page_id
)
tabel_header_cells
=
data
[
"results"
][
0
][
'table_row'
][
'cells'
]
for
tabel_header_cell
in
tabel_header_cells
:
if
tabel_header_cell
:
for
table_header_cell_text
in
tabel_header_cell
:
text
=
table_header_cell_text
[
"text"
][
"content"
]
table_header_cell_texts
.
append
(
text
)
# get table columns text and format
results
=
data
[
"results"
]
for
i
in
range
(
len
(
results
)
-
1
):
column_texts
=
[]
tabel_column_cells
=
data
[
"results"
][
i
+
1
][
'table_row'
][
'cells'
]
for
j
in
range
(
len
(
tabel_column_cells
)):
if
tabel_column_cells
[
j
]:
for
table_column_cell_text
in
tabel_column_cells
[
j
]:
column_text
=
table_column_cell_text
[
"text"
][
"content"
]
column_texts
.
append
(
f
'{table_header_cell_texts[j]}:{column_text}'
)
cur_result_text
=
"
\n
"
.
join
(
column_texts
)
result_lines_arr
.
append
(
cur_result_text
)
if
data
[
"next_cursor"
]
is
None
:
if
data
[
"next_cursor"
]
is
None
:
done
=
True
done
=
True
break
break
else
:
else
:
next_cursor
=
data
[
"next_cursor"
]
cur_block_id
=
data
[
"next_cursor"
]
return
page_ids
def
load_data
(
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
self
,
page_ids
:
List
[
str
]
=
[],
database_id
:
Optional
[
str
]
=
None
return
result_lines
)
->
List
[
Document
]:
"""Load data from the input directory.
Args:
def
update_last_edited_time
(
self
,
document_model
:
DocumentModel
):
page_ids (List[str]): List of page ids to load.
if
not
document_model
:
return
Returns:
last_edited_time
=
self
.
get_notion_last_edited_time
()
List[Document]: List of documents.
data_source_info
=
document_model
.
data_source_info_dict
data_source_info
[
'last_edited_time'
]
=
last_edited_time
update_params
=
{
DocumentModel
.
data_source_info
:
json
.
dumps
(
data_source_info
)
}
"""
DocumentModel
.
query
.
filter_by
(
id
=
document_model
.
id
)
.
update
(
update_params
)
if
not
page_ids
and
not
database_id
:
db
.
session
.
commit
()
raise
ValueError
(
"Must specify either `page_ids` or `database_id`."
)
docs
=
[]
if
database_id
is
not
None
:
# get all the pages in the database
page_ids
=
self
.
query_database
(
database_id
)
for
page_id
in
page_ids
:
page_text
=
self
.
read_page
(
page_id
)
docs
.
append
(
Document
(
page_text
))
else
:
for
page_id
in
page_ids
:
page_text
=
self
.
read_page
(
page_id
)
docs
.
append
(
Document
(
page_text
))
return
docs
def
get_notion_last_edited_time
(
self
)
->
str
:
obj_id
=
self
.
_notion_obj_id
def
load_data_as_documents
(
page_type
=
self
.
_notion_page_type
self
,
page_ids
:
List
[
str
]
=
[],
database_id
:
Optional
[
str
]
=
None
if
page_type
==
'database'
:
)
->
List
[
Document
]:
retrieve_page_url
=
RETRIEVE_DATABASE_URL_TMPL
.
format
(
database_id
=
obj_id
)
if
not
page_ids
and
not
database_id
:
raise
ValueError
(
"Must specify either `page_ids` or `database_id`."
)
docs
=
[]
if
database_id
is
not
None
:
# get all the pages in the database
page_text
=
self
.
query_database_data
(
database_id
)
docs
.
append
(
Document
(
page_text
))
else
:
else
:
for
page_id
in
page_ids
:
retrieve_page_url
=
RETRIEVE_PAGE_URL_TMPL
.
format
(
page_id
=
obj_id
)
page_text_list
=
self
.
read_page_as_documents
(
page_id
)
for
page_text
in
page_text_list
:
docs
.
append
(
Document
(
page_text
))
return
docs
def
get_page_last_edited_time
(
self
,
page_id
:
str
)
->
str
:
retrieve_page_url
=
RETRIEVE_PAGE_URL_TMPL
.
format
(
page_id
=
page_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
res
=
requests
.
request
(
"GET"
,
retrieve_page_url
,
headers
=
self
.
headers
,
json
=
query_dict
"GET"
,
retrieve_page_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
)
data
=
res
.
json
()
return
data
[
"last_edited_time"
]
def
get_database_last_edited_time
(
self
,
database_id
:
str
)
->
str
:
retrieve_page_url
=
RETRIEVE_DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
retrieve_page_url
,
headers
=
self
.
headers
,
json
=
query_dict
)
data
=
res
.
json
()
data
=
res
.
json
()
return
data
[
"last_edited_time"
]
return
data
[
"last_edited_time"
]
@
classmethod
def
_get_access_token
(
cls
,
tenant_id
:
str
,
notion_workspace_id
:
str
)
->
str
:
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
db
.
and_
(
DataSourceBinding
.
tenant_id
==
tenant_id
,
DataSourceBinding
.
provider
==
'notion'
,
DataSourceBinding
.
disabled
==
False
,
DataSourceBinding
.
source_info
[
'workspace_id'
]
==
f
'"{notion_workspace_id}"'
)
)
.
first
()
if
not
data_source_binding
:
raise
Exception
(
f
'No notion data source binding found for tenant {tenant_id} '
f
'and notion workspace {notion_workspace_id}'
)
if
__name__
==
"__main__"
:
return
data_source_binding
.
access_token
reader
=
NotionPageReader
()
logger
.
info
(
reader
.
search
(
"What I"
))
api/core/data_loader/loader/pdf.py
0 → 100644
View file @
eea011bd
import
logging
from
typing
import
List
,
Optional
from
langchain.document_loaders
import
PyPDFium2Loader
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
logger
=
logging
.
getLogger
(
__name__
)
class
PdfLoader
(
BaseLoader
):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def
__init__
(
self
,
file_path
:
str
,
upload_file
:
Optional
[
UploadFile
]
=
None
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
self
.
_upload_file
=
upload_file
def
load
(
self
)
->
List
[
Document
]:
plaintext_file_key
=
''
plaintext_file_exists
=
False
if
self
.
_upload_file
:
if
self
.
_upload_file
.
hash
:
plaintext_file_key
=
'upload_files/'
+
self
.
_upload_file
.
tenant_id
+
'/'
\
+
self
.
_upload_file
.
hash
+
'.0625.plaintext'
try
:
text
=
storage
.
load
(
plaintext_file_key
)
.
decode
(
'utf-8'
)
plaintext_file_exists
=
True
return
[
Document
(
page_content
=
text
)]
except
FileNotFoundError
:
pass
documents
=
PyPDFium2Loader
(
file_path
=
self
.
_file_path
)
.
load
()
text_list
=
[]
for
document
in
documents
:
text_list
.
append
(
document
.
page_content
)
text
=
"
\n\n
"
.
join
(
text_list
)
# save plaintext file for caching
if
not
plaintext_file_exists
and
plaintext_file_key
:
storage
.
save
(
plaintext_file_key
,
text
.
encode
(
'utf-8'
))
return
documents
api/core/docstore/dataset_docstore.py
View file @
eea011bd
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
import
tiktoken
from
langchain.schema
import
Document
from
llama_index.data_structs
import
Node
from
llama_index.docstore.types
import
BaseDocumentStore
from
llama_index.docstore.utils
import
json_to_doc
from
llama_index.schema
import
BaseDocument
from
sqlalchemy
import
func
from
sqlalchemy
import
func
from
core.llm.token_calculator
import
TokenCalculator
from
core.llm.token_calculator
import
TokenCalculator
...
@@ -12,7 +8,7 @@ from extensions.ext_database import db
...
@@ -12,7 +8,7 @@ from extensions.ext_database import db
from
models.dataset
import
Dataset
,
DocumentSegment
from
models.dataset
import
Dataset
,
DocumentSegment
class
DatesetDocumentStore
(
BaseDocumentStore
)
:
class
DatesetDocumentStore
:
def
__init__
(
def
__init__
(
self
,
self
,
dataset
:
Dataset
,
dataset
:
Dataset
,
...
@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
...
@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
return
self
.
_embedding_model_name
return
self
.
_embedding_model_name
@
property
@
property
def
docs
(
self
)
->
Dict
[
str
,
Base
Document
]:
def
docs
(
self
)
->
Dict
[
str
,
Document
]:
document_segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
document_segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
)
.
all
()
)
.
all
()
...
@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
...
@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
output
=
{}
output
=
{}
for
document_segment
in
document_segments
:
for
document_segment
in
document_segments
:
doc_id
=
document_segment
.
index_node_id
doc_id
=
document_segment
.
index_node_id
result
=
self
.
segment_to_dict
(
document_segment
)
output
[
doc_id
]
=
Document
(
output
[
doc_id
]
=
json_to_doc
(
result
)
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
return
output
return
output
def
add_documents
(
def
add_documents
(
self
,
docs
:
Sequence
[
Base
Document
],
allow_update
:
bool
=
True
self
,
docs
:
Sequence
[
Document
],
allow_update
:
bool
=
True
)
->
None
:
)
->
None
:
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
DocumentSegment
.
document
==
self
.
_document_id
DocumentSegment
.
document
==
self
.
_document_id
...
@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
...
@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
max_position
=
0
max_position
=
0
for
doc
in
docs
:
for
doc
in
docs
:
if
doc
.
is_doc_id_none
:
if
not
isinstance
(
doc
,
Document
)
:
raise
ValueError
(
"doc
_id not se
t"
)
raise
ValueError
(
"doc
must be a Documen
t"
)
if
not
isinstance
(
doc
,
Node
):
segment_document
=
self
.
get_document
(
doc_id
=
doc
.
metadata
[
'doc_id'
],
raise_error
=
False
)
raise
ValueError
(
"doc must be a Node"
)
segment_document
=
self
.
get_document
(
doc_id
=
doc
.
get_doc_id
(),
raise_error
=
False
)
# NOTE: doc could already exist in the store, but we overwrite it
# NOTE: doc could already exist in the store, but we overwrite it
if
not
allow_update
and
segment_document
:
if
not
allow_update
and
segment_document
:
raise
ValueError
(
raise
ValueError
(
f
"doc_id {doc.
get_doc_id()
} already exists. "
f
"doc_id {doc.
metadata['doc_id']
} already exists. "
"Set allow_update to True to overwrite."
"Set allow_update to True to overwrite."
)
)
# calc embedding use tokens
# calc embedding use tokens
tokens
=
TokenCalculator
.
get_num_tokens
(
self
.
_embedding_model_name
,
doc
.
get_text
()
)
tokens
=
TokenCalculator
.
get_num_tokens
(
self
.
_embedding_model_name
,
doc
.
page_content
)
if
not
segment_document
:
if
not
segment_document
:
max_position
+=
1
max_position
+=
1
...
@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
...
@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
tenant_id
=
self
.
_dataset
.
tenant_id
,
tenant_id
=
self
.
_dataset
.
tenant_id
,
dataset_id
=
self
.
_dataset
.
id
,
dataset_id
=
self
.
_dataset
.
id
,
document_id
=
self
.
_document_id
,
document_id
=
self
.
_document_id
,
index_node_id
=
doc
.
get_doc_id
()
,
index_node_id
=
doc
.
metadata
[
'doc_id'
]
,
index_node_hash
=
doc
.
get_doc_hash
()
,
index_node_hash
=
doc
.
metadata
[
'doc_hash'
]
,
position
=
max_position
,
position
=
max_position
,
content
=
doc
.
get_text
()
,
content
=
doc
.
page_content
,
word_count
=
len
(
doc
.
get_text
()
),
word_count
=
len
(
doc
.
page_content
),
tokens
=
tokens
,
tokens
=
tokens
,
created_by
=
self
.
_user_id
,
created_by
=
self
.
_user_id
,
)
)
db
.
session
.
add
(
segment_document
)
db
.
session
.
add
(
segment_document
)
else
:
else
:
segment_document
.
content
=
doc
.
get_text
()
segment_document
.
content
=
doc
.
page_content
segment_document
.
index_node_hash
=
doc
.
get_doc_hash
()
segment_document
.
index_node_hash
=
doc
.
metadata
[
'doc_hash'
]
segment_document
.
word_count
=
len
(
doc
.
get_text
()
)
segment_document
.
word_count
=
len
(
doc
.
page_content
)
segment_document
.
tokens
=
tokens
segment_document
.
tokens
=
tokens
db
.
session
.
commit
()
db
.
session
.
commit
()
...
@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
...
@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
def
get_document
(
def
get_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
Optional
[
Base
Document
]:
)
->
Optional
[
Document
]:
document_segment
=
self
.
get_document_segment
(
doc_id
)
document_segment
=
self
.
get_document_segment
(
doc_id
)
if
document_segment
is
None
:
if
document_segment
is
None
:
...
@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
...
@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
else
:
else
:
return
None
return
None
result
=
self
.
segment_to_dict
(
document_segment
)
return
Document
(
return
json_to_doc
(
result
)
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
def
delete_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
None
:
def
delete_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
None
:
document_segment
=
self
.
get_document_segment
(
doc_id
)
document_segment
=
self
.
get_document_segment
(
doc_id
)
...
@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
...
@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
return
document_segment
.
index_node_hash
return
document_segment
.
index_node_hash
def
update_docstore
(
self
,
other
:
"BaseDocumentStore"
)
->
None
:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self
.
add_documents
(
list
(
other
.
docs
.
values
()))
def
get_document_segment
(
self
,
doc_id
:
str
)
->
DocumentSegment
:
def
get_document_segment
(
self
,
doc_id
:
str
)
->
DocumentSegment
:
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
,
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
,
...
@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
...
@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
)
.
first
()
)
.
first
()
return
document_segment
return
document_segment
def
segment_to_dict
(
self
,
segment
:
DocumentSegment
)
->
Dict
[
str
,
Any
]:
return
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"text"
:
segment
.
content
,
"__type__"
:
Node
.
get_type
()
}
api/core/docstore/empty_docstore.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
from
llama_index.docstore.types
import
BaseDocumentStore
from
llama_index.schema
import
BaseDocument
class
EmptyDocumentStore
(
BaseDocumentStore
):
@
classmethod
def
from_dict
(
cls
,
config_dict
:
Dict
[
str
,
Any
])
->
"EmptyDocumentStore"
:
return
cls
()
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""Serialize to dict."""
return
{}
@
property
def
docs
(
self
)
->
Dict
[
str
,
BaseDocument
]:
return
{}
def
add_documents
(
self
,
docs
:
Sequence
[
BaseDocument
],
allow_update
:
bool
=
True
)
->
None
:
pass
def
document_exists
(
self
,
doc_id
:
str
)
->
bool
:
"""Check if document exists."""
return
False
def
get_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
Optional
[
BaseDocument
]:
return
None
def
delete_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
None
:
pass
def
set_document_hash
(
self
,
doc_id
:
str
,
doc_hash
:
str
)
->
None
:
"""Set the hash for a given doc_id."""
pass
def
get_document_hash
(
self
,
doc_id
:
str
)
->
Optional
[
str
]:
"""Get the stored hash for a document, if it exists."""
return
None
def
update_docstore
(
self
,
other
:
"BaseDocumentStore"
)
->
None
:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self
.
add_documents
(
list
(
other
.
docs
.
values
()))
api/core/embedding/cached_embedding.py
0 → 100644
View file @
eea011bd
import
logging
from
typing
import
List
from
langchain.embeddings.base
import
Embeddings
from
sqlalchemy.exc
import
IntegrityError
from
extensions.ext_database
import
db
from
libs
import
helper
from
models.dataset
import
Embedding
class
CacheEmbedding
(
Embeddings
):
def
__init__
(
self
,
embeddings
:
Embeddings
):
self
.
_embeddings
=
embeddings
def
embed_documents
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Embed search docs."""
# use doc embedding cache or store if not exists
text_embeddings
=
[]
embedding_queue_texts
=
[]
for
text
in
texts
:
hash
=
helper
.
generate_text_hash
(
text
)
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
hash
)
.
first
()
if
embedding
:
text_embeddings
.
append
(
embedding
.
get_embedding
())
else
:
embedding_queue_texts
.
append
(
text
)
embedding_results
=
self
.
_embeddings
.
embed_documents
(
embedding_queue_texts
)
i
=
0
for
text
in
embedding_queue_texts
:
hash
=
helper
.
generate_text_hash
(
text
)
try
:
embedding
=
Embedding
(
hash
=
hash
)
embedding
.
set_embedding
(
embedding_results
[
i
])
db
.
session
.
add
(
embedding
)
db
.
session
.
commit
()
except
IntegrityError
:
db
.
session
.
rollback
()
continue
except
:
logging
.
exception
(
'Failed to add embedding to db'
)
continue
i
+=
1
text_embeddings
.
extend
(
embedding_results
)
return
text_embeddings
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash
=
helper
.
generate_text_hash
(
text
)
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
hash
)
.
first
()
if
embedding
:
return
embedding
.
get_embedding
()
embedding_results
=
self
.
_embeddings
.
embed_query
(
text
)
try
:
embedding
=
Embedding
(
hash
=
hash
)
embedding
.
set_embedding
(
embedding_results
)
db
.
session
.
add
(
embedding
)
db
.
session
.
commit
()
except
IntegrityError
:
db
.
session
.
rollback
()
except
:
logging
.
exception
(
'Failed to add embedding to db'
)
return
embedding_results
api/core/embedding/openai_embedding.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
Optional
,
Any
,
List
import
openai
from
llama_index.embeddings.base
import
BaseEmbedding
from
llama_index.embeddings.openai
import
OpenAIEmbeddingMode
,
OpenAIEmbeddingModelType
,
_QUERY_MODE_MODEL_DICT
,
\
_TEXT_MODE_MODEL_DICT
from
tenacity
import
wait_random_exponential
,
retry
,
stop_after_attempt
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
def
get_embedding
(
text
:
str
,
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
float
]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text
=
text
.
replace
(
"
\n
"
,
" "
)
return
openai
.
Embedding
.
create
(
input
=
[
text
],
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
)[
"data"
][
0
][
"embedding"
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
async
def
aget_embedding
(
text
:
str
,
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
float
]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text
=
text
.
replace
(
"
\n
"
,
" "
)
return
(
await
openai
.
Embedding
.
acreate
(
input
=
[
text
],
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
))[
"data"
][
0
][
"embedding"
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
def
get_embeddings
(
list_of_text
:
List
[
str
],
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
List
[
float
]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert
len
(
list_of_text
)
<=
2048
,
"The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
data
=
openai
.
Embedding
.
create
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
)
.
data
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
return
[
d
[
"embedding"
]
for
d
in
data
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
async
def
aget_embeddings
(
list_of_text
:
List
[
str
],
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
List
[
float
]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert
len
(
list_of_text
)
<=
2048
,
"The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
data
=
(
await
openai
.
Embedding
.
acreate
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
))
.
data
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
return
[
d
[
"embedding"
]
for
d
in
data
]
class
OpenAIEmbedding
(
BaseEmbedding
):
def
__init__
(
self
,
mode
:
str
=
OpenAIEmbeddingMode
.
TEXT_SEARCH_MODE
,
model
:
str
=
OpenAIEmbeddingModelType
.
TEXT_EMBED_ADA_002
,
deployment_name
:
Optional
[
str
]
=
None
,
openai_api_key
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
)
->
None
:
"""Init params."""
new_kwargs
=
{}
if
'embed_batch_size'
in
kwargs
:
new_kwargs
[
'embed_batch_size'
]
=
kwargs
[
'embed_batch_size'
]
if
'tokenizer'
in
kwargs
:
new_kwargs
[
'tokenizer'
]
=
kwargs
[
'tokenizer'
]
super
()
.
__init__
(
**
new_kwargs
)
self
.
mode
=
OpenAIEmbeddingMode
(
mode
)
self
.
model
=
OpenAIEmbeddingModelType
(
model
)
self
.
deployment_name
=
deployment_name
self
.
openai_api_key
=
openai_api_key
self
.
openai_api_type
=
kwargs
.
get
(
'openai_api_type'
)
self
.
openai_api_version
=
kwargs
.
get
(
'openai_api_version'
)
self
.
openai_api_base
=
kwargs
.
get
(
'openai_api_base'
)
@
handle_llm_exceptions
def
_get_query_embedding
(
self
,
query
:
str
)
->
List
[
float
]:
"""Get query embedding."""
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_QUERY_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_QUERY_MODE_MODEL_DICT
[
key
]
return
get_embedding
(
query
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
def
_get_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
"""Get text embedding."""
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
return
get_embedding
(
text
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
async
def
_aget_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
"""Asynchronously get text embedding."""
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
return
await
aget_embedding
(
text
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
def
_get_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Can be overriden for batch queries.
"""
if
self
.
openai_api_type
and
self
.
openai_api_type
==
'azure'
:
embeddings
=
[]
for
text
in
texts
:
embeddings
.
append
(
self
.
_get_text_embedding
(
text
))
return
embeddings
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
embeddings
=
get_embeddings
(
texts
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
return
embeddings
async
def
_aget_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Asynchronously get text embeddings."""
if
self
.
openai_api_type
and
self
.
openai_api_type
==
'azure'
:
embeddings
=
[]
for
text
in
texts
:
embeddings
.
append
(
await
self
.
_aget_text_embedding
(
text
))
return
embeddings
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
embeddings
=
await
aget_embeddings
(
texts
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
return
embeddings
api/core/index/base.py
0 → 100644
View file @
eea011bd
from
__future__
import
annotations
from
abc
import
abstractmethod
,
ABC
from
typing
import
List
,
Any
from
langchain.schema
import
Document
,
BaseRetriever
from
models.dataset
import
Dataset
class
BaseIndex
(
ABC
):
def
__init__
(
self
,
dataset
:
Dataset
):
self
.
dataset
=
dataset
@
abstractmethod
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
raise
NotImplementedError
@
abstractmethod
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
raise
NotImplementedError
@
abstractmethod
def
text_exists
(
self
,
id
:
str
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
raise
NotImplementedError
@
abstractmethod
def
delete_by_document_id
(
self
,
document_id
:
str
):
raise
NotImplementedError
@
abstractmethod
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
raise
NotImplementedError
@
abstractmethod
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
raise
NotImplementedError
def
delete
(
self
)
->
None
:
raise
NotImplementedError
def
_filter_duplicate_texts
(
self
,
texts
:
list
[
Document
])
->
list
[
Document
]:
for
text
in
texts
:
doc_id
=
text
.
metadata
[
'doc_id'
]
exists_duplicate_node
=
self
.
text_exists
(
doc_id
)
if
exists_duplicate_node
:
texts
.
remove
(
text
)
return
texts
def
_get_uuids
(
self
,
texts
:
list
[
Document
])
->
list
[
str
]:
return
[
text
.
metadata
[
'doc_id'
]
for
text
in
texts
]
api/core/index/index.py
0 → 100644
View file @
eea011bd
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
class
IndexBuilder
:
@
classmethod
def
get_index
(
cls
,
dataset
:
Dataset
,
indexing_technique
:
str
,
ignore_high_quality_check
:
bool
=
False
):
if
indexing_technique
==
"high_quality"
:
if
not
ignore_high_quality_check
and
dataset
.
indexing_technique
!=
'high_quality'
:
return
None
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
return
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
elif
indexing_technique
==
"economy"
:
return
KeywordTableIndex
(
dataset
=
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
10
)
)
else
:
raise
ValueError
(
'Unknown indexing technique'
)
\ No newline at end of file
api/core/index/index_builder.py
deleted
100644 → 0
View file @
3eb8e66b
from
langchain.callbacks
import
CallbackManager
from
llama_index
import
ServiceContext
,
PromptHelper
,
LLMPredictor
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.embedding.openai_embedding
import
OpenAIEmbedding
from
core.llm.llm_builder
import
LLMBuilder
class
IndexBuilder
:
@
classmethod
def
get_default_service_context
(
cls
,
tenant_id
:
str
)
->
ServiceContext
:
# set number of output tokens
num_output
=
512
# only for verbose
callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'text-davinci-003'
,
temperature
=
0
,
max_tokens
=
num_output
,
callback_manager
=
callback_manager
,
)
llm_predictor
=
LLMPredictor
(
llm
=
llm
)
# These parameters here will affect the logic of segmenting the final synthesized response.
# The number of refinement iterations in the synthesis process depends
# on whether the length of the segmented output exceeds the max_input_size.
prompt_helper
=
PromptHelper
(
max_input_size
=
3500
,
num_output
=
num_output
,
max_chunk_overlap
=
20
)
provider
=
LLMBuilder
.
get_default_provider
(
tenant_id
)
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
tenant_id
,
model_provider
=
provider
,
model_name
=
'text-embedding-ada-002'
)
return
ServiceContext
.
from_defaults
(
llm_predictor
=
llm_predictor
,
prompt_helper
=
prompt_helper
,
embed_model
=
OpenAIEmbedding
(
**
model_credentials
),
)
@
classmethod
def
get_fake_llm_service_context
(
cls
,
tenant_id
:
str
)
->
ServiceContext
:
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'fake'
)
return
ServiceContext
.
from_defaults
(
llm_predictor
=
LLMPredictor
(
llm
=
llm
),
embed_model
=
OpenAIEmbedding
()
)
api/core/index/keyword_table/jieba_keyword_table.py
deleted
100644 → 0
View file @
3eb8e66b
import
re
from
typing
import
(
Any
,
Dict
,
List
,
Set
,
Optional
)
import
jieba.analyse
from
core.index.keyword_table.stopwords
import
STOPWORDS
from
llama_index.indices.query.base
import
IS
from
llama_index
import
QueryMode
from
llama_index.indices.base
import
QueryMap
from
llama_index.indices.keyword_table.base
import
BaseGPTKeywordTableIndex
from
llama_index.indices.keyword_table.query
import
BaseGPTKeywordTableQuery
from
llama_index.docstore
import
BaseDocumentStore
from
llama_index.indices.postprocessor.node
import
(
BaseNodePostprocessor
,
)
from
llama_index.indices.response.response_builder
import
ResponseMode
from
llama_index.indices.service_context
import
ServiceContext
from
llama_index.optimization.optimizer
import
BaseTokenUsageOptimizer
from
llama_index.prompts.prompts
import
(
QuestionAnswerPrompt
,
RefinePrompt
,
SimpleInputPrompt
,
)
from
core.index.query.synthesizer
import
EnhanceResponseSynthesizer
def
jieba_extract_keywords
(
text_chunk
:
str
,
max_keywords
:
Optional
[
int
]
=
None
,
expand_with_subtokens
:
bool
=
True
,
)
->
Set
[
str
]:
"""Extract keywords with JIEBA tfidf."""
keywords
=
jieba
.
analyse
.
extract_tags
(
sentence
=
text_chunk
,
topK
=
max_keywords
,
)
if
expand_with_subtokens
:
return
set
(
expand_tokens_with_subtokens
(
keywords
))
else
:
return
set
(
keywords
)
def
expand_tokens_with_subtokens
(
tokens
:
Set
[
str
])
->
Set
[
str
]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results
=
set
()
for
token
in
tokens
:
results
.
add
(
token
)
sub_tokens
=
re
.
findall
(
r"\w+"
,
token
)
if
len
(
sub_tokens
)
>
1
:
results
.
update
({
w
for
w
in
sub_tokens
if
w
not
in
list
(
STOPWORDS
)})
return
results
class
GPTJIEBAKeywordTableIndex
(
BaseGPTKeywordTableIndex
):
"""GPT JIEBA Keyword Table Index.
This index uses a JIEBA keyword extractor to extract keywords from the text.
"""
def
_extract_keywords
(
self
,
text
:
str
)
->
Set
[
str
]:
"""Extract keywords from text."""
return
jieba_extract_keywords
(
text
,
max_keywords
=
self
.
max_keywords_per_chunk
)
@
classmethod
def
get_query_map
(
self
)
->
QueryMap
:
"""Get query map."""
super_map
=
super
()
.
get_query_map
()
super_map
[
QueryMode
.
DEFAULT
]
=
GPTKeywordTableJIEBAQuery
return
super_map
def
_delete
(
self
,
doc_id
:
str
,
**
delete_kwargs
:
Any
)
->
None
:
"""Delete a document."""
# get set of ids that correspond to node
node_idxs_to_delete
=
{
doc_id
}
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete
=
set
()
for
keyword
,
node_idxs
in
self
.
_index_struct
.
table
.
items
():
if
node_idxs_to_delete
.
intersection
(
node_idxs
):
self
.
_index_struct
.
table
[
keyword
]
=
node_idxs
.
difference
(
node_idxs_to_delete
)
if
not
self
.
_index_struct
.
table
[
keyword
]:
keywords_to_delete
.
add
(
keyword
)
for
keyword
in
keywords_to_delete
:
del
self
.
_index_struct
.
table
[
keyword
]
class
GPTKeywordTableJIEBAQuery
(
BaseGPTKeywordTableQuery
):
"""GPT Keyword Table Index JIEBA Query.
Extracts keywords using JIEBA keyword extractor.
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
.. code-block:: python
response = index.query("<query_str>", mode="jieba")
See BaseGPTKeywordTableQuery for arguments.
"""
@
classmethod
def
from_args
(
cls
,
index_struct
:
IS
,
service_context
:
ServiceContext
,
docstore
:
Optional
[
BaseDocumentStore
]
=
None
,
node_postprocessors
:
Optional
[
List
[
BaseNodePostprocessor
]]
=
None
,
verbose
:
bool
=
False
,
# response synthesizer args
response_mode
:
ResponseMode
=
ResponseMode
.
DEFAULT
,
text_qa_template
:
Optional
[
QuestionAnswerPrompt
]
=
None
,
refine_template
:
Optional
[
RefinePrompt
]
=
None
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
response_kwargs
:
Optional
[
Dict
]
=
None
,
use_async
:
bool
=
False
,
streaming
:
bool
=
False
,
optimizer
:
Optional
[
BaseTokenUsageOptimizer
]
=
None
,
# class-specific args
**
kwargs
:
Any
,
)
->
"BaseGPTIndexQuery"
:
response_synthesizer
=
EnhanceResponseSynthesizer
.
from_args
(
service_context
=
service_context
,
text_qa_template
=
text_qa_template
,
refine_template
=
refine_template
,
simple_template
=
simple_template
,
response_mode
=
response_mode
,
response_kwargs
=
response_kwargs
,
use_async
=
use_async
,
streaming
=
streaming
,
optimizer
=
optimizer
,
)
return
cls
(
index_struct
=
index_struct
,
service_context
=
service_context
,
response_synthesizer
=
response_synthesizer
,
docstore
=
docstore
,
node_postprocessors
=
node_postprocessors
,
verbose
=
verbose
,
**
kwargs
,
)
def
_get_keywords
(
self
,
query_str
:
str
)
->
List
[
str
]:
"""Extract keywords."""
return
list
(
jieba_extract_keywords
(
query_str
,
max_keywords
=
self
.
max_keywords_per_query
)
)
api/core/index/keyword_table_index.py
deleted
100644 → 0
View file @
3eb8e66b
import
json
from
typing
import
List
,
Optional
from
llama_index
import
ServiceContext
,
LLMPredictor
,
OpenAIEmbedding
from
llama_index.data_structs
import
KeywordTable
,
Node
from
llama_index.indices.keyword_table.base
import
BaseGPTKeywordTableIndex
from
llama_index.indices.registry
import
load_index_struct_from_dict
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.docstore.empty_docstore
import
EmptyDocumentStore
from
core.index.index_builder
import
IndexBuilder
from
core.index.keyword_table.jieba_keyword_table
import
GPTJIEBAKeywordTableIndex
from
core.llm.llm_builder
import
LLMBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DatasetKeywordTable
,
DocumentSegment
class
KeywordTableIndex
:
def
__init__
(
self
,
dataset
:
Dataset
):
self
.
_dataset
=
dataset
def
add_nodes
(
self
,
nodes
:
List
[
Node
]):
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
_dataset
.
tenant_id
,
model_name
=
'fake'
)
service_context
=
ServiceContext
.
from_defaults
(
llm_predictor
=
LLMPredictor
(
llm
=
llm
),
embed_model
=
OpenAIEmbedding
()
)
dataset_keyword_table
=
self
.
get_keyword_table
()
if
not
dataset_keyword_table
or
not
dataset_keyword_table
.
keyword_table_dict
:
index_struct
=
KeywordTable
()
else
:
index_struct_dict
=
dataset_keyword_table
.
keyword_table_dict
index_struct
:
KeywordTable
=
load_index_struct_from_dict
(
index_struct_dict
)
# create index
index
=
GPTJIEBAKeywordTableIndex
(
index_struct
=
index_struct
,
docstore
=
EmptyDocumentStore
(),
service_context
=
service_context
)
for
node
in
nodes
:
keywords
=
index
.
_extract_keywords
(
node
.
get_text
())
self
.
update_segment_keywords
(
node
.
doc_id
,
list
(
keywords
))
index
.
_index_struct
.
add_node
(
list
(
keywords
),
node
)
index_struct_dict
=
index
.
index_struct
.
to_dict
()
if
not
dataset_keyword_table
:
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
_dataset
.
id
,
keyword_table
=
json
.
dumps
(
index_struct_dict
)
)
db
.
session
.
add
(
dataset_keyword_table
)
else
:
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
index_struct_dict
)
db
.
session
.
commit
()
def
del_nodes
(
self
,
node_ids
:
List
[
str
]):
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
_dataset
.
tenant_id
,
model_name
=
'fake'
)
service_context
=
ServiceContext
.
from_defaults
(
llm_predictor
=
LLMPredictor
(
llm
=
llm
),
embed_model
=
OpenAIEmbedding
()
)
dataset_keyword_table
=
self
.
get_keyword_table
()
if
not
dataset_keyword_table
or
not
dataset_keyword_table
.
keyword_table_dict
:
return
else
:
index_struct_dict
=
dataset_keyword_table
.
keyword_table_dict
index_struct
:
KeywordTable
=
load_index_struct_from_dict
(
index_struct_dict
)
# create index
index
=
GPTJIEBAKeywordTableIndex
(
index_struct
=
index_struct
,
docstore
=
EmptyDocumentStore
(),
service_context
=
service_context
)
for
node_id
in
node_ids
:
index
.
delete
(
node_id
)
index_struct_dict
=
index
.
index_struct
.
to_dict
()
if
not
dataset_keyword_table
:
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
_dataset
.
id
,
keyword_table
=
json
.
dumps
(
index_struct_dict
)
)
db
.
session
.
add
(
dataset_keyword_table
)
else
:
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
index_struct_dict
)
db
.
session
.
commit
()
@
property
def
query_index
(
self
)
->
Optional
[
BaseGPTKeywordTableIndex
]:
docstore
=
DatesetDocumentStore
(
dataset
=
self
.
_dataset
,
user_id
=
self
.
_dataset
.
created_by
,
embedding_model_name
=
"text-embedding-ada-002"
)
service_context
=
IndexBuilder
.
get_default_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
dataset_keyword_table
=
self
.
get_keyword_table
()
if
not
dataset_keyword_table
or
not
dataset_keyword_table
.
keyword_table_dict
:
return
None
index_struct
:
KeywordTable
=
load_index_struct_from_dict
(
dataset_keyword_table
.
keyword_table_dict
)
return
GPTJIEBAKeywordTableIndex
(
index_struct
=
index_struct
,
docstore
=
docstore
,
service_context
=
service_context
)
def
get_keyword_table
(
self
):
dataset_keyword_table
=
self
.
_dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
return
dataset_keyword_table
return
None
def
update_segment_keywords
(
self
,
node_id
:
str
,
keywords
:
List
[
str
]):
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
index_node_id
==
node_id
)
.
first
()
if
document_segment
:
document_segment
.
keywords
=
keywords
db
.
session
.
commit
()
api/core/index/keyword_table_index/jieba_keyword_table_handler.py
0 → 100644
View file @
eea011bd
import
re
from
typing
import
Set
import
jieba
from
jieba.analyse
import
default_tfidf
from
core.index.keyword_table_index.stopwords
import
STOPWORDS
class
JiebaKeywordTableHandler
:
def
__init__
(
self
):
default_tfidf
.
stop_words
=
STOPWORDS
def
extract_keywords
(
self
,
text
:
str
,
max_keywords_per_chunk
:
int
=
10
)
->
Set
[
str
]:
"""Extract keywords with JIEBA tfidf."""
keywords
=
jieba
.
analyse
.
extract_tags
(
sentence
=
text
,
topK
=
max_keywords_per_chunk
,
)
return
set
(
self
.
_expand_tokens_with_subtokens
(
keywords
))
def
_expand_tokens_with_subtokens
(
self
,
tokens
:
Set
[
str
])
->
Set
[
str
]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results
=
set
()
for
token
in
tokens
:
results
.
add
(
token
)
sub_tokens
=
re
.
findall
(
r"\w+"
,
token
)
if
len
(
sub_tokens
)
>
1
:
results
.
update
({
w
for
w
in
sub_tokens
if
w
not
in
list
(
STOPWORDS
)})
return
results
\ No newline at end of file
api/core/index/keyword_table_index/keyword_table_index.py
0 → 100644
View file @
eea011bd
import
json
from
collections
import
defaultdict
from
typing
import
Any
,
List
,
Optional
,
Dict
from
langchain.schema
import
Document
,
BaseRetriever
from
pydantic
import
BaseModel
,
Field
,
Extra
from
core.index.base
import
BaseIndex
from
core.index.keyword_table_index.jieba_keyword_table_handler
import
JiebaKeywordTableHandler
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetKeywordTable
class
KeywordTableConfig
(
BaseModel
):
max_keywords_per_chunk
:
int
=
10
class
KeywordTableIndex
(
BaseIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
KeywordTableConfig
=
KeywordTableConfig
()):
super
()
.
__init__
(
dataset
)
self
.
_config
=
config
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
{}
for
text
in
texts
:
keywords
=
keyword_table_handler
.
extract_keywords
(
text
.
page_content
,
self
.
_config
.
max_keywords_per_chunk
)
self
.
_update_segment_keywords
(
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
keyword_table
=
self
.
_add_text_to_keyword_table
(
keyword_table
,
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
dataset
.
id
,
keyword_table
=
json
.
dumps
({
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
{}
}
},
cls
=
SetEncoder
)
)
db
.
session
.
add
(
dataset_keyword_table
)
db
.
session
.
commit
()
self
.
_save_dataset_keyword_table
(
keyword_table
)
return
self
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
self
.
_get_dataset_keyword_table
()
for
text
in
texts
:
keywords
=
keyword_table_handler
.
extract_keywords
(
text
.
page_content
,
self
.
_config
.
max_keywords_per_chunk
)
self
.
_update_segment_keywords
(
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
keyword_table
=
self
.
_add_text_to_keyword_table
(
keyword_table
,
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
self
.
_save_dataset_keyword_table
(
keyword_table
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
keyword_table
=
self
.
_get_dataset_keyword_table
()
return
id
in
set
.
union
(
*
keyword_table
.
values
())
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
keyword_table
=
self
.
_get_dataset_keyword_table
()
keyword_table
=
self
.
_delete_ids_from_keyword_table
(
keyword_table
,
ids
)
self
.
_save_dataset_keyword_table
(
keyword_table
)
def
delete_by_document_id
(
self
,
document_id
:
str
):
# get segment ids by document_id
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
dataset
.
id
,
DocumentSegment
.
document_id
==
document_id
)
.
all
()
ids
=
[
segment
.
id
for
segment
in
segments
]
keyword_table
=
self
.
_get_dataset_keyword_table
()
keyword_table
=
self
.
_delete_ids_from_keyword_table
(
keyword_table
,
ids
)
self
.
_save_dataset_keyword_table
(
keyword_table
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
return
KeywordTableRetriever
(
index
=
self
,
**
kwargs
)
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
keyword_table
=
self
.
_get_dataset_keyword_table
()
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
k
=
search_kwargs
.
get
(
'k'
)
if
search_kwargs
.
get
(
'k'
)
else
4
sorted_chunk_indices
=
self
.
_retrieve_ids_by_query
(
keyword_table
,
query
,
k
)
documents
=
[]
for
chunk_index
in
sorted_chunk_indices
:
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
dataset
.
id
,
DocumentSegment
.
index_node_id
==
chunk_index
)
.
first
()
if
segment
:
documents
.
append
(
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
chunk_index
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
))
return
documents
def
delete
(
self
)
->
None
:
dataset_keyword_table
=
self
.
dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
db
.
session
.
delete
(
dataset_keyword_table
)
db
.
session
.
commit
()
def
_save_dataset_keyword_table
(
self
,
keyword_table
):
keyword_table_dict
=
{
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
keyword_table
}
}
self
.
dataset
.
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
keyword_table_dict
,
cls
=
SetEncoder
)
db
.
session
.
commit
()
def
_get_dataset_keyword_table
(
self
)
->
Optional
[
dict
]:
dataset_keyword_table
=
self
.
dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
if
dataset_keyword_table
.
keyword_table_dict
:
return
dataset_keyword_table
.
keyword_table_dict
[
'__data__'
][
'table'
]
else
:
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
dataset
.
id
,
keyword_table
=
json
.
dumps
({
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
{}
}
},
cls
=
SetEncoder
)
)
db
.
session
.
add
(
dataset_keyword_table
)
db
.
session
.
commit
()
return
{}
def
_add_text_to_keyword_table
(
self
,
keyword_table
:
dict
,
id
:
str
,
keywords
:
list
[
str
])
->
dict
:
for
keyword
in
keywords
:
if
keyword
not
in
keyword_table
:
keyword_table
[
keyword
]
=
set
()
keyword_table
[
keyword
]
.
add
(
id
)
return
keyword_table
def
_delete_ids_from_keyword_table
(
self
,
keyword_table
:
dict
,
ids
:
list
[
str
])
->
dict
:
# get set of ids that correspond to node
node_idxs_to_delete
=
set
(
ids
)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete
=
set
()
for
keyword
,
node_idxs
in
keyword_table
.
items
():
if
node_idxs_to_delete
.
intersection
(
node_idxs
):
keyword_table
[
keyword
]
=
node_idxs
.
difference
(
node_idxs_to_delete
)
if
not
keyword_table
[
keyword
]:
keywords_to_delete
.
add
(
keyword
)
for
keyword
in
keywords_to_delete
:
del
keyword_table
[
keyword
]
return
keyword_table
def
_retrieve_ids_by_query
(
self
,
keyword_table
:
dict
,
query
:
str
,
k
:
int
=
4
):
keyword_table_handler
=
JiebaKeywordTableHandler
()
keywords
=
keyword_table_handler
.
extract_keywords
(
query
)
# go through text chunks in order of most matching keywords
chunk_indices_count
:
Dict
[
str
,
int
]
=
defaultdict
(
int
)
keywords
=
[
keyword
for
keyword
in
keywords
if
keyword
in
set
(
keyword_table
.
keys
())]
for
keyword
in
keywords
:
for
node_id
in
keyword_table
[
keyword
]:
chunk_indices_count
[
node_id
]
+=
1
sorted_chunk_indices
=
sorted
(
list
(
chunk_indices_count
.
keys
()),
key
=
lambda
x
:
chunk_indices_count
[
x
],
reverse
=
True
,
)
return
sorted_chunk_indices
[:
k
]
def
_update_segment_keywords
(
self
,
node_id
:
str
,
keywords
:
List
[
str
]):
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
index_node_id
==
node_id
)
.
first
()
if
document_segment
:
document_segment
.
keywords
=
keywords
db
.
session
.
commit
()
class
KeywordTableRetriever
(
BaseRetriever
,
BaseModel
):
index
:
KeywordTableIndex
search_kwargs
:
dict
=
Field
(
default_factory
=
dict
)
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
arbitrary_types_allowed
=
True
def
get_relevant_documents
(
self
,
query
:
str
)
->
List
[
Document
]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return
self
.
index
.
search
(
query
,
**
self
.
search_kwargs
)
async
def
aget_relevant_documents
(
self
,
query
:
str
)
->
List
[
Document
]:
raise
NotImplementedError
(
"KeywordTableRetriever does not support async"
)
class
SetEncoder
(
json
.
JSONEncoder
):
def
default
(
self
,
obj
):
if
isinstance
(
obj
,
set
):
return
list
(
obj
)
return
super
()
.
default
(
obj
)
api/core/index/keyword_table/stopwords.py
→
api/core/index/keyword_table
_index
/stopwords.py
View file @
eea011bd
File moved
api/core/index/query/synthesizer.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
(
Any
,
Dict
,
Optional
,
Sequence
,
)
from
llama_index.indices.response.response_synthesis
import
ResponseSynthesizer
from
llama_index.indices.response.response_builder
import
ResponseMode
,
BaseResponseBuilder
,
get_response_builder
from
llama_index.indices.service_context
import
ServiceContext
from
llama_index.optimization.optimizer
import
BaseTokenUsageOptimizer
from
llama_index.prompts.prompts
import
(
QuestionAnswerPrompt
,
RefinePrompt
,
SimpleInputPrompt
,
)
from
llama_index.types
import
RESPONSE_TEXT_TYPE
class
EnhanceResponseSynthesizer
(
ResponseSynthesizer
):
@
classmethod
def
from_args
(
cls
,
service_context
:
ServiceContext
,
streaming
:
bool
=
False
,
use_async
:
bool
=
False
,
text_qa_template
:
Optional
[
QuestionAnswerPrompt
]
=
None
,
refine_template
:
Optional
[
RefinePrompt
]
=
None
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
response_mode
:
ResponseMode
=
ResponseMode
.
DEFAULT
,
response_kwargs
:
Optional
[
Dict
]
=
None
,
optimizer
:
Optional
[
BaseTokenUsageOptimizer
]
=
None
,
)
->
"ResponseSynthesizer"
:
response_builder
:
Optional
[
BaseResponseBuilder
]
=
None
if
response_mode
!=
ResponseMode
.
NO_TEXT
:
if
response_mode
==
'no_synthesizer'
:
response_builder
=
NoSynthesizer
(
service_context
=
service_context
,
simple_template
=
simple_template
,
streaming
=
streaming
,
)
else
:
response_builder
=
get_response_builder
(
service_context
,
text_qa_template
,
refine_template
,
simple_template
,
response_mode
,
use_async
=
use_async
,
streaming
=
streaming
,
)
return
cls
(
response_builder
,
response_mode
,
response_kwargs
,
optimizer
)
class
NoSynthesizer
(
BaseResponseBuilder
):
def
__init__
(
self
,
service_context
:
ServiceContext
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
streaming
:
bool
=
False
,
)
->
None
:
super
()
.
__init__
(
service_context
,
streaming
)
async
def
aget_response
(
self
,
query_str
:
str
,
text_chunks
:
Sequence
[
str
],
prev_response
:
Optional
[
str
]
=
None
,
**
response_kwargs
:
Any
,
)
->
RESPONSE_TEXT_TYPE
:
return
"
\n
"
.
join
(
text_chunks
)
def
get_response
(
self
,
query_str
:
str
,
text_chunks
:
Sequence
[
str
],
prev_response
:
Optional
[
str
]
=
None
,
**
response_kwargs
:
Any
,
)
->
RESPONSE_TEXT_TYPE
:
return
"
\n
"
.
join
(
text_chunks
)
\ No newline at end of file
api/core/index/readers/html_parser.py
deleted
100644 → 0
View file @
3eb8e66b
from
pathlib
import
Path
from
typing
import
Dict
from
bs4
import
BeautifulSoup
from
llama_index.readers.file.base_parser
import
BaseParser
class
HTMLParser
(
BaseParser
):
"""HTML parser."""
def
_init_parser
(
self
)
->
Dict
:
"""Init parser."""
return
{}
def
parse_file
(
self
,
file
:
Path
,
errors
:
str
=
"ignore"
)
->
str
:
"""Parse file."""
with
open
(
file
,
"rb"
)
as
fp
:
soup
=
BeautifulSoup
(
fp
,
'html.parser'
)
text
=
soup
.
get_text
()
text
=
text
.
strip
()
if
text
else
''
return
text
api/core/index/readers/pdf_parser.py
deleted
100644 → 0
View file @
3eb8e66b
from
pathlib
import
Path
from
typing
import
Dict
from
flask
import
current_app
from
llama_index.readers.file.base_parser
import
BaseParser
from
pypdf
import
PdfReader
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
class
PDFParser
(
BaseParser
):
"""PDF parser."""
def
_init_parser
(
self
)
->
Dict
:
"""Init parser."""
return
{}
def
parse_file
(
self
,
file
:
Path
,
errors
:
str
=
"ignore"
)
->
str
:
"""Parse file."""
if
not
current_app
.
config
.
get
(
'PDF_PREVIEW'
,
True
):
return
''
plaintext_file_key
=
''
plaintext_file_exists
=
False
if
self
.
_parser_config
and
'upload_file'
in
self
.
_parser_config
and
self
.
_parser_config
[
'upload_file'
]:
upload_file
:
UploadFile
=
self
.
_parser_config
[
'upload_file'
]
if
upload_file
.
hash
:
plaintext_file_key
=
'upload_files/'
+
upload_file
.
tenant_id
+
'/'
+
upload_file
.
hash
+
'.plaintext'
try
:
text
=
storage
.
load
(
plaintext_file_key
)
.
decode
(
'utf-8'
)
plaintext_file_exists
=
True
return
text
except
FileNotFoundError
:
pass
text_list
=
[]
with
open
(
file
,
"rb"
)
as
fp
:
# Create a PDF object
pdf
=
PdfReader
(
fp
)
# Get the number of pages in the PDF document
num_pages
=
len
(
pdf
.
pages
)
# Iterate over every page
for
page
in
range
(
num_pages
):
# Extract the text from the page
page_text
=
pdf
.
pages
[
page
]
.
extract_text
()
text_list
.
append
(
page_text
)
text
=
"
\n
"
.
join
(
text_list
)
# save plaintext file for caching
if
not
plaintext_file_exists
and
plaintext_file_key
:
storage
.
save
(
plaintext_file_key
,
text
.
encode
(
'utf-8'
))
return
text
api/core/index/readers/xlsx_parser.py
deleted
100644 → 0
View file @
3eb8e66b
from
pathlib
import
Path
import
json
from
typing
import
Dict
from
openpyxl
import
load_workbook
from
llama_index.readers.file.base_parser
import
BaseParser
from
flask
import
current_app
class
XLSXParser
(
BaseParser
):
"""XLSX parser."""
def
_init_parser
(
self
)
->
Dict
:
"""Init parser"""
return
{}
def
parse_file
(
self
,
file
:
Path
,
errors
:
str
=
"ignore"
)
->
str
:
data
=
[]
keys
=
[]
with
open
(
file
,
"r"
)
as
fp
:
wb
=
load_workbook
(
filename
=
file
,
read_only
=
True
)
# loop over all sheets
for
sheet
in
wb
:
for
row
in
sheet
.
iter_rows
(
values_only
=
True
):
if
all
(
v
is
None
for
v
in
row
):
continue
if
keys
==
[]:
keys
=
list
(
map
(
str
,
row
))
else
:
row_dict
=
dict
(
zip
(
keys
,
row
))
row_dict
=
{
k
:
v
for
k
,
v
in
row_dict
.
items
()
if
v
}
data
.
append
(
json
.
dumps
(
row_dict
,
ensure_ascii
=
False
))
return
'
\n\n
'
.
join
(
data
)
api/core/index/vector_index.py
deleted
100644 → 0
View file @
3eb8e66b
import
json
import
logging
from
typing
import
List
,
Optional
from
llama_index.data_structs
import
Node
from
requests
import
ReadTimeout
from
sqlalchemy.exc
import
IntegrityError
from
tenacity
import
retry
,
stop_after_attempt
,
retry_if_exception_type
from
core.index.index_builder
import
IndexBuilder
from
core.vector_store.base
import
BaseGPTVectorStoreIndex
from
extensions.ext_vector_store
import
vector_store
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
Embedding
class
VectorIndex
:
def
__init__
(
self
,
dataset
:
Dataset
):
self
.
_dataset
=
dataset
def
add_nodes
(
self
,
nodes
:
List
[
Node
],
duplicate_check
:
bool
=
False
):
if
not
self
.
_dataset
.
index_struct_dict
:
index_id
=
"Vector_index_"
+
self
.
_dataset
.
id
.
replace
(
"-"
,
"_"
)
self
.
_dataset
.
index_struct
=
json
.
dumps
(
vector_store
.
to_index_struct
(
index_id
))
db
.
session
.
commit
()
service_context
=
IndexBuilder
.
get_default_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
index
=
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
if
duplicate_check
:
nodes
=
self
.
_filter_duplicate_nodes
(
index
,
nodes
)
embedding_queue_nodes
=
[]
embedded_nodes
=
[]
for
node
in
nodes
:
node_hash
=
node
.
doc_hash
# if node hash in cached embedding tables, use cached embedding
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
node_hash
)
.
first
()
if
embedding
:
node
.
embedding
=
embedding
.
get_embedding
()
embedded_nodes
.
append
(
node
)
else
:
embedding_queue_nodes
.
append
(
node
)
if
embedding_queue_nodes
:
embedding_results
=
index
.
_get_node_embedding_results
(
embedding_queue_nodes
,
set
(),
)
# pre embed nodes for cached embedding
for
embedding_result
in
embedding_results
:
node
=
embedding_result
.
node
node
.
embedding
=
embedding_result
.
embedding
try
:
embedding
=
Embedding
(
hash
=
node
.
doc_hash
)
embedding
.
set_embedding
(
node
.
embedding
)
db
.
session
.
add
(
embedding
)
db
.
session
.
commit
()
except
IntegrityError
:
db
.
session
.
rollback
()
continue
except
:
logging
.
exception
(
'Failed to add embedding to db'
)
continue
embedded_nodes
.
append
(
node
)
self
.
index_insert_nodes
(
index
,
embedded_nodes
)
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
index_insert_nodes
(
self
,
index
:
BaseGPTVectorStoreIndex
,
nodes
:
List
[
Node
]):
index
.
insert_nodes
(
nodes
)
def
del_nodes
(
self
,
node_ids
:
List
[
str
]):
if
not
self
.
_dataset
.
index_struct_dict
:
return
service_context
=
IndexBuilder
.
get_fake_llm_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
index
=
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
for
node_id
in
node_ids
:
self
.
index_delete_node
(
index
,
node_id
)
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
index_delete_node
(
self
,
index
:
BaseGPTVectorStoreIndex
,
node_id
:
str
):
index
.
delete_node
(
node_id
)
def
del_doc
(
self
,
doc_id
:
str
):
if
not
self
.
_dataset
.
index_struct_dict
:
return
service_context
=
IndexBuilder
.
get_fake_llm_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
index
=
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
self
.
index_delete_doc
(
index
,
doc_id
)
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
index_delete_doc
(
self
,
index
:
BaseGPTVectorStoreIndex
,
doc_id
:
str
):
index
.
delete
(
doc_id
)
@
property
def
query_index
(
self
)
->
Optional
[
BaseGPTVectorStoreIndex
]:
if
not
self
.
_dataset
.
index_struct_dict
:
return
None
service_context
=
IndexBuilder
.
get_default_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
return
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
def
_filter_duplicate_nodes
(
self
,
index
:
BaseGPTVectorStoreIndex
,
nodes
:
List
[
Node
])
->
List
[
Node
]:
for
node
in
nodes
:
node_id
=
node
.
doc_id
exists_duplicate_node
=
index
.
exists_by_node_id
(
node_id
)
if
exists_duplicate_node
:
nodes
.
remove
(
node
)
return
nodes
api/core/index/vector_index/base.py
0 → 100644
View file @
eea011bd
import
json
import
logging
from
abc
import
abstractmethod
from
typing
import
List
,
Any
,
cast
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
weaviate
import
UnexpectedStatusCodeException
from
core.index.base
import
BaseIndex
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DocumentSegment
from
models.dataset
import
Document
as
DatasetDocument
class
BaseVectorIndex
(
BaseIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
)
self
.
_embeddings
=
embeddings
self
.
_vector_store
=
None
def
get_type
(
self
)
->
str
:
raise
NotImplementedError
@
abstractmethod
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
raise
NotImplementedError
@
abstractmethod
def
to_index_struct
(
self
)
->
dict
:
raise
NotImplementedError
@
abstractmethod
def
_get_vector_store
(
self
)
->
VectorStore
:
raise
NotImplementedError
@
abstractmethod
def
_get_vector_store_class
(
self
)
->
type
:
raise
NotImplementedError
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
if
search_type
==
'similarity_score_threshold'
:
score_threshold
=
search_kwargs
.
get
(
"score_threshold"
)
if
(
score_threshold
is
None
)
or
(
not
isinstance
(
score_threshold
,
float
)):
search_kwargs
[
'score_threshold'
]
=
.0
docs_with_similarity
=
vector_store
.
similarity_search_with_relevance_scores
(
query
,
**
search_kwargs
)
docs
=
[]
for
doc
,
similarity
in
docs_with_similarity
:
doc
.
metadata
[
'score'
]
=
similarity
docs
.
append
(
doc
)
return
docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
if
kwargs
.
get
(
'duplicate_check'
,
False
):
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
def
delete
(
self
)
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
.
delete
()
def
_is_origin
(
self
):
return
False
def
recreate_dataset
(
self
,
dataset
:
Dataset
):
logging
.
info
(
f
"Recreating dataset {dataset.id}"
)
try
:
self
.
delete
()
except
UnexpectedStatusCodeException
as
e
:
if
e
.
status_code
!=
400
:
# 400 means index not exists
raise
e
dataset_documents
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
DatasetDocument
.
dataset_id
==
dataset
.
id
,
DatasetDocument
.
indexing_status
==
'completed'
,
DatasetDocument
.
enabled
==
True
,
DatasetDocument
.
archived
==
False
,
)
.
all
()
documents
=
[]
for
dataset_document
in
dataset_documents
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
dataset_document
.
id
,
DocumentSegment
.
status
==
'completed'
,
DocumentSegment
.
enabled
==
True
)
.
all
()
for
segment
in
segments
:
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
documents
.
append
(
document
)
origin_index_struct
=
self
.
dataset
.
index_struct
self
.
dataset
.
index_struct
=
None
if
documents
:
try
:
self
.
create
(
documents
)
except
Exception
as
e
:
self
.
dataset
.
index_struct
=
origin_index_struct
raise
e
dataset
.
index_struct
=
json
.
dumps
(
self
.
to_index_struct
())
db
.
session
.
commit
()
self
.
dataset
=
dataset
logging
.
info
(
f
"Dataset {dataset.id} recreate successfully."
)
api/core/index/vector_index/qdrant_vector_index.py
0 → 100644
View file @
eea011bd
import
os
from
typing
import
Optional
,
Any
,
List
,
cast
import
qdrant_client
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.qdrant_vector_store
import
QdrantVectorStore
from
models.dataset
import
Dataset
class
QdrantConfig
(
BaseModel
):
endpoint
:
str
api_key
:
Optional
[
str
]
root_path
:
Optional
[
str
]
def
to_qdrant_params
(
self
):
if
self
.
endpoint
and
self
.
endpoint
.
startswith
(
'path:'
):
path
=
self
.
endpoint
.
replace
(
'path:'
,
''
)
if
not
os
.
path
.
isabs
(
path
):
path
=
os
.
path
.
join
(
self
.
root_path
,
path
)
return
{
'path'
:
path
}
else
:
return
{
'url'
:
self
.
endpoint
,
'api_key'
:
self
.
api_key
,
}
class
QdrantVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
QdrantConfig
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
,
embeddings
)
self
.
_client_config
=
config
def
get_type
(
self
)
->
str
:
return
'qdrant'
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
if
self
.
dataset
.
index_struct_dict
:
return
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'collection_name'
]
dataset_id
=
dataset
.
id
return
"Index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
dataset
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
QdrantVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
collection_name
=
self
.
get_index_name
(
self
.
dataset
),
ids
=
uuids
,
content_payload_key
=
'text'
,
**
self
.
_client_config
.
to_qdrant_params
()
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
return
self
.
_vector_store
client
=
qdrant_client
.
QdrantClient
(
**
self
.
_client_config
.
to_qdrant_params
()
)
return
QdrantVectorStore
(
client
=
client
,
collection_name
=
self
.
get_index_name
(
self
.
dataset
),
embeddings
=
self
.
_embeddings
,
content_payload_key
=
'text'
)
def
_get_vector_store_class
(
self
)
->
type
:
return
QdrantVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
from
qdrant_client.http
import
models
vector_store
.
del_texts
(
models
.
Filter
(
must
=
[
models
.
FieldCondition
(
key
=
"metadata.document_id"
,
match
=
models
.
MatchValue
(
value
=
document_id
),
),
],
))
def
_is_origin
(
self
):
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'collection_name'
]
if
class_prefix
.
startswith
(
'Vector_'
):
# original class_prefix
return
True
return
False
api/core/index/vector_index/vector_index.py
0 → 100644
View file @
eea011bd
import
json
from
flask
import
current_app
from
langchain.embeddings.base
import
Embeddings
from
core.index.vector_index.base
import
BaseVectorIndex
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
Document
class
VectorIndex
:
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
dict
,
embeddings
:
Embeddings
):
self
.
_dataset
=
dataset
self
.
_embeddings
=
embeddings
self
.
_vector_index
=
self
.
_init_vector_index
(
dataset
,
config
,
embeddings
)
def
_init_vector_index
(
self
,
dataset
:
Dataset
,
config
:
dict
,
embeddings
:
Embeddings
)
->
BaseVectorIndex
:
vector_type
=
config
.
get
(
'VECTOR_STORE'
)
if
self
.
_dataset
.
index_struct_dict
:
vector_type
=
self
.
_dataset
.
index_struct_dict
[
'type'
]
if
not
vector_type
:
raise
ValueError
(
f
"Vector store must be specified."
)
if
vector_type
==
"weaviate"
:
from
core.index.vector_index.weaviate_vector_index
import
WeaviateVectorIndex
,
WeaviateConfig
return
WeaviateVectorIndex
(
dataset
=
dataset
,
config
=
WeaviateConfig
(
endpoint
=
config
.
get
(
'WEAVIATE_ENDPOINT'
),
api_key
=
config
.
get
(
'WEAVIATE_API_KEY'
),
batch_size
=
int
(
config
.
get
(
'WEAVIATE_BATCH_SIZE'
))
),
embeddings
=
embeddings
)
elif
vector_type
==
"qdrant"
:
from
core.index.vector_index.qdrant_vector_index
import
QdrantVectorIndex
,
QdrantConfig
return
QdrantVectorIndex
(
dataset
=
dataset
,
config
=
QdrantConfig
(
endpoint
=
config
.
get
(
'QDRANT_URL'
),
api_key
=
config
.
get
(
'QDRANT_API_KEY'
),
root_path
=
current_app
.
root_path
),
embeddings
=
embeddings
)
else
:
raise
ValueError
(
f
"Vector store {config.get('VECTOR_STORE')} is not supported."
)
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
if
not
self
.
_dataset
.
index_struct_dict
:
self
.
_vector_index
.
create
(
texts
,
**
kwargs
)
self
.
_dataset
.
index_struct
=
json
.
dumps
(
self
.
_vector_index
.
to_index_struct
())
db
.
session
.
commit
()
return
self
.
_vector_index
.
add_texts
(
texts
,
**
kwargs
)
def
__getattr__
(
self
,
name
):
if
self
.
_vector_index
is
not
None
:
method
=
getattr
(
self
.
_vector_index
,
name
)
if
callable
(
method
):
return
method
raise
AttributeError
(
f
"'VectorIndex' object has no attribute '{name}'"
)
api/core/index/vector_index/weaviate_vector_index.py
0 → 100644
View file @
eea011bd
from
typing
import
Optional
,
cast
import
requests
import
weaviate
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
,
root_validator
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.weaviate_vector_store
import
WeaviateVectorStore
from
models.dataset
import
Dataset
class
WeaviateConfig
(
BaseModel
):
endpoint
:
str
api_key
:
Optional
[
str
]
batch_size
:
int
=
100
@
root_validator
()
def
validate_config
(
cls
,
values
:
dict
)
->
dict
:
if
not
values
[
'endpoint'
]:
raise
ValueError
(
"config WEAVIATE_ENDPOINT is required"
)
return
values
class
WeaviateVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
WeaviateConfig
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
,
embeddings
)
self
.
_client
=
self
.
_init_client
(
config
)
def
_init_client
(
self
,
config
:
WeaviateConfig
)
->
weaviate
.
Client
:
auth_config
=
weaviate
.
auth
.
AuthApiKey
(
api_key
=
config
.
api_key
)
weaviate
.
connect
.
connection
.
has_grpc
=
False
try
:
client
=
weaviate
.
Client
(
url
=
config
.
endpoint
,
auth_client_secret
=
auth_config
,
timeout_config
=
(
5
,
60
),
startup_period
=
None
)
except
requests
.
exceptions
.
ConnectionError
:
raise
ConnectionError
(
"Vector database connection error"
)
client
.
batch
.
configure
(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size
=
config
.
batch_size
,
# dynamically update the `batch_size` based on import speed
dynamic
=
True
,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries
=
3
,
)
return
client
def
get_type
(
self
)
->
str
:
return
'weaviate'
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
class_prefix
+=
'_Node'
return
class_prefix
dataset_id
=
dataset
.
id
return
"Vector_index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
+
'_Node'
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
dataset
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
dataset
),
uuids
=
uuids
,
by_text
=
False
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
return
self
.
_vector_store
attributes
=
[
'doc_id'
,
'dataset_id'
,
'document_id'
]
if
self
.
_is_origin
():
attributes
=
[
'doc_id'
]
return
WeaviateVectorStore
(
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
dataset
),
text_key
=
'text'
,
embedding
=
self
.
_embeddings
,
attributes
=
attributes
,
by_text
=
False
)
def
_get_vector_store_class
(
self
)
->
type
:
return
WeaviateVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
.
del_texts
({
"operator"
:
"Equal"
,
"path"
:
[
"document_id"
],
"valueText"
:
document_id
})
def
_is_origin
(
self
):
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
return
True
return
False
api/core/indexing_runner.py
View file @
eea011bd
import
datetime
import
datetime
import
json
import
json
import
logging
import
re
import
re
import
tempfile
import
time
import
time
from
pathlib
import
Path
import
uuid
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
cast
from
flask
import
current_app
from
flask_login
import
current_user
from
flask_login
import
current_user
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.schema
import
Document
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
,
TextSplitter
from
llama_index
import
SimpleDirectoryReader
from
core.data_loader.file_extractor
import
FileExtractor
from
llama_index.data_structs
import
Node
from
core.data_loader.loader.notion
import
NotionLoader
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
llama_index.node_parser
import
SimpleNodeParser
,
NodeParser
from
llama_index.readers.file.base
import
DEFAULT_FILE_EXTRACTOR
from
llama_index.readers.file.markdown_parser
import
MarkdownParser
from
core.data_source.notion
import
NotionPageReader
from
core.index.readers.xlsx_parser
import
XLSXParser
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.readers.html_parser
import
HTMLParser
from
core.index.index
import
IndexBuilder
from
core.index.readers.markdown_parser
import
MarkdownParser
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.readers.pdf_parser
import
PDFParser
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.index.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.llm.error
import
ProviderTokenNotInitError
from
core.index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
core.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.llm.token_calculator
import
TokenCalculator
from
core.llm.token_calculator
import
TokenCalculator
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
extensions.ext_storage
import
storage
from
extensions.ext_storage
import
storage
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
,
DatasetProcessRule
from
libs
import
helper
from
models.dataset
import
Document
as
DatasetDocument
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetProcessRule
from
models.model
import
UploadFile
from
models.model
import
UploadFile
from
models.source
import
DataSourceBinding
from
models.source
import
DataSourceBinding
...
@@ -40,135 +39,171 @@ class IndexingRunner:
...
@@ -40,135 +39,171 @@ class IndexingRunner:
self
.
storage
=
storage
self
.
storage
=
storage
self
.
embedding_model_name
=
embedding_model_name
self
.
embedding_model_name
=
embedding_model_name
def
run
(
self
,
d
ocuments
:
List
[
Document
]):
def
run
(
self
,
d
ataset_documents
:
List
[
Dataset
Document
]):
"""Run the indexing process."""
"""Run the indexing process."""
for
document
in
documents
:
for
dataset_document
in
dataset_documents
:
try
:
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
dataset_document
.
dataset_process_rule_id
)
.
\
first
()
# get splitter
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
documents
=
self
.
_step_split
(
text_docs
=
text_docs
,
splitter
=
splitter
,
dataset
=
dataset
,
dataset_document
=
dataset_document
,
processing_rule
=
processing_rule
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
dataset_document
=
dataset_document
,
documents
=
documents
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
run_in_splitting_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is splitting."""
try
:
# get dataset
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
dataset
=
Dataset
.
query
.
filter_by
(
id
=
document
.
dataset_id
id
=
d
ataset_d
ocument
.
dataset_id
)
.
first
()
)
.
first
()
if
not
dataset
:
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
)
.
all
()
db
.
session
.
delete
(
document_segments
)
db
.
session
.
commit
()
# load file
# load file
text_docs
=
self
.
_load_data
(
document
)
text_docs
=
self
.
_load_data
(
d
ataset_d
ocument
)
# get the process rule
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
document
.
dataset_process_rule_id
)
.
\
filter
(
DatasetProcessRule
.
id
==
d
ataset_d
ocument
.
dataset_process_rule_id
)
.
\
first
()
first
()
# get
node parser for splitting
# get
splitter
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
# split to
document
s
node
s
=
self
.
_step_split
(
document
s
=
self
.
_step_split
(
text_docs
=
text_docs
,
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
dataset
=
dataset
,
dataset
=
dataset
,
d
ocument
=
document
,
d
ataset_document
=
dataset_
document
,
processing_rule
=
processing_rule
processing_rule
=
processing_rule
)
)
# build index
# build index
self
.
_build_index
(
self
.
_build_index
(
dataset
=
dataset
,
dataset
=
dataset
,
d
ocument
=
document
,
d
ataset_document
=
dataset_
document
,
nodes
=
node
s
documents
=
document
s
)
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
run_in_splitting_status
(
self
,
document
:
Document
):
def
run_in_indexing_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is splitting."""
"""Run the indexing process when the index_status is indexing."""
# get dataset
try
:
dataset
=
Dataset
.
query
.
filter_by
(
# get dataset
id
=
document
.
dataset_id
dataset
=
Dataset
.
query
.
filter_by
(
)
.
first
()
id
=
dataset_document
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
document
.
id
)
.
all
()
db
.
session
.
delete
(
document_segments
)
db
.
session
.
commit
()
# load file
text_docs
=
self
.
_load_data
(
document
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
document
.
dataset_process_rule_id
)
.
\
first
()
# get node parser for splitting
node_parser
=
self
.
_get_node_parser
(
processing_rule
)
# split to nodes
if
not
dataset
:
nodes
=
self
.
_step_split
(
raise
ValueError
(
"no dataset found"
)
text_docs
=
text_docs
,
node_parser
=
node_parser
,
dataset
=
dataset
,
document
=
document
,
processing_rule
=
processing_rule
)
# build index
# get exist document_segment list and delete
self
.
_build_index
(
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset
=
dataset
,
dataset_id
=
dataset
.
id
,
document
=
document
,
document_id
=
dataset_document
.
id
nodes
=
nodes
)
.
all
()
)
documents
=
[]
if
document_segments
:
for
document_segment
in
document_segments
:
# transform segment to node
if
document_segment
.
status
!=
"completed"
:
document
=
Document
(
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
documents
.
append
(
document
)
def
run_in_indexing_status
(
self
,
document
:
Document
):
# build index
"""Run the indexing process when the index_status is indexing."""
self
.
_build_index
(
# get dataset
dataset
=
dataset
,
dataset
=
Dataset
.
query
.
filter_by
(
dataset_document
=
dataset_document
,
id
=
document
.
dataset_id
documents
=
documents
)
.
first
()
)
except
DocumentIsPausedException
:
if
not
dataset
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
raise
ValueError
(
"no dataset found"
)
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
# get exist document_segment list and delete
dataset_document
.
error
=
str
(
e
.
description
)
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
dataset_id
=
dataset
.
id
,
db
.
session
.
commit
()
document_id
=
document
.
id
except
Exception
as
e
:
)
.
all
()
logging
.
exception
(
"consume document failed"
)
nodes
=
[]
dataset_document
.
indexing_status
=
'error'
if
document_segments
:
dataset_document
.
error
=
str
(
e
)
for
document_segment
in
document_segments
:
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
# transform segment to node
db
.
session
.
commit
()
if
document_segment
.
status
!=
"completed"
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
document_segment
.
document_id
,
}
previous_segment
=
document_segment
.
previous_segment
if
previous_segment
:
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_segment
.
index_node_id
next_segment
=
document_segment
.
next_segment
if
next_segment
:
relationships
[
DocumentRelationship
.
NEXT
]
=
next_segment
.
index_node_id
node
=
Node
(
doc_id
=
document_segment
.
index_node_id
,
doc_hash
=
document_segment
.
index_node_hash
,
text
=
document_segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
)
nodes
.
append
(
node
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
document
=
document
,
nodes
=
nodes
)
def
file_indexing_estimate
(
self
,
file_details
:
List
[
UploadFile
],
tmp_processing_rule
:
dict
)
->
dict
:
def
file_indexing_estimate
(
self
,
file_details
:
List
[
UploadFile
],
tmp_processing_rule
:
dict
)
->
dict
:
"""
"""
...
@@ -179,28 +214,28 @@ class IndexingRunner:
...
@@ -179,28 +214,28 @@ class IndexingRunner:
total_segments
=
0
total_segments
=
0
for
file_detail
in
file_details
:
for
file_detail
in
file_details
:
# load data from file
# load data from file
text_docs
=
self
.
_load_data_from_file
(
file_detail
)
text_docs
=
FileExtractor
.
load
(
file_detail
)
processing_rule
=
DatasetProcessRule
(
processing_rule
=
DatasetProcessRule
(
mode
=
tmp_processing_rule
[
"mode"
],
mode
=
tmp_processing_rule
[
"mode"
],
rules
=
json
.
dumps
(
tmp_processing_rule
[
"rules"
])
rules
=
json
.
dumps
(
tmp_processing_rule
[
"rules"
])
)
)
# get
node parser for splitting
# get
splitter
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
# split to
document
s
nodes
=
self
.
_split_to_node
s
(
documents
=
self
.
_split_to_document
s
(
text_docs
=
text_docs
,
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
processing_rule
=
processing_rule
processing_rule
=
processing_rule
)
)
total_segments
+=
len
(
node
s
)
total_segments
+=
len
(
document
s
)
for
node
in
node
s
:
for
document
in
document
s
:
if
len
(
preview_texts
)
<
5
:
if
len
(
preview_texts
)
<
5
:
preview_texts
.
append
(
node
.
get_text
()
)
preview_texts
.
append
(
document
.
page_content
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
node
.
get_text
()
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
document
.
page_content
)
return
{
return
{
"total_segments"
:
total_segments
,
"total_segments"
:
total_segments
,
...
@@ -230,35 +265,36 @@ class IndexingRunner:
...
@@ -230,35 +265,36 @@ class IndexingRunner:
)
.
first
()
)
.
first
()
if
not
data_source_binding
:
if
not
data_source_binding
:
raise
ValueError
(
'Data source binding not found.'
)
raise
ValueError
(
'Data source binding not found.'
)
reader
=
NotionPageReader
(
integration_token
=
data_source_binding
.
access_token
)
for
page
in
notion_info
[
'pages'
]:
for
page
in
notion_info
[
'pages'
]:
if
page
[
'type'
]
==
'page'
:
loader
=
NotionLoader
(
page_ids
=
[
page
[
'page_id'
]]
notion_access_token
=
data_source_binding
.
access_token
,
documents
=
reader
.
load_data_as_documents
(
page_ids
=
page_ids
)
notion_workspace_id
=
workspace_id
,
elif
page
[
'type'
]
==
'database'
:
notion_obj_id
=
page
[
'page_id'
],
documents
=
reader
.
load_data_as_documents
(
database_id
=
page
[
'page_id'
])
notion_page_type
=
page
[
'type'
]
else
:
)
documents
=
[]
documents
=
loader
.
load
()
processing_rule
=
DatasetProcessRule
(
processing_rule
=
DatasetProcessRule
(
mode
=
tmp_processing_rule
[
"mode"
],
mode
=
tmp_processing_rule
[
"mode"
],
rules
=
json
.
dumps
(
tmp_processing_rule
[
"rules"
])
rules
=
json
.
dumps
(
tmp_processing_rule
[
"rules"
])
)
)
# get
node parser for splitting
# get
splitter
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
# split to
document
s
nodes
=
self
.
_split_to_node
s
(
documents
=
self
.
_split_to_document
s
(
text_docs
=
documents
,
text_docs
=
documents
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
processing_rule
=
processing_rule
processing_rule
=
processing_rule
)
)
total_segments
+=
len
(
node
s
)
total_segments
+=
len
(
document
s
)
for
node
in
node
s
:
for
document
in
document
s
:
if
len
(
preview_texts
)
<
5
:
if
len
(
preview_texts
)
<
5
:
preview_texts
.
append
(
node
.
get_text
()
)
preview_texts
.
append
(
document
.
page_content
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
node
.
get_text
()
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
document
.
page_content
)
return
{
return
{
"total_segments"
:
total_segments
,
"total_segments"
:
total_segments
,
...
@@ -268,14 +304,14 @@ class IndexingRunner:
...
@@ -268,14 +304,14 @@ class IndexingRunner:
"preview"
:
preview_texts
"preview"
:
preview_texts
}
}
def
_load_data
(
self
,
d
ocument
:
Document
)
->
List
[
Document
]:
def
_load_data
(
self
,
d
ataset_document
:
Dataset
Document
)
->
List
[
Document
]:
# load file
# load file
if
document
.
data_source_type
not
in
[
"upload_file"
,
"notion_import"
]:
if
d
ataset_d
ocument
.
data_source_type
not
in
[
"upload_file"
,
"notion_import"
]:
return
[]
return
[]
data_source_info
=
document
.
data_source_info_dict
data_source_info
=
d
ataset_d
ocument
.
data_source_info_dict
text_docs
=
[]
text_docs
=
[]
if
document
.
data_source_type
==
'upload_file'
:
if
d
ataset_d
ocument
.
data_source_type
==
'upload_file'
:
if
not
data_source_info
or
'upload_file_id'
not
in
data_source_info
:
if
not
data_source_info
or
'upload_file_id'
not
in
data_source_info
:
raise
ValueError
(
"no upload file found"
)
raise
ValueError
(
"no upload file found"
)
...
@@ -283,47 +319,28 @@ class IndexingRunner:
...
@@ -283,47 +319,28 @@ class IndexingRunner:
filter
(
UploadFile
.
id
==
data_source_info
[
'upload_file_id'
])
.
\
filter
(
UploadFile
.
id
==
data_source_info
[
'upload_file_id'
])
.
\
one_or_none
()
one_or_none
()
text_docs
=
self
.
_load_data_from_file
(
file_detail
)
text_docs
=
FileExtractor
.
load
(
file_detail
)
elif
document
.
data_source_type
==
'notion_import'
:
elif
dataset_document
.
data_source_type
==
'notion_import'
:
if
not
data_source_info
or
'notion_page_id'
not
in
data_source_info
\
loader
=
NotionLoader
.
from_document
(
dataset_document
)
or
'notion_workspace_id'
not
in
data_source_info
:
text_docs
=
loader
.
load
()
raise
ValueError
(
"no notion page found"
)
workspace_id
=
data_source_info
[
'notion_workspace_id'
]
page_id
=
data_source_info
[
'notion_page_id'
]
page_type
=
data_source_info
[
'type'
]
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
db
.
and_
(
DataSourceBinding
.
tenant_id
==
document
.
tenant_id
,
DataSourceBinding
.
provider
==
'notion'
,
DataSourceBinding
.
disabled
==
False
,
DataSourceBinding
.
source_info
[
'workspace_id'
]
==
f
'"{workspace_id}"'
)
)
.
first
()
if
not
data_source_binding
:
raise
ValueError
(
'Data source binding not found.'
)
if
page_type
==
'page'
:
# add page last_edited_time to data_source_info
self
.
_get_notion_page_last_edited_time
(
page_id
,
data_source_binding
.
access_token
,
document
)
text_docs
=
self
.
_load_page_data_from_notion
(
page_id
,
data_source_binding
.
access_token
)
elif
page_type
==
'database'
:
# add page last_edited_time to data_source_info
self
.
_get_notion_database_last_edited_time
(
page_id
,
data_source_binding
.
access_token
,
document
)
text_docs
=
self
.
_load_database_data_from_notion
(
page_id
,
data_source_binding
.
access_token
)
# update document status to splitting
# update document status to splitting
self
.
_update_document_index_status
(
self
.
_update_document_index_status
(
document_id
=
document
.
id
,
document_id
=
d
ataset_d
ocument
.
id
,
after_indexing_status
=
"splitting"
,
after_indexing_status
=
"splitting"
,
extra_update_params
=
{
extra_update_params
=
{
D
ocument
.
word_count
:
sum
([
len
(
text_doc
.
tex
t
)
for
text_doc
in
text_docs
]),
D
atasetDocument
.
word_count
:
sum
([
len
(
text_doc
.
page_conten
t
)
for
text_doc
in
text_docs
]),
Document
.
parsing_completed_at
:
datetime
.
datetime
.
utcnow
()
D
atasetD
ocument
.
parsing_completed_at
:
datetime
.
datetime
.
utcnow
()
}
}
)
)
# replace doc id to document model id
# replace doc id to document model id
text_docs
=
cast
(
List
[
Document
],
text_docs
)
for
text_doc
in
text_docs
:
for
text_doc
in
text_docs
:
# remove invalid symbol
# remove invalid symbol
text_doc
.
text
=
self
.
filter_string
(
text_doc
.
get_text
())
text_doc
.
page_content
=
self
.
filter_string
(
text_doc
.
page_content
)
text_doc
.
doc_id
=
document
.
id
text_doc
.
metadata
[
'document_id'
]
=
dataset_document
.
id
text_doc
.
metadata
[
'dataset_id'
]
=
dataset_document
.
dataset_id
return
text_docs
return
text_docs
...
@@ -331,61 +348,7 @@ class IndexingRunner:
...
@@ -331,61 +348,7 @@ class IndexingRunner:
pattern
=
re
.
compile
(
'[
\x00
-
\x08\x0B\x0C\x0E
-
\x1F\x7F\x80
-
\xFF
]'
)
pattern
=
re
.
compile
(
'[
\x00
-
\x08\x0B\x0C\x0E
-
\x1F\x7F\x80
-
\xFF
]'
)
return
pattern
.
sub
(
''
,
text
)
return
pattern
.
sub
(
''
,
text
)
def
_load_data_from_file
(
self
,
upload_file
:
UploadFile
)
->
List
[
Document
]:
def
_get_splitter
(
self
,
processing_rule
:
DatasetProcessRule
)
->
TextSplitter
:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
upload_file
.
key
)
.
suffix
filepath
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
self
.
storage
.
download
(
upload_file
.
key
,
filepath
)
file_extractor
=
DEFAULT_FILE_EXTRACTOR
.
copy
()
file_extractor
[
".markdown"
]
=
MarkdownParser
()
file_extractor
[
".md"
]
=
MarkdownParser
()
file_extractor
[
".html"
]
=
HTMLParser
()
file_extractor
[
".htm"
]
=
HTMLParser
()
file_extractor
[
".pdf"
]
=
PDFParser
({
'upload_file'
:
upload_file
})
file_extractor
[
".xlsx"
]
=
XLSXParser
()
loader
=
SimpleDirectoryReader
(
input_files
=
[
filepath
],
file_extractor
=
file_extractor
)
text_docs
=
loader
.
load_data
()
return
text_docs
def
_load_page_data_from_notion
(
self
,
page_id
:
str
,
access_token
:
str
)
->
List
[
Document
]:
page_ids
=
[
page_id
]
reader
=
NotionPageReader
(
integration_token
=
access_token
)
text_docs
=
reader
.
load_data_as_documents
(
page_ids
=
page_ids
)
return
text_docs
def
_load_database_data_from_notion
(
self
,
database_id
:
str
,
access_token
:
str
)
->
List
[
Document
]:
reader
=
NotionPageReader
(
integration_token
=
access_token
)
text_docs
=
reader
.
load_data_as_documents
(
database_id
=
database_id
)
return
text_docs
def
_get_notion_page_last_edited_time
(
self
,
page_id
:
str
,
access_token
:
str
,
document
:
Document
):
reader
=
NotionPageReader
(
integration_token
=
access_token
)
last_edited_time
=
reader
.
get_page_last_edited_time
(
page_id
)
data_source_info
=
document
.
data_source_info_dict
data_source_info
[
'last_edited_time'
]
=
last_edited_time
update_params
=
{
Document
.
data_source_info
:
json
.
dumps
(
data_source_info
)
}
Document
.
query
.
filter_by
(
id
=
document
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
def
_get_notion_database_last_edited_time
(
self
,
page_id
:
str
,
access_token
:
str
,
document
:
Document
):
reader
=
NotionPageReader
(
integration_token
=
access_token
)
last_edited_time
=
reader
.
get_database_last_edited_time
(
page_id
)
data_source_info
=
document
.
data_source_info_dict
data_source_info
[
'last_edited_time'
]
=
last_edited_time
update_params
=
{
Document
.
data_source_info
:
json
.
dumps
(
data_source_info
)
}
Document
.
query
.
filter_by
(
id
=
document
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
def
_get_node_parser
(
self
,
processing_rule
:
DatasetProcessRule
)
->
NodeParser
:
"""
"""
Get the NodeParser object according to the processing rule.
Get the NodeParser object according to the processing rule.
"""
"""
...
@@ -414,68 +377,83 @@ class IndexingRunner:
...
@@ -414,68 +377,83 @@ class IndexingRunner:
separators
=
[
"
\n\n
"
,
"。"
,
"."
,
" "
,
""
]
separators
=
[
"
\n\n
"
,
"。"
,
"."
,
" "
,
""
]
)
)
return
SimpleNodeParser
(
text_splitter
=
character_splitter
,
include_extra_info
=
True
)
return
character_splitter
def
_step_split
(
self
,
text_docs
:
List
[
Document
],
node_parser
:
NodeParser
,
def
_step_split
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitter
,
dataset
:
Dataset
,
document
:
Document
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Node
]:
dataset
:
Dataset
,
dataset_document
:
DatasetDocument
,
processing_rule
:
DatasetProcessRule
)
\
->
List
[
Document
]:
"""
"""
Split the text documents into
node
s and save them to the document segment.
Split the text documents into
document
s and save them to the document segment.
"""
"""
nodes
=
self
.
_split_to_node
s
(
documents
=
self
.
_split_to_document
s
(
text_docs
=
text_docs
,
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
processing_rule
=
processing_rule
processing_rule
=
processing_rule
)
)
# save node to document segment
# save node to document segment
doc_store
=
DatesetDocumentStore
(
doc_store
=
DatesetDocumentStore
(
dataset
=
dataset
,
dataset
=
dataset
,
user_id
=
document
.
created_by
,
user_id
=
d
ataset_d
ocument
.
created_by
,
embedding_model_name
=
self
.
embedding_model_name
,
embedding_model_name
=
self
.
embedding_model_name
,
document_id
=
document
.
id
document_id
=
d
ataset_d
ocument
.
id
)
)
# add document segments
# add document segments
doc_store
.
add_documents
(
node
s
)
doc_store
.
add_documents
(
document
s
)
# update document status to indexing
# update document status to indexing
cur_time
=
datetime
.
datetime
.
utcnow
()
cur_time
=
datetime
.
datetime
.
utcnow
()
self
.
_update_document_index_status
(
self
.
_update_document_index_status
(
document_id
=
document
.
id
,
document_id
=
d
ataset_d
ocument
.
id
,
after_indexing_status
=
"indexing"
,
after_indexing_status
=
"indexing"
,
extra_update_params
=
{
extra_update_params
=
{
Document
.
cleaning_completed_at
:
cur_time
,
D
atasetD
ocument
.
cleaning_completed_at
:
cur_time
,
Document
.
splitting_completed_at
:
cur_time
,
D
atasetD
ocument
.
splitting_completed_at
:
cur_time
,
}
}
)
)
# update segment status to indexing
# update segment status to indexing
self
.
_update_segments_by_document
(
self
.
_update_segments_by_document
(
d
ocument_id
=
document
.
id
,
d
ataset_document_id
=
dataset_
document
.
id
,
update_params
=
{
update_params
=
{
DocumentSegment
.
status
:
"indexing"
,
DocumentSegment
.
status
:
"indexing"
,
DocumentSegment
.
indexing_at
:
datetime
.
datetime
.
utcnow
()
DocumentSegment
.
indexing_at
:
datetime
.
datetime
.
utcnow
()
}
}
)
)
return
node
s
return
document
s
def
_split_to_
nodes
(
self
,
text_docs
:
List
[
Document
],
node_parser
:
NodePars
er
,
def
_split_to_
documents
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitt
er
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Node
]:
processing_rule
:
DatasetProcessRule
)
->
List
[
Document
]:
"""
"""
Split the text documents into nodes.
Split the text documents into nodes.
"""
"""
all_
node
s
=
[]
all_
document
s
=
[]
for
text_doc
in
text_docs
:
for
text_doc
in
text_docs
:
# document clean
# document clean
document_text
=
self
.
_document_clean
(
text_doc
.
get_text
()
,
processing_rule
)
document_text
=
self
.
_document_clean
(
text_doc
.
page_content
,
processing_rule
)
text_doc
.
tex
t
=
document_text
text_doc
.
page_conten
t
=
document_text
# parse document to nodes
# parse document to nodes
nodes
=
node_parser
.
get_nodes_from_documents
([
text_doc
])
documents
=
splitter
.
split_documents
([
text_doc
])
nodes
=
[
node
for
node
in
nodes
if
node
.
text
is
not
None
and
node
.
text
.
strip
()]
all_nodes
.
extend
(
nodes
)
split_documents
=
[]
for
document
in
documents
:
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
continue
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
document
.
metadata
[
'doc_id'
]
=
doc_id
document
.
metadata
[
'doc_hash'
]
=
hash
split_documents
.
append
(
document
)
all_documents
.
extend
(
split_documents
)
return
all_
node
s
return
all_
document
s
def
_document_clean
(
self
,
text
:
str
,
processing_rule
:
DatasetProcessRule
)
->
str
:
def
_document_clean
(
self
,
text
:
str
,
processing_rule
:
DatasetProcessRule
)
->
str
:
"""
"""
...
@@ -506,37 +484,38 @@ class IndexingRunner:
...
@@ -506,37 +484,38 @@ class IndexingRunner:
return
text
return
text
def
_build_index
(
self
,
dataset
:
Dataset
,
d
ocument
:
Document
,
nodes
:
List
[
Node
])
->
None
:
def
_build_index
(
self
,
dataset
:
Dataset
,
d
ataset_document
:
DatasetDocument
,
documents
:
List
[
Document
])
->
None
:
"""
"""
Build the index for the document.
Build the index for the document.
"""
"""
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
keyword_table_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# chunk nodes by chunk size
# chunk nodes by chunk size
indexing_start_at
=
time
.
perf_counter
()
indexing_start_at
=
time
.
perf_counter
()
tokens
=
0
tokens
=
0
chunk_size
=
100
chunk_size
=
100
for
i
in
range
(
0
,
len
(
node
s
),
chunk_size
):
for
i
in
range
(
0
,
len
(
document
s
),
chunk_size
):
# check document is paused
# check document is paused
self
.
_check_document_paused_status
(
document
.
id
)
self
.
_check_document_paused_status
(
d
ataset_d
ocument
.
id
)
chunk_
nodes
=
node
s
[
i
:
i
+
chunk_size
]
chunk_
documents
=
document
s
[
i
:
i
+
chunk_size
]
tokens
+=
sum
(
tokens
+=
sum
(
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
node
.
get_text
())
for
node
in
chunk_nodes
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
document
.
page_content
)
for
document
in
chunk_documents
)
)
# save vector index
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
if
vector_index
:
vector_index
.
add_
nodes
(
chunk_node
s
)
vector_index
.
add_
texts
(
chunk_document
s
)
# save keyword index
# save keyword index
keyword_table_index
.
add_
nodes
(
chunk_node
s
)
keyword_table_index
.
add_
texts
(
chunk_document
s
)
node_ids
=
[
node
.
doc_id
for
node
in
chunk_node
s
]
document_ids
=
[
document
.
metadata
[
'doc_id'
]
for
document
in
chunk_document
s
]
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
document_id
==
d
ataset_d
ocument
.
id
,
DocumentSegment
.
index_node_id
.
in_
(
node
_ids
),
DocumentSegment
.
index_node_id
.
in_
(
document
_ids
),
DocumentSegment
.
status
==
"indexing"
DocumentSegment
.
status
==
"indexing"
)
.
update
({
)
.
update
({
DocumentSegment
.
status
:
"completed"
,
DocumentSegment
.
status
:
"completed"
,
...
@@ -549,12 +528,12 @@ class IndexingRunner:
...
@@ -549,12 +528,12 @@ class IndexingRunner:
# update document status to completed
# update document status to completed
self
.
_update_document_index_status
(
self
.
_update_document_index_status
(
document_id
=
document
.
id
,
document_id
=
d
ataset_d
ocument
.
id
,
after_indexing_status
=
"completed"
,
after_indexing_status
=
"completed"
,
extra_update_params
=
{
extra_update_params
=
{
Document
.
tokens
:
tokens
,
D
atasetD
ocument
.
tokens
:
tokens
,
Document
.
completed_at
:
datetime
.
datetime
.
utcnow
(),
D
atasetD
ocument
.
completed_at
:
datetime
.
datetime
.
utcnow
(),
Document
.
indexing_latency
:
indexing_end_at
-
indexing_start_at
,
D
atasetD
ocument
.
indexing_latency
:
indexing_end_at
-
indexing_start_at
,
}
}
)
)
...
@@ -569,25 +548,25 @@ class IndexingRunner:
...
@@ -569,25 +548,25 @@ class IndexingRunner:
"""
"""
Update the document indexing status.
Update the document indexing status.
"""
"""
count
=
Document
.
query
.
filter_by
(
id
=
document_id
,
is_paused
=
True
)
.
count
()
count
=
D
atasetD
ocument
.
query
.
filter_by
(
id
=
document_id
,
is_paused
=
True
)
.
count
()
if
count
>
0
:
if
count
>
0
:
raise
DocumentIsPausedException
()
raise
DocumentIsPausedException
()
update_params
=
{
update_params
=
{
Document
.
indexing_status
:
after_indexing_status
D
atasetD
ocument
.
indexing_status
:
after_indexing_status
}
}
if
extra_update_params
:
if
extra_update_params
:
update_params
.
update
(
extra_update_params
)
update_params
.
update
(
extra_update_params
)
Document
.
query
.
filter_by
(
id
=
document_id
)
.
update
(
update_params
)
D
atasetD
ocument
.
query
.
filter_by
(
id
=
document_id
)
.
update
(
update_params
)
db
.
session
.
commit
()
db
.
session
.
commit
()
def
_update_segments_by_document
(
self
,
document_id
:
str
,
update_params
:
dict
)
->
None
:
def
_update_segments_by_document
(
self
,
d
ataset_d
ocument_id
:
str
,
update_params
:
dict
)
->
None
:
"""
"""
Update the document segment by document id.
Update the document segment by document id.
"""
"""
DocumentSegment
.
query
.
filter_by
(
document_id
=
document_id
)
.
update
(
update_params
)
DocumentSegment
.
query
.
filter_by
(
document_id
=
d
ataset_d
ocument_id
)
.
update
(
update_params
)
db
.
session
.
commit
()
db
.
session
.
commit
()
...
...
api/core/llm/llm_builder.py
View file @
eea011bd
from
typing
import
Union
,
Optional
from
typing
import
Union
,
Optional
,
List
from
langchain.callbacks
import
CallbackManager
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.llms.fake
import
FakeListLLM
from
core.constant
import
llm_constant
from
core.constant
import
llm_constant
from
core.llm.error
import
ProviderTokenNotInitError
from
core.llm.error
import
ProviderTokenNotInitError
...
@@ -32,12 +31,11 @@ class LLMBuilder:
...
@@ -32,12 +31,11 @@ class LLMBuilder:
"""
"""
@
classmethod
@
classmethod
def
to_llm
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
**
kwargs
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
,
FakeListLLM
]:
def
to_llm
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
**
kwargs
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
if
model_name
==
'fake'
:
return
FakeListLLM
(
responses
=
[])
provider
=
cls
.
get_default_provider
(
tenant_id
)
provider
=
cls
.
get_default_provider
(
tenant_id
)
model_credentials
=
cls
.
get_model_credentials
(
tenant_id
,
provider
,
model_name
)
mode
=
cls
.
get_mode_by_model
(
model_name
)
mode
=
cls
.
get_mode_by_model
(
model_name
)
if
mode
==
'chat'
:
if
mode
==
'chat'
:
if
provider
==
'openai'
:
if
provider
==
'openai'
:
...
@@ -52,16 +50,21 @@ class LLMBuilder:
...
@@ -52,16 +50,21 @@ class LLMBuilder:
else
:
else
:
raise
ValueError
(
f
"model name {model_name} is not supported."
)
raise
ValueError
(
f
"model name {model_name} is not supported."
)
model_credentials
=
cls
.
get_model_credentials
(
tenant_id
,
provider
,
model_name
)
model_kwargs
=
{
'top_p'
:
kwargs
.
get
(
'top_p'
,
1
),
'frequency_penalty'
:
kwargs
.
get
(
'frequency_penalty'
,
0
),
'presence_penalty'
:
kwargs
.
get
(
'presence_penalty'
,
0
),
}
model_extras_kwargs
=
model_kwargs
if
mode
==
'completion'
else
{
'model_kwargs'
:
model_kwargs
}
return
llm_cls
(
return
llm_cls
(
model_name
=
model_name
,
model_name
=
model_name
,
temperature
=
kwargs
.
get
(
'temperature'
,
0
),
temperature
=
kwargs
.
get
(
'temperature'
,
0
),
max_tokens
=
kwargs
.
get
(
'max_tokens'
,
256
),
max_tokens
=
kwargs
.
get
(
'max_tokens'
,
256
),
top_p
=
kwargs
.
get
(
'top_p'
,
1
),
**
model_extras_kwargs
,
frequency_penalty
=
kwargs
.
get
(
'frequency_penalty'
,
0
),
callbacks
=
kwargs
.
get
(
'callbacks'
,
None
),
presence_penalty
=
kwargs
.
get
(
'presence_penalty'
,
0
),
callback_manager
=
kwargs
.
get
(
'callback_manager'
,
None
),
streaming
=
kwargs
.
get
(
'streaming'
,
False
),
streaming
=
kwargs
.
get
(
'streaming'
,
False
),
# request_timeout=None
# request_timeout=None
**
model_credentials
**
model_credentials
...
@@ -69,7 +72,7 @@ class LLMBuilder:
...
@@ -69,7 +72,7 @@ class LLMBuilder:
@
classmethod
@
classmethod
def
to_llm_from_model
(
cls
,
tenant_id
:
str
,
model
:
dict
,
streaming
:
bool
=
False
,
def
to_llm_from_model
(
cls
,
tenant_id
:
str
,
model
:
dict
,
streaming
:
bool
=
False
,
callback
_manager
:
Optional
[
CallbackManager
]
=
None
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
callback
s
:
Optional
[
List
[
BaseCallbackHandler
]
]
=
None
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
model_name
=
model
.
get
(
"name"
)
model_name
=
model
.
get
(
"name"
)
completion_params
=
model
.
get
(
"completion_params"
,
{})
completion_params
=
model
.
get
(
"completion_params"
,
{})
...
@@ -82,7 +85,7 @@ class LLMBuilder:
...
@@ -82,7 +85,7 @@ class LLMBuilder:
frequency_penalty
=
completion_params
.
get
(
'frequency_penalty'
,
0.1
),
frequency_penalty
=
completion_params
.
get
(
'frequency_penalty'
,
0.1
),
presence_penalty
=
completion_params
.
get
(
'presence_penalty'
,
0.1
),
presence_penalty
=
completion_params
.
get
(
'presence_penalty'
,
0.1
),
streaming
=
streaming
,
streaming
=
streaming
,
callback
_manager
=
callback_manager
callback
s
=
callbacks
)
)
@
classmethod
@
classmethod
...
...
api/core/llm/provider/azure_provider.py
View file @
eea011bd
...
@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
...
@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
"""
"""
config
=
self
.
get_provider_api_key
(
model_id
=
model_id
)
config
=
self
.
get_provider_api_key
(
model_id
=
model_id
)
config
[
'openai_api_type'
]
=
'azure'
config
[
'openai_api_type'
]
=
'azure'
config
[
'deployment_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
if
model_id
==
'text-embedding-ada-002'
:
config
[
'deployment'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
else
:
config
[
'deployment_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
return
config
return
config
def
get_provider_name
(
self
):
def
get_provider_name
(
self
):
...
...
api/core/llm/streamable_azure_chat_open_ai.py
View file @
eea011bd
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
AsyncCallbackManagerForLLMRun
,
Callbacks
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.chat_models
import
AzureChatOpenAI
from
langchain.chat_models
import
AzureChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
from
typing
import
Optional
,
List
,
Dict
,
Any
...
@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
...
@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return
message_tokens
return
message_tokens
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
ChatResult
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
chat_result
=
super
()
.
_generate
(
messages
,
stop
)
result
=
LLMResult
(
generations
=
[
chat_result
.
generations
],
llm_output
=
chat_result
.
llm_output
)
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
return
chat_result
async
def
_agenerate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
ChatResult
:
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
else
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
chat_result
=
super
()
.
_generate
(
messages
,
stop
)
result
=
LLMResult
(
generations
=
[
chat_result
.
generations
],
llm_output
=
chat_result
.
llm_output
)
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
else
:
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
return
chat_result
@
handle_llm_exceptions
@
handle_llm_exceptions
def
generate
(
def
generate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
)
->
LLMResult
:
return
super
()
.
generate
(
messages
,
stop
)
return
super
()
.
generate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
@
handle_llm_exceptions_async
async
def
agenerate
(
async
def
agenerate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
messages
,
stop
)
return
await
super
()
.
agenerate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
api/core/llm/streamable_azure_open_ai.py
View file @
eea011bd
import
os
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
AzureOpenAI
from
langchain.llms
import
AzureOpenAI
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
from
typing
import
Optional
,
List
,
Dict
,
Mapping
,
Any
from
typing
import
Optional
,
List
,
Dict
,
Mapping
,
Any
...
@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
...
@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
@
handle_llm_exceptions
@
handle_llm_exceptions
def
generate
(
def
generate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
)
->
LLMResult
:
return
super
()
.
generate
(
prompts
,
stop
)
return
super
()
.
generate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
@
handle_llm_exceptions_async
async
def
agenerate
(
async
def
agenerate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
prompts
,
stop
)
return
await
super
()
.
agenerate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
api/core/llm/streamable_chat_open_ai.py
View file @
eea011bd
import
os
import
os
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
BaseMessage
,
LLMResult
from
langchain.chat_models
import
ChatOpenAI
from
langchain.chat_models
import
ChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
from
typing
import
Optional
,
List
,
Dict
,
Any
...
@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
...
@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
return
message_tokens
return
message_tokens
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
ChatResult
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
chat_result
=
super
()
.
_generate
(
messages
,
stop
)
result
=
LLMResult
(
generations
=
[
chat_result
.
generations
],
llm_output
=
chat_result
.
llm_output
)
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
return
chat_result
async
def
_agenerate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
ChatResult
:
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
else
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
chat_result
=
super
()
.
_generate
(
messages
,
stop
)
result
=
LLMResult
(
generations
=
[
chat_result
.
generations
],
llm_output
=
chat_result
.
llm_output
)
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
else
:
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
return
chat_result
@
handle_llm_exceptions
@
handle_llm_exceptions
def
generate
(
def
generate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
)
->
LLMResult
:
return
super
()
.
generate
(
messages
,
stop
)
return
super
()
.
generate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
@
handle_llm_exceptions_async
async
def
agenerate
(
async
def
agenerate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
messages
,
stop
)
return
await
super
()
.
agenerate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
api/core/llm/streamable_open_ai.py
View file @
eea011bd
import
os
import
os
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
from
typing
import
Optional
,
List
,
Dict
,
Any
,
Mapping
from
typing
import
Optional
,
List
,
Dict
,
Any
,
Mapping
from
langchain
import
OpenAI
from
langchain
import
OpenAI
...
@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
...
@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}}
}}
@
handle_llm_exceptions
@
handle_llm_exceptions
def
generate
(
def
generate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
)
->
LLMResult
:
return
super
()
.
generate
(
prompts
,
stop
)
return
super
()
.
generate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
@
handle_llm_exceptions_async
async
def
agenerate
(
async
def
agenerate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
prompts
,
stop
)
return
await
super
()
.
agenerate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py
View file @
eea011bd
from
typing
import
Any
,
List
,
Dict
from
typing
import
Any
,
List
,
Dict
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.schema
import
get_buffer_string
,
BaseMessage
,
BaseLanguageModel
from
langchain.schema
import
get_buffer_string
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
ReadOnlyConversationTokenDBBufferSharedMemory
ReadOnlyConversationTokenDBBufferSharedMemory
...
...
api/core/prompt/prompts.py
View file @
eea011bd
from
llama_index
import
QueryKeywordExtractPrompt
CONVERSATION_TITLE_PROMPT
=
(
CONVERSATION_TITLE_PROMPT
=
(
"Human:{query}
\n
-----
\n
"
"Human:{query}
\n
-----
\n
"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.
\n
"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.
\n
"
...
@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
...
@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[
\"
question1
\"
,
\"
question2
\"
,
\"
question3
\"
]
\n
"
"[
\"
question1
\"
,
\"
question2
\"
,
\"
question3
\"
]
\n
"
)
)
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
=
(
"A question is provided below. Given the question, extract up to {max_keywords} "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords."
"I am not sure which language the following question is in. "
"If the user asked the question in Chinese, please return the keywords in Chinese. "
"If the user asked the question in English, please return the keywords in English.
\n
"
"---------------------
\n
"
"{question}
\n
"
"---------------------
\n
"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'
\n
"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE
=
QueryKeywordExtractPrompt
(
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
)
RULE_CONFIG_GENERATE_TEMPLATE
=
"""Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select
\
RULE_CONFIG_GENERATE_TEMPLATE
=
"""Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select
\
the model prompt that best suits the input.
the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement.
You will be provided with the prompt, variables, and an opening statement.
...
...
api/core/
index/
spiltter/fixed_text_splitter.py
→
api/core/spiltter/fixed_text_splitter.py
View file @
eea011bd
File moved
api/core/tool/dataset_index_tool.py
0 → 100644
View file @
eea011bd
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.tools
import
BaseTool
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
class
DatasetTool
(
BaseTool
):
"""Tool for querying a Dataset."""
dataset
:
Dataset
k
:
int
=
2
def
_run
(
self
,
tool_input
:
str
)
->
str
:
if
self
.
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
kw_table_index
=
KeywordTableIndex
(
dataset
=
self
.
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
5
)
)
documents
=
kw_table_index
.
search
(
tool_input
,
search_kwargs
=
{
'k'
:
self
.
k
})
else
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
vector_index
.
search
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
self
.
k
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
await
vector_index
.
asearch
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
10
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
api/core/tool/dataset_tool_builder.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
Optional
from
langchain.callbacks
import
CallbackManager
from
llama_index.langchain_helpers.agents
import
IndexToolConfig
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.prompt.prompts
import
QUERY_KEYWORD_EXTRACT_TEMPLATE
from
core.tool.llama_index_tool
import
EnhanceLlamaIndexTool
from
models.dataset
import
Dataset
class
DatasetToolBuilder
:
@
classmethod
def
build_dataset_tool
(
cls
,
dataset
:
Dataset
,
response_mode
:
str
=
"no_synthesizer"
,
callback_handler
:
Optional
[
DatasetToolCallbackHandler
]
=
None
):
if
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
index
=
KeywordTableIndex
(
dataset
=
dataset
)
.
query_index
if
not
index
:
return
None
query_kwargs
=
{
"mode"
:
"default"
,
"response_mode"
:
response_mode
,
"query_keyword_extract_template"
:
QUERY_KEYWORD_EXTRACT_TEMPLATE
,
"max_keywords_per_query"
:
5
,
# If num_chunks_per_query is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"num_chunks_per_query"
:
2
}
else
:
index
=
VectorIndex
(
dataset
=
dataset
)
.
query_index
if
not
index
:
return
None
query_kwargs
=
{
"mode"
:
"default"
,
"response_mode"
:
response_mode
,
# If top_k is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"similarity_top_k"
:
2
}
# fulfill description when it is empty
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
index_tool_config
=
IndexToolConfig
(
index
=
index
,
name
=
f
"dataset-{dataset.id}"
,
description
=
description
,
index_query_kwargs
=
query_kwargs
,
tool_kwargs
=
{
"callback_manager"
:
CallbackManager
([
callback_handler
,
DifyStdOutCallbackHandler
()])
},
# tool_kwargs={"return_direct": True},
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
)
index_callback_handler
=
DatasetIndexToolCallbackHandler
(
dataset_id
=
dataset
.
id
)
return
EnhanceLlamaIndexTool
.
from_tool_config
(
tool_config
=
index_tool_config
,
callback_handler
=
index_callback_handler
)
api/core/tool/llama_index_tool.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
Dict
from
langchain.tools
import
BaseTool
from
llama_index.indices.base
import
BaseGPTIndex
from
llama_index.langchain_helpers.agents
import
IndexToolConfig
from
pydantic
import
Field
from
core.callback_handler.index_tool_callback_handler
import
IndexToolCallbackHandler
class
EnhanceLlamaIndexTool
(
BaseTool
):
"""Tool for querying a LlamaIndex."""
# NOTE: name/description still needs to be set
index
:
BaseGPTIndex
query_kwargs
:
Dict
=
Field
(
default_factory
=
dict
)
return_sources
:
bool
=
False
callback_handler
:
IndexToolCallbackHandler
@
classmethod
def
from_tool_config
(
cls
,
tool_config
:
IndexToolConfig
,
callback_handler
:
IndexToolCallbackHandler
)
->
"EnhanceLlamaIndexTool"
:
"""Create a tool from a tool config."""
return_sources
=
tool_config
.
tool_kwargs
.
pop
(
"return_sources"
,
False
)
return
cls
(
index
=
tool_config
.
index
,
callback_handler
=
callback_handler
,
name
=
tool_config
.
name
,
description
=
tool_config
.
description
,
return_sources
=
return_sources
,
query_kwargs
=
tool_config
.
index_query_kwargs
,
**
tool_config
.
tool_kwargs
,
)
def
_run
(
self
,
tool_input
:
str
)
->
str
:
response
=
self
.
index
.
query
(
tool_input
,
**
self
.
query_kwargs
)
self
.
callback_handler
.
on_tool_end
(
response
)
return
str
(
response
)
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
response
=
await
self
.
index
.
aquery
(
tool_input
,
**
self
.
query_kwargs
)
self
.
callback_handler
.
on_tool_end
(
response
)
return
str
(
response
)
api/core/vector_store/base.py
deleted
100644 → 0
View file @
3eb8e66b
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
llama_index
import
ServiceContext
,
GPTVectorStoreIndex
from
llama_index.data_structs
import
Node
from
llama_index.vector_stores.types
import
VectorStore
class
BaseVectorStoreClient
(
ABC
):
@
abstractmethod
def
get_index
(
self
,
service_context
:
ServiceContext
,
config
:
dict
)
->
GPTVectorStoreIndex
:
raise
NotImplementedError
@
abstractmethod
def
to_index_config
(
self
,
index_id
:
str
)
->
dict
:
raise
NotImplementedError
class
BaseGPTVectorStoreIndex
(
GPTVectorStoreIndex
):
def
delete_node
(
self
,
node_id
:
str
):
self
.
_vector_store
.
delete_node
(
node_id
)
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
return
self
.
_vector_store
.
exists_by_node_id
(
node_id
)
class
EnhanceVectorStore
(
ABC
):
@
abstractmethod
def
delete_node
(
self
,
node_id
:
str
):
pass
@
abstractmethod
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
pass
api/core/vector_store/qdrant_vector_store.py
0 → 100644
View file @
eea011bd
from
typing
import
cast
,
Any
from
langchain.schema
import
Document
from
langchain.vectorstores
import
Qdrant
from
qdrant_client.http.models
import
Filter
,
PointIdsList
,
FilterSelector
from
qdrant_client.local.qdrant_local
import
QdrantLocal
class
QdrantVectorStore
(
Qdrant
):
def
del_texts
(
self
,
filter
:
Filter
):
if
not
filter
:
raise
ValueError
(
'filter must not be empty'
)
self
.
_reload_if_needed
()
self
.
client
.
delete
(
collection_name
=
self
.
collection_name
,
points_selector
=
FilterSelector
(
filter
=
filter
),
)
def
del_text
(
self
,
uuid
:
str
)
->
None
:
self
.
_reload_if_needed
()
self
.
client
.
delete
(
collection_name
=
self
.
collection_name
,
points_selector
=
PointIdsList
(
points
=
[
uuid
],
),
)
def
text_exists
(
self
,
uuid
:
str
)
->
bool
:
self
.
_reload_if_needed
()
response
=
self
.
client
.
retrieve
(
collection_name
=
self
.
collection_name
,
ids
=
[
uuid
]
)
return
len
(
response
)
>
0
def
delete
(
self
):
self
.
_reload_if_needed
()
self
.
client
.
delete_collection
(
collection_name
=
self
.
collection_name
)
@
classmethod
def
_document_from_scored_point
(
cls
,
scored_point
:
Any
,
content_payload_key
:
str
,
metadata_payload_key
:
str
,
)
->
Document
:
if
scored_point
.
payload
.
get
(
'doc_id'
):
return
Document
(
page_content
=
scored_point
.
payload
.
get
(
content_payload_key
),
metadata
=
{
'doc_id'
:
scored_point
.
id
}
)
return
Document
(
page_content
=
scored_point
.
payload
.
get
(
content_payload_key
),
metadata
=
scored_point
.
payload
.
get
(
metadata_payload_key
)
or
{},
)
def
_reload_if_needed
(
self
):
if
isinstance
(
self
.
client
,
QdrantLocal
):
self
.
client
=
cast
(
QdrantLocal
,
self
.
client
)
self
.
client
.
_load
()
api/core/vector_store/qdrant_vector_store_client.py
deleted
100644 → 0
View file @
3eb8e66b
import
os
from
typing
import
cast
,
List
from
llama_index.data_structs
import
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
llama_index.vector_stores.types
import
VectorStoreQuery
,
VectorStoreQueryResult
from
qdrant_client.http.models
import
Payload
,
Filter
import
qdrant_client
from
llama_index
import
ServiceContext
,
GPTVectorStoreIndex
,
GPTQdrantIndex
from
llama_index.data_structs.data_structs_v2
import
QdrantIndexDict
from
llama_index.vector_stores
import
QdrantVectorStore
from
qdrant_client.local.qdrant_local
import
QdrantLocal
from
core.vector_store.base
import
BaseVectorStoreClient
,
BaseGPTVectorStoreIndex
,
EnhanceVectorStore
class
QdrantVectorStoreClient
(
BaseVectorStoreClient
):
def
__init__
(
self
,
url
:
str
,
api_key
:
str
,
root_path
:
str
):
self
.
_client
=
self
.
init_from_config
(
url
,
api_key
,
root_path
)
@
classmethod
def
init_from_config
(
cls
,
url
:
str
,
api_key
:
str
,
root_path
:
str
):
if
url
and
url
.
startswith
(
'path:'
):
path
=
url
.
replace
(
'path:'
,
''
)
if
not
os
.
path
.
isabs
(
path
):
path
=
os
.
path
.
join
(
root_path
,
path
)
return
qdrant_client
.
QdrantClient
(
path
=
path
)
else
:
return
qdrant_client
.
QdrantClient
(
url
=
url
,
api_key
=
api_key
,
)
def
get_index
(
self
,
service_context
:
ServiceContext
,
config
:
dict
)
->
GPTVectorStoreIndex
:
index_struct
=
QdrantIndexDict
()
if
self
.
_client
is
None
:
raise
Exception
(
"Vector client is not initialized."
)
# {"collection_name": "Gpt_index_xxx"}
collection_name
=
config
.
get
(
'collection_name'
)
if
not
collection_name
:
raise
Exception
(
"collection_name cannot be None."
)
return
GPTQdrantEnhanceIndex
(
service_context
=
service_context
,
index_struct
=
index_struct
,
vector_store
=
QdrantEnhanceVectorStore
(
client
=
self
.
_client
,
collection_name
=
collection_name
)
)
def
to_index_config
(
self
,
index_id
:
str
)
->
dict
:
return
{
"collection_name"
:
index_id
}
class
GPTQdrantEnhanceIndex
(
GPTQdrantIndex
,
BaseGPTVectorStoreIndex
):
pass
class
QdrantEnhanceVectorStore
(
QdrantVectorStore
,
EnhanceVectorStore
):
def
delete_node
(
self
,
node_id
:
str
):
"""
Delete node from the index.
:param node_id: node id
"""
from
qdrant_client.http
import
models
as
rest
self
.
_reload_if_needed
()
self
.
_client
.
delete
(
collection_name
=
self
.
_collection_name
,
points_selector
=
rest
.
Filter
(
must
=
[
rest
.
FieldCondition
(
key
=
"id"
,
match
=
rest
.
MatchValue
(
value
=
node_id
)
)
]
),
)
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
"""
Get node from the index by node id.
:param node_id: node id
"""
self
.
_reload_if_needed
()
response
=
self
.
_client
.
retrieve
(
collection_name
=
self
.
_collection_name
,
ids
=
[
node_id
]
)
return
len
(
response
)
>
0
def
query
(
self
,
query
:
VectorStoreQuery
,
)
->
VectorStoreQueryResult
:
"""Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): query
"""
query_embedding
=
cast
(
List
[
float
],
query
.
query_embedding
)
self
.
_reload_if_needed
()
response
=
self
.
_client
.
search
(
collection_name
=
self
.
_collection_name
,
query_vector
=
query_embedding
,
limit
=
cast
(
int
,
query
.
similarity_top_k
),
query_filter
=
cast
(
Filter
,
self
.
_build_query_filter
(
query
)),
with_vectors
=
True
)
nodes
=
[]
similarities
=
[]
ids
=
[]
for
point
in
response
:
payload
=
cast
(
Payload
,
point
.
payload
)
node
=
Node
(
doc_id
=
str
(
point
.
id
),
text
=
payload
.
get
(
"text"
),
embedding
=
point
.
vector
,
extra_info
=
payload
.
get
(
"extra_info"
),
relationships
=
{
DocumentRelationship
.
SOURCE
:
payload
.
get
(
"doc_id"
,
"None"
),
},
)
nodes
.
append
(
node
)
similarities
.
append
(
point
.
score
)
ids
.
append
(
str
(
point
.
id
))
return
VectorStoreQueryResult
(
nodes
=
nodes
,
similarities
=
similarities
,
ids
=
ids
)
def
_reload_if_needed
(
self
):
if
isinstance
(
self
.
_client
.
_client
,
QdrantLocal
):
self
.
_client
.
_client
.
_load
()
api/core/vector_store/vector_store.py
deleted
100644 → 0
View file @
3eb8e66b
from
flask
import
Flask
from
llama_index
import
ServiceContext
,
GPTVectorStoreIndex
from
requests
import
ReadTimeout
from
tenacity
import
retry
,
retry_if_exception_type
,
stop_after_attempt
from
core.vector_store.qdrant_vector_store_client
import
QdrantVectorStoreClient
from
core.vector_store.weaviate_vector_store_client
import
WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES
=
[
'weaviate'
,
'qdrant'
]
class
VectorStore
:
def
__init__
(
self
):
self
.
_vector_store
=
None
self
.
_client
=
None
def
init_app
(
self
,
app
:
Flask
):
if
not
app
.
config
[
'VECTOR_STORE'
]:
return
self
.
_vector_store
=
app
.
config
[
'VECTOR_STORE'
]
if
self
.
_vector_store
not
in
SUPPORTED_VECTOR_STORES
:
raise
ValueError
(
f
"Vector store {self._vector_store} is not supported."
)
if
self
.
_vector_store
==
'weaviate'
:
self
.
_client
=
WeaviateVectorStoreClient
(
endpoint
=
app
.
config
[
'WEAVIATE_ENDPOINT'
],
api_key
=
app
.
config
[
'WEAVIATE_API_KEY'
],
grpc_enabled
=
app
.
config
[
'WEAVIATE_GRPC_ENABLED'
],
batch_size
=
app
.
config
[
'WEAVIATE_BATCH_SIZE'
]
)
elif
self
.
_vector_store
==
'qdrant'
:
self
.
_client
=
QdrantVectorStoreClient
(
url
=
app
.
config
[
'QDRANT_URL'
],
api_key
=
app
.
config
[
'QDRANT_API_KEY'
],
root_path
=
app
.
root_path
)
app
.
extensions
[
'vector_store'
]
=
self
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
get_index
(
self
,
service_context
:
ServiceContext
,
index_struct
:
dict
)
->
GPTVectorStoreIndex
:
vector_store_config
:
dict
=
index_struct
.
get
(
'vector_store'
)
index
=
self
.
get_client
()
.
get_index
(
service_context
=
service_context
,
config
=
vector_store_config
)
return
index
def
to_index_struct
(
self
,
index_id
:
str
)
->
dict
:
return
{
"type"
:
self
.
_vector_store
,
"vector_store"
:
self
.
get_client
()
.
to_index_config
(
index_id
)
}
def
get_client
(
self
):
if
not
self
.
_client
:
raise
Exception
(
"Vector store client is not initialized."
)
return
self
.
_client
api/core/vector_store/vector_store_index_query.py
deleted
100644 → 0
View file @
3eb8e66b
from
llama_index.indices.query.base
import
IS
from
typing
import
(
Any
,
Dict
,
List
,
Optional
)
from
llama_index.docstore
import
BaseDocumentStore
from
llama_index.indices.postprocessor.node
import
(
BaseNodePostprocessor
,
)
from
llama_index.indices.vector_store
import
GPTVectorStoreIndexQuery
from
llama_index.indices.response.response_builder
import
ResponseMode
from
llama_index.indices.service_context
import
ServiceContext
from
llama_index.optimization.optimizer
import
BaseTokenUsageOptimizer
from
llama_index.prompts.prompts
import
(
QuestionAnswerPrompt
,
RefinePrompt
,
SimpleInputPrompt
,
)
from
core.index.query.synthesizer
import
EnhanceResponseSynthesizer
class
EnhanceGPTVectorStoreIndexQuery
(
GPTVectorStoreIndexQuery
):
@
classmethod
def
from_args
(
cls
,
index_struct
:
IS
,
service_context
:
ServiceContext
,
docstore
:
Optional
[
BaseDocumentStore
]
=
None
,
node_postprocessors
:
Optional
[
List
[
BaseNodePostprocessor
]]
=
None
,
verbose
:
bool
=
False
,
# response synthesizer args
response_mode
:
ResponseMode
=
ResponseMode
.
DEFAULT
,
text_qa_template
:
Optional
[
QuestionAnswerPrompt
]
=
None
,
refine_template
:
Optional
[
RefinePrompt
]
=
None
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
response_kwargs
:
Optional
[
Dict
]
=
None
,
use_async
:
bool
=
False
,
streaming
:
bool
=
False
,
optimizer
:
Optional
[
BaseTokenUsageOptimizer
]
=
None
,
# class-specific args
**
kwargs
:
Any
,
)
->
"BaseGPTIndexQuery"
:
response_synthesizer
=
EnhanceResponseSynthesizer
.
from_args
(
service_context
=
service_context
,
text_qa_template
=
text_qa_template
,
refine_template
=
refine_template
,
simple_template
=
simple_template
,
response_mode
=
response_mode
,
response_kwargs
=
response_kwargs
,
use_async
=
use_async
,
streaming
=
streaming
,
optimizer
=
optimizer
,
)
return
cls
(
index_struct
=
index_struct
,
service_context
=
service_context
,
response_synthesizer
=
response_synthesizer
,
docstore
=
docstore
,
node_postprocessors
=
node_postprocessors
,
verbose
=
verbose
,
**
kwargs
,
)
api/core/vector_store/weaviate_vector_store.py
0 → 100644
View file @
eea011bd
from
langchain.vectorstores
import
Weaviate
class
WeaviateVectorStore
(
Weaviate
):
def
del_texts
(
self
,
where_filter
:
dict
):
if
not
where_filter
:
raise
ValueError
(
'where_filter must not be empty'
)
self
.
_client
.
batch
.
delete_objects
(
class_name
=
self
.
_index_name
,
where
=
where_filter
,
output
=
'minimal'
)
def
del_text
(
self
,
uuid
:
str
)
->
None
:
self
.
_client
.
data_object
.
delete
(
uuid
,
class_name
=
self
.
_index_name
)
def
text_exists
(
self
,
uuid
:
str
)
->
bool
:
result
=
self
.
_client
.
query
.
get
(
self
.
_index_name
)
.
with_additional
([
"id"
])
.
with_where
({
"path"
:
[
"doc_id"
],
"operator"
:
"Equal"
,
"valueText"
:
uuid
,
})
.
with_limit
(
1
)
.
do
()
if
"errors"
in
result
:
raise
ValueError
(
f
"Error during query: {result['errors']}"
)
entries
=
result
[
"data"
][
"Get"
][
self
.
_index_name
]
if
len
(
entries
)
==
0
:
return
False
return
True
def
delete
(
self
):
self
.
_client
.
schema
.
delete_class
(
self
.
_index_name
)
api/core/vector_store/weaviate_vector_store_client.py
deleted
100644 → 0
View file @
3eb8e66b
import
json
import
weaviate
from
dataclasses
import
field
from
typing
import
List
,
Any
,
Dict
,
Optional
from
core.vector_store.base
import
BaseVectorStoreClient
,
BaseGPTVectorStoreIndex
,
EnhanceVectorStore
from
llama_index
import
ServiceContext
,
GPTWeaviateIndex
,
GPTVectorStoreIndex
from
llama_index.data_structs.data_structs_v2
import
WeaviateIndexDict
,
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
llama_index.readers.weaviate.client
import
_class_name
,
NODE_SCHEMA
,
_logger
from
llama_index.vector_stores
import
WeaviateVectorStore
from
llama_index.vector_stores.types
import
VectorStoreQuery
,
VectorStoreQueryResult
,
VectorStoreQueryMode
from
llama_index.readers.weaviate.utils
import
(
parse_get_response
,
validate_client
,
)
class
WeaviateVectorStoreClient
(
BaseVectorStoreClient
):
def
__init__
(
self
,
endpoint
:
str
,
api_key
:
str
,
grpc_enabled
:
bool
,
batch_size
:
int
):
self
.
_client
=
self
.
init_from_config
(
endpoint
,
api_key
,
grpc_enabled
,
batch_size
)
def
init_from_config
(
self
,
endpoint
:
str
,
api_key
:
str
,
grpc_enabled
:
bool
,
batch_size
:
int
):
auth_config
=
weaviate
.
auth
.
AuthApiKey
(
api_key
=
api_key
)
weaviate
.
connect
.
connection
.
has_grpc
=
grpc_enabled
client
=
weaviate
.
Client
(
url
=
endpoint
,
auth_client_secret
=
auth_config
,
timeout_config
=
(
5
,
60
),
startup_period
=
None
)
client
.
batch
.
configure
(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size
=
batch_size
,
# dynamically update the `batch_size` based on import speed
dynamic
=
True
,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries
=
3
,
)
return
client
def
get_index
(
self
,
service_context
:
ServiceContext
,
config
:
dict
)
->
GPTVectorStoreIndex
:
index_struct
=
WeaviateIndexDict
()
if
self
.
_client
is
None
:
raise
Exception
(
"Vector client is not initialized."
)
# {"class_prefix": "Gpt_index_xxx"}
class_prefix
=
config
.
get
(
'class_prefix'
)
if
not
class_prefix
:
raise
Exception
(
"class_prefix cannot be None."
)
return
GPTWeaviateEnhanceIndex
(
service_context
=
service_context
,
index_struct
=
index_struct
,
vector_store
=
WeaviateWithSimilaritiesVectorStore
(
weaviate_client
=
self
.
_client
,
class_prefix
=
class_prefix
)
)
def
to_index_config
(
self
,
index_id
:
str
)
->
dict
:
return
{
"class_prefix"
:
index_id
}
class
WeaviateWithSimilaritiesVectorStore
(
WeaviateVectorStore
,
EnhanceVectorStore
):
def
query
(
self
,
query
:
VectorStoreQuery
)
->
VectorStoreQueryResult
:
"""Query index for top k most similar nodes."""
nodes
=
self
.
weaviate_query
(
self
.
_client
,
self
.
_class_prefix
,
query
,
)
nodes
=
nodes
[:
query
.
similarity_top_k
]
node_idxs
=
[
str
(
i
)
for
i
in
range
(
len
(
nodes
))]
similarities
=
[]
for
node
in
nodes
:
similarities
.
append
(
node
.
extra_info
[
'similarity'
])
del
node
.
extra_info
[
'similarity'
]
return
VectorStoreQueryResult
(
nodes
=
nodes
,
ids
=
node_idxs
,
similarities
=
similarities
)
def
weaviate_query
(
self
,
client
:
Any
,
class_prefix
:
str
,
query_spec
:
VectorStoreQuery
,
)
->
List
[
Node
]:
"""Convert to LlamaIndex list."""
validate_client
(
client
)
class_name
=
_class_name
(
class_prefix
)
prop_names
=
[
p
[
"name"
]
for
p
in
NODE_SCHEMA
]
vector
=
query_spec
.
query_embedding
# build query
query
=
client
.
query
.
get
(
class_name
,
prop_names
)
.
with_additional
([
"id"
,
"vector"
,
"certainty"
])
if
query_spec
.
mode
==
VectorStoreQueryMode
.
DEFAULT
:
_logger
.
debug
(
"Using vector search"
)
if
vector
is
not
None
:
query
=
query
.
with_near_vector
(
{
"vector"
:
vector
,
}
)
elif
query_spec
.
mode
==
VectorStoreQueryMode
.
HYBRID
:
_logger
.
debug
(
f
"Using hybrid search with alpha {query_spec.alpha}"
)
query
=
query
.
with_hybrid
(
query
=
query_spec
.
query_str
,
alpha
=
query_spec
.
alpha
,
vector
=
vector
,
)
query
=
query
.
with_limit
(
query_spec
.
similarity_top_k
)
_logger
.
debug
(
f
"Using limit of {query_spec.similarity_top_k}"
)
# execute query
query_result
=
query
.
do
()
# parse results
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
results
=
[
self
.
_to_node
(
entry
)
for
entry
in
entries
]
return
results
def
_to_node
(
self
,
entry
:
Dict
)
->
Node
:
"""Convert to Node."""
extra_info_str
=
entry
[
"extra_info"
]
if
extra_info_str
==
""
:
extra_info
=
None
else
:
extra_info
=
json
.
loads
(
extra_info_str
)
if
'certainty'
in
entry
[
'_additional'
]:
if
extra_info
:
extra_info
[
'similarity'
]
=
entry
[
'_additional'
][
'certainty'
]
else
:
extra_info
=
{
'similarity'
:
entry
[
'_additional'
][
'certainty'
]}
node_info_str
=
entry
[
"node_info"
]
if
node_info_str
==
""
:
node_info
=
None
else
:
node_info
=
json
.
loads
(
node_info_str
)
relationships_str
=
entry
[
"relationships"
]
relationships
:
Dict
[
DocumentRelationship
,
str
]
if
relationships_str
==
""
:
relationships
=
field
(
default_factory
=
dict
)
else
:
relationships
=
{
DocumentRelationship
(
k
):
v
for
k
,
v
in
json
.
loads
(
relationships_str
)
.
items
()
}
return
Node
(
text
=
entry
[
"text"
],
doc_id
=
entry
[
"doc_id"
],
embedding
=
entry
[
"_additional"
][
"vector"
],
extra_info
=
extra_info
,
node_info
=
node_info
,
relationships
=
relationships
,
)
def
delete
(
self
,
doc_id
:
str
,
**
delete_kwargs
:
Any
)
->
None
:
"""Delete a document.
Args:
doc_id (str): document id
"""
delete_document
(
self
.
_client
,
doc_id
,
self
.
_class_prefix
)
def
delete_node
(
self
,
node_id
:
str
):
"""
Delete node from the index.
:param node_id: node id
"""
delete_node
(
self
.
_client
,
node_id
,
self
.
_class_prefix
)
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
"""
Get node from the index by node id.
:param node_id: node id
"""
entry
=
get_by_node_id
(
self
.
_client
,
node_id
,
self
.
_class_prefix
)
return
True
if
entry
else
False
class
GPTWeaviateEnhanceIndex
(
GPTWeaviateIndex
,
BaseGPTVectorStoreIndex
):
pass
def
delete_document
(
client
:
Any
,
ref_doc_id
:
str
,
class_prefix
:
str
)
->
None
:
"""Delete entry."""
validate_client
(
client
)
# make sure that each entry
class_name
=
_class_name
(
class_prefix
)
where_filter
=
{
"path"
:
[
"ref_doc_id"
],
"operator"
:
"Equal"
,
"valueString"
:
ref_doc_id
,
}
query
=
(
client
.
query
.
get
(
class_name
)
.
with_additional
([
"id"
])
.
with_where
(
where_filter
)
)
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
for
entry
in
entries
:
client
.
data_object
.
delete
(
entry
[
"_additional"
][
"id"
],
class_name
)
while
len
(
entries
)
>
0
:
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
for
entry
in
entries
:
client
.
data_object
.
delete
(
entry
[
"_additional"
][
"id"
],
class_name
)
def
delete_node
(
client
:
Any
,
node_id
:
str
,
class_prefix
:
str
)
->
None
:
"""Delete entry."""
validate_client
(
client
)
# make sure that each entry
class_name
=
_class_name
(
class_prefix
)
where_filter
=
{
"path"
:
[
"doc_id"
],
"operator"
:
"Equal"
,
"valueString"
:
node_id
,
}
query
=
(
client
.
query
.
get
(
class_name
)
.
with_additional
([
"id"
])
.
with_where
(
where_filter
)
)
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
for
entry
in
entries
:
client
.
data_object
.
delete
(
entry
[
"_additional"
][
"id"
],
class_name
)
def
get_by_node_id
(
client
:
Any
,
node_id
:
str
,
class_prefix
:
str
)
->
Optional
[
Dict
]:
"""Delete entry."""
validate_client
(
client
)
# make sure that each entry
class_name
=
_class_name
(
class_prefix
)
where_filter
=
{
"path"
:
[
"doc_id"
],
"operator"
:
"Equal"
,
"valueString"
:
node_id
,
}
query
=
(
client
.
query
.
get
(
class_name
)
.
with_additional
([
"id"
])
.
with_where
(
where_filter
)
)
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
if
len
(
entries
)
==
0
:
return
None
return
entries
[
0
]
api/extensions/ext_vector_store.py
deleted
100644 → 0
View file @
3eb8e66b
from
core.vector_store.vector_store
import
VectorStore
vector_store
=
VectorStore
()
def
init_app
(
app
):
vector_store
.
init_app
(
app
)
api/libs/helper.py
View file @
eea011bd
...
@@ -3,6 +3,7 @@ import re
...
@@ -3,6 +3,7 @@ import re
import
subprocess
import
subprocess
import
uuid
import
uuid
from
datetime
import
datetime
from
datetime
import
datetime
from
hashlib
import
sha256
from
zoneinfo
import
available_timezones
from
zoneinfo
import
available_timezones
import
random
import
random
import
string
import
string
...
@@ -147,3 +148,8 @@ def get_remote_ip(request):
...
@@ -147,3 +148,8 @@ def get_remote_ip(request):
return
request
.
headers
.
getlist
(
"X-Forwarded-For"
)[
0
]
return
request
.
headers
.
getlist
(
"X-Forwarded-For"
)[
0
]
else
:
else
:
return
request
.
remote_addr
return
request
.
remote_addr
def
generate_text_hash
(
text
:
str
)
->
str
:
hash_text
=
str
(
text
)
+
'None'
return
sha256
(
hash_text
.
encode
())
.
hexdigest
()
api/models/account.py
View file @
eea011bd
...
@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model):
...
@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model):
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
_current_tenant
:
db
.
Model
=
None
@
property
@
property
def
current_tenant
(
self
):
def
current_tenant
(
self
):
return
self
.
_current_tenant
return
self
.
_current_tenant
...
...
api/models/dataset.py
View file @
eea011bd
...
@@ -66,6 +66,23 @@ class Dataset(db.Model):
...
@@ -66,6 +66,23 @@ class Dataset(db.Model):
def
document_count
(
self
):
def
document_count
(
self
):
return
db
.
session
.
query
(
func
.
count
(
Document
.
id
))
.
filter
(
Document
.
dataset_id
==
self
.
id
)
.
scalar
()
return
db
.
session
.
query
(
func
.
count
(
Document
.
id
))
.
filter
(
Document
.
dataset_id
==
self
.
id
)
.
scalar
()
@
property
def
available_document_count
(
self
):
return
db
.
session
.
query
(
func
.
count
(
Document
.
id
))
.
filter
(
Document
.
dataset_id
==
self
.
id
,
Document
.
indexing_status
==
'completed'
,
Document
.
enabled
==
True
,
Document
.
archived
==
False
)
.
scalar
()
@
property
def
available_segment_count
(
self
):
return
db
.
session
.
query
(
func
.
count
(
DocumentSegment
.
id
))
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
id
,
DocumentSegment
.
status
==
'completed'
,
DocumentSegment
.
enabled
==
True
)
.
scalar
()
@
property
@
property
def
word_count
(
self
):
def
word_count
(
self
):
return
Document
.
query
.
with_entities
(
func
.
coalesce
(
func
.
sum
(
Document
.
word_count
)))
\
return
Document
.
query
.
with_entities
(
func
.
coalesce
(
func
.
sum
(
Document
.
word_count
)))
\
...
@@ -260,7 +277,7 @@ class Document(db.Model):
...
@@ -260,7 +277,7 @@ class Document(db.Model):
@
property
@
property
def
dataset
(
self
):
def
dataset
(
self
):
return
Dataset
.
query
.
get
(
self
.
dataset_id
)
return
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
id
==
self
.
dataset_id
)
.
one_or_none
(
)
@
property
@
property
def
segment_count
(
self
):
def
segment_count
(
self
):
...
@@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model):
...
@@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model):
@
property
@
property
def
keyword_table_dict
(
self
):
def
keyword_table_dict
(
self
):
return
json
.
loads
(
self
.
keyword_table
)
if
self
.
keyword_table
else
None
class
SetDecoder
(
json
.
JSONDecoder
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
()
.
__init__
(
object_hook
=
self
.
object_hook
,
*
args
,
**
kwargs
)
def
object_hook
(
self
,
dct
):
if
isinstance
(
dct
,
dict
):
for
keyword
,
node_idxs
in
dct
.
items
():
if
isinstance
(
node_idxs
,
list
):
dct
[
keyword
]
=
set
(
node_idxs
)
return
dct
return
json
.
loads
(
self
.
keyword_table
,
cls
=
SetDecoder
)
if
self
.
keyword_table
else
None
class
Embedding
(
db
.
Model
):
class
Embedding
(
db
.
Model
):
...
...
api/requirements.txt
View file @
eea011bd
...
@@ -2,6 +2,7 @@ coverage~=7.2.4
...
@@ -2,6 +2,7 @@ coverage~=7.2.4
beautifulsoup4==4.12.2
beautifulsoup4==4.12.2
flask~=2.3.2
flask~=2.3.2
Flask-SQLAlchemy~=3.0.3
Flask-SQLAlchemy~=3.0.3
SQLAlchemy~=1.4.28
flask-login==0.6.2
flask-login==0.6.2
flask-migrate~=4.0.4
flask-migrate~=4.0.4
flask-restful==0.3.9
flask-restful==0.3.9
...
@@ -9,8 +10,7 @@ flask-session2==1.3.1
...
@@ -9,8 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10
flask-cors==3.0.10
gunicorn~=20.1.0
gunicorn~=20.1.0
gevent~=22.10.2
gevent~=22.10.2
langchain==0.0.142
langchain==0.0.209
llama-index==0.5.27
openai~=0.27.5
openai~=0.27.5
psycopg2-binary~=2.9.6
psycopg2-binary~=2.9.6
pycryptodome==3.17
pycryptodome==3.17
...
@@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1
...
@@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1
jieba==0.42.1
jieba==0.42.1
celery==5.2.7
celery==5.2.7
redis~=4.5.4
redis~=4.5.4
pypdf==3.8.1
openpyxl==3.1.2
openpyxl==3.1.2
chardet~=5.1.0
chardet~=5.1.0
\ No newline at end of file
docx2txt==0.8
pypdfium2==4.16.0
\ No newline at end of file
api/services/app_model_config_service.py
View file @
eea011bd
...
@@ -4,7 +4,6 @@ import uuid
...
@@ -4,7 +4,6 @@ import uuid
from
core.constant
import
llm_constant
from
core.constant
import
llm_constant
from
models.account
import
Account
from
models.account
import
Account
from
services.dataset_service
import
DatasetService
from
services.dataset_service
import
DatasetService
from
services.errors.account
import
NoPermissionError
class
AppModelConfigService
:
class
AppModelConfigService
:
...
...
api/services/dataset_service.py
View file @
eea011bd
...
@@ -7,7 +7,6 @@ from typing import Optional, List
...
@@ -7,7 +7,6 @@ from typing import Optional, List
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
flask_login
import
current_user
from
flask_login
import
current_user
from
core.index.index_builder
import
IndexBuilder
from
events.dataset_event
import
dataset_was_deleted
from
events.dataset_event
import
dataset_was_deleted
from
events.document_event
import
document_was_deleted
from
events.document_event
import
document_was_deleted
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
...
@@ -386,8 +385,6 @@ class DocumentService:
...
@@ -386,8 +385,6 @@ class DocumentService:
dataset
.
indexing_technique
=
document_data
[
"indexing_technique"
]
dataset
.
indexing_technique
=
document_data
[
"indexing_technique"
]
if
dataset
.
indexing_technique
==
'high_quality'
:
IndexBuilder
.
get_default_service_context
(
dataset
.
tenant_id
)
documents
=
[]
documents
=
[]
batch
=
time
.
strftime
(
'
%
Y
%
m
%
d
%
H
%
M
%
S'
)
+
str
(
random
.
randint
(
100000
,
999999
))
batch
=
time
.
strftime
(
'
%
Y
%
m
%
d
%
H
%
M
%
S'
)
+
str
(
random
.
randint
(
100000
,
999999
))
if
'original_document_id'
in
document_data
and
document_data
[
"original_document_id"
]:
if
'original_document_id'
in
document_data
and
document_data
[
"original_document_id"
]:
...
...
api/services/hit_testing_service.py
View file @
eea011bd
...
@@ -3,47 +3,56 @@ import time
...
@@ -3,47 +3,56 @@ import time
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
from
llama_index.data_structs.node_v2
import
NodeWithScore
from
flask
import
current_app
from
llama_index.indices.query.schema
import
QueryBundle
from
langchain.embeddings
import
OpenAIEmbeddings
from
llama_index.indices.vector_store
import
GPTVectorStoreIndexQuery
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
from
sklearn.manifold
import
TSNE
from
sklearn.manifold
import
TSNE
from
core.docstore.empty_docstore
import
EmptyDocumentStore
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.vector_index
import
VectorIndex
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.account
import
Account
from
models.account
import
Account
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetQuery
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetQuery
from
services.errors.index
import
IndexNotInitializedError
class
HitTestingService
:
class
HitTestingService
:
@
classmethod
@
classmethod
def
retrieve
(
cls
,
dataset
:
Dataset
,
query
:
str
,
account
:
Account
,
limit
:
int
=
10
)
->
dict
:
def
retrieve
(
cls
,
dataset
:
Dataset
,
query
:
str
,
account
:
Account
,
limit
:
int
=
10
)
->
dict
:
index
=
VectorIndex
(
dataset
=
dataset
)
.
query_index
if
dataset
.
available_document_count
==
0
or
dataset
.
available_document_count
==
0
:
return
{
if
not
index
:
"query"
:
{
raise
IndexNotInitializedError
()
"content"
:
query
,
"tsne_position"
:
{
'x'
:
0
,
'y'
:
0
},
index_query
=
GPTVectorStoreIndexQuery
(
},
index_struct
=
index
.
index_struct
,
"records"
:
[]
service_context
=
index
.
service_context
,
}
vector_store
=
index
.
query_context
.
get
(
'vector_store'
),
docstore
=
EmptyDocumentStore
(),
response_synthesizer
=
None
,
similarity_top_k
=
limit
)
query_bundle
=
QueryBundle
(
model_credentials
=
LLMBuilder
.
get_model_credentials
(
query_str
=
query
,
tenant_id
=
dataset
.
tenant_id
,
custom_embedding_strs
=
[
query
],
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
)
query_bundle
.
embedding
=
index
.
service_context
.
embed_model
.
get_agg_embedding_from_queries
(
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
query_bundle
.
embedding_strs
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
)
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
nodes
=
index_query
.
retrieve
(
query_bundle
=
query_bundle
)
documents
=
vector_index
.
search
(
query
,
search_type
=
'similarity_score_threshold'
,
search_kwargs
=
{
'k'
:
10
}
)
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
logging
.
debug
(
f
"Hit testing retrieve in {end - start:0.4f} seconds"
)
logging
.
debug
(
f
"Hit testing retrieve in {end - start:0.4f} seconds"
)
...
@@ -58,25 +67,24 @@ class HitTestingService:
...
@@ -58,25 +67,24 @@ class HitTestingService:
db
.
session
.
add
(
dataset_query
)
db
.
session
.
add
(
dataset_query
)
db
.
session
.
commit
()
db
.
session
.
commit
()
return
cls
.
compact_retrieve_response
(
dataset
,
query_bundle
,
node
s
)
return
cls
.
compact_retrieve_response
(
dataset
,
embeddings
,
query
,
document
s
)
@
classmethod
@
classmethod
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
query_bundle
:
QueryBundle
,
nodes
:
List
[
NodeWithScore
]):
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
embeddings
:
Embeddings
,
query
:
str
,
documents
:
List
[
Document
]):
embeddings
=
[
text_
embeddings
=
[
query_bundle
.
embedding
embeddings
.
embed_query
(
query
)
]
]
for
node
in
nodes
:
text_embeddings
.
extend
(
embeddings
.
embed_documents
([
document
.
page_content
for
document
in
documents
]))
embeddings
.
append
(
node
.
node
.
embedding
)
tsne_position_data
=
cls
.
get_tsne_positions_from_embeddings
(
embeddings
)
tsne_position_data
=
cls
.
get_tsne_positions_from_embeddings
(
text_
embeddings
)
query_position
=
tsne_position_data
.
pop
(
0
)
query_position
=
tsne_position_data
.
pop
(
0
)
i
=
0
i
=
0
records
=
[]
records
=
[]
for
node
in
node
s
:
for
document
in
document
s
:
index_node_id
=
node
.
node
.
doc_id
index_node_id
=
document
.
metadata
[
'doc_id'
]
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset
.
id
,
DocumentSegment
.
dataset_id
==
dataset
.
id
,
...
@@ -91,7 +99,7 @@ class HitTestingService:
...
@@ -91,7 +99,7 @@ class HitTestingService:
record
=
{
record
=
{
"segment"
:
segment
,
"segment"
:
segment
,
"score"
:
node
.
score
,
"score"
:
document
.
metadata
[
'score'
]
,
"tsne_position"
:
tsne_position_data
[
i
]
"tsne_position"
:
tsne_position_data
[
i
]
}
}
...
@@ -101,7 +109,7 @@ class HitTestingService:
...
@@ -101,7 +109,7 @@ class HitTestingService:
return
{
return
{
"query"
:
{
"query"
:
{
"content"
:
query
_bundle
.
query_str
,
"content"
:
query
,
"tsne_position"
:
query_position
,
"tsne_position"
:
query_position
,
},
},
"records"
:
records
"records"
:
records
...
...
api/tasks/add_document_to_index_task.py
View file @
eea011bd
...
@@ -4,96 +4,81 @@ import time
...
@@ -4,96 +4,81 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
llama_index.data_structs
import
Node
from
langchain.schema
import
Document
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
,
Document
from
models.dataset
import
DocumentSegment
from
models.dataset
import
Document
as
DatasetDocument
@
shared_task
@
shared_task
def
add_document_to_index_task
(
document_id
:
str
):
def
add_document_to_index_task
(
d
ataset_d
ocument_id
:
str
):
"""
"""
Async Add document to index
Async Add document to index
:param document_id:
:param document_id:
Usage: add_document_to_index.delay(document_id)
Usage: add_document_to_index.delay(document_id)
"""
"""
logging
.
info
(
click
.
style
(
'Start add document to index: {}'
.
format
(
document_id
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Start add document to index: {}'
.
format
(
d
ataset_d
ocument_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
start_at
=
time
.
perf_counter
()
d
ocument
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
)
.
first
()
d
ataset_document
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
DatasetDocument
.
id
==
dataset_
document_id
)
.
first
()
if
not
document
:
if
not
d
ataset_d
ocument
:
raise
NotFound
(
'Document not found'
)
raise
NotFound
(
'Document not found'
)
if
document
.
indexing_status
!=
'completed'
:
if
d
ataset_d
ocument
.
indexing_status
!=
'completed'
:
return
return
indexing_cache_key
=
'document_{}_indexing'
.
format
(
document
.
id
)
indexing_cache_key
=
'document_{}_indexing'
.
format
(
d
ataset_d
ocument
.
id
)
try
:
try
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
document_id
==
d
ataset_d
ocument
.
id
,
DocumentSegment
.
enabled
==
True
DocumentSegment
.
enabled
==
True
)
\
)
\
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
nodes
=
[]
documents
=
[]
previous_node
=
None
for
segment
in
segments
:
for
segment
in
segments
:
relationships
=
{
document
=
Document
(
DocumentRelationship
.
SOURCE
:
document
.
id
page_content
=
segment
.
content
,
}
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
if
previous_node
:
"doc_hash"
:
segment
.
index_node_hash
,
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_node
.
doc_id
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
previous_node
.
relationships
[
DocumentRelationship
.
NEXT
]
=
segment
.
index_node_id
}
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
)
)
previous_node
=
node
documents
.
append
(
document
)
nodes
.
append
(
node
)
dataset
=
dataset_document
.
dataset
dataset
=
document
.
dataset
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
# save vector index
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
vector_index
.
add_nodes
(
if
index
:
nodes
=
nodes
,
index
.
add_texts
(
documents
)
duplicate_check
=
True
)
# save keyword index
# save keyword index
keyword_table_index
.
add_nodes
(
nodes
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
(
documents
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
logging
.
info
(
click
.
style
(
'Document added to index: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
click
.
style
(
'Document added to index: {} latency: {}'
.
format
(
d
ataset_d
ocument
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
exception
(
"add document to index failed"
)
logging
.
exception
(
"add document to index failed"
)
document
.
enabled
=
False
d
ataset_d
ocument
.
enabled
=
False
document
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
d
ataset_d
ocument
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
document
.
status
=
'error'
d
ataset_d
ocument
.
status
=
'error'
document
.
error
=
str
(
e
)
d
ataset_d
ocument
.
error
=
str
(
e
)
db
.
session
.
commit
()
db
.
session
.
commit
()
finally
:
finally
:
redis_client
.
delete
(
indexing_cache_key
)
redis_client
.
delete
(
indexing_cache_key
)
api/tasks/add_segment_to_index_task.py
View file @
eea011bd
...
@@ -4,12 +4,10 @@ import time
...
@@ -4,12 +4,10 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
llama_index.data_structs
import
Node
from
langchain.schema
import
Document
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
from
models.dataset
import
DocumentSegment
...
@@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str):
...
@@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str):
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
try
:
try
:
relationships
=
{
document
=
Document
(
DocumentRelationship
.
SOURCE
:
segment
.
document_id
,
page_content
=
segment
.
content
,
}
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
previous_segment
=
segment
.
previous_segment
"doc_hash"
:
segment
.
index_node_hash
,
if
previous_segment
:
"document_id"
:
segment
.
document_id
,
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_segment
.
index_node_id
"dataset_id"
:
segment
.
dataset_id
,
}
next_segment
=
segment
.
next_segment
if
next_segment
:
relationships
[
DocumentRelationship
.
NEXT
]
=
next_segment
.
index_node_id
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
)
)
dataset
=
segment
.
dataset
dataset
=
segment
.
dataset
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Segment has no dataset'
)
logging
.
info
(
click
.
style
(
'Segment {} has no dataset, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
vector_index
=
VectorIndex
(
dataset
=
dataset
)
dataset_document
=
segment
.
document
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
if
not
dataset_document
:
logging
.
info
(
click
.
style
(
'Segment {} has no document, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
if
not
dataset_document
.
enabled
or
dataset_document
.
archived
or
dataset_document
.
indexing_status
!=
'completed'
:
logging
.
info
(
click
.
style
(
'Segment {} document status is invalid, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
# save vector index
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
vector_index
.
add_nodes
(
if
index
:
nodes
=
[
node
],
index
.
add_texts
([
document
],
duplicate_check
=
True
)
duplicate_check
=
True
)
# save keyword index
# save keyword index
keyword_table_index
.
add_nodes
([
node
])
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
([
document
])
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment added to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Segment added to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
api/tasks/clean_dataset_task.py
View file @
eea011bd
...
@@ -4,8 +4,7 @@ import time
...
@@ -4,8 +4,7 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
,
DatasetKeywordTable
,
DatasetQuery
,
DatasetProcessRule
,
\
from
models.dataset
import
DocumentSegment
,
Dataset
,
DatasetKeywordTable
,
DatasetQuery
,
DatasetProcessRule
,
\
AppDatasetJoin
AppDatasetJoin
...
@@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
...
@@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct
=
index_struct
index_struct
=
index_struct
)
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
documents
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
documents
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
index_doc_ids
=
[
document
.
id
for
document
in
documents
]
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
dataset
.
indexing_technique
==
"high_quality"
:
kw_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
for
index_doc_id
in
index_doc_ids
:
try
:
vector_index
.
del_doc
(
index_doc_id
)
except
Exception
:
logging
.
exception
(
"Delete doc index failed when dataset deleted."
)
continue
# delete from
keyword
index
# delete from
vector
index
if
index_node_ids
:
if
vector_index
:
try
:
try
:
keyword_table_index
.
del_nodes
(
index_node_ids
)
vector_index
.
delete
(
)
except
Exception
:
except
Exception
:
logging
.
exception
(
"Delete nodes index failed when dataset deleted."
)
logging
.
exception
(
"Delete doc index failed when dataset deleted."
)
# delete from keyword index
try
:
kw_index
.
delete
()
except
Exception
:
logging
.
exception
(
"Delete nodes index failed when dataset deleted."
)
for
document
in
documents
:
for
document
in
documents
:
db
.
session
.
delete
(
document
)
db
.
session
.
delete
(
document
)
...
@@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
...
@@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
for
segment
in
segments
:
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
delete
(
segment
)
db
.
session
.
query
(
DatasetKeywordTable
)
.
filter
(
DatasetKeywordTable
.
dataset_id
==
dataset_id
)
.
delete
()
db
.
session
.
query
(
DatasetProcessRule
)
.
filter
(
DatasetProcessRule
.
dataset_id
==
dataset_id
)
.
delete
()
db
.
session
.
query
(
DatasetProcessRule
)
.
filter
(
DatasetProcessRule
.
dataset_id
==
dataset_id
)
.
delete
()
db
.
session
.
query
(
DatasetQuery
)
.
filter
(
DatasetQuery
.
dataset_id
==
dataset_id
)
.
delete
()
db
.
session
.
query
(
DatasetQuery
)
.
filter
(
DatasetQuery
.
dataset_id
==
dataset_id
)
.
delete
()
db
.
session
.
query
(
AppDatasetJoin
)
.
filter
(
AppDatasetJoin
.
dataset_id
==
dataset_id
)
.
delete
()
db
.
session
.
query
(
AppDatasetJoin
)
.
filter
(
AppDatasetJoin
.
dataset_id
==
dataset_id
)
.
delete
()
...
...
api/tasks/clean_document_task.py
View file @
eea011bd
...
@@ -4,8 +4,7 @@ import time
...
@@ -4,8 +4,7 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
from
models.dataset
import
DocumentSegment
,
Dataset
...
@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
...
@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
# delete from keyword index
if
index_node_ids
:
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
delete
(
segment
)
db
.
session
.
commit
()
db
.
session
.
commit
()
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
logging
.
info
(
...
...
api/tasks/clean_notion_document_task.py
View file @
eea011bd
...
@@ -5,8 +5,7 @@ from typing import List
...
@@ -5,8 +5,7 @@ from typing import List
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
,
Document
from
models.dataset
import
DocumentSegment
,
Dataset
,
Document
...
@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
...
@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
for
document_id
in
document_ids
:
for
document_id
in
document_ids
:
document
=
db
.
session
.
query
(
Document
)
.
filter
(
document
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
Document
.
id
==
document_id
)
.
first
()
)
.
first
()
db
.
session
.
delete
(
document
)
db
.
session
.
delete
(
document
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
# delete from keyword index
if
index_node_ids
:
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
delete
(
segment
)
...
...
api/tasks/deal_dataset_vector_index_task.py
View file @
eea011bd
...
@@ -3,10 +3,12 @@ import time
...
@@ -3,10 +3,12 @@ import time
import
click
import
click
from
celery
import
shared_task
from
celery
import
shared_task
from
llama_index.data_structs.node_v2
import
DocumentRelationship
,
Node
from
langchain.schema
import
Document
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Document
,
Dataset
from
models.dataset
import
DocumentSegment
,
Dataset
from
models.dataset
import
Document
as
DatasetDocument
@
shared_task
@
shared_task
...
@@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
...
@@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
dataset
=
Dataset
.
query
.
filter_by
(
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_id
id
=
dataset_id
)
.
first
()
)
.
first
()
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
raise
Exception
(
'Dataset not found'
)
documents
=
Document
.
query
.
filter_by
(
dataset_id
=
dataset_id
)
.
all
()
if
documents
:
if
action
==
"remove"
:
vector_index
=
VectorIndex
(
dataset
=
dataset
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
,
ignore_high_quality_check
=
True
)
for
document
in
documents
:
index
.
delete
()
# delete from vector index
elif
action
==
"add"
:
if
action
==
"remove"
:
dataset_documents
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
vector_index
.
del_doc
(
document
.
id
)
DatasetDocument
.
dataset_id
==
dataset_id
,
elif
action
==
"add"
:
DatasetDocument
.
indexing_status
==
'completed'
,
DatasetDocument
.
enabled
==
True
,
DatasetDocument
.
archived
==
False
,
)
.
all
()
if
dataset_documents
:
# save vector index
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
,
ignore_high_quality_check
=
True
)
for
dataset_document
in
dataset_documents
:
# delete from vector index
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
document_id
==
d
ataset_d
ocument
.
id
,
DocumentSegment
.
enabled
==
True
DocumentSegment
.
enabled
==
True
)
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
)
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
nodes
=
[]
documents
=
[]
previous_node
=
None
for
segment
in
segments
:
for
segment
in
segments
:
relationships
=
{
document
=
Document
(
DocumentRelationship
.
SOURCE
:
document
.
id
page_content
=
segment
.
content
,
}
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
if
previous_node
:
"doc_hash"
:
segment
.
index_node_hash
,
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_node
.
doc_id
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
previous_node
.
relationships
[
DocumentRelationship
.
NEXT
]
=
segment
.
index_node_id
}
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
)
)
previous_node
=
node
documents
.
append
(
document
)
nodes
.
append
(
node
)
# save vector index
# save vector index
vector_index
.
add_nodes
(
index
.
add_texts
(
documents
)
nodes
=
nodes
,
duplicate_check
=
True
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
logging
.
info
(
...
...
api/tasks/document_indexing_sync_task.py
View file @
eea011bd
...
@@ -6,11 +6,9 @@ import click
...
@@ -6,11 +6,9 @@ import click
from
celery
import
shared_task
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.data_source.notion
import
NotionPageReader
from
core.data_loader.loader.notion
import
NotionLoader
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
from
models.source
import
DataSourceBinding
from
models.source
import
DataSourceBinding
...
@@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
...
@@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
raise
ValueError
(
"no notion page found"
)
raise
ValueError
(
"no notion page found"
)
workspace_id
=
data_source_info
[
'notion_workspace_id'
]
workspace_id
=
data_source_info
[
'notion_workspace_id'
]
page_id
=
data_source_info
[
'notion_page_id'
]
page_id
=
data_source_info
[
'notion_page_id'
]
page_type
=
data_source_info
[
'type'
]
page_edited_time
=
data_source_info
[
'last_edited_time'
]
page_edited_time
=
data_source_info
[
'last_edited_time'
]
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
db
.
and_
(
db
.
and_
(
...
@@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
...
@@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
)
.
first
()
)
.
first
()
if
not
data_source_binding
:
if
not
data_source_binding
:
raise
ValueError
(
'Data source binding not found.'
)
raise
ValueError
(
'Data source binding not found.'
)
reader
=
NotionPageReader
(
integration_token
=
data_source_binding
.
access_token
)
last_edited_time
=
reader
.
get_page_last_edited_time
(
page_id
)
loader
=
NotionLoader
(
notion_access_token
=
data_source_binding
.
access_token
,
notion_workspace_id
=
workspace_id
,
notion_obj_id
=
page_id
,
notion_page_type
=
page_type
)
last_edited_time
=
loader
.
get_notion_last_edited_time
()
# check the page is updated
# check the page is updated
if
last_edited_time
!=
page_edited_time
:
if
last_edited_time
!=
page_edited_time
:
document
.
indexing_status
=
'parsing'
document
.
indexing_status
=
'parsing'
...
@@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
...
@@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
raise
Exception
(
'Dataset not found'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
# delete from keyword index
if
index_node_ids
:
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
delete
(
segment
)
...
@@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
...
@@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
:
except
Exception
:
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
try
:
try
:
indexing_runner
=
IndexingRunner
()
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
([
document
])
indexing_runner
.
run
([
document
])
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
'Document update paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
except
Exception
:
document
.
indexing_status
=
'error'
pass
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume update document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
api/tasks/document_indexing_task.py
View file @
eea011bd
...
@@ -7,7 +7,6 @@ from celery import shared_task
...
@@ -7,7 +7,6 @@ from celery import shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
Document
from
models.dataset
import
Document
...
@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
...
@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
Usage: document_indexing_task.delay(dataset_id, document_id)
Usage: document_indexing_task.delay(dataset_id, document_id)
"""
"""
documents
=
[]
documents
=
[]
start_at
=
time
.
perf_counter
()
for
document_id
in
document_ids
:
for
document_id
in
document_ids
:
logging
.
info
(
click
.
style
(
'Start process document: {}'
.
format
(
document_id
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Start process document: {}'
.
format
(
document_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
document
=
db
.
session
.
query
(
Document
)
.
filter
(
document
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
,
Document
.
id
==
document_id
,
...
@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
...
@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner
=
IndexingRunner
()
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
(
documents
)
indexing_runner
.
run
(
documents
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Processed dataset: {} latency: {}'
.
format
(
dataset_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
'Document paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
except
Exception
:
document
.
indexing_status
=
'error'
pass
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
api/tasks/document_indexing_update_task.py
View file @
eea011bd
...
@@ -6,10 +6,8 @@ import click
...
@@ -6,10 +6,8 @@ import click
from
celery
import
shared_task
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
...
@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
...
@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
raise
Exception
(
'Dataset not found'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_ids
(
index_node_ids
)
# delete from keyword index
# delete from keyword index
if
index_node_ids
:
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
delete
(
segment
)
...
@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
...
@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
:
except
Exception
:
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
try
:
try
:
indexing_runner
=
IndexingRunner
()
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
([
document
])
indexing_runner
.
run
([
document
])
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
'Document update paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
except
Exception
:
document
.
indexing_status
=
'error'
pass
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume update document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
api/tasks/recover_document_indexing_task.py
View file @
eea011bd
import
datetime
import
logging
import
logging
import
time
import
time
...
@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
...
@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner
.
run_in_indexing_status
(
document
)
indexing_runner
.
run_in_indexing_status
(
document
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
'Document paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
as
e
:
except
Exception
:
logging
.
exception
(
"consume document failed"
)
pass
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
api/tasks/remove_document_from_index_task.py
View file @
eea011bd
...
@@ -5,8 +5,7 @@ import click
...
@@ -5,8 +5,7 @@ import click
from
celery
import
shared_task
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
,
Document
from
models.dataset
import
DocumentSegment
,
Document
...
@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
...
@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
# delete from vector index
vector_index
.
del
_doc
(
document
.
id
)
vector_index
.
del
ete_by_document_id
(
document
.
id
)
# delete from keyword index
# delete from keyword index
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
)
.
all
()
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
if
index_node_ids
:
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
logging
.
info
(
...
...
api/tasks/remove_segment_from_index_task.py
View file @
eea011bd
...
@@ -5,8 +5,7 @@ import click
...
@@ -5,8 +5,7 @@ import click
from
celery
import
shared_task
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.index
import
IndexBuilder
from
core.index.vector_index
import
VectorIndex
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
from
models.dataset
import
DocumentSegment
...
@@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str):
...
@@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str):
dataset
=
segment
.
dataset
dataset
=
segment
.
dataset
if
not
dataset
:
if
not
dataset
:
raise
Exception
(
'Segment has no dataset'
)
logging
.
info
(
click
.
style
(
'Segment {} has no dataset, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
vector_index
=
VectorIndex
(
dataset
=
dataset
)
dataset_document
=
segment
.
document
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
if
not
dataset_document
:
logging
.
info
(
click
.
style
(
'Segment {} has no document, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
if
not
dataset_document
.
enabled
or
dataset_document
.
archived
or
dataset_document
.
indexing_status
!=
'completed'
:
logging
.
info
(
click
.
style
(
'Segment {} document status is invalid, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
kw_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
# delete from vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
if
vector_index
:
vector_index
.
del
_node
s
([
segment
.
index_node_id
])
vector_index
.
del
ete_by_id
s
([
segment
.
index_node_id
])
# delete from keyword index
# delete from keyword index
k
eyword_table_index
.
del_node
s
([
segment
.
index_node_id
])
k
w_index
.
delete_by_id
s
([
segment
.
index_node_id
])
end_at
=
time
.
perf_counter
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment removed from index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Segment removed from index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
sdks/python-client/dify_client/client.py
View file @
eea011bd
...
@@ -65,8 +65,8 @@ class ChatClient(DifyClient):
...
@@ -65,8 +65,8 @@ class ChatClient(DifyClient):
return
self
.
_send_request
(
"GET"
,
"/messages"
,
params
=
params
)
return
self
.
_send_request
(
"GET"
,
"/messages"
,
params
=
params
)
def
get_conversations
(
self
,
user
,
fir
st_id
=
None
,
limit
=
None
,
pinned
=
None
):
def
get_conversations
(
self
,
user
,
la
st_id
=
None
,
limit
=
None
,
pinned
=
None
):
params
=
{
"user"
:
user
,
"
first_id"
:
fir
st_id
,
"limit"
:
limit
,
"pinned"
:
pinned
}
params
=
{
"user"
:
user
,
"
last_id"
:
la
st_id
,
"limit"
:
limit
,
"pinned"
:
pinned
}
return
self
.
_send_request
(
"GET"
,
"/conversations"
,
params
=
params
)
return
self
.
_send_request
(
"GET"
,
"/conversations"
,
params
=
params
)
def
rename_conversation
(
self
,
conversation_id
,
name
,
user
):
def
rename_conversation
(
self
,
conversation_id
,
name
,
user
):
...
...
sdks/python-client/setup.py
View file @
eea011bd
...
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
...
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
setup
(
setup
(
name
=
"dify-client"
,
name
=
"dify-client"
,
version
=
"0.1.
7
"
,
version
=
"0.1.
8
"
,
author
=
"Dify"
,
author
=
"Dify"
,
author_email
=
"hello@dify.ai"
,
author_email
=
"hello@dify.ai"
,
description
=
"A package for interacting with the Dify Service-API"
,
description
=
"A package for interacting with the Dify Service-API"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment