Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
eea011bd
Commit
eea011bd
authored
Jun 25, 2023
by
StyleZhang
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'main' into feat/add-icons
parents
3eb8e66b
951afcaa
Changes
92
Hide whitespace changes
Inline
Side-by-side
Showing
92 changed files
with
2696 additions
and
3031 deletions
+2696
-3031
app.py
api/app.py
+1
-2
commands.py
api/commands.py
+35
-0
config.py
api/config.py
+2
-0
data_source.py
api/controllers/console/datasets/data_source.py
+11
-10
file.py
api/controllers/console/datasets/file.py
+2
-28
version.py
api/controllers/console/version.py
+7
-2
__init__.py
api/core/__init__.py
+0
-20
agent_builder.py
api/core/agent/agent_builder.py
+8
-11
agent_loop_gather_callback_handler.py
...re/callback_handler/agent_loop_gather_callback_handler.py
+1
-29
dataset_tool_callback_handler.py
api/core/callback_handler/dataset_tool_callback_handler.py
+1
-50
index_tool_callback_handler.py
api/core/callback_handler/index_tool_callback_handler.py
+8
-21
llm_callback_handler.py
api/core/callback_handler/llm_callback_handler.py
+31
-85
main_chain_gather_callback_handler.py
...re/callback_handler/main_chain_gather_callback_handler.py
+12
-70
std_out_callback_handler.py
api/core/callback_handler/std_out_callback_handler.py
+36
-9
chain_builder.py
api/core/chain/chain_builder.py
+2
-4
llm_router_chain.py
api/core/chain/llm_router_chain.py
+6
-4
main_chain_builder.py
api/core/chain/main_chain_builder.py
+10
-8
multi_dataset_router_chain.py
api/core/chain/multi_dataset_router_chain.py
+62
-16
sensitive_word_avoidance_chain.py
api/core/chain/sensitive_word_avoidance_chain.py
+7
-2
tool_chain.py
api/core/chain/tool_chain.py
+12
-3
completion.py
api/core/completion.py
+42
-17
conversation_message_task.py
api/core/conversation_message_task.py
+4
-4
file_extractor.py
api/core/data_loader/file_extractor.py
+43
-0
csv.py
api/core/data_loader/loader/csv.py
+67
-0
excel.py
api/core/data_loader/loader/excel.py
+43
-0
html.py
api/core/data_loader/loader/html.py
+35
-0
markdown.py
api/core/data_loader/loader/markdown.py
+134
-0
notion.py
api/core/data_loader/loader/notion.py
+236
-236
pdf.py
api/core/data_loader/loader/pdf.py
+55
-0
dataset_docstore.py
api/core/docstore/dataset_docstore.py
+35
-45
empty_docstore.py
api/core/docstore/empty_docstore.py
+0
-51
cached_embedding.py
api/core/embedding/cached_embedding.py
+72
-0
openai_embedding.py
api/core/embedding/openai_embedding.py
+0
-214
base.py
api/core/index/base.py
+59
-0
index.py
api/core/index/index.py
+41
-0
index_builder.py
api/core/index/index_builder.py
+0
-60
jieba_keyword_table.py
api/core/index/keyword_table/jieba_keyword_table.py
+0
-159
keyword_table_index.py
api/core/index/keyword_table_index.py
+0
-135
jieba_keyword_table_handler.py
.../index/keyword_table_index/jieba_keyword_table_handler.py
+33
-0
keyword_table_index.py
api/core/index/keyword_table_index/keyword_table_index.py
+238
-0
stopwords.py
api/core/index/keyword_table_index/stopwords.py
+0
-0
synthesizer.py
api/core/index/query/synthesizer.py
+0
-79
html_parser.py
api/core/index/readers/html_parser.py
+0
-22
pdf_parser.py
api/core/index/readers/pdf_parser.py
+0
-56
xlsx_parser.py
api/core/index/readers/xlsx_parser.py
+0
-33
vector_index.py
api/core/index/vector_index.py
+0
-136
base.py
api/core/index/vector_index/base.py
+175
-0
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+116
-0
vector_index.py
api/core/index/vector_index/vector_index.py
+69
-0
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+136
-0
indexing_runner.py
api/core/indexing_runner.py
+262
-283
llm_builder.py
api/core/llm/llm_builder.py
+17
-14
azure_provider.py
api/core/llm/provider/azure_provider.py
+4
-1
streamable_azure_chat_open_ai.py
api/core/llm/streamable_azure_chat_open_ai.py
+13
-50
streamable_azure_open_ai.py
api/core/llm/streamable_azure_open_ai.py
+13
-6
streamable_chat_open_ai.py
api/core/llm/streamable_chat_open_ai.py
+14
-48
streamable_open_ai.py
api/core/llm/streamable_open_ai.py
+13
-5
read_only_conversation_token_db_string_buffer_shared_memory.py
...only_conversation_token_db_string_buffer_shared_memory.py
+1
-1
prompts.py
api/core/prompt/prompts.py
+0
-19
fixed_text_splitter.py
api/core/spiltter/fixed_text_splitter.py
+0
-0
dataset_index_tool.py
api/core/tool/dataset_index_tool.py
+87
-0
dataset_tool_builder.py
api/core/tool/dataset_tool_builder.py
+0
-73
llama_index_tool.py
api/core/tool/llama_index_tool.py
+0
-43
base.py
api/core/vector_store/base.py
+0
-34
qdrant_vector_store.py
api/core/vector_store/qdrant_vector_store.py
+69
-0
qdrant_vector_store_client.py
api/core/vector_store/qdrant_vector_store_client.py
+0
-147
vector_store.py
api/core/vector_store/vector_store.py
+0
-62
vector_store_index_query.py
api/core/vector_store/vector_store_index_query.py
+0
-66
weaviate_vector_store.py
api/core/vector_store/weaviate_vector_store.py
+38
-0
weaviate_vector_store_client.py
api/core/vector_store/weaviate_vector_store_client.py
+0
-270
ext_vector_store.py
api/extensions/ext_vector_store.py
+0
-7
helper.py
api/libs/helper.py
+6
-0
account.py
api/models/account.py
+0
-2
dataset.py
api/models/dataset.py
+30
-2
requirements.txt
api/requirements.txt
+5
-4
app_model_config_service.py
api/services/app_model_config_service.py
+0
-1
dataset_service.py
api/services/dataset_service.py
+0
-3
hit_testing_service.py
api/services/hit_testing_service.py
+44
-36
add_document_to_index_task.py
api/tasks/add_document_to_index_task.py
+33
-48
add_segment_to_index_task.py
api/tasks/add_segment_to_index_task.py
+27
-32
clean_dataset_task.py
api/tasks/clean_dataset_task.py
+13
-20
clean_document_task.py
api/tasks/clean_document_task.py
+7
-6
clean_notion_document_task.py
api/tasks/clean_notion_document_task.py
+7
-6
deal_dataset_vector_index_task.py
api/tasks/deal_dataset_vector_index_task.py
+36
-36
document_indexing_sync_task.py
api/tasks/document_indexing_sync_task.py
+23
-23
document_indexing_task.py
api/tasks/document_indexing_task.py
+6
-16
document_indexing_update_task.py
api/tasks/document_indexing_update_task.py
+11
-20
recover_document_indexing_task.py
api/tasks/recover_document_indexing_task.py
+4
-9
remove_document_from_index_task.py
api/tasks/remove_document_from_index_task.py
+5
-6
remove_segment_from_index_task.py
api/tasks/remove_segment_from_index_task.py
+18
-8
client.py
sdks/python-client/dify_client/client.py
+2
-2
setup.py
sdks/python-client/setup.py
+1
-1
No files found.
api/app.py
View file @
eea011bd
...
...
@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
import
flask_login
from
flask_cors
import
CORS
from
extensions
import
ext_session
,
ext_celery
,
ext_sentry
,
ext_redis
,
ext_login
,
ext_
vector_store
,
ext_
migrate
,
\
from
extensions
import
ext_session
,
ext_celery
,
ext_sentry
,
ext_redis
,
ext_login
,
ext_migrate
,
\
ext_database
,
ext_storage
from
extensions.ext_database
import
db
from
extensions.ext_login
import
login_manager
...
...
@@ -79,7 +79,6 @@ def initialize_extensions(app):
ext_database
.
init_app
(
app
)
ext_migrate
.
init
(
app
,
db
)
ext_redis
.
init_app
(
app
)
ext_vector_store
.
init_app
(
app
)
ext_storage
.
init_app
(
app
)
ext_celery
.
init_app
(
app
)
ext_session
.
init_app
(
app
)
...
...
api/commands.py
View file @
eea011bd
import
datetime
import
logging
import
random
import
string
import
click
from
flask
import
current_app
from
werkzeug.exceptions
import
NotFound
from
core.index.index
import
IndexBuilder
from
libs.password
import
password_pattern
,
valid_password
,
hash_password
from
libs.helper
import
email
as
email_validate
from
extensions.ext_database
import
db
from
libs.rsa
import
generate_key_pair
from
models.account
import
InvitationCode
,
Tenant
from
models.dataset
import
Dataset
from
models.model
import
Account
import
secrets
import
base64
...
...
@@ -159,8 +163,39 @@ def generate_upper_string():
return
result
@
click
.
command
(
'recreate-all-dataset-indexes'
,
help
=
'Recreate all dataset indexes.'
)
def
recreate_all_dataset_indexes
():
click
.
echo
(
click
.
style
(
'Start recreate all dataset indexes.'
,
fg
=
'green'
))
recreate_count
=
0
page
=
1
while
True
:
try
:
datasets
=
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
indexing_technique
==
'high_quality'
)
\
.
order_by
(
Dataset
.
created_at
.
desc
())
.
paginate
(
page
=
page
,
per_page
=
50
)
except
NotFound
:
break
page
+=
1
for
dataset
in
datasets
:
try
:
click
.
echo
(
'Recreating dataset index: {}'
.
format
(
dataset
.
id
))
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
and
index
.
_is_origin
():
index
.
recreate_dataset
(
dataset
)
recreate_count
+=
1
else
:
click
.
echo
(
'passed.'
)
except
Exception
as
e
:
click
.
echo
(
click
.
style
(
'Recreate dataset index error: {} {}'
.
format
(
e
.
__class__
.
__name__
,
str
(
e
)),
fg
=
'red'
))
continue
click
.
echo
(
click
.
style
(
'Congratulations! Recreate {} dataset indexes.'
.
format
(
recreate_count
),
fg
=
'green'
))
def
register_commands
(
app
):
app
.
cli
.
add_command
(
reset_password
)
app
.
cli
.
add_command
(
reset_email
)
app
.
cli
.
add_command
(
generate_invitation_codes
)
app
.
cli
.
add_command
(
reset_encrypt_key_pair
)
app
.
cli
.
add_command
(
recreate_all_dataset_indexes
)
api/config.py
View file @
eea011bd
...
...
@@ -187,11 +187,13 @@ class Config:
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self
.
DEFAULT_LLM_PROVIDER
=
get_env
(
'DEFAULT_LLM_PROVIDER'
)
# notion import setting
self
.
NOTION_CLIENT_ID
=
get_env
(
'NOTION_CLIENT_ID'
)
self
.
NOTION_CLIENT_SECRET
=
get_env
(
'NOTION_CLIENT_SECRET'
)
self
.
NOTION_INTEGRATION_TYPE
=
get_env
(
'NOTION_INTEGRATION_TYPE'
)
self
.
NOTION_INTERNAL_SECRET
=
get_env
(
'NOTION_INTERNAL_SECRET'
)
self
.
NOTION_INTEGRATION_TOKEN
=
get_env
(
'NOTION_INTEGRATION_TOKEN'
)
class
CloudEditionConfig
(
Config
):
...
...
api/controllers/console/datasets/data_source.py
View file @
eea011bd
...
...
@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
from
controllers.console
import
api
from
controllers.console.setup
import
setup_required
from
controllers.console.wraps
import
account_initialization_required
from
core.data_
source.notion
import
NotionPageRe
ader
from
core.data_
loader.loader.notion
import
NotionLo
ader
from
core.indexing_runner
import
IndexingRunner
from
extensions.ext_database
import
db
from
libs.helper
import
TimestampField
from
libs.oauth_data_source
import
NotionOAuth
from
models.dataset
import
Document
from
models.source
import
DataSourceBinding
from
services.dataset_service
import
DatasetService
,
DocumentService
...
...
@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
)
.
first
()
if
not
data_source_binding
:
raise
NotFound
(
'Data source binding not found.'
)
reader
=
NotionPageReader
(
integration_token
=
data_source_binding
.
access_token
)
if
page_type
==
'page'
:
page_content
=
reader
.
read_page
(
page_id
)
elif
page_type
==
'database'
:
page_content
=
reader
.
query_database_data
(
page_id
)
else
:
page_content
=
""
loader
=
NotionLoader
(
notion_access_token
=
data_source_binding
.
access_token
,
notion_workspace_id
=
workspace_id
,
notion_obj_id
=
page_id
,
notion_page_type
=
page_type
)
text_docs
=
loader
.
load
()
return
{
'content'
:
page_content
'content'
:
"
\n
"
.
join
([
doc
.
page_content
for
doc
in
text_docs
])
},
200
@
setup_required
...
...
api/controllers/console/datasets/file.py
View file @
eea011bd
...
...
@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
UnsupportedFileTypeError
from
controllers.console.setup
import
setup_required
from
controllers.console.wraps
import
account_initialization_required
from
core.index.readers.html_parser
import
HTMLParser
from
core.index.readers.pdf_parser
import
PDFParser
from
core.index.readers.xlsx_parser
import
XLSXParser
from
core.data_loader.file_extractor
import
FileExtractor
from
extensions.ext_storage
import
storage
from
libs.helper
import
TimestampField
from
extensions.ext_database
import
db
...
...
@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
if
extension
not
in
ALLOWED_EXTENSIONS
:
raise
UnsupportedFileTypeError
()
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
upload_file
.
key
)
.
suffix
filepath
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage
.
download
(
upload_file
.
key
,
filepath
)
if
extension
==
'pdf'
:
parser
=
PDFParser
({
'upload_file'
:
upload_file
})
text
=
parser
.
parse_file
(
Path
(
filepath
))
elif
extension
in
[
'html'
,
'htm'
]:
# Use BeautifulSoup to extract text
parser
=
HTMLParser
()
text
=
parser
.
parse_file
(
Path
(
filepath
))
elif
extension
==
'xlsx'
:
parser
=
XLSXParser
()
text
=
parser
.
parse_file
(
filepath
)
else
:
# ['txt', 'markdown', 'md']
with
open
(
filepath
,
"rb"
)
as
fp
:
data
=
fp
.
read
()
encoding
=
chardet
.
detect
(
data
)[
'encoding'
]
if
encoding
:
text
=
data
.
decode
(
encoding
=
encoding
)
.
strip
()
if
data
else
''
else
:
text
=
data
.
decode
(
encoding
=
'utf-8'
)
.
strip
()
if
data
else
''
text
=
FileExtractor
.
load
(
upload_file
,
return_text
=
True
)
text
=
text
[
0
:
PREVIEW_WORDS_LIMIT
]
if
text
else
''
return
{
'content'
:
text
}
...
...
api/controllers/console/version.py
View file @
eea011bd
...
...
@@ -32,8 +32,13 @@ class VersionApi(Resource):
'current_version'
:
args
.
get
(
'current_version'
)
})
except
Exception
as
error
:
logging
.
exception
(
"Check update error."
)
raise
InternalServerError
()
logging
.
warning
(
"Check update version error: {}."
.
format
(
str
(
error
)))
return
{
'version'
:
args
.
get
(
'current_version'
),
'release_date'
:
''
,
'release_notes'
:
''
,
'can_auto_update'
:
False
}
content
=
json
.
loads
(
response
.
content
)
return
{
...
...
api/core/__init__.py
View file @
eea011bd
...
...
@@ -3,19 +3,11 @@ from typing import Optional
import
langchain
from
flask
import
Flask
from
jieba.analyse
import
default_tfidf
from
langchain
import
set_handler
from
langchain.prompts.base
import
DEFAULT_FORMATTER_MAPPING
from
llama_index
import
IndexStructType
,
QueryMode
from
llama_index.indices.registry
import
INDEX_STRUT_TYPE_TO_QUERY_MAP
from
pydantic
import
BaseModel
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.index.keyword_table.jieba_keyword_table
import
GPTJIEBAKeywordTableIndex
from
core.index.keyword_table.stopwords
import
STOPWORDS
from
core.prompt.prompt_template
import
OneLineFormatter
from
core.vector_store.vector_store
import
VectorStore
from
core.vector_store.vector_store_index_query
import
EnhanceGPTVectorStoreIndexQuery
class
HostedOpenAICredential
(
BaseModel
):
...
...
@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials()
def
init_app
(
app
:
Flask
):
formatter
=
OneLineFormatter
()
DEFAULT_FORMATTER_MAPPING
[
'f-string'
]
=
formatter
.
format
INDEX_STRUT_TYPE_TO_QUERY_MAP
[
IndexStructType
.
KEYWORD_TABLE
]
=
GPTJIEBAKeywordTableIndex
.
get_query_map
()
INDEX_STRUT_TYPE_TO_QUERY_MAP
[
IndexStructType
.
WEAVIATE
]
=
{
QueryMode
.
DEFAULT
:
EnhanceGPTVectorStoreIndexQuery
,
QueryMode
.
EMBEDDING
:
EnhanceGPTVectorStoreIndexQuery
,
}
INDEX_STRUT_TYPE_TO_QUERY_MAP
[
IndexStructType
.
QDRANT
]
=
{
QueryMode
.
DEFAULT
:
EnhanceGPTVectorStoreIndexQuery
,
QueryMode
.
EMBEDDING
:
EnhanceGPTVectorStoreIndexQuery
,
}
default_tfidf
.
stop_words
=
STOPWORDS
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
langchain
.
verbose
=
True
set_handler
(
DifyStdOutCallbackHandler
())
if
app
.
config
.
get
(
"OPENAI_API_KEY"
):
hosted_llm_credentials
.
openai
=
HostedOpenAICredential
(
api_key
=
app
.
config
.
get
(
"OPENAI_API_KEY"
))
api/core/agent/agent_builder.py
View file @
eea011bd
...
...
@@ -2,7 +2,7 @@ from typing import Optional
from
langchain
import
LLMChain
from
langchain.agents
import
ZeroShotAgent
,
AgentExecutor
,
ConversationalAgent
from
langchain.callbacks
import
CallbackManager
from
langchain.callbacks
.manager
import
CallbackManager
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
...
...
@@ -16,23 +16,20 @@ class AgentBuilder:
def
to_agent_chain
(
cls
,
tenant_id
:
str
,
tools
,
memory
:
Optional
[
BaseChatMemory
],
dataset_tool_callback_handler
:
DatasetToolCallbackHandler
,
agent_loop_gather_callback_handler
:
AgentLoopGatherCallbackHandler
):
llm_callback_manager
=
CallbackManager
([
agent_loop_gather_callback_handler
,
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
agent_loop_gather_callback_handler
.
model_name
,
temperature
=
0
,
max_tokens
=
1024
,
callback
_manager
=
llm_callback_manager
callback
s
=
[
agent_loop_gather_callback_handler
,
DifyStdOutCallbackHandler
()]
)
tool_callback_manager
=
CallbackManager
([
agent_loop_gather_callback_handler
,
dataset_tool_callback_handler
,
DifyStdOutCallbackHandler
()
])
for
tool
in
tools
:
tool
.
callback_manager
=
tool_callback_manager
tool
.
callbacks
=
[
agent_loop_gather_callback_handler
,
dataset_tool_callback_handler
,
DifyStdOutCallbackHandler
()
]
prompt
=
cls
.
build_agent_prompt_template
(
tools
=
tools
,
...
...
@@ -54,7 +51,7 @@ class AgentBuilder:
tools
=
tools
,
agent
=
agent
,
memory
=
memory
,
callback
_manager
=
agent_callback_manager
,
callback
s
=
agent_callback_manager
,
max_iterations
=
6
,
early_stopping_method
=
"generate"
,
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
...
...
api/core/callback_handler/agent_loop_gather_callback_handler.py
View file @
eea011bd
...
...
@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
class
AgentLoopGatherCallbackHandler
(
BaseCallbackHandler
):
"""Callback Handler that prints to std out."""
raise_error
:
bool
=
True
def
__init__
(
self
,
model_name
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
"""Initialize callback handler."""
...
...
@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_current_loop
.
completion
=
response
.
generations
[
0
][
0
]
.
text
self
.
_current_loop
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
def
on_llm_new_token
(
self
,
token
:
str
,
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
pass
def
on_llm_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
...
...
@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_agent_loops
=
[]
self
.
_current_loop
=
None
def
on_chain_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we are entering a chain."""
pass
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we finished a chain."""
pass
def
on_chain_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
logging
.
error
(
error
)
def
on_tool_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
...
...
@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_agent_loops
=
[]
self
.
_current_loop
=
None
def
on_text
(
self
,
text
:
str
,
color
:
Optional
[
str
]
=
None
,
end
:
str
=
""
,
**
kwargs
:
Optional
[
str
],
)
->
None
:
"""Run on additional input from chains and agents."""
pass
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
**
kwargs
:
Any
)
->
Any
:
"""Run on agent end."""
# Final Answer
...
...
api/core/callback_handler/dataset_tool_callback_handler.py
View file @
eea011bd
...
...
@@ -3,7 +3,6 @@ import logging
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
from
core.callback_handler.entity.dataset_query
import
DatasetQueryObj
from
core.conversation_message_task
import
ConversationMessageTask
...
...
@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
class
DatasetToolCallbackHandler
(
BaseCallbackHandler
):
"""Callback Handler that prints to std out."""
raise_error
:
bool
=
True
def
__init__
(
self
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
"""Initialize callback handler."""
...
...
@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
)
->
None
:
"""Do nothing."""
logging
.
error
(
error
)
def
on_chain_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_chain_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_new_token
(
self
,
token
:
str
,
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
pass
def
on_llm_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
logging
.
error
(
error
)
def
on_agent_action
(
self
,
action
:
AgentAction
,
color
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
)
->
Any
:
pass
def
on_text
(
self
,
text
:
str
,
color
:
Optional
[
str
]
=
None
,
end
:
str
=
""
,
**
kwargs
:
Optional
[
str
],
)
->
None
:
"""Run on additional input from chains and agents."""
pass
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
**
kwargs
:
Any
)
->
Any
:
"""Run on agent end."""
pass
api/core/callback_handler/index_tool_callback_handler.py
View file @
eea011bd
from
llama_index
import
Response
from
typing
import
List
from
langchain.schema
import
Document
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
class
IndexToolCallbackHandler
:
def
__init__
(
self
)
->
None
:
self
.
_response
=
None
@
property
def
response
(
self
)
->
Response
:
return
self
.
_response
def
on_tool_end
(
self
,
response
:
Response
)
->
None
:
"""Handle tool end."""
self
.
_response
=
response
class
DatasetIndexToolCallbackHandler
(
IndexToolCallbackHandler
):
class
DatasetIndexToolCallbackHandler
:
"""Callback handler for dataset tool."""
def
__init__
(
self
,
dataset_id
:
str
)
->
None
:
super
()
.
__init__
()
self
.
dataset_id
=
dataset_id
def
on_tool_end
(
self
,
response
:
Response
)
->
None
:
def
on_tool_end
(
self
,
documents
:
List
[
Document
]
)
->
None
:
"""Handle tool end."""
for
node
in
response
.
source_node
s
:
index_node_id
=
node
.
node
.
doc_id
for
document
in
document
s
:
doc_id
=
document
.
metadata
[
'doc_id'
]
# add hit count to document segment
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
dataset_id
,
DocumentSegment
.
index_node_id
==
index_node
_id
DocumentSegment
.
index_node_id
==
doc
_id
)
.
update
(
{
DocumentSegment
.
hit_count
:
DocumentSegment
.
hit_count
+
1
},
synchronize_session
=
False
...
...
api/core/callback_handler/llm_callback_handler.py
View file @
eea011bd
...
...
@@ -3,7 +3,7 @@ import time
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
HumanMessage
,
AIMessage
,
SystemMessage
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
HumanMessage
,
AIMessage
,
SystemMessage
,
BaseMessage
from
core.callback_handler.entity.llm_message
import
LLMMessage
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
...
...
@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
class
LLMCallbackHandler
(
BaseCallbackHandler
):
raise_error
:
bool
=
True
def
__init__
(
self
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
conversation_message_task
:
ConversationMessageTask
):
...
...
@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Whether to call verbose callbacks even if verbose is False."""
return
True
def
on_chat_model_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
messages
:
List
[
List
[
BaseMessage
]],
**
kwargs
:
Any
)
->
Any
:
self
.
start_at
=
time
.
perf_counter
()
real_prompts
=
[]
for
message
in
messages
[
0
]:
if
message
.
type
==
'human'
:
role
=
'user'
elif
message
.
type
==
'ai'
:
role
=
'assistant'
else
:
role
=
'system'
real_prompts
.
append
({
"role"
:
role
,
"text"
:
message
.
content
})
self
.
llm_message
.
prompt
=
real_prompts
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_messages_tokens
(
messages
[
0
])
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
self
.
start_at
=
time
.
perf_counter
()
if
'Chat'
in
serialized
[
'name'
]:
real_prompts
=
[]
messages
=
[]
for
prompt
in
prompts
:
role
,
content
=
prompt
.
split
(
': '
,
maxsplit
=
1
)
if
role
==
'human'
:
role
=
'user'
message
=
HumanMessage
(
content
=
content
)
elif
role
==
'ai'
:
role
=
'assistant'
message
=
AIMessage
(
content
=
content
)
else
:
message
=
SystemMessage
(
content
=
content
)
real_prompt
=
{
"role"
:
role
,
"text"
:
content
}
real_prompts
.
append
(
real_prompt
)
messages
.
append
(
message
)
self
.
llm_message
.
prompt
=
real_prompts
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_messages_tokens
(
messages
)
else
:
self
.
llm_message
.
prompt
=
[{
"role"
:
'user'
,
"text"
:
prompts
[
0
]
}]
self
.
llm_message
.
prompt
=
[{
"role"
:
'user'
,
"text"
:
prompts
[
0
]
}]
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_num_tokens
(
prompts
[
0
])
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_num_tokens
(
prompts
[
0
])
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
end_at
=
time
.
perf_counter
()
...
...
@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
self
.
conversation_message_task
.
save_message
(
llm_message
=
self
.
llm_message
,
by_stopped
=
True
)
else
:
logging
.
error
(
error
)
def
on_chain_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_chain_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_tool_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
input_str
:
str
,
**
kwargs
:
Any
,
)
->
None
:
pass
def
on_agent_action
(
self
,
action
:
AgentAction
,
color
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
)
->
Any
:
pass
def
on_tool_end
(
self
,
output
:
str
,
color
:
Optional
[
str
]
=
None
,
observation_prefix
:
Optional
[
str
]
=
None
,
llm_prefix
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
)
->
None
:
pass
def
on_tool_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_text
(
self
,
text
:
str
,
color
:
Optional
[
str
]
=
None
,
end
:
str
=
""
,
**
kwargs
:
Optional
[
str
],
)
->
None
:
pass
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
color
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
)
->
None
:
pass
api/core/callback_handler/main_chain_gather_callback_handler.py
View file @
eea011bd
import
logging
import
time
from
typing
import
Any
,
Dict
,
List
,
Union
,
Optional
from
typing
import
Any
,
Dict
,
Union
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.entity.chain_result
import
ChainResult
...
...
@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
class
MainChainGatherCallbackHandler
(
BaseCallbackHandler
):
"""Callback Handler that prints to std out."""
raise_error
:
bool
=
True
def
__init__
(
self
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
"""Initialize callback handler."""
...
...
@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
)
->
None
:
"""Print out that we are entering a chain."""
if
not
self
.
_current_chain_result
:
self
.
_current_chain_result
=
ChainResult
(
type
=
serialized
[
'name'
],
prompt
=
inputs
,
started_at
=
time
.
perf_counter
()
)
self
.
_current_chain_message
=
self
.
conversation_message_task
.
init_chain
(
self
.
_current_chain_result
)
self
.
agent_loop_gather_callback_handler
.
current_chain
=
self
.
_current_chain_message
chain_type
=
serialized
[
'id'
][
-
1
]
if
chain_type
:
self
.
_current_chain_result
=
ChainResult
(
type
=
chain_type
,
prompt
=
inputs
,
started_at
=
time
.
perf_counter
()
)
self
.
_current_chain_message
=
self
.
conversation_message_task
.
init_chain
(
self
.
_current_chain_result
)
self
.
agent_loop_gather_callback_handler
.
current_chain
=
self
.
_current_chain_message
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we finished a chain."""
...
...
@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
logging
.
error
(
error
)
self
.
clear_chain_results
()
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
pass
def
on_llm_new_token
(
self
,
token
:
str
,
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
pass
def
on_llm_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
logging
.
error
(
error
)
def
on_tool_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
input_str
:
str
,
**
kwargs
:
Any
,
)
->
None
:
pass
def
on_agent_action
(
self
,
action
:
AgentAction
,
color
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
)
->
Any
:
pass
def
on_tool_end
(
self
,
output
:
str
,
color
:
Optional
[
str
]
=
None
,
observation_prefix
:
Optional
[
str
]
=
None
,
llm_prefix
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
)
->
None
:
pass
def
on_tool_error
(
self
,
error
:
Union
[
Exception
,
KeyboardInterrupt
],
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
logging
.
error
(
error
)
def
on_text
(
self
,
text
:
str
,
color
:
Optional
[
str
]
=
None
,
end
:
str
=
""
,
**
kwargs
:
Optional
[
str
],
)
->
None
:
"""Run on additional input from chains and agents."""
pass
def
on_agent_finish
(
self
,
finish
:
AgentFinish
,
**
kwargs
:
Any
)
->
Any
:
"""Run on agent end."""
pass
self
.
clear_chain_results
()
\ No newline at end of file
api/core/callback_handler/std_out_callback_handler.py
View file @
eea011bd
import
os
import
sys
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.input
import
print_text
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
BaseMessage
class
DifyStdOutCallbackHandler
(
BaseCallbackHandler
):
...
...
@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Initialize callback handler."""
self
.
color
=
color
def
on_chat_model_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
messages
:
List
[
List
[
BaseMessage
]],
**
kwargs
:
Any
)
->
Any
:
print_text
(
"
\n
[on_chat_model_start]
\n
"
,
color
=
'blue'
)
for
sub_messages
in
messages
:
for
sub_message
in
sub_messages
:
print_text
(
str
(
sub_message
)
+
"
\n
"
,
color
=
'blue'
)
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
"""Print out the prompts."""
print_text
(
"
\n
[on_llm_start]
\n
"
,
color
=
'blue'
)
if
'Chat'
in
serialized
[
'name'
]:
for
prompt
in
prompts
:
print_text
(
prompt
+
"
\n
"
,
color
=
'blue'
)
else
:
print_text
(
prompts
[
0
]
+
"
\n
"
,
color
=
'blue'
)
print_text
(
prompts
[
0
]
+
"
\n
"
,
color
=
'blue'
)
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
"""Do nothing."""
...
...
@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we are entering a chain."""
c
lass_name
=
serialized
[
"name"
]
print_text
(
"
\n
[on_chain_start]
\n
Chain: "
+
c
lass_nam
e
+
"
\n
Inputs: "
+
str
(
inputs
)
+
"
\n
"
,
color
=
'pink'
)
c
hain_type
=
serialized
[
'id'
][
-
1
]
print_text
(
"
\n
[on_chain_start]
\n
Chain: "
+
c
hain_typ
e
+
"
\n
Inputs: "
+
str
(
inputs
)
+
"
\n
"
,
color
=
'pink'
)
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we finished a chain."""
...
...
@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent end."""
print_text
(
"[on_agent_finish] "
+
finish
.
return_values
[
'output'
]
+
"
\n
"
,
color
=
'green'
,
end
=
"
\n
"
)
@
property
def
ignore_llm
(
self
)
->
bool
:
"""Whether to ignore LLM callbacks."""
return
not
os
.
environ
.
get
(
"DEBUG"
)
or
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
!=
'true'
@
property
def
ignore_chain
(
self
)
->
bool
:
"""Whether to ignore chain callbacks."""
return
not
os
.
environ
.
get
(
"DEBUG"
)
or
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
!=
'true'
@
property
def
ignore_agent
(
self
)
->
bool
:
"""Whether to ignore agent callbacks."""
return
not
os
.
environ
.
get
(
"DEBUG"
)
or
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
!=
'true'
@
property
def
ignore_chat_model
(
self
)
->
bool
:
"""Whether to ignore chat model callbacks."""
return
not
os
.
environ
.
get
(
"DEBUG"
)
or
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
!=
'true'
class
DifyStreamingStdOutCallbackHandler
(
DifyStdOutCallbackHandler
):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
...
...
api/core/chain/chain_builder.py
View file @
eea011bd
from
typing
import
Optional
from
langchain.callbacks
import
CallbackManager
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.tool_chain
import
ToolChain
...
...
@@ -14,7 +12,7 @@ class ChainBuilder:
tool
=
tool
,
input_key
=
kwargs
.
get
(
'input_key'
,
'input'
),
output_key
=
kwargs
.
get
(
'output_key'
,
'tool_output'
),
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
@
classmethod
...
...
@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words
=
sensitive_words
.
split
(
","
),
canned_response
=
tool_config
.
get
(
"canned_response"
,
''
),
output_key
=
"sensitive_word_avoidance_output"
,
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
,
callback
s
=
[
DifyStdOutCallbackHandler
()]
,
**
kwargs
)
...
...
api/core/chain/llm_router_chain.py
View file @
eea011bd
"""Base classes for LLM-powered router chains."""
from
__future__
import
annotations
import
json
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
pydantic
import
root_validator
from
langchain.chains
import
LLMChain
from
langchain.prompts
import
BasePromptTemplate
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
,
BaseLanguageModel
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
from
libs.json_in_md_parser
import
parse_and_check_json_markdown
...
...
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
raise
ValueError
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
]
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
output
=
cast
(
Dict
[
str
,
Any
],
...
...
api/core/chain/main_chain_builder.py
View file @
eea011bd
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
cast
from
langchain.callbacks
import
SharedCallbackManager
,
CallbackManager
from
langchain.chains
import
SequentialChain
from
langchain.chains.base
import
Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.chain_builder
import
ChainBuilder
...
...
@@ -18,6 +16,7 @@ from models.dataset import Dataset
class
MainChainBuilder
:
@
classmethod
def
to_langchain_components
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
rest_tokens
:
int
,
conversation_message_task
:
ConversationMessageTask
):
first_input_key
=
"input"
final_output_key
=
"output"
...
...
@@ -30,6 +29,7 @@ class MainChainBuilder:
tool_chains
,
chains_output_key
=
cls
.
get_agent_chains
(
tenant_id
=
tenant_id
,
agent_mode
=
agent_mode
,
rest_tokens
=
rest_tokens
,
memory
=
memory
,
conversation_message_task
=
conversation_message_task
)
...
...
@@ -42,9 +42,8 @@ class MainChainBuilder:
return
None
for
chain
in
chains
:
# do not add handler into singleton callback manager
if
not
isinstance
(
chain
.
callback_manager
,
SharedCallbackManager
):
chain
.
callback_manager
.
add_handler
(
chain_callback_handler
)
chain
=
cast
(
Chain
,
chain
)
chain
.
callbacks
.
append
(
chain_callback_handler
)
# build main chain
overall_chain
=
SequentialChain
(
...
...
@@ -57,7 +56,9 @@ class MainChainBuilder:
return
overall_chain
@
classmethod
def
get_agent_chains
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
memory
:
Optional
[
BaseChatMemory
],
def
get_agent_chains
(
cls
,
tenant_id
:
str
,
agent_mode
:
dict
,
rest_tokens
:
int
,
memory
:
Optional
[
BaseChatMemory
],
conversation_message_task
:
ConversationMessageTask
):
# agent mode
chains
=
[]
...
...
@@ -93,7 +94,8 @@ class MainChainBuilder:
tenant_id
=
tenant_id
,
datasets
=
datasets
,
conversation_message_task
=
conversation_message_task
,
callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
rest_tokens
=
rest_tokens
,
callbacks
=
[
DifyStdOutCallbackHandler
()]
)
chains
.
append
(
multi_dataset_router_chain
)
...
...
api/core/chain/multi_dataset_router_chain.py
View file @
eea011bd
import
math
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
langchain
import
LLMChain
,
PromptTemplate
,
ConversationChain
from
langchain.callbacks
import
CallbackManager
from
langchain
import
PromptTemplate
from
langchain.callbacks
.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.schema
import
BaseLanguageModel
from
pydantic
import
Extra
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
...
...
@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from
core.chain.llm_router_chain
import
LLMRouterChain
,
RouterOutputParser
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.tool.dataset_tool_builder
import
DatasetToolBuilder
from
core.tool.llama_index_tool
import
EnhanceLlamaIndexTool
from
models.dataset
import
Dataset
from
core.tool.dataset_index_tool
import
DatasetTool
from
models.dataset
import
Dataset
,
DatasetProcessRule
DEFAULT_K
=
2
CONTEXT_TOKENS_PERCENT
=
0.3
MULTI_PROMPT_ROUTER_TEMPLATE
=
"""
Given a raw text input to a language model select the model prompt best suited for
\
the input. You will be given the names of the available prompts and a description of
\
...
...
@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
router_chain
:
LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools
:
Mapping
[
str
,
EnhanceLlamaIndex
Tool
]
dataset_tools
:
Mapping
[
str
,
Dataset
Tool
]
"""Map of name to candidate chains that inputs can be routed to."""
class
Config
:
...
...
@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
tenant_id
:
str
,
datasets
:
List
[
Dataset
],
conversation_message_task
:
ConversationMessageTask
,
rest_tokens
:
int
,
**
kwargs
:
Any
,
):
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
temperature
=
0
,
max_tokens
=
1024
,
callback
_manager
=
llm_callback_manager
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
destinations
=
[
"
{}
: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
destinations
=
[
"
[[{}]]
: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
else
(
'useful for when you want to answer queries about the '
+
d
.
name
))
for
d
in
datasets
]
destinations_str
=
"
\n
"
.
join
(
destinations
)
router_template
=
MULTI_PROMPT_ROUTER_TEMPLATE
.
format
(
destinations
=
destinations_str
)
router_prompt
=
PromptTemplate
(
template
=
router_template
,
input_variables
=
[
"input"
],
output_parser
=
RouterOutputParser
(),
)
router_chain
=
LLMRouterChain
.
from_llm
(
llm
,
router_prompt
)
dataset_tools
=
{}
for
dataset
in
datasets
:
dataset_tool
=
DatasetToolBuilder
.
build_dataset_tool
(
# fulfill description when it is empty
if
dataset
.
available_document_count
==
0
or
dataset
.
available_document_count
==
0
:
continue
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
k
=
cls
.
_dynamic_calc_retrieve_k
(
dataset
,
rest_tokens
)
if
k
==
0
:
continue
dataset_tool
=
DatasetTool
(
name
=
f
"dataset-{dataset.id}"
,
description
=
description
,
k
=
k
,
dataset
=
dataset
,
response_mode
=
'no_synthesizer'
,
# "compact"
callback_handler
=
DatasetToolCallbackHandler
(
conversation_message_task
)
callbacks
=
[
DatasetToolCallbackHandler
(
conversation_message_task
),
DifyStdOutCallbackHandler
()]
)
if
dataset_tool
:
dataset_tools
[
dataset
.
id
]
=
dataset_tool
dataset_tools
[
str
(
dataset
.
id
)]
=
dataset_tool
return
cls
(
router_chain
=
router_chain
,
...
...
@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
**
kwargs
,
)
@
classmethod
def
_dynamic_calc_retrieve_k
(
cls
,
dataset
:
Dataset
,
rest_tokens
:
int
)
->
int
:
processing_rule
=
dataset
.
latest_process_rule
if
not
processing_rule
:
return
DEFAULT_K
if
processing_rule
.
mode
==
"custom"
:
rules
=
processing_rule
.
rules_dict
if
not
rules
:
return
DEFAULT_K
segmentation
=
rules
[
"segmentation"
]
segment_max_tokens
=
segmentation
[
"max_tokens"
]
else
:
segment_max_tokens
=
DatasetProcessRule
.
AUTOMATIC_RULES
[
'segmentation'
][
'max_tokens'
]
# when rest_tokens is less than default context tokens
if
rest_tokens
<
segment_max_tokens
*
DEFAULT_K
:
return
rest_tokens
//
segment_max_tokens
context_limit_tokens
=
math
.
floor
(
rest_tokens
*
CONTEXT_TOKENS_PERCENT
)
# when context_limit_tokens is less than default context tokens, use default_k
if
context_limit_tokens
<=
segment_max_tokens
*
DEFAULT_K
:
return
DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return
context_limit_tokens
//
segment_max_tokens
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
]
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
if
len
(
self
.
dataset_tools
)
==
0
:
return
{
"text"
:
''
}
...
...
api/core/chain/sensitive_word_avoidance_chain.py
View file @
eea011bd
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
...
...
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return
self
.
canned_response
return
text
def
_call
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
text
=
inputs
[
self
.
input_key
]
output
=
self
.
_check_sensitive_word
(
text
)
return
{
self
.
output_key
:
output
}
api/core/chain/tool_chain.py
View file @
eea011bd
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Optional
,
Any
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
,
AsyncCallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.tools
import
BaseTool
...
...
@@ -30,12 +31,20 @@ class ToolChain(Chain):
"""
return
[
self
.
output_key
]
def
_call
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
def
_call
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
input
=
inputs
[
self
.
input_key
]
output
=
self
.
tool
.
run
(
input
,
self
.
verbose
)
return
{
self
.
output_key
:
output
}
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
str
])
->
Dict
[
str
,
str
]:
async
def
_acall
(
self
,
inputs
:
Dict
[
str
,
Any
],
run_manager
:
Optional
[
AsyncCallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Run the logic of this chain and return the output."""
input
=
inputs
[
self
.
input_key
]
output
=
await
self
.
tool
.
arun
(
input
,
self
.
verbose
)
...
...
api/core/completion.py
View file @
eea011bd
import
logging
from
typing
import
Optional
,
List
,
Union
,
Tuple
from
langchain.callbacks
import
CallbackManager
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.llms
import
BaseLLM
from
langchain.schema
import
BaseMessage
,
BaseLanguageModel
,
HumanMessage
from
langchain.schema
import
BaseMessage
,
HumanMessage
from
requests.exceptions
import
ChunkedEncodingError
from
core.constant
import
llm_constant
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
DifyStdOutCallbackHandler
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
,
PubHandler
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.llm.error
import
LLMBadRequestError
from
core.llm.llm_builder
import
LLMBuilder
from
core.chain.main_chain_builder
import
MainChainBuilder
...
...
@@ -34,8 +35,6 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
cls
.
validate_query_tokens
(
app
.
tenant_id
,
app_model_config
,
query
)
memory
=
None
if
conversation
:
# get memory of conversation (read-only)
...
...
@@ -48,6 +47,14 @@ class Completion:
inputs
=
conversation
.
inputs
rest_tokens_for_context_and_memory
=
cls
.
get_validate_rest_tokens
(
mode
=
app
.
mode
,
tenant_id
=
app
.
tenant_id
,
app_model_config
=
app_model_config
,
query
=
query
,
inputs
=
inputs
)
conversation_message_task
=
ConversationMessageTask
(
task_id
=
task_id
,
app
=
app
,
...
...
@@ -64,6 +71,7 @@ class Completion:
main_chain
=
MainChainBuilder
.
to_langchain_components
(
tenant_id
=
app
.
tenant_id
,
agent_mode
=
app_model_config
.
agent_mode_dict
,
rest_tokens
=
rest_tokens_for_context_and_memory
,
memory
=
ReadOnlyConversationTokenDBStringBufferSharedMemory
(
memory
=
memory
)
if
memory
else
None
,
conversation_message_task
=
conversation_message_task
)
...
...
@@ -115,7 +123,7 @@ class Completion:
memory
=
memory
)
final_llm
.
callback
_manager
=
cls
.
get_llm_callback_manager
(
final_llm
,
streaming
,
conversation_message_task
)
final_llm
.
callback
s
=
cls
.
get_llm_callbacks
(
final_llm
,
streaming
,
conversation_message_task
)
cls
.
recale_llm_max_tokens
(
final_llm
=
final_llm
,
...
...
@@ -247,16 +255,14 @@ And answer according to the language of the user's question.
return
messages
,
[
'
\n
Human:'
]
@
classmethod
def
get_llm_callback
_manager
(
cls
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
streaming
:
bool
,
conversation_message_task
:
ConversationMessageTask
)
->
CallbackManager
:
def
get_llm_callback
s
(
cls
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
streaming
:
bool
,
conversation_message_task
:
ConversationMessageTask
)
->
List
[
BaseCallbackHandler
]
:
llm_callback_handler
=
LLMCallbackHandler
(
llm
,
conversation_message_task
)
if
streaming
:
callback_handlers
=
[
llm_callback_handler
,
DifyStreamingStdOutCallbackHandler
()]
return
[
llm_callback_handler
,
DifyStreamingStdOutCallbackHandler
()]
else
:
callback_handlers
=
[
llm_callback_handler
,
DifyStdOutCallbackHandler
()]
return
CallbackManager
(
callback_handlers
)
return
[
llm_callback_handler
,
DifyStdOutCallbackHandler
()]
@
classmethod
def
get_history_messages_from_memory
(
cls
,
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
,
...
...
@@ -293,7 +299,8 @@ And answer according to the language of the user's question.
return
memory
@
classmethod
def
validate_query_tokens
(
cls
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
,
query
:
str
):
def
get_validate_rest_tokens
(
cls
,
mode
:
str
,
tenant_id
:
str
,
app_model_config
:
AppModelConfig
,
query
:
str
,
inputs
:
dict
)
->
int
:
llm
=
LLMBuilder
.
to_llm_from_model
(
tenant_id
=
tenant_id
,
model
=
app_model_config
.
model_dict
...
...
@@ -302,8 +309,26 @@ And answer according to the language of the user's question.
model_limited_tokens
=
llm_constant
.
max_context_token_length
[
llm
.
model_name
]
max_tokens
=
llm
.
max_tokens
if
model_limited_tokens
-
max_tokens
-
llm
.
get_num_tokens
(
query
)
<
0
:
raise
LLMBadRequestError
(
"Query is too long"
)
# get prompt without memory and context
prompt
,
_
=
cls
.
get_main_llm_prompt
(
mode
=
mode
,
llm
=
llm
,
pre_prompt
=
app_model_config
.
pre_prompt
,
query
=
query
,
inputs
=
inputs
,
chain_output
=
None
,
memory
=
None
)
prompt_tokens
=
llm
.
get_num_tokens
(
prompt
)
if
isinstance
(
prompt
,
str
)
\
else
llm
.
get_num_tokens_from_messages
(
prompt
)
rest_tokens
=
model_limited_tokens
-
max_tokens
-
prompt_tokens
if
rest_tokens
<
0
:
raise
LLMBadRequestError
(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return
rest_tokens
@
classmethod
def
recale_llm_max_tokens
(
cls
,
final_llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
...
...
@@ -360,7 +385,7 @@ And answer according to the language of the user's question.
streaming
=
streaming
)
llm
.
callback
_manager
=
cls
.
get_llm_callback_manager
(
llm
,
streaming
,
conversation_message_task
)
llm
.
callback
s
=
cls
.
get_llm_callbacks
(
llm
,
streaming
,
conversation_message_task
)
cls
.
recale_llm_max_tokens
(
final_llm
=
llm
,
...
...
api/core/conversation_message_task.py
View file @
eea011bd
...
...
@@ -293,12 +293,12 @@ class PubHandler:
if
not
user
:
raise
ValueError
(
"user is required"
)
user_str
=
'account-'
+
user
.
id
if
isinstance
(
user
,
Account
)
else
'end-user-'
+
user
.
id
user_str
=
'account-'
+
str
(
user
.
id
)
if
isinstance
(
user
,
Account
)
else
'end-user-'
+
str
(
user
.
id
)
return
"generate_result:{}-{}"
.
format
(
user_str
,
task_id
)
@
classmethod
def
generate_stopped_cache_key
(
cls
,
user
:
Union
[
Account
|
EndUser
],
task_id
:
str
):
user_str
=
'account-'
+
user
.
id
if
isinstance
(
user
,
Account
)
else
'end-user-'
+
user
.
id
user_str
=
'account-'
+
str
(
user
.
id
)
if
isinstance
(
user
,
Account
)
else
'end-user-'
+
str
(
user
.
id
)
return
"generate_result_stopped:{}-{}"
.
format
(
user_str
,
task_id
)
def
pub_text
(
self
,
text
:
str
):
...
...
@@ -306,10 +306,10 @@ class PubHandler:
'event'
:
'message'
,
'data'
:
{
'task_id'
:
self
.
_task_id
,
'message_id'
:
s
elf
.
_message
.
id
,
'message_id'
:
s
tr
(
self
.
_message
.
id
)
,
'text'
:
text
,
'mode'
:
self
.
_conversation
.
mode
,
'conversation_id'
:
s
elf
.
_conversation
.
id
'conversation_id'
:
s
tr
(
self
.
_conversation
.
id
)
}
}
...
...
api/core/data_loader/file_extractor.py
0 → 100644
View file @
eea011bd
import
tempfile
from
pathlib
import
Path
from
typing
import
List
,
Union
from
langchain.document_loaders
import
TextLoader
,
Docx2txtLoader
from
langchain.schema
import
Document
from
core.data_loader.loader.csv
import
CSVLoader
from
core.data_loader.loader.excel
import
ExcelLoader
from
core.data_loader.loader.html
import
HTMLLoader
from
core.data_loader.loader.markdown
import
MarkdownLoader
from
core.data_loader.loader.pdf
import
PdfLoader
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
class
FileExtractor
:
@
classmethod
def
load
(
cls
,
upload_file
:
UploadFile
,
return_text
:
bool
=
False
)
->
Union
[
List
[
Document
]
|
str
]:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
upload_file
.
key
)
.
suffix
file_path
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage
.
download
(
upload_file
.
key
,
file_path
)
input_file
=
Path
(
file_path
)
delimiter
=
'
\n
'
if
input_file
.
suffix
==
'.xlsx'
:
loader
=
ExcelLoader
(
file_path
)
elif
input_file
.
suffix
==
'.pdf'
:
loader
=
PdfLoader
(
file_path
,
upload_file
=
upload_file
)
elif
input_file
.
suffix
in
[
'.md'
,
'.markdown'
]:
loader
=
MarkdownLoader
(
file_path
,
autodetect_encoding
=
True
)
elif
input_file
.
suffix
in
[
'.htm'
,
'.html'
]:
loader
=
HTMLLoader
(
file_path
)
elif
input_file
.
suffix
==
'.docx'
:
loader
=
Docx2txtLoader
(
file_path
)
elif
input_file
.
suffix
==
'.csv'
:
loader
=
CSVLoader
(
file_path
,
autodetect_encoding
=
True
)
else
:
# txt
loader
=
TextLoader
(
file_path
,
autodetect_encoding
=
True
)
return
delimiter
.
join
([
document
.
page_content
for
document
in
loader
.
load
()])
if
return_text
else
loader
.
load
()
api/core/data_loader/loader/csv.py
0 → 100644
View file @
eea011bd
import
logging
from
typing
import
Optional
,
Dict
,
List
from
langchain.document_loaders
import
CSVLoader
as
LCCSVLoader
from
langchain.document_loaders.helpers
import
detect_file_encodings
from
models.dataset
import
Document
logger
=
logging
.
getLogger
(
__name__
)
class
CSVLoader
(
LCCSVLoader
):
def
__init__
(
self
,
file_path
:
str
,
source_column
:
Optional
[
str
]
=
None
,
csv_args
:
Optional
[
Dict
]
=
None
,
encoding
:
Optional
[
str
]
=
None
,
autodetect_encoding
:
bool
=
True
,
):
self
.
file_path
=
file_path
self
.
source_column
=
source_column
self
.
encoding
=
encoding
self
.
csv_args
=
csv_args
or
{}
self
.
autodetect_encoding
=
autodetect_encoding
def
load
(
self
)
->
List
[
Document
]:
"""Load data into document objects."""
try
:
with
open
(
self
.
file_path
,
newline
=
""
,
encoding
=
self
.
encoding
)
as
csvfile
:
docs
=
self
.
_read_from_file
(
csvfile
)
except
UnicodeDecodeError
as
e
:
if
self
.
autodetect_encoding
:
detected_encodings
=
detect_file_encodings
(
self
.
file_path
)
for
encoding
in
detected_encodings
:
logger
.
debug
(
"Trying encoding: "
,
encoding
.
encoding
)
try
:
with
open
(
self
.
file_path
,
newline
=
""
,
encoding
=
encoding
.
encoding
)
as
csvfile
:
docs
=
self
.
_read_from_file
(
csvfile
)
break
except
UnicodeDecodeError
:
continue
else
:
raise
RuntimeError
(
f
"Error loading {self.file_path}"
)
from
e
return
docs
def
_read_from_file
(
self
,
csvfile
):
docs
=
[]
csv_reader
=
csv
.
DictReader
(
csvfile
,
**
self
.
csv_args
)
# type: ignore
for
i
,
row
in
enumerate
(
csv_reader
):
content
=
"
\n
"
.
join
(
f
"{k.strip()}: {v.strip()}"
for
k
,
v
in
row
.
items
())
try
:
source
=
(
row
[
self
.
source_column
]
if
self
.
source_column
is
not
None
else
''
)
except
KeyError
:
raise
ValueError
(
f
"Source column '{self.source_column}' not found in CSV file."
)
metadata
=
{
"source"
:
source
,
"row"
:
i
}
doc
=
Document
(
page_content
=
content
,
metadata
=
metadata
)
docs
.
append
(
doc
)
return
docs
api/core/data_loader/loader/excel.py
0 → 100644
View file @
eea011bd
import
json
import
logging
from
typing
import
List
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
from
openpyxl.reader.excel
import
load_workbook
logger
=
logging
.
getLogger
(
__name__
)
class
ExcelLoader
(
BaseLoader
):
"""Load xlxs files.
Args:
file_path: Path to the file to load.
"""
def
__init__
(
self
,
file_path
:
str
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
def
load
(
self
)
->
List
[
Document
]:
data
=
[]
keys
=
[]
wb
=
load_workbook
(
filename
=
self
.
_file_path
,
read_only
=
True
)
# loop over all sheets
for
sheet
in
wb
:
for
row
in
sheet
.
iter_rows
(
values_only
=
True
):
if
all
(
v
is
None
for
v
in
row
):
continue
if
keys
==
[]:
keys
=
list
(
map
(
str
,
row
))
else
:
row_dict
=
dict
(
zip
(
keys
,
row
))
row_dict
=
{
k
:
v
for
k
,
v
in
row_dict
.
items
()
if
v
}
data
.
append
(
json
.
dumps
(
row_dict
,
ensure_ascii
=
False
))
return
[
Document
(
page_content
=
'
\n\n
'
.
join
(
data
))]
api/core/data_loader/loader/html.py
0 → 100644
View file @
eea011bd
import
logging
from
typing
import
List
from
bs4
import
BeautifulSoup
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
logger
=
logging
.
getLogger
(
__name__
)
class
HTMLLoader
(
BaseLoader
):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def
__init__
(
self
,
file_path
:
str
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
def
load
(
self
)
->
List
[
Document
]:
return
[
Document
(
page_content
=
self
.
_load_as_text
())]
def
_load_as_text
(
self
)
->
str
:
with
open
(
self
.
_file_path
,
"rb"
)
as
fp
:
soup
=
BeautifulSoup
(
fp
,
'html.parser'
)
text
=
soup
.
get_text
()
text
=
text
.
strip
()
if
text
else
''
return
text
api/core/
index/readers/markdown_parser
.py
→
api/core/
data_loader/loader/markdown
.py
View file @
eea011bd
"""Markdown parser.
import
logging
import
re
from
typing
import
Optional
,
List
,
Tuple
,
cast
Contains parser for md files.
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.document_loaders.helpers
import
detect_file_encodings
from
langchain.schema
import
Document
logger
=
logging
.
getLogger
(
__name__
)
class
MarkdownLoader
(
BaseLoader
):
"""Load md files.
"""
import
re
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
cast
from
llama_index.readers.file.base_parser
import
BaseParser
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
class
MarkdownParser
(
BaseParser
):
"""Markdown parser.
remove_images: Whether to remove images from the text.
Extract text from markdown files.
Returns dictionary with keys as headers and values as the text between headers
.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding
.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
def
__init__
(
self
,
*
args
:
Any
,
file_path
:
str
,
remove_hyperlinks
:
bool
=
True
,
remove_images
:
bool
=
True
,
**
kwargs
:
Any
,
)
->
None
:
"""Init params."""
super
()
.
__init__
(
*
args
,
**
kwargs
)
encoding
:
Optional
[
str
]
=
None
,
autodetect_encoding
:
bool
=
True
,
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
self
.
_remove_hyperlinks
=
remove_hyperlinks
self
.
_remove_images
=
remove_images
self
.
_encoding
=
encoding
self
.
_autodetect_encoding
=
autodetect_encoding
def
load
(
self
)
->
List
[
Document
]:
tups
=
self
.
parse_tups
(
self
.
_file_path
)
documents
=
[]
for
header
,
value
in
tups
:
value
=
value
.
strip
()
if
header
is
None
:
documents
.
append
(
Document
(
page_content
=
value
))
else
:
documents
.
append
(
Document
(
page_content
=
f
"
\n\n
{header}
\n
{value}"
))
return
documents
def
markdown_to_tups
(
self
,
markdown_text
:
str
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
"""Convert a markdown file to a dictionary.
...
...
@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser):
content
=
re
.
sub
(
pattern
,
r"\1"
,
content
)
return
content
def
_init_parser
(
self
)
->
Dict
:
"""Initialize the parser with the config."""
return
{}
def
parse_tups
(
self
,
filepath
:
Path
,
errors
:
str
=
"ignore"
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
def
parse_tups
(
self
,
filepath
:
str
)
->
List
[
Tuple
[
Optional
[
str
],
str
]]:
"""Parse file into tuples."""
with
open
(
filepath
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
content
=
f
.
read
()
content
=
""
try
:
with
open
(
filepath
,
"r"
,
encoding
=
self
.
_encoding
)
as
f
:
content
=
f
.
read
()
except
UnicodeDecodeError
as
e
:
if
self
.
_autodetect_encoding
:
detected_encodings
=
detect_file_encodings
(
filepath
)
for
encoding
in
detected_encodings
:
logger
.
debug
(
"Trying encoding: "
,
encoding
.
encoding
)
try
:
with
open
(
filepath
,
encoding
=
encoding
.
encoding
)
as
f
:
content
=
f
.
read
()
break
except
UnicodeDecodeError
:
continue
else
:
raise
RuntimeError
(
f
"Error loading {filepath}"
)
from
e
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error loading {filepath}"
)
from
e
if
self
.
_remove_hyperlinks
:
content
=
self
.
remove_hyperlinks
(
content
)
if
self
.
_remove_images
:
content
=
self
.
remove_images
(
content
)
markdown_tups
=
self
.
markdown_to_tups
(
content
)
return
markdown_tups
def
parse_file
(
self
,
filepath
:
Path
,
errors
:
str
=
"ignore"
)
->
Union
[
str
,
List
[
str
]]:
"""Parse file into string."""
tups
=
self
.
parse_tups
(
filepath
,
errors
=
errors
)
results
=
[]
# TODO: don't include headers right now
for
header
,
value
in
tups
:
if
header
is
None
:
results
.
append
(
value
)
else
:
results
.
append
(
f
"
\n\n
{header}
\n
{value}"
)
return
results
return
self
.
markdown_to_tups
(
content
)
api/core/data_
source
/notion.py
→
api/core/data_
loader/loader
/notion.py
View file @
eea011bd
"""Notion reader."""
import
json
import
logging
import
os
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
List
,
Dict
,
Any
,
Optional
import
requests
# type: ignore
import
requests
from
flask
import
current_app
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
from
llama_index.readers.base
import
BaseReader
from
llama_index.readers.schema.base
import
Document
from
extensions.ext_database
import
db
from
models.dataset
import
Document
as
DocumentModel
from
models.source
import
DataSourceBinding
logger
=
logging
.
getLogger
(
__name__
)
INTEGRATION_TOKEN_NAME
=
"NOTION_INTEGRATION_TOKEN"
BLOCK_CHILD_URL_TMPL
=
"https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL
=
"https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL
=
"https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL
=
"https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL
=
"https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE
=
[
'heading_1'
,
'heading_2'
,
'heading_3'
]
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Notion DB reader coming soon!
class
NotionPageReader
(
BaseReader
):
"""Notion Page reader.
Reads a set of Notion pages.
Args:
integration_token (str): Notion integration token.
"""
def
__init__
(
self
,
integration_token
:
Optional
[
str
]
=
None
)
->
None
:
"""Initialize with parameters."""
if
integration_token
is
None
:
integration_token
=
os
.
getenv
(
INTEGRATION_TOKEN_NAME
)
class
NotionLoader
(
BaseLoader
):
def
__init__
(
self
,
notion_access_token
:
str
,
notion_workspace_id
:
str
,
notion_obj_id
:
str
,
notion_page_type
:
str
,
document_model
:
Optional
[
DocumentModel
]
=
None
):
self
.
_document_model
=
document_model
self
.
_notion_workspace_id
=
notion_workspace_id
self
.
_notion_obj_id
=
notion_obj_id
self
.
_notion_page_type
=
notion_page_type
self
.
_notion_access_token
=
notion_access_token
if
not
self
.
_notion_access_token
:
integration_token
=
current_app
.
config
.
get
(
'NOTION_INTEGRATION_TOKEN'
)
if
integration_token
is
None
:
raise
ValueError
(
"Must specify `integration_token` or set environment "
"variable `NOTION_INTEGRATION_TOKEN`."
)
self
.
token
=
integration_token
self
.
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
}
def
_read_block
(
self
,
block_id
:
str
,
num_tabs
:
int
=
0
)
->
str
:
"""Read a block."""
done
=
False
self
.
_notion_access_token
=
integration_token
@
classmethod
def
from_document
(
cls
,
document_model
:
DocumentModel
):
data_source_info
=
document_model
.
data_source_info_dict
if
not
data_source_info
or
'notion_page_id'
not
in
data_source_info
\
or
'notion_workspace_id'
not
in
data_source_info
:
raise
ValueError
(
"no notion page found"
)
notion_workspace_id
=
data_source_info
[
'notion_workspace_id'
]
notion_obj_id
=
data_source_info
[
'notion_page_id'
]
notion_page_type
=
data_source_info
[
'type'
]
notion_access_token
=
cls
.
_get_access_token
(
document_model
.
tenant_id
,
notion_workspace_id
)
return
cls
(
notion_access_token
=
notion_access_token
,
notion_workspace_id
=
notion_workspace_id
,
notion_obj_id
=
notion_obj_id
,
notion_page_type
=
notion_page_type
,
document_model
=
document_model
)
def
load
(
self
)
->
List
[
Document
]:
self
.
update_last_edited_time
(
self
.
_document_model
)
text_docs
=
self
.
_load_data_as_documents
(
self
.
_notion_obj_id
,
self
.
_notion_page_type
)
return
text_docs
def
_load_data_as_documents
(
self
,
notion_obj_id
:
str
,
notion_page_type
:
str
)
->
List
[
Document
]:
docs
=
[]
if
notion_page_type
==
'database'
:
# get all the pages in the database
page_text
=
self
.
_get_notion_database_data
(
notion_obj_id
)
docs
.
append
(
Document
(
page_content
=
page_text
))
elif
notion_page_type
==
'page'
:
page_text_list
=
self
.
_get_notion_block_data
(
notion_obj_id
)
for
page_text
in
page_text_list
:
docs
.
append
(
Document
(
page_content
=
page_text
))
else
:
raise
ValueError
(
"notion page type not supported"
)
return
docs
def
_get_notion_database_data
(
self
,
database_id
:
str
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
)
->
str
:
"""Get all the pages from a Notion database."""
res
=
requests
.
post
(
DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
),
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
,
)
data
=
res
.
json
()
database_content_list
=
[]
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
return
""
for
result
in
data
[
"results"
]:
properties
=
result
[
'properties'
]
data
=
{}
for
property_name
,
property_value
in
properties
.
items
():
type
=
property_value
[
'type'
]
if
type
==
'multi_select'
:
value
=
[]
multi_select_list
=
property_value
[
type
]
for
multi_select
in
multi_select_list
:
value
.
append
(
multi_select
[
'name'
])
elif
type
==
'rich_text'
or
type
==
'title'
:
if
len
(
property_value
[
type
])
>
0
:
value
=
property_value
[
type
][
0
][
'plain_text'
]
else
:
value
=
''
elif
type
==
'select'
or
type
==
'status'
:
if
property_value
[
type
]:
value
=
property_value
[
type
][
'name'
]
else
:
value
=
''
else
:
value
=
property_value
[
type
]
data
[
property_name
]
=
value
database_content_list
.
append
(
json
.
dumps
(
data
,
ensure_ascii
=
False
))
return
"
\n\n
"
.
join
(
database_content_list
)
def
_get_notion_block_data
(
self
,
page_id
:
str
)
->
List
[
str
]:
result_lines_arr
=
[]
cur_block_id
=
block
_id
while
not
don
e
:
cur_block_id
=
page
_id
while
Tru
e
:
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
self
.
headers
,
json
=
query_dict
"GET"
,
block_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
data
=
res
.
json
()
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
done
=
True
break
# current block's heading
heading
=
''
for
result
in
data
[
"results"
]:
result_type
=
result
[
"type"
]
...
...
@@ -71,6 +165,7 @@ class NotionPageReader(BaseReader):
if
result_type
==
'table'
:
result_block_id
=
result
[
"id"
]
text
=
self
.
_read_table_rows
(
result_block_id
)
text
+=
"
\n\n
"
result_lines_arr
.
append
(
text
)
else
:
if
"rich_text"
in
result_obj
:
...
...
@@ -78,91 +173,53 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object
if
"text"
in
rich_text
:
text
=
rich_text
[
"text"
][
"content"
]
prefix
=
"
\t
"
*
num_tabs
cur_result_text_arr
.
append
(
prefix
+
text
)
cur_result_text_arr
.
append
(
text
)
if
result_type
in
HEADING_TYPE
:
heading
=
text
result_block_id
=
result
[
"id"
]
has_children
=
result
[
"has_children"
]
block_type
=
result
[
"type"
]
if
has_children
and
block_type
!=
'child_page'
:
children_text
=
self
.
_read_block
(
result_block_id
,
num_tabs
=
num_tabs
+
1
result_block_id
,
num_tabs
=
1
)
cur_result_text_arr
.
append
(
children_text
)
cur_result_text
=
"
\n
"
.
join
(
cur_result_text_arr
)
cur_result_text
+=
"
\n\n
"
if
result_type
in
HEADING_TYPE
:
result_lines_arr
.
append
(
cur_result_text
)
else
:
result_lines_arr
.
append
(
f
'{heading}
\n
{cur_result_text}'
)
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
else
:
cur_block_id
=
data
[
"next_cursor"
]
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
def
_read_table_rows
(
self
,
block_id
:
str
)
->
str
:
"""Read table rows."""
done
=
False
result_lines_arr
=
[]
cur_block_id
=
block_id
while
not
done
:
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
self
.
headers
,
json
=
query_dict
)
data
=
res
.
json
()
# get table headers text
table_header_cell_texts
=
[]
tabel_header_cells
=
data
[
"results"
][
0
][
'table_row'
][
'cells'
]
for
tabel_header_cell
in
tabel_header_cells
:
if
tabel_header_cell
:
for
table_header_cell_text
in
tabel_header_cell
:
text
=
table_header_cell_text
[
"text"
][
"content"
]
table_header_cell_texts
.
append
(
text
)
# get table columns text and format
results
=
data
[
"results"
]
for
i
in
range
(
len
(
results
)
-
1
):
column_texts
=
[]
tabel_column_cells
=
data
[
"results"
][
i
+
1
][
'table_row'
][
'cells'
]
for
j
in
range
(
len
(
tabel_column_cells
)):
if
tabel_column_cells
[
j
]:
for
table_column_cell_text
in
tabel_column_cells
[
j
]:
column_text
=
table_column_cell_text
[
"text"
][
"content"
]
column_texts
.
append
(
f
'{table_header_cell_texts[j]}:{column_text}'
)
cur_result_text
=
"
\n
"
.
join
(
column_texts
)
result_lines_arr
.
append
(
cur_result_text
)
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
else
:
cur_block_id
=
data
[
"next_cursor"
]
return
result_lines_arr
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
def
_read_parent_blocks
(
self
,
block_id
:
str
,
num_tabs
:
int
=
0
)
->
List
[
str
]:
def
_read_block
(
self
,
block_id
:
str
,
num_tabs
:
int
=
0
)
->
str
:
"""Read a block."""
done
=
False
result_lines_arr
=
[]
cur_block_id
=
block_id
while
not
don
e
:
while
Tru
e
:
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
self
.
headers
,
json
=
query_dict
"GET"
,
block_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
data
=
res
.
json
()
# current block's heading
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
break
heading
=
''
for
result
in
data
[
"results"
]:
result_type
=
result
[
"type"
]
...
...
@@ -171,7 +228,6 @@ class NotionPageReader(BaseReader):
if
result_type
==
'table'
:
result_block_id
=
result
[
"id"
]
text
=
self
.
_read_table_rows
(
result_block_id
)
text
+=
"
\n\n
"
result_lines_arr
.
append
(
text
)
else
:
if
"rich_text"
in
result_obj
:
...
...
@@ -179,10 +235,10 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object
if
"text"
in
rich_text
:
text
=
rich_text
[
"text"
][
"content"
]
cur_result_text_arr
.
append
(
text
)
prefix
=
"
\t
"
*
num_tabs
cur_result_text_arr
.
append
(
prefix
+
text
)
if
result_type
in
HEADING_TYPE
:
heading
=
text
result_block_id
=
result
[
"id"
]
has_children
=
result
[
"has_children"
]
block_type
=
result
[
"type"
]
...
...
@@ -193,177 +249,121 @@ class NotionPageReader(BaseReader):
cur_result_text_arr
.
append
(
children_text
)
cur_result_text
=
"
\n
"
.
join
(
cur_result_text_arr
)
cur_result_text
+=
"
\n\n
"
if
result_type
in
HEADING_TYPE
:
result_lines_arr
.
append
(
cur_result_text
)
else
:
result_lines_arr
.
append
(
f
'{heading}
\n
{cur_result_text}'
)
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
else
:
cur_block_id
=
data
[
"next_cursor"
]
return
result_lines_arr
def
read_page
(
self
,
page_id
:
str
)
->
str
:
"""Read a page."""
return
self
.
_read_block
(
page_id
)
def
read_page_as_documents
(
self
,
page_id
:
str
)
->
List
[
str
]:
"""Read a page as documents."""
return
self
.
_read_parent_blocks
(
page_id
)
def
query_database_data
(
self
,
database_id
:
str
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
)
->
str
:
"""Get all the pages from a Notion database."""
res
=
requests
.
post
\
(
DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
),
headers
=
self
.
headers
,
json
=
query_dict
,
)
data
=
res
.
json
()
database_content_list
=
[]
if
'results'
not
in
data
or
data
[
"results"
]
is
None
:
return
""
for
result
in
data
[
"results"
]:
properties
=
result
[
'properties'
]
data
=
{}
for
property_name
,
property_value
in
properties
.
items
():
type
=
property_value
[
'type'
]
if
type
==
'multi_select'
:
value
=
[]
multi_select_list
=
property_value
[
type
]
for
multi_select
in
multi_select_list
:
value
.
append
(
multi_select
[
'name'
])
elif
type
==
'rich_text'
or
type
==
'title'
:
if
len
(
property_value
[
type
])
>
0
:
value
=
property_value
[
type
][
0
][
'plain_text'
]
else
:
value
=
''
elif
type
==
'select'
or
type
==
'status'
:
if
property_value
[
type
]:
value
=
property_value
[
type
][
'name'
]
else
:
value
=
''
else
:
value
=
property_value
[
type
]
data
[
property_name
]
=
value
database_content_list
.
append
(
json
.
dumps
(
data
,
ensure_ascii
=
False
))
return
"
\n\n
"
.
join
(
database_content_list
)
def
query_database
(
self
,
database_id
:
str
,
query_dict
:
Dict
[
str
,
Any
]
=
{}
)
->
List
[
str
]:
"""Get all the pages from a Notion database."""
res
=
requests
.
post
\
(
DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
),
headers
=
self
.
headers
,
json
=
query_dict
,
)
data
=
res
.
json
()
page_ids
=
[]
for
result
in
data
[
"results"
]:
page_id
=
result
[
"id"
]
page_ids
.
append
(
page_id
)
return
page_ids
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
def
search
(
self
,
query
:
str
)
->
List
[
str
]
:
"""
Search Notion page given a text query
."""
def
_read_table_rows
(
self
,
block_id
:
str
)
->
str
:
"""
Read table rows
."""
done
=
False
next_cursor
:
Optional
[
str
]
=
None
page_ids
=
[]
result_lines_arr
=
[]
cur_block_id
=
block_id
while
not
done
:
query_dict
=
{
"query"
:
query
,
}
if
next_cursor
is
not
None
:
query_dict
[
"start_cursor"
]
=
next_cursor
res
=
requests
.
post
(
SEARCH_URL
,
headers
=
self
.
headers
,
json
=
query_dict
)
block_url
=
BLOCK_CHILD_URL_TMPL
.
format
(
block_id
=
cur_block_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
block_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
data
=
res
.
json
()
for
result
in
data
[
"results"
]:
page_id
=
result
[
"id"
]
page_ids
.
append
(
page_id
)
# get table headers text
table_header_cell_texts
=
[]
tabel_header_cells
=
data
[
"results"
][
0
][
'table_row'
][
'cells'
]
for
tabel_header_cell
in
tabel_header_cells
:
if
tabel_header_cell
:
for
table_header_cell_text
in
tabel_header_cell
:
text
=
table_header_cell_text
[
"text"
][
"content"
]
table_header_cell_texts
.
append
(
text
)
# get table columns text and format
results
=
data
[
"results"
]
for
i
in
range
(
len
(
results
)
-
1
):
column_texts
=
[]
tabel_column_cells
=
data
[
"results"
][
i
+
1
][
'table_row'
][
'cells'
]
for
j
in
range
(
len
(
tabel_column_cells
)):
if
tabel_column_cells
[
j
]:
for
table_column_cell_text
in
tabel_column_cells
[
j
]:
column_text
=
table_column_cell_text
[
"text"
][
"content"
]
column_texts
.
append
(
f
'{table_header_cell_texts[j]}:{column_text}'
)
cur_result_text
=
"
\n
"
.
join
(
column_texts
)
result_lines_arr
.
append
(
cur_result_text
)
if
data
[
"next_cursor"
]
is
None
:
done
=
True
break
else
:
next_cursor
=
data
[
"next_cursor"
]
return
page_ids
cur_block_id
=
data
[
"next_cursor"
]
def
load_data
(
self
,
page_ids
:
List
[
str
]
=
[],
database_id
:
Optional
[
str
]
=
None
)
->
List
[
Document
]:
"""Load data from the input directory.
result_lines
=
"
\n
"
.
join
(
result_lines_arr
)
return
result_lines
Args:
page_ids (List[str]): List of page ids to load.
def
update_last_edited_time
(
self
,
document_model
:
DocumentModel
):
if
not
document_model
:
return
Returns:
List[Document]: List of documents.
last_edited_time
=
self
.
get_notion_last_edited_time
()
data_source_info
=
document_model
.
data_source_info_dict
data_source_info
[
'last_edited_time'
]
=
last_edited_time
update_params
=
{
DocumentModel
.
data_source_info
:
json
.
dumps
(
data_source_info
)
}
"""
if
not
page_ids
and
not
database_id
:
raise
ValueError
(
"Must specify either `page_ids` or `database_id`."
)
docs
=
[]
if
database_id
is
not
None
:
# get all the pages in the database
page_ids
=
self
.
query_database
(
database_id
)
for
page_id
in
page_ids
:
page_text
=
self
.
read_page
(
page_id
)
docs
.
append
(
Document
(
page_text
))
else
:
for
page_id
in
page_ids
:
page_text
=
self
.
read_page
(
page_id
)
docs
.
append
(
Document
(
page_text
))
DocumentModel
.
query
.
filter_by
(
id
=
document_model
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
return
docs
def
load_data_as_documents
(
self
,
page_ids
:
List
[
str
]
=
[],
database_id
:
Optional
[
str
]
=
None
)
->
List
[
Document
]:
if
not
page_ids
and
not
database_id
:
raise
ValueError
(
"Must specify either `page_ids` or `database_id`."
)
docs
=
[]
if
database_id
is
not
None
:
# get all the pages in the database
page_text
=
self
.
query_database_data
(
database_id
)
docs
.
append
(
Document
(
page_text
))
def
get_notion_last_edited_time
(
self
)
->
str
:
obj_id
=
self
.
_notion_obj_id
page_type
=
self
.
_notion_page_type
if
page_type
==
'database'
:
retrieve_page_url
=
RETRIEVE_DATABASE_URL_TMPL
.
format
(
database_id
=
obj_id
)
else
:
for
page_id
in
page_ids
:
page_text_list
=
self
.
read_page_as_documents
(
page_id
)
for
page_text
in
page_text_list
:
docs
.
append
(
Document
(
page_text
))
retrieve_page_url
=
RETRIEVE_PAGE_URL_TMPL
.
format
(
page_id
=
obj_id
)
return
docs
def
get_page_last_edited_time
(
self
,
page_id
:
str
)
->
str
:
retrieve_page_url
=
RETRIEVE_PAGE_URL_TMPL
.
format
(
page_id
=
page_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
retrieve_page_url
,
headers
=
self
.
headers
,
json
=
query_dict
"GET"
,
retrieve_page_url
,
headers
=
{
"Authorization"
:
"Bearer "
+
self
.
_notion_access_token
,
"Content-Type"
:
"application/json"
,
"Notion-Version"
:
"2022-06-28"
,
},
json
=
query_dict
)
data
=
res
.
json
()
return
data
[
"last_edited_time"
]
def
get_database_last_edited_time
(
self
,
database_id
:
str
)
->
str
:
retrieve_page_url
=
RETRIEVE_DATABASE_URL_TMPL
.
format
(
database_id
=
database_id
)
query_dict
:
Dict
[
str
,
Any
]
=
{}
res
=
requests
.
request
(
"GET"
,
retrieve_page_url
,
headers
=
self
.
headers
,
json
=
query_dict
)
data
=
res
.
json
()
return
data
[
"last_edited_time"
]
@
classmethod
def
_get_access_token
(
cls
,
tenant_id
:
str
,
notion_workspace_id
:
str
)
->
str
:
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
db
.
and_
(
DataSourceBinding
.
tenant_id
==
tenant_id
,
DataSourceBinding
.
provider
==
'notion'
,
DataSourceBinding
.
disabled
==
False
,
DataSourceBinding
.
source_info
[
'workspace_id'
]
==
f
'"{notion_workspace_id}"'
)
)
.
first
()
if
not
data_source_binding
:
raise
Exception
(
f
'No notion data source binding found for tenant {tenant_id} '
f
'and notion workspace {notion_workspace_id}'
)
if
__name__
==
"__main__"
:
reader
=
NotionPageReader
()
logger
.
info
(
reader
.
search
(
"What I"
))
return
data_source_binding
.
access_token
api/core/data_loader/loader/pdf.py
0 → 100644
View file @
eea011bd
import
logging
from
typing
import
List
,
Optional
from
langchain.document_loaders
import
PyPDFium2Loader
from
langchain.document_loaders.base
import
BaseLoader
from
langchain.schema
import
Document
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
logger
=
logging
.
getLogger
(
__name__
)
class
PdfLoader
(
BaseLoader
):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def
__init__
(
self
,
file_path
:
str
,
upload_file
:
Optional
[
UploadFile
]
=
None
):
"""Initialize with file path."""
self
.
_file_path
=
file_path
self
.
_upload_file
=
upload_file
def
load
(
self
)
->
List
[
Document
]:
plaintext_file_key
=
''
plaintext_file_exists
=
False
if
self
.
_upload_file
:
if
self
.
_upload_file
.
hash
:
plaintext_file_key
=
'upload_files/'
+
self
.
_upload_file
.
tenant_id
+
'/'
\
+
self
.
_upload_file
.
hash
+
'.0625.plaintext'
try
:
text
=
storage
.
load
(
plaintext_file_key
)
.
decode
(
'utf-8'
)
plaintext_file_exists
=
True
return
[
Document
(
page_content
=
text
)]
except
FileNotFoundError
:
pass
documents
=
PyPDFium2Loader
(
file_path
=
self
.
_file_path
)
.
load
()
text_list
=
[]
for
document
in
documents
:
text_list
.
append
(
document
.
page_content
)
text
=
"
\n\n
"
.
join
(
text_list
)
# save plaintext file for caching
if
not
plaintext_file_exists
and
plaintext_file_key
:
storage
.
save
(
plaintext_file_key
,
text
.
encode
(
'utf-8'
))
return
documents
api/core/docstore/dataset_docstore.py
View file @
eea011bd
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
import
tiktoken
from
llama_index.data_structs
import
Node
from
llama_index.docstore.types
import
BaseDocumentStore
from
llama_index.docstore.utils
import
json_to_doc
from
llama_index.schema
import
BaseDocument
from
langchain.schema
import
Document
from
sqlalchemy
import
func
from
core.llm.token_calculator
import
TokenCalculator
...
...
@@ -12,7 +8,7 @@ from extensions.ext_database import db
from
models.dataset
import
Dataset
,
DocumentSegment
class
DatesetDocumentStore
(
BaseDocumentStore
)
:
class
DatesetDocumentStore
:
def
__init__
(
self
,
dataset
:
Dataset
,
...
...
@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
return
self
.
_embedding_model_name
@
property
def
docs
(
self
)
->
Dict
[
str
,
Base
Document
]:
def
docs
(
self
)
->
Dict
[
str
,
Document
]:
document_segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
)
.
all
()
...
...
@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
output
=
{}
for
document_segment
in
document_segments
:
doc_id
=
document_segment
.
index_node_id
result
=
self
.
segment_to_dict
(
document_segment
)
output
[
doc_id
]
=
json_to_doc
(
result
)
output
[
doc_id
]
=
Document
(
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
return
output
def
add_documents
(
self
,
docs
:
Sequence
[
Base
Document
],
allow_update
:
bool
=
True
self
,
docs
:
Sequence
[
Document
],
allow_update
:
bool
=
True
)
->
None
:
max_position
=
db
.
session
.
query
(
func
.
max
(
DocumentSegment
.
position
))
.
filter
(
DocumentSegment
.
document
==
self
.
_document_id
...
...
@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
max_position
=
0
for
doc
in
docs
:
if
doc
.
is_doc_id_none
:
raise
ValueError
(
"doc
_id not se
t"
)
if
not
isinstance
(
doc
,
Document
)
:
raise
ValueError
(
"doc
must be a Documen
t"
)
if
not
isinstance
(
doc
,
Node
):
raise
ValueError
(
"doc must be a Node"
)
segment_document
=
self
.
get_document
(
doc_id
=
doc
.
get_doc_id
(),
raise_error
=
False
)
segment_document
=
self
.
get_document
(
doc_id
=
doc
.
metadata
[
'doc_id'
],
raise_error
=
False
)
# NOTE: doc could already exist in the store, but we overwrite it
if
not
allow_update
and
segment_document
:
raise
ValueError
(
f
"doc_id {doc.
get_doc_id()
} already exists. "
f
"doc_id {doc.
metadata['doc_id']
} already exists. "
"Set allow_update to True to overwrite."
)
# calc embedding use tokens
tokens
=
TokenCalculator
.
get_num_tokens
(
self
.
_embedding_model_name
,
doc
.
get_text
()
)
tokens
=
TokenCalculator
.
get_num_tokens
(
self
.
_embedding_model_name
,
doc
.
page_content
)
if
not
segment_document
:
max_position
+=
1
...
...
@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
tenant_id
=
self
.
_dataset
.
tenant_id
,
dataset_id
=
self
.
_dataset
.
id
,
document_id
=
self
.
_document_id
,
index_node_id
=
doc
.
get_doc_id
()
,
index_node_hash
=
doc
.
get_doc_hash
()
,
index_node_id
=
doc
.
metadata
[
'doc_id'
]
,
index_node_hash
=
doc
.
metadata
[
'doc_hash'
]
,
position
=
max_position
,
content
=
doc
.
get_text
()
,
word_count
=
len
(
doc
.
get_text
()
),
content
=
doc
.
page_content
,
word_count
=
len
(
doc
.
page_content
),
tokens
=
tokens
,
created_by
=
self
.
_user_id
,
)
db
.
session
.
add
(
segment_document
)
else
:
segment_document
.
content
=
doc
.
get_text
()
segment_document
.
index_node_hash
=
doc
.
get_doc_hash
()
segment_document
.
word_count
=
len
(
doc
.
get_text
()
)
segment_document
.
content
=
doc
.
page_content
segment_document
.
index_node_hash
=
doc
.
metadata
[
'doc_hash'
]
segment_document
.
word_count
=
len
(
doc
.
page_content
)
segment_document
.
tokens
=
tokens
db
.
session
.
commit
()
...
...
@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
def
get_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
Optional
[
Base
Document
]:
)
->
Optional
[
Document
]:
document_segment
=
self
.
get_document_segment
(
doc_id
)
if
document_segment
is
None
:
...
...
@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
else
:
return
None
result
=
self
.
segment_to_dict
(
document_segment
)
return
json_to_doc
(
result
)
return
Document
(
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
def
delete_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
None
:
document_segment
=
self
.
get_document_segment
(
doc_id
)
...
...
@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
return
document_segment
.
index_node_hash
def
update_docstore
(
self
,
other
:
"BaseDocumentStore"
)
->
None
:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self
.
add_documents
(
list
(
other
.
docs
.
values
()))
def
get_document_segment
(
self
,
doc_id
:
str
)
->
DocumentSegment
:
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
_dataset
.
id
,
...
...
@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
)
.
first
()
return
document_segment
def
segment_to_dict
(
self
,
segment
:
DocumentSegment
)
->
Dict
[
str
,
Any
]:
return
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"text"
:
segment
.
content
,
"__type__"
:
Node
.
get_type
()
}
api/core/docstore/empty_docstore.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
from
llama_index.docstore.types
import
BaseDocumentStore
from
llama_index.schema
import
BaseDocument
class
EmptyDocumentStore
(
BaseDocumentStore
):
@
classmethod
def
from_dict
(
cls
,
config_dict
:
Dict
[
str
,
Any
])
->
"EmptyDocumentStore"
:
return
cls
()
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""Serialize to dict."""
return
{}
@
property
def
docs
(
self
)
->
Dict
[
str
,
BaseDocument
]:
return
{}
def
add_documents
(
self
,
docs
:
Sequence
[
BaseDocument
],
allow_update
:
bool
=
True
)
->
None
:
pass
def
document_exists
(
self
,
doc_id
:
str
)
->
bool
:
"""Check if document exists."""
return
False
def
get_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
Optional
[
BaseDocument
]:
return
None
def
delete_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
None
:
pass
def
set_document_hash
(
self
,
doc_id
:
str
,
doc_hash
:
str
)
->
None
:
"""Set the hash for a given doc_id."""
pass
def
get_document_hash
(
self
,
doc_id
:
str
)
->
Optional
[
str
]:
"""Get the stored hash for a document, if it exists."""
return
None
def
update_docstore
(
self
,
other
:
"BaseDocumentStore"
)
->
None
:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self
.
add_documents
(
list
(
other
.
docs
.
values
()))
api/core/embedding/cached_embedding.py
0 → 100644
View file @
eea011bd
import
logging
from
typing
import
List
from
langchain.embeddings.base
import
Embeddings
from
sqlalchemy.exc
import
IntegrityError
from
extensions.ext_database
import
db
from
libs
import
helper
from
models.dataset
import
Embedding
class
CacheEmbedding
(
Embeddings
):
def
__init__
(
self
,
embeddings
:
Embeddings
):
self
.
_embeddings
=
embeddings
def
embed_documents
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Embed search docs."""
# use doc embedding cache or store if not exists
text_embeddings
=
[]
embedding_queue_texts
=
[]
for
text
in
texts
:
hash
=
helper
.
generate_text_hash
(
text
)
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
hash
)
.
first
()
if
embedding
:
text_embeddings
.
append
(
embedding
.
get_embedding
())
else
:
embedding_queue_texts
.
append
(
text
)
embedding_results
=
self
.
_embeddings
.
embed_documents
(
embedding_queue_texts
)
i
=
0
for
text
in
embedding_queue_texts
:
hash
=
helper
.
generate_text_hash
(
text
)
try
:
embedding
=
Embedding
(
hash
=
hash
)
embedding
.
set_embedding
(
embedding_results
[
i
])
db
.
session
.
add
(
embedding
)
db
.
session
.
commit
()
except
IntegrityError
:
db
.
session
.
rollback
()
continue
except
:
logging
.
exception
(
'Failed to add embedding to db'
)
continue
i
+=
1
text_embeddings
.
extend
(
embedding_results
)
return
text_embeddings
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash
=
helper
.
generate_text_hash
(
text
)
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
hash
)
.
first
()
if
embedding
:
return
embedding
.
get_embedding
()
embedding_results
=
self
.
_embeddings
.
embed_query
(
text
)
try
:
embedding
=
Embedding
(
hash
=
hash
)
embedding
.
set_embedding
(
embedding_results
)
db
.
session
.
add
(
embedding
)
db
.
session
.
commit
()
except
IntegrityError
:
db
.
session
.
rollback
()
except
:
logging
.
exception
(
'Failed to add embedding to db'
)
return
embedding_results
api/core/embedding/openai_embedding.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
Optional
,
Any
,
List
import
openai
from
llama_index.embeddings.base
import
BaseEmbedding
from
llama_index.embeddings.openai
import
OpenAIEmbeddingMode
,
OpenAIEmbeddingModelType
,
_QUERY_MODE_MODEL_DICT
,
\
_TEXT_MODE_MODEL_DICT
from
tenacity
import
wait_random_exponential
,
retry
,
stop_after_attempt
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
def
get_embedding
(
text
:
str
,
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
float
]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text
=
text
.
replace
(
"
\n
"
,
" "
)
return
openai
.
Embedding
.
create
(
input
=
[
text
],
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
)[
"data"
][
0
][
"embedding"
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
async
def
aget_embedding
(
text
:
str
,
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
float
]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text
=
text
.
replace
(
"
\n
"
,
" "
)
return
(
await
openai
.
Embedding
.
acreate
(
input
=
[
text
],
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
))[
"data"
][
0
][
"embedding"
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
def
get_embeddings
(
list_of_text
:
List
[
str
],
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
List
[
float
]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert
len
(
list_of_text
)
<=
2048
,
"The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
data
=
openai
.
Embedding
.
create
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
)
.
data
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
return
[
d
[
"embedding"
]
for
d
in
data
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
async
def
aget_embeddings
(
list_of_text
:
List
[
str
],
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
List
[
float
]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert
len
(
list_of_text
)
<=
2048
,
"The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
data
=
(
await
openai
.
Embedding
.
acreate
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
))
.
data
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
return
[
d
[
"embedding"
]
for
d
in
data
]
class
OpenAIEmbedding
(
BaseEmbedding
):
def
__init__
(
self
,
mode
:
str
=
OpenAIEmbeddingMode
.
TEXT_SEARCH_MODE
,
model
:
str
=
OpenAIEmbeddingModelType
.
TEXT_EMBED_ADA_002
,
deployment_name
:
Optional
[
str
]
=
None
,
openai_api_key
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
)
->
None
:
"""Init params."""
new_kwargs
=
{}
if
'embed_batch_size'
in
kwargs
:
new_kwargs
[
'embed_batch_size'
]
=
kwargs
[
'embed_batch_size'
]
if
'tokenizer'
in
kwargs
:
new_kwargs
[
'tokenizer'
]
=
kwargs
[
'tokenizer'
]
super
()
.
__init__
(
**
new_kwargs
)
self
.
mode
=
OpenAIEmbeddingMode
(
mode
)
self
.
model
=
OpenAIEmbeddingModelType
(
model
)
self
.
deployment_name
=
deployment_name
self
.
openai_api_key
=
openai_api_key
self
.
openai_api_type
=
kwargs
.
get
(
'openai_api_type'
)
self
.
openai_api_version
=
kwargs
.
get
(
'openai_api_version'
)
self
.
openai_api_base
=
kwargs
.
get
(
'openai_api_base'
)
@
handle_llm_exceptions
def
_get_query_embedding
(
self
,
query
:
str
)
->
List
[
float
]:
"""Get query embedding."""
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_QUERY_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_QUERY_MODE_MODEL_DICT
[
key
]
return
get_embedding
(
query
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
def
_get_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
"""Get text embedding."""
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
return
get_embedding
(
text
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
async
def
_aget_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
"""Asynchronously get text embedding."""
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
return
await
aget_embedding
(
text
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
def
_get_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Can be overriden for batch queries.
"""
if
self
.
openai_api_type
and
self
.
openai_api_type
==
'azure'
:
embeddings
=
[]
for
text
in
texts
:
embeddings
.
append
(
self
.
_get_text_embedding
(
text
))
return
embeddings
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
embeddings
=
get_embeddings
(
texts
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
return
embeddings
async
def
_aget_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Asynchronously get text embeddings."""
if
self
.
openai_api_type
and
self
.
openai_api_type
==
'azure'
:
embeddings
=
[]
for
text
in
texts
:
embeddings
.
append
(
await
self
.
_aget_text_embedding
(
text
))
return
embeddings
if
self
.
deployment_name
is
not
None
:
engine
=
self
.
deployment_name
else
:
key
=
(
self
.
mode
,
self
.
model
)
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
embeddings
=
await
aget_embeddings
(
texts
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
return
embeddings
api/core/index/base.py
0 → 100644
View file @
eea011bd
from
__future__
import
annotations
from
abc
import
abstractmethod
,
ABC
from
typing
import
List
,
Any
from
langchain.schema
import
Document
,
BaseRetriever
from
models.dataset
import
Dataset
class
BaseIndex
(
ABC
):
def
__init__
(
self
,
dataset
:
Dataset
):
self
.
dataset
=
dataset
@
abstractmethod
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
raise
NotImplementedError
@
abstractmethod
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
raise
NotImplementedError
@
abstractmethod
def
text_exists
(
self
,
id
:
str
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
raise
NotImplementedError
@
abstractmethod
def
delete_by_document_id
(
self
,
document_id
:
str
):
raise
NotImplementedError
@
abstractmethod
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
raise
NotImplementedError
@
abstractmethod
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
raise
NotImplementedError
def
delete
(
self
)
->
None
:
raise
NotImplementedError
def
_filter_duplicate_texts
(
self
,
texts
:
list
[
Document
])
->
list
[
Document
]:
for
text
in
texts
:
doc_id
=
text
.
metadata
[
'doc_id'
]
exists_duplicate_node
=
self
.
text_exists
(
doc_id
)
if
exists_duplicate_node
:
texts
.
remove
(
text
)
return
texts
def
_get_uuids
(
self
,
texts
:
list
[
Document
])
->
list
[
str
]:
return
[
text
.
metadata
[
'doc_id'
]
for
text
in
texts
]
api/core/index/index.py
0 → 100644
View file @
eea011bd
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
class
IndexBuilder
:
@
classmethod
def
get_index
(
cls
,
dataset
:
Dataset
,
indexing_technique
:
str
,
ignore_high_quality_check
:
bool
=
False
):
if
indexing_technique
==
"high_quality"
:
if
not
ignore_high_quality_check
and
dataset
.
indexing_technique
!=
'high_quality'
:
return
None
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
return
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
elif
indexing_technique
==
"economy"
:
return
KeywordTableIndex
(
dataset
=
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
10
)
)
else
:
raise
ValueError
(
'Unknown indexing technique'
)
\ No newline at end of file
api/core/index/index_builder.py
deleted
100644 → 0
View file @
3eb8e66b
from
langchain.callbacks
import
CallbackManager
from
llama_index
import
ServiceContext
,
PromptHelper
,
LLMPredictor
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.embedding.openai_embedding
import
OpenAIEmbedding
from
core.llm.llm_builder
import
LLMBuilder
class
IndexBuilder
:
@
classmethod
def
get_default_service_context
(
cls
,
tenant_id
:
str
)
->
ServiceContext
:
# set number of output tokens
num_output
=
512
# only for verbose
callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'text-davinci-003'
,
temperature
=
0
,
max_tokens
=
num_output
,
callback_manager
=
callback_manager
,
)
llm_predictor
=
LLMPredictor
(
llm
=
llm
)
# These parameters here will affect the logic of segmenting the final synthesized response.
# The number of refinement iterations in the synthesis process depends
# on whether the length of the segmented output exceeds the max_input_size.
prompt_helper
=
PromptHelper
(
max_input_size
=
3500
,
num_output
=
num_output
,
max_chunk_overlap
=
20
)
provider
=
LLMBuilder
.
get_default_provider
(
tenant_id
)
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
tenant_id
,
model_provider
=
provider
,
model_name
=
'text-embedding-ada-002'
)
return
ServiceContext
.
from_defaults
(
llm_predictor
=
llm_predictor
,
prompt_helper
=
prompt_helper
,
embed_model
=
OpenAIEmbedding
(
**
model_credentials
),
)
@
classmethod
def
get_fake_llm_service_context
(
cls
,
tenant_id
:
str
)
->
ServiceContext
:
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'fake'
)
return
ServiceContext
.
from_defaults
(
llm_predictor
=
LLMPredictor
(
llm
=
llm
),
embed_model
=
OpenAIEmbedding
()
)
api/core/index/keyword_table/jieba_keyword_table.py
deleted
100644 → 0
View file @
3eb8e66b
import
re
from
typing
import
(
Any
,
Dict
,
List
,
Set
,
Optional
)
import
jieba.analyse
from
core.index.keyword_table.stopwords
import
STOPWORDS
from
llama_index.indices.query.base
import
IS
from
llama_index
import
QueryMode
from
llama_index.indices.base
import
QueryMap
from
llama_index.indices.keyword_table.base
import
BaseGPTKeywordTableIndex
from
llama_index.indices.keyword_table.query
import
BaseGPTKeywordTableQuery
from
llama_index.docstore
import
BaseDocumentStore
from
llama_index.indices.postprocessor.node
import
(
BaseNodePostprocessor
,
)
from
llama_index.indices.response.response_builder
import
ResponseMode
from
llama_index.indices.service_context
import
ServiceContext
from
llama_index.optimization.optimizer
import
BaseTokenUsageOptimizer
from
llama_index.prompts.prompts
import
(
QuestionAnswerPrompt
,
RefinePrompt
,
SimpleInputPrompt
,
)
from
core.index.query.synthesizer
import
EnhanceResponseSynthesizer
def
jieba_extract_keywords
(
text_chunk
:
str
,
max_keywords
:
Optional
[
int
]
=
None
,
expand_with_subtokens
:
bool
=
True
,
)
->
Set
[
str
]:
"""Extract keywords with JIEBA tfidf."""
keywords
=
jieba
.
analyse
.
extract_tags
(
sentence
=
text_chunk
,
topK
=
max_keywords
,
)
if
expand_with_subtokens
:
return
set
(
expand_tokens_with_subtokens
(
keywords
))
else
:
return
set
(
keywords
)
def
expand_tokens_with_subtokens
(
tokens
:
Set
[
str
])
->
Set
[
str
]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results
=
set
()
for
token
in
tokens
:
results
.
add
(
token
)
sub_tokens
=
re
.
findall
(
r"\w+"
,
token
)
if
len
(
sub_tokens
)
>
1
:
results
.
update
({
w
for
w
in
sub_tokens
if
w
not
in
list
(
STOPWORDS
)})
return
results
class
GPTJIEBAKeywordTableIndex
(
BaseGPTKeywordTableIndex
):
"""GPT JIEBA Keyword Table Index.
This index uses a JIEBA keyword extractor to extract keywords from the text.
"""
def
_extract_keywords
(
self
,
text
:
str
)
->
Set
[
str
]:
"""Extract keywords from text."""
return
jieba_extract_keywords
(
text
,
max_keywords
=
self
.
max_keywords_per_chunk
)
@
classmethod
def
get_query_map
(
self
)
->
QueryMap
:
"""Get query map."""
super_map
=
super
()
.
get_query_map
()
super_map
[
QueryMode
.
DEFAULT
]
=
GPTKeywordTableJIEBAQuery
return
super_map
def
_delete
(
self
,
doc_id
:
str
,
**
delete_kwargs
:
Any
)
->
None
:
"""Delete a document."""
# get set of ids that correspond to node
node_idxs_to_delete
=
{
doc_id
}
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete
=
set
()
for
keyword
,
node_idxs
in
self
.
_index_struct
.
table
.
items
():
if
node_idxs_to_delete
.
intersection
(
node_idxs
):
self
.
_index_struct
.
table
[
keyword
]
=
node_idxs
.
difference
(
node_idxs_to_delete
)
if
not
self
.
_index_struct
.
table
[
keyword
]:
keywords_to_delete
.
add
(
keyword
)
for
keyword
in
keywords_to_delete
:
del
self
.
_index_struct
.
table
[
keyword
]
class
GPTKeywordTableJIEBAQuery
(
BaseGPTKeywordTableQuery
):
"""GPT Keyword Table Index JIEBA Query.
Extracts keywords using JIEBA keyword extractor.
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
.. code-block:: python
response = index.query("<query_str>", mode="jieba")
See BaseGPTKeywordTableQuery for arguments.
"""
@
classmethod
def
from_args
(
cls
,
index_struct
:
IS
,
service_context
:
ServiceContext
,
docstore
:
Optional
[
BaseDocumentStore
]
=
None
,
node_postprocessors
:
Optional
[
List
[
BaseNodePostprocessor
]]
=
None
,
verbose
:
bool
=
False
,
# response synthesizer args
response_mode
:
ResponseMode
=
ResponseMode
.
DEFAULT
,
text_qa_template
:
Optional
[
QuestionAnswerPrompt
]
=
None
,
refine_template
:
Optional
[
RefinePrompt
]
=
None
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
response_kwargs
:
Optional
[
Dict
]
=
None
,
use_async
:
bool
=
False
,
streaming
:
bool
=
False
,
optimizer
:
Optional
[
BaseTokenUsageOptimizer
]
=
None
,
# class-specific args
**
kwargs
:
Any
,
)
->
"BaseGPTIndexQuery"
:
response_synthesizer
=
EnhanceResponseSynthesizer
.
from_args
(
service_context
=
service_context
,
text_qa_template
=
text_qa_template
,
refine_template
=
refine_template
,
simple_template
=
simple_template
,
response_mode
=
response_mode
,
response_kwargs
=
response_kwargs
,
use_async
=
use_async
,
streaming
=
streaming
,
optimizer
=
optimizer
,
)
return
cls
(
index_struct
=
index_struct
,
service_context
=
service_context
,
response_synthesizer
=
response_synthesizer
,
docstore
=
docstore
,
node_postprocessors
=
node_postprocessors
,
verbose
=
verbose
,
**
kwargs
,
)
def
_get_keywords
(
self
,
query_str
:
str
)
->
List
[
str
]:
"""Extract keywords."""
return
list
(
jieba_extract_keywords
(
query_str
,
max_keywords
=
self
.
max_keywords_per_query
)
)
api/core/index/keyword_table_index.py
deleted
100644 → 0
View file @
3eb8e66b
import
json
from
typing
import
List
,
Optional
from
llama_index
import
ServiceContext
,
LLMPredictor
,
OpenAIEmbedding
from
llama_index.data_structs
import
KeywordTable
,
Node
from
llama_index.indices.keyword_table.base
import
BaseGPTKeywordTableIndex
from
llama_index.indices.registry
import
load_index_struct_from_dict
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.docstore.empty_docstore
import
EmptyDocumentStore
from
core.index.index_builder
import
IndexBuilder
from
core.index.keyword_table.jieba_keyword_table
import
GPTJIEBAKeywordTableIndex
from
core.llm.llm_builder
import
LLMBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DatasetKeywordTable
,
DocumentSegment
class
KeywordTableIndex
:
def
__init__
(
self
,
dataset
:
Dataset
):
self
.
_dataset
=
dataset
def
add_nodes
(
self
,
nodes
:
List
[
Node
]):
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
_dataset
.
tenant_id
,
model_name
=
'fake'
)
service_context
=
ServiceContext
.
from_defaults
(
llm_predictor
=
LLMPredictor
(
llm
=
llm
),
embed_model
=
OpenAIEmbedding
()
)
dataset_keyword_table
=
self
.
get_keyword_table
()
if
not
dataset_keyword_table
or
not
dataset_keyword_table
.
keyword_table_dict
:
index_struct
=
KeywordTable
()
else
:
index_struct_dict
=
dataset_keyword_table
.
keyword_table_dict
index_struct
:
KeywordTable
=
load_index_struct_from_dict
(
index_struct_dict
)
# create index
index
=
GPTJIEBAKeywordTableIndex
(
index_struct
=
index_struct
,
docstore
=
EmptyDocumentStore
(),
service_context
=
service_context
)
for
node
in
nodes
:
keywords
=
index
.
_extract_keywords
(
node
.
get_text
())
self
.
update_segment_keywords
(
node
.
doc_id
,
list
(
keywords
))
index
.
_index_struct
.
add_node
(
list
(
keywords
),
node
)
index_struct_dict
=
index
.
index_struct
.
to_dict
()
if
not
dataset_keyword_table
:
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
_dataset
.
id
,
keyword_table
=
json
.
dumps
(
index_struct_dict
)
)
db
.
session
.
add
(
dataset_keyword_table
)
else
:
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
index_struct_dict
)
db
.
session
.
commit
()
def
del_nodes
(
self
,
node_ids
:
List
[
str
]):
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
self
.
_dataset
.
tenant_id
,
model_name
=
'fake'
)
service_context
=
ServiceContext
.
from_defaults
(
llm_predictor
=
LLMPredictor
(
llm
=
llm
),
embed_model
=
OpenAIEmbedding
()
)
dataset_keyword_table
=
self
.
get_keyword_table
()
if
not
dataset_keyword_table
or
not
dataset_keyword_table
.
keyword_table_dict
:
return
else
:
index_struct_dict
=
dataset_keyword_table
.
keyword_table_dict
index_struct
:
KeywordTable
=
load_index_struct_from_dict
(
index_struct_dict
)
# create index
index
=
GPTJIEBAKeywordTableIndex
(
index_struct
=
index_struct
,
docstore
=
EmptyDocumentStore
(),
service_context
=
service_context
)
for
node_id
in
node_ids
:
index
.
delete
(
node_id
)
index_struct_dict
=
index
.
index_struct
.
to_dict
()
if
not
dataset_keyword_table
:
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
_dataset
.
id
,
keyword_table
=
json
.
dumps
(
index_struct_dict
)
)
db
.
session
.
add
(
dataset_keyword_table
)
else
:
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
index_struct_dict
)
db
.
session
.
commit
()
@
property
def
query_index
(
self
)
->
Optional
[
BaseGPTKeywordTableIndex
]:
docstore
=
DatesetDocumentStore
(
dataset
=
self
.
_dataset
,
user_id
=
self
.
_dataset
.
created_by
,
embedding_model_name
=
"text-embedding-ada-002"
)
service_context
=
IndexBuilder
.
get_default_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
dataset_keyword_table
=
self
.
get_keyword_table
()
if
not
dataset_keyword_table
or
not
dataset_keyword_table
.
keyword_table_dict
:
return
None
index_struct
:
KeywordTable
=
load_index_struct_from_dict
(
dataset_keyword_table
.
keyword_table_dict
)
return
GPTJIEBAKeywordTableIndex
(
index_struct
=
index_struct
,
docstore
=
docstore
,
service_context
=
service_context
)
def
get_keyword_table
(
self
):
dataset_keyword_table
=
self
.
_dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
return
dataset_keyword_table
return
None
def
update_segment_keywords
(
self
,
node_id
:
str
,
keywords
:
List
[
str
]):
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
index_node_id
==
node_id
)
.
first
()
if
document_segment
:
document_segment
.
keywords
=
keywords
db
.
session
.
commit
()
api/core/index/keyword_table_index/jieba_keyword_table_handler.py
0 → 100644
View file @
eea011bd
import
re
from
typing
import
Set
import
jieba
from
jieba.analyse
import
default_tfidf
from
core.index.keyword_table_index.stopwords
import
STOPWORDS
class
JiebaKeywordTableHandler
:
def
__init__
(
self
):
default_tfidf
.
stop_words
=
STOPWORDS
def
extract_keywords
(
self
,
text
:
str
,
max_keywords_per_chunk
:
int
=
10
)
->
Set
[
str
]:
"""Extract keywords with JIEBA tfidf."""
keywords
=
jieba
.
analyse
.
extract_tags
(
sentence
=
text
,
topK
=
max_keywords_per_chunk
,
)
return
set
(
self
.
_expand_tokens_with_subtokens
(
keywords
))
def
_expand_tokens_with_subtokens
(
self
,
tokens
:
Set
[
str
])
->
Set
[
str
]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results
=
set
()
for
token
in
tokens
:
results
.
add
(
token
)
sub_tokens
=
re
.
findall
(
r"\w+"
,
token
)
if
len
(
sub_tokens
)
>
1
:
results
.
update
({
w
for
w
in
sub_tokens
if
w
not
in
list
(
STOPWORDS
)})
return
results
\ No newline at end of file
api/core/index/keyword_table_index/keyword_table_index.py
0 → 100644
View file @
eea011bd
import
json
from
collections
import
defaultdict
from
typing
import
Any
,
List
,
Optional
,
Dict
from
langchain.schema
import
Document
,
BaseRetriever
from
pydantic
import
BaseModel
,
Field
,
Extra
from
core.index.base
import
BaseIndex
from
core.index.keyword_table_index.jieba_keyword_table_handler
import
JiebaKeywordTableHandler
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetKeywordTable
class
KeywordTableConfig
(
BaseModel
):
max_keywords_per_chunk
:
int
=
10
class
KeywordTableIndex
(
BaseIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
KeywordTableConfig
=
KeywordTableConfig
()):
super
()
.
__init__
(
dataset
)
self
.
_config
=
config
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
{}
for
text
in
texts
:
keywords
=
keyword_table_handler
.
extract_keywords
(
text
.
page_content
,
self
.
_config
.
max_keywords_per_chunk
)
self
.
_update_segment_keywords
(
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
keyword_table
=
self
.
_add_text_to_keyword_table
(
keyword_table
,
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
dataset
.
id
,
keyword_table
=
json
.
dumps
({
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
{}
}
},
cls
=
SetEncoder
)
)
db
.
session
.
add
(
dataset_keyword_table
)
db
.
session
.
commit
()
self
.
_save_dataset_keyword_table
(
keyword_table
)
return
self
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
keyword_table_handler
=
JiebaKeywordTableHandler
()
keyword_table
=
self
.
_get_dataset_keyword_table
()
for
text
in
texts
:
keywords
=
keyword_table_handler
.
extract_keywords
(
text
.
page_content
,
self
.
_config
.
max_keywords_per_chunk
)
self
.
_update_segment_keywords
(
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
keyword_table
=
self
.
_add_text_to_keyword_table
(
keyword_table
,
text
.
metadata
[
'doc_id'
],
list
(
keywords
))
self
.
_save_dataset_keyword_table
(
keyword_table
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
keyword_table
=
self
.
_get_dataset_keyword_table
()
return
id
in
set
.
union
(
*
keyword_table
.
values
())
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
keyword_table
=
self
.
_get_dataset_keyword_table
()
keyword_table
=
self
.
_delete_ids_from_keyword_table
(
keyword_table
,
ids
)
self
.
_save_dataset_keyword_table
(
keyword_table
)
def
delete_by_document_id
(
self
,
document_id
:
str
):
# get segment ids by document_id
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
dataset
.
id
,
DocumentSegment
.
document_id
==
document_id
)
.
all
()
ids
=
[
segment
.
id
for
segment
in
segments
]
keyword_table
=
self
.
_get_dataset_keyword_table
()
keyword_table
=
self
.
_delete_ids_from_keyword_table
(
keyword_table
,
ids
)
self
.
_save_dataset_keyword_table
(
keyword_table
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
return
KeywordTableRetriever
(
index
=
self
,
**
kwargs
)
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
keyword_table
=
self
.
_get_dataset_keyword_table
()
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
k
=
search_kwargs
.
get
(
'k'
)
if
search_kwargs
.
get
(
'k'
)
else
4
sorted_chunk_indices
=
self
.
_retrieve_ids_by_query
(
keyword_table
,
query
,
k
)
documents
=
[]
for
chunk_index
in
sorted_chunk_indices
:
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
dataset
.
id
,
DocumentSegment
.
index_node_id
==
chunk_index
)
.
first
()
if
segment
:
documents
.
append
(
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
chunk_index
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
))
return
documents
def
delete
(
self
)
->
None
:
dataset_keyword_table
=
self
.
dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
db
.
session
.
delete
(
dataset_keyword_table
)
db
.
session
.
commit
()
def
_save_dataset_keyword_table
(
self
,
keyword_table
):
keyword_table_dict
=
{
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
keyword_table
}
}
self
.
dataset
.
dataset_keyword_table
.
keyword_table
=
json
.
dumps
(
keyword_table_dict
,
cls
=
SetEncoder
)
db
.
session
.
commit
()
def
_get_dataset_keyword_table
(
self
)
->
Optional
[
dict
]:
dataset_keyword_table
=
self
.
dataset
.
dataset_keyword_table
if
dataset_keyword_table
:
if
dataset_keyword_table
.
keyword_table_dict
:
return
dataset_keyword_table
.
keyword_table_dict
[
'__data__'
][
'table'
]
else
:
dataset_keyword_table
=
DatasetKeywordTable
(
dataset_id
=
self
.
dataset
.
id
,
keyword_table
=
json
.
dumps
({
'__type__'
:
'keyword_table'
,
'__data__'
:
{
"index_id"
:
self
.
dataset
.
id
,
"summary"
:
None
,
"table"
:
{}
}
},
cls
=
SetEncoder
)
)
db
.
session
.
add
(
dataset_keyword_table
)
db
.
session
.
commit
()
return
{}
def
_add_text_to_keyword_table
(
self
,
keyword_table
:
dict
,
id
:
str
,
keywords
:
list
[
str
])
->
dict
:
for
keyword
in
keywords
:
if
keyword
not
in
keyword_table
:
keyword_table
[
keyword
]
=
set
()
keyword_table
[
keyword
]
.
add
(
id
)
return
keyword_table
def
_delete_ids_from_keyword_table
(
self
,
keyword_table
:
dict
,
ids
:
list
[
str
])
->
dict
:
# get set of ids that correspond to node
node_idxs_to_delete
=
set
(
ids
)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete
=
set
()
for
keyword
,
node_idxs
in
keyword_table
.
items
():
if
node_idxs_to_delete
.
intersection
(
node_idxs
):
keyword_table
[
keyword
]
=
node_idxs
.
difference
(
node_idxs_to_delete
)
if
not
keyword_table
[
keyword
]:
keywords_to_delete
.
add
(
keyword
)
for
keyword
in
keywords_to_delete
:
del
keyword_table
[
keyword
]
return
keyword_table
def
_retrieve_ids_by_query
(
self
,
keyword_table
:
dict
,
query
:
str
,
k
:
int
=
4
):
keyword_table_handler
=
JiebaKeywordTableHandler
()
keywords
=
keyword_table_handler
.
extract_keywords
(
query
)
# go through text chunks in order of most matching keywords
chunk_indices_count
:
Dict
[
str
,
int
]
=
defaultdict
(
int
)
keywords
=
[
keyword
for
keyword
in
keywords
if
keyword
in
set
(
keyword_table
.
keys
())]
for
keyword
in
keywords
:
for
node_id
in
keyword_table
[
keyword
]:
chunk_indices_count
[
node_id
]
+=
1
sorted_chunk_indices
=
sorted
(
list
(
chunk_indices_count
.
keys
()),
key
=
lambda
x
:
chunk_indices_count
[
x
],
reverse
=
True
,
)
return
sorted_chunk_indices
[:
k
]
def
_update_segment_keywords
(
self
,
node_id
:
str
,
keywords
:
List
[
str
]):
document_segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
index_node_id
==
node_id
)
.
first
()
if
document_segment
:
document_segment
.
keywords
=
keywords
db
.
session
.
commit
()
class
KeywordTableRetriever
(
BaseRetriever
,
BaseModel
):
index
:
KeywordTableIndex
search_kwargs
:
dict
=
Field
(
default_factory
=
dict
)
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
arbitrary_types_allowed
=
True
def
get_relevant_documents
(
self
,
query
:
str
)
->
List
[
Document
]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return
self
.
index
.
search
(
query
,
**
self
.
search_kwargs
)
async
def
aget_relevant_documents
(
self
,
query
:
str
)
->
List
[
Document
]:
raise
NotImplementedError
(
"KeywordTableRetriever does not support async"
)
class
SetEncoder
(
json
.
JSONEncoder
):
def
default
(
self
,
obj
):
if
isinstance
(
obj
,
set
):
return
list
(
obj
)
return
super
()
.
default
(
obj
)
api/core/index/keyword_table/stopwords.py
→
api/core/index/keyword_table
_index
/stopwords.py
View file @
eea011bd
File moved
api/core/index/query/synthesizer.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
(
Any
,
Dict
,
Optional
,
Sequence
,
)
from
llama_index.indices.response.response_synthesis
import
ResponseSynthesizer
from
llama_index.indices.response.response_builder
import
ResponseMode
,
BaseResponseBuilder
,
get_response_builder
from
llama_index.indices.service_context
import
ServiceContext
from
llama_index.optimization.optimizer
import
BaseTokenUsageOptimizer
from
llama_index.prompts.prompts
import
(
QuestionAnswerPrompt
,
RefinePrompt
,
SimpleInputPrompt
,
)
from
llama_index.types
import
RESPONSE_TEXT_TYPE
class
EnhanceResponseSynthesizer
(
ResponseSynthesizer
):
@
classmethod
def
from_args
(
cls
,
service_context
:
ServiceContext
,
streaming
:
bool
=
False
,
use_async
:
bool
=
False
,
text_qa_template
:
Optional
[
QuestionAnswerPrompt
]
=
None
,
refine_template
:
Optional
[
RefinePrompt
]
=
None
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
response_mode
:
ResponseMode
=
ResponseMode
.
DEFAULT
,
response_kwargs
:
Optional
[
Dict
]
=
None
,
optimizer
:
Optional
[
BaseTokenUsageOptimizer
]
=
None
,
)
->
"ResponseSynthesizer"
:
response_builder
:
Optional
[
BaseResponseBuilder
]
=
None
if
response_mode
!=
ResponseMode
.
NO_TEXT
:
if
response_mode
==
'no_synthesizer'
:
response_builder
=
NoSynthesizer
(
service_context
=
service_context
,
simple_template
=
simple_template
,
streaming
=
streaming
,
)
else
:
response_builder
=
get_response_builder
(
service_context
,
text_qa_template
,
refine_template
,
simple_template
,
response_mode
,
use_async
=
use_async
,
streaming
=
streaming
,
)
return
cls
(
response_builder
,
response_mode
,
response_kwargs
,
optimizer
)
class
NoSynthesizer
(
BaseResponseBuilder
):
def
__init__
(
self
,
service_context
:
ServiceContext
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
streaming
:
bool
=
False
,
)
->
None
:
super
()
.
__init__
(
service_context
,
streaming
)
async
def
aget_response
(
self
,
query_str
:
str
,
text_chunks
:
Sequence
[
str
],
prev_response
:
Optional
[
str
]
=
None
,
**
response_kwargs
:
Any
,
)
->
RESPONSE_TEXT_TYPE
:
return
"
\n
"
.
join
(
text_chunks
)
def
get_response
(
self
,
query_str
:
str
,
text_chunks
:
Sequence
[
str
],
prev_response
:
Optional
[
str
]
=
None
,
**
response_kwargs
:
Any
,
)
->
RESPONSE_TEXT_TYPE
:
return
"
\n
"
.
join
(
text_chunks
)
\ No newline at end of file
api/core/index/readers/html_parser.py
deleted
100644 → 0
View file @
3eb8e66b
from
pathlib
import
Path
from
typing
import
Dict
from
bs4
import
BeautifulSoup
from
llama_index.readers.file.base_parser
import
BaseParser
class
HTMLParser
(
BaseParser
):
"""HTML parser."""
def
_init_parser
(
self
)
->
Dict
:
"""Init parser."""
return
{}
def
parse_file
(
self
,
file
:
Path
,
errors
:
str
=
"ignore"
)
->
str
:
"""Parse file."""
with
open
(
file
,
"rb"
)
as
fp
:
soup
=
BeautifulSoup
(
fp
,
'html.parser'
)
text
=
soup
.
get_text
()
text
=
text
.
strip
()
if
text
else
''
return
text
api/core/index/readers/pdf_parser.py
deleted
100644 → 0
View file @
3eb8e66b
from
pathlib
import
Path
from
typing
import
Dict
from
flask
import
current_app
from
llama_index.readers.file.base_parser
import
BaseParser
from
pypdf
import
PdfReader
from
extensions.ext_storage
import
storage
from
models.model
import
UploadFile
class
PDFParser
(
BaseParser
):
"""PDF parser."""
def
_init_parser
(
self
)
->
Dict
:
"""Init parser."""
return
{}
def
parse_file
(
self
,
file
:
Path
,
errors
:
str
=
"ignore"
)
->
str
:
"""Parse file."""
if
not
current_app
.
config
.
get
(
'PDF_PREVIEW'
,
True
):
return
''
plaintext_file_key
=
''
plaintext_file_exists
=
False
if
self
.
_parser_config
and
'upload_file'
in
self
.
_parser_config
and
self
.
_parser_config
[
'upload_file'
]:
upload_file
:
UploadFile
=
self
.
_parser_config
[
'upload_file'
]
if
upload_file
.
hash
:
plaintext_file_key
=
'upload_files/'
+
upload_file
.
tenant_id
+
'/'
+
upload_file
.
hash
+
'.plaintext'
try
:
text
=
storage
.
load
(
plaintext_file_key
)
.
decode
(
'utf-8'
)
plaintext_file_exists
=
True
return
text
except
FileNotFoundError
:
pass
text_list
=
[]
with
open
(
file
,
"rb"
)
as
fp
:
# Create a PDF object
pdf
=
PdfReader
(
fp
)
# Get the number of pages in the PDF document
num_pages
=
len
(
pdf
.
pages
)
# Iterate over every page
for
page
in
range
(
num_pages
):
# Extract the text from the page
page_text
=
pdf
.
pages
[
page
]
.
extract_text
()
text_list
.
append
(
page_text
)
text
=
"
\n
"
.
join
(
text_list
)
# save plaintext file for caching
if
not
plaintext_file_exists
and
plaintext_file_key
:
storage
.
save
(
plaintext_file_key
,
text
.
encode
(
'utf-8'
))
return
text
api/core/index/readers/xlsx_parser.py
deleted
100644 → 0
View file @
3eb8e66b
from
pathlib
import
Path
import
json
from
typing
import
Dict
from
openpyxl
import
load_workbook
from
llama_index.readers.file.base_parser
import
BaseParser
from
flask
import
current_app
class
XLSXParser
(
BaseParser
):
"""XLSX parser."""
def
_init_parser
(
self
)
->
Dict
:
"""Init parser"""
return
{}
def
parse_file
(
self
,
file
:
Path
,
errors
:
str
=
"ignore"
)
->
str
:
data
=
[]
keys
=
[]
with
open
(
file
,
"r"
)
as
fp
:
wb
=
load_workbook
(
filename
=
file
,
read_only
=
True
)
# loop over all sheets
for
sheet
in
wb
:
for
row
in
sheet
.
iter_rows
(
values_only
=
True
):
if
all
(
v
is
None
for
v
in
row
):
continue
if
keys
==
[]:
keys
=
list
(
map
(
str
,
row
))
else
:
row_dict
=
dict
(
zip
(
keys
,
row
))
row_dict
=
{
k
:
v
for
k
,
v
in
row_dict
.
items
()
if
v
}
data
.
append
(
json
.
dumps
(
row_dict
,
ensure_ascii
=
False
))
return
'
\n\n
'
.
join
(
data
)
api/core/index/vector_index.py
deleted
100644 → 0
View file @
3eb8e66b
import
json
import
logging
from
typing
import
List
,
Optional
from
llama_index.data_structs
import
Node
from
requests
import
ReadTimeout
from
sqlalchemy.exc
import
IntegrityError
from
tenacity
import
retry
,
stop_after_attempt
,
retry_if_exception_type
from
core.index.index_builder
import
IndexBuilder
from
core.vector_store.base
import
BaseGPTVectorStoreIndex
from
extensions.ext_vector_store
import
vector_store
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
Embedding
class
VectorIndex
:
def
__init__
(
self
,
dataset
:
Dataset
):
self
.
_dataset
=
dataset
def
add_nodes
(
self
,
nodes
:
List
[
Node
],
duplicate_check
:
bool
=
False
):
if
not
self
.
_dataset
.
index_struct_dict
:
index_id
=
"Vector_index_"
+
self
.
_dataset
.
id
.
replace
(
"-"
,
"_"
)
self
.
_dataset
.
index_struct
=
json
.
dumps
(
vector_store
.
to_index_struct
(
index_id
))
db
.
session
.
commit
()
service_context
=
IndexBuilder
.
get_default_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
index
=
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
if
duplicate_check
:
nodes
=
self
.
_filter_duplicate_nodes
(
index
,
nodes
)
embedding_queue_nodes
=
[]
embedded_nodes
=
[]
for
node
in
nodes
:
node_hash
=
node
.
doc_hash
# if node hash in cached embedding tables, use cached embedding
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
node_hash
)
.
first
()
if
embedding
:
node
.
embedding
=
embedding
.
get_embedding
()
embedded_nodes
.
append
(
node
)
else
:
embedding_queue_nodes
.
append
(
node
)
if
embedding_queue_nodes
:
embedding_results
=
index
.
_get_node_embedding_results
(
embedding_queue_nodes
,
set
(),
)
# pre embed nodes for cached embedding
for
embedding_result
in
embedding_results
:
node
=
embedding_result
.
node
node
.
embedding
=
embedding_result
.
embedding
try
:
embedding
=
Embedding
(
hash
=
node
.
doc_hash
)
embedding
.
set_embedding
(
node
.
embedding
)
db
.
session
.
add
(
embedding
)
db
.
session
.
commit
()
except
IntegrityError
:
db
.
session
.
rollback
()
continue
except
:
logging
.
exception
(
'Failed to add embedding to db'
)
continue
embedded_nodes
.
append
(
node
)
self
.
index_insert_nodes
(
index
,
embedded_nodes
)
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
index_insert_nodes
(
self
,
index
:
BaseGPTVectorStoreIndex
,
nodes
:
List
[
Node
]):
index
.
insert_nodes
(
nodes
)
def
del_nodes
(
self
,
node_ids
:
List
[
str
]):
if
not
self
.
_dataset
.
index_struct_dict
:
return
service_context
=
IndexBuilder
.
get_fake_llm_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
index
=
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
for
node_id
in
node_ids
:
self
.
index_delete_node
(
index
,
node_id
)
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
index_delete_node
(
self
,
index
:
BaseGPTVectorStoreIndex
,
node_id
:
str
):
index
.
delete_node
(
node_id
)
def
del_doc
(
self
,
doc_id
:
str
):
if
not
self
.
_dataset
.
index_struct_dict
:
return
service_context
=
IndexBuilder
.
get_fake_llm_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
index
=
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
self
.
index_delete_doc
(
index
,
doc_id
)
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
index_delete_doc
(
self
,
index
:
BaseGPTVectorStoreIndex
,
doc_id
:
str
):
index
.
delete
(
doc_id
)
@
property
def
query_index
(
self
)
->
Optional
[
BaseGPTVectorStoreIndex
]:
if
not
self
.
_dataset
.
index_struct_dict
:
return
None
service_context
=
IndexBuilder
.
get_default_service_context
(
tenant_id
=
self
.
_dataset
.
tenant_id
)
return
vector_store
.
get_index
(
service_context
=
service_context
,
index_struct
=
self
.
_dataset
.
index_struct_dict
)
def
_filter_duplicate_nodes
(
self
,
index
:
BaseGPTVectorStoreIndex
,
nodes
:
List
[
Node
])
->
List
[
Node
]:
for
node
in
nodes
:
node_id
=
node
.
doc_id
exists_duplicate_node
=
index
.
exists_by_node_id
(
node_id
)
if
exists_duplicate_node
:
nodes
.
remove
(
node
)
return
nodes
api/core/index/vector_index/base.py
0 → 100644
View file @
eea011bd
import
json
import
logging
from
abc
import
abstractmethod
from
typing
import
List
,
Any
,
cast
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
weaviate
import
UnexpectedStatusCodeException
from
core.index.base
import
BaseIndex
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DocumentSegment
from
models.dataset
import
Document
as
DatasetDocument
class
BaseVectorIndex
(
BaseIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
)
self
.
_embeddings
=
embeddings
self
.
_vector_store
=
None
def
get_type
(
self
)
->
str
:
raise
NotImplementedError
@
abstractmethod
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
raise
NotImplementedError
@
abstractmethod
def
to_index_struct
(
self
)
->
dict
:
raise
NotImplementedError
@
abstractmethod
def
_get_vector_store
(
self
)
->
VectorStore
:
raise
NotImplementedError
@
abstractmethod
def
_get_vector_store_class
(
self
)
->
type
:
raise
NotImplementedError
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
if
search_type
==
'similarity_score_threshold'
:
score_threshold
=
search_kwargs
.
get
(
"score_threshold"
)
if
(
score_threshold
is
None
)
or
(
not
isinstance
(
score_threshold
,
float
)):
search_kwargs
[
'score_threshold'
]
=
.0
docs_with_similarity
=
vector_store
.
similarity_search_with_relevance_scores
(
query
,
**
search_kwargs
)
docs
=
[]
for
doc
,
similarity
in
docs_with_similarity
:
doc
.
metadata
[
'score'
]
=
similarity
docs
.
append
(
doc
)
return
docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
if
kwargs
.
get
(
'duplicate_check'
,
False
):
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
def
delete
(
self
)
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
.
delete
()
def
_is_origin
(
self
):
return
False
def
recreate_dataset
(
self
,
dataset
:
Dataset
):
logging
.
info
(
f
"Recreating dataset {dataset.id}"
)
try
:
self
.
delete
()
except
UnexpectedStatusCodeException
as
e
:
if
e
.
status_code
!=
400
:
# 400 means index not exists
raise
e
dataset_documents
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
DatasetDocument
.
dataset_id
==
dataset
.
id
,
DatasetDocument
.
indexing_status
==
'completed'
,
DatasetDocument
.
enabled
==
True
,
DatasetDocument
.
archived
==
False
,
)
.
all
()
documents
=
[]
for
dataset_document
in
dataset_documents
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
dataset_document
.
id
,
DocumentSegment
.
status
==
'completed'
,
DocumentSegment
.
enabled
==
True
)
.
all
()
for
segment
in
segments
:
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
documents
.
append
(
document
)
origin_index_struct
=
self
.
dataset
.
index_struct
self
.
dataset
.
index_struct
=
None
if
documents
:
try
:
self
.
create
(
documents
)
except
Exception
as
e
:
self
.
dataset
.
index_struct
=
origin_index_struct
raise
e
dataset
.
index_struct
=
json
.
dumps
(
self
.
to_index_struct
())
db
.
session
.
commit
()
self
.
dataset
=
dataset
logging
.
info
(
f
"Dataset {dataset.id} recreate successfully."
)
api/core/index/vector_index/qdrant_vector_index.py
0 → 100644
View file @
eea011bd
import
os
from
typing
import
Optional
,
Any
,
List
,
cast
import
qdrant_client
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.qdrant_vector_store
import
QdrantVectorStore
from
models.dataset
import
Dataset
class
QdrantConfig
(
BaseModel
):
endpoint
:
str
api_key
:
Optional
[
str
]
root_path
:
Optional
[
str
]
def
to_qdrant_params
(
self
):
if
self
.
endpoint
and
self
.
endpoint
.
startswith
(
'path:'
):
path
=
self
.
endpoint
.
replace
(
'path:'
,
''
)
if
not
os
.
path
.
isabs
(
path
):
path
=
os
.
path
.
join
(
self
.
root_path
,
path
)
return
{
'path'
:
path
}
else
:
return
{
'url'
:
self
.
endpoint
,
'api_key'
:
self
.
api_key
,
}
class
QdrantVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
QdrantConfig
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
,
embeddings
)
self
.
_client_config
=
config
def
get_type
(
self
)
->
str
:
return
'qdrant'
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
if
self
.
dataset
.
index_struct_dict
:
return
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'collection_name'
]
dataset_id
=
dataset
.
id
return
"Index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
dataset
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
QdrantVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
collection_name
=
self
.
get_index_name
(
self
.
dataset
),
ids
=
uuids
,
content_payload_key
=
'text'
,
**
self
.
_client_config
.
to_qdrant_params
()
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
return
self
.
_vector_store
client
=
qdrant_client
.
QdrantClient
(
**
self
.
_client_config
.
to_qdrant_params
()
)
return
QdrantVectorStore
(
client
=
client
,
collection_name
=
self
.
get_index_name
(
self
.
dataset
),
embeddings
=
self
.
_embeddings
,
content_payload_key
=
'text'
)
def
_get_vector_store_class
(
self
)
->
type
:
return
QdrantVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
from
qdrant_client.http
import
models
vector_store
.
del_texts
(
models
.
Filter
(
must
=
[
models
.
FieldCondition
(
key
=
"metadata.document_id"
,
match
=
models
.
MatchValue
(
value
=
document_id
),
),
],
))
def
_is_origin
(
self
):
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'collection_name'
]
if
class_prefix
.
startswith
(
'Vector_'
):
# original class_prefix
return
True
return
False
api/core/index/vector_index/vector_index.py
0 → 100644
View file @
eea011bd
import
json
from
flask
import
current_app
from
langchain.embeddings.base
import
Embeddings
from
core.index.vector_index.base
import
BaseVectorIndex
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
Document
class
VectorIndex
:
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
dict
,
embeddings
:
Embeddings
):
self
.
_dataset
=
dataset
self
.
_embeddings
=
embeddings
self
.
_vector_index
=
self
.
_init_vector_index
(
dataset
,
config
,
embeddings
)
def
_init_vector_index
(
self
,
dataset
:
Dataset
,
config
:
dict
,
embeddings
:
Embeddings
)
->
BaseVectorIndex
:
vector_type
=
config
.
get
(
'VECTOR_STORE'
)
if
self
.
_dataset
.
index_struct_dict
:
vector_type
=
self
.
_dataset
.
index_struct_dict
[
'type'
]
if
not
vector_type
:
raise
ValueError
(
f
"Vector store must be specified."
)
if
vector_type
==
"weaviate"
:
from
core.index.vector_index.weaviate_vector_index
import
WeaviateVectorIndex
,
WeaviateConfig
return
WeaviateVectorIndex
(
dataset
=
dataset
,
config
=
WeaviateConfig
(
endpoint
=
config
.
get
(
'WEAVIATE_ENDPOINT'
),
api_key
=
config
.
get
(
'WEAVIATE_API_KEY'
),
batch_size
=
int
(
config
.
get
(
'WEAVIATE_BATCH_SIZE'
))
),
embeddings
=
embeddings
)
elif
vector_type
==
"qdrant"
:
from
core.index.vector_index.qdrant_vector_index
import
QdrantVectorIndex
,
QdrantConfig
return
QdrantVectorIndex
(
dataset
=
dataset
,
config
=
QdrantConfig
(
endpoint
=
config
.
get
(
'QDRANT_URL'
),
api_key
=
config
.
get
(
'QDRANT_API_KEY'
),
root_path
=
current_app
.
root_path
),
embeddings
=
embeddings
)
else
:
raise
ValueError
(
f
"Vector store {config.get('VECTOR_STORE')} is not supported."
)
def
add_texts
(
self
,
texts
:
list
[
Document
],
**
kwargs
):
if
not
self
.
_dataset
.
index_struct_dict
:
self
.
_vector_index
.
create
(
texts
,
**
kwargs
)
self
.
_dataset
.
index_struct
=
json
.
dumps
(
self
.
_vector_index
.
to_index_struct
())
db
.
session
.
commit
()
return
self
.
_vector_index
.
add_texts
(
texts
,
**
kwargs
)
def
__getattr__
(
self
,
name
):
if
self
.
_vector_index
is
not
None
:
method
=
getattr
(
self
.
_vector_index
,
name
)
if
callable
(
method
):
return
method
raise
AttributeError
(
f
"'VectorIndex' object has no attribute '{name}'"
)
api/core/index/vector_index/weaviate_vector_index.py
0 → 100644
View file @
eea011bd
from
typing
import
Optional
,
cast
import
requests
import
weaviate
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
,
root_validator
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
from
core.vector_store.weaviate_vector_store
import
WeaviateVectorStore
from
models.dataset
import
Dataset
class
WeaviateConfig
(
BaseModel
):
endpoint
:
str
api_key
:
Optional
[
str
]
batch_size
:
int
=
100
@
root_validator
()
def
validate_config
(
cls
,
values
:
dict
)
->
dict
:
if
not
values
[
'endpoint'
]:
raise
ValueError
(
"config WEAVIATE_ENDPOINT is required"
)
return
values
class
WeaviateVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
WeaviateConfig
,
embeddings
:
Embeddings
):
super
()
.
__init__
(
dataset
,
embeddings
)
self
.
_client
=
self
.
_init_client
(
config
)
def
_init_client
(
self
,
config
:
WeaviateConfig
)
->
weaviate
.
Client
:
auth_config
=
weaviate
.
auth
.
AuthApiKey
(
api_key
=
config
.
api_key
)
weaviate
.
connect
.
connection
.
has_grpc
=
False
try
:
client
=
weaviate
.
Client
(
url
=
config
.
endpoint
,
auth_client_secret
=
auth_config
,
timeout_config
=
(
5
,
60
),
startup_period
=
None
)
except
requests
.
exceptions
.
ConnectionError
:
raise
ConnectionError
(
"Vector database connection error"
)
client
.
batch
.
configure
(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size
=
config
.
batch_size
,
# dynamically update the `batch_size` based on import speed
dynamic
=
True
,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries
=
3
,
)
return
client
def
get_type
(
self
)
->
str
:
return
'weaviate'
def
get_index_name
(
self
,
dataset
:
Dataset
)
->
str
:
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
class_prefix
+=
'_Node'
return
class_prefix
dataset_id
=
dataset
.
id
return
"Vector_index_"
+
dataset_id
.
replace
(
"-"
,
"_"
)
+
'_Node'
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
dataset
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
uuids
=
self
.
_get_uuids
(
texts
)
self
.
_vector_store
=
WeaviateVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
dataset
),
uuids
=
uuids
,
by_text
=
False
)
return
self
def
_get_vector_store
(
self
)
->
VectorStore
:
"""Only for created index."""
if
self
.
_vector_store
:
return
self
.
_vector_store
attributes
=
[
'doc_id'
,
'dataset_id'
,
'document_id'
]
if
self
.
_is_origin
():
attributes
=
[
'doc_id'
]
return
WeaviateVectorStore
(
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
dataset
),
text_key
=
'text'
,
embedding
=
self
.
_embeddings
,
attributes
=
attributes
,
by_text
=
False
)
def
_get_vector_store_class
(
self
)
->
type
:
return
WeaviateVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
if
self
.
_is_origin
():
self
.
recreate_dataset
(
self
.
dataset
)
return
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
vector_store
.
del_texts
({
"operator"
:
"Equal"
,
"path"
:
[
"document_id"
],
"valueText"
:
document_id
})
def
_is_origin
(
self
):
if
self
.
dataset
.
index_struct_dict
:
class_prefix
:
str
=
self
.
dataset
.
index_struct_dict
[
'vector_store'
][
'class_prefix'
]
if
not
class_prefix
.
endswith
(
'_Node'
):
# original class_prefix
return
True
return
False
api/core/indexing_runner.py
View file @
eea011bd
import
datetime
import
json
import
logging
import
re
import
tempfile
import
time
from
pathlib
import
Path
from
typing
import
Optional
,
List
import
uuid
from
typing
import
Optional
,
List
,
cast
from
flask
import
current_app
from
flask_login
import
current_user
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.schema
import
Document
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
,
TextSplitter
from
llama_index
import
SimpleDirectoryReader
from
llama_index.data_structs
import
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
llama_index.node_parser
import
SimpleNodeParser
,
NodeParser
from
llama_index.readers.file.base
import
DEFAULT_FILE_EXTRACTOR
from
llama_index.readers.file.markdown_parser
import
MarkdownParser
from
core.data_source.notion
import
NotionPageReader
from
core.index.readers.xlsx_parser
import
XLSXParser
from
core.data_loader.file_extractor
import
FileExtractor
from
core.data_loader.loader.notion
import
NotionLoader
from
core.docstore.dataset_docstore
import
DatesetDocumentStore
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.readers.html_parser
import
HTMLParser
from
core.index.readers.markdown_parser
import
MarkdownParser
from
core.index.readers.pdf_parser
import
PDFParser
from
core.index.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.index.vector_index
import
VectorIndex
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.index
import
IndexBuilder
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.error
import
ProviderTokenNotInitError
from
core.llm.llm_builder
import
LLMBuilder
from
core.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.llm.token_calculator
import
TokenCalculator
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_storage
import
storage
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
,
DatasetProcessRule
from
libs
import
helper
from
models.dataset
import
Document
as
DatasetDocument
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetProcessRule
from
models.model
import
UploadFile
from
models.source
import
DataSourceBinding
...
...
@@ -40,135 +39,171 @@ class IndexingRunner:
self
.
storage
=
storage
self
.
embedding_model_name
=
embedding_model_name
def
run
(
self
,
d
ocuments
:
List
[
Document
]):
def
run
(
self
,
d
ataset_documents
:
List
[
Dataset
Document
]):
"""Run the indexing process."""
for
document
in
documents
:
for
dataset_document
in
dataset_documents
:
try
:
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# load file
text_docs
=
self
.
_load_data
(
dataset_document
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
dataset_document
.
dataset_process_rule_id
)
.
\
first
()
# get splitter
splitter
=
self
.
_get_splitter
(
processing_rule
)
# split to documents
documents
=
self
.
_step_split
(
text_docs
=
text_docs
,
splitter
=
splitter
,
dataset
=
dataset
,
dataset_document
=
dataset_document
,
processing_rule
=
processing_rule
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
dataset_document
=
dataset_document
,
documents
=
documents
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
run_in_splitting_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is splitting."""
try
:
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
document
.
dataset_id
id
=
d
ataset_d
ocument
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
)
.
all
()
db
.
session
.
delete
(
document_segments
)
db
.
session
.
commit
()
# load file
text_docs
=
self
.
_load_data
(
document
)
text_docs
=
self
.
_load_data
(
d
ataset_d
ocument
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
document
.
dataset_process_rule_id
)
.
\
filter
(
DatasetProcessRule
.
id
==
d
ataset_d
ocument
.
dataset_process_rule_id
)
.
\
first
()
# get
node parser for splitting
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
# get
splitter
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
node
s
=
self
.
_step_split
(
# split to
document
s
document
s
=
self
.
_step_split
(
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
dataset
=
dataset
,
d
ocument
=
document
,
d
ataset_document
=
dataset_
document
,
processing_rule
=
processing_rule
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
d
ocument
=
document
,
nodes
=
node
s
d
ataset_document
=
dataset_
document
,
documents
=
document
s
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
run_in_splitting_status
(
self
,
document
:
Document
):
"""Run the indexing process when the index_status is splitting."""
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
document
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
document
.
id
)
.
all
()
db
.
session
.
delete
(
document_segments
)
db
.
session
.
commit
()
# load file
text_docs
=
self
.
_load_data
(
document
)
# get the process rule
processing_rule
=
db
.
session
.
query
(
DatasetProcessRule
)
.
\
filter
(
DatasetProcessRule
.
id
==
document
.
dataset_process_rule_id
)
.
\
first
()
# get node parser for splitting
node_parser
=
self
.
_get_node_parser
(
processing_rule
)
def
run_in_indexing_status
(
self
,
dataset_document
:
DatasetDocument
):
"""Run the indexing process when the index_status is indexing."""
try
:
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_document
.
dataset_id
)
.
first
()
# split to nodes
nodes
=
self
.
_step_split
(
text_docs
=
text_docs
,
node_parser
=
node_parser
,
dataset
=
dataset
,
document
=
document
,
processing_rule
=
processing_rule
)
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
document
=
document
,
nodes
=
nodes
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
dataset_document
.
id
)
.
all
()
documents
=
[]
if
document_segments
:
for
document_segment
in
document_segments
:
# transform segment to node
if
document_segment
.
status
!=
"completed"
:
document
=
Document
(
page_content
=
document_segment
.
content
,
metadata
=
{
"doc_id"
:
document_segment
.
index_node_id
,
"doc_hash"
:
document_segment
.
index_node_hash
,
"document_id"
:
document_segment
.
document_id
,
"dataset_id"
:
document_segment
.
dataset_id
,
}
)
documents
.
append
(
document
)
def
run_in_indexing_status
(
self
,
document
:
Document
):
"""Run the indexing process when the index_status is indexing."""
# get dataset
dataset
=
Dataset
.
query
.
filter_by
(
id
=
document
.
dataset_id
)
.
first
()
if
not
dataset
:
raise
ValueError
(
"no dataset found"
)
# get exist document_segment list and delete
document_segments
=
DocumentSegment
.
query
.
filter_by
(
dataset_id
=
dataset
.
id
,
document_id
=
document
.
id
)
.
all
()
nodes
=
[]
if
document_segments
:
for
document_segment
in
document_segments
:
# transform segment to node
if
document_segment
.
status
!=
"completed"
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
document_segment
.
document_id
,
}
previous_segment
=
document_segment
.
previous_segment
if
previous_segment
:
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_segment
.
index_node_id
next_segment
=
document_segment
.
next_segment
if
next_segment
:
relationships
[
DocumentRelationship
.
NEXT
]
=
next_segment
.
index_node_id
node
=
Node
(
doc_id
=
document_segment
.
index_node_id
,
doc_hash
=
document_segment
.
index_node_hash
,
text
=
document_segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
)
nodes
.
append
(
node
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
document
=
document
,
nodes
=
nodes
)
# build index
self
.
_build_index
(
dataset
=
dataset
,
dataset_document
=
dataset_document
,
documents
=
documents
)
except
DocumentIsPausedException
:
raise
DocumentIsPausedException
(
'Document paused, document id: {}'
.
format
(
dataset_document
.
id
))
except
ProviderTokenNotInitError
as
e
:
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
.
description
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
dataset_document
.
indexing_status
=
'error'
dataset_document
.
error
=
str
(
e
)
dataset_document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
def
file_indexing_estimate
(
self
,
file_details
:
List
[
UploadFile
],
tmp_processing_rule
:
dict
)
->
dict
:
"""
...
...
@@ -179,28 +214,28 @@ class IndexingRunner:
total_segments
=
0
for
file_detail
in
file_details
:
# load data from file
text_docs
=
self
.
_load_data_from_file
(
file_detail
)
text_docs
=
FileExtractor
.
load
(
file_detail
)
processing_rule
=
DatasetProcessRule
(
mode
=
tmp_processing_rule
[
"mode"
],
rules
=
json
.
dumps
(
tmp_processing_rule
[
"rules"
])
)
# get
node parser for splitting
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
# get
splitter
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
nodes
=
self
.
_split_to_node
s
(
# split to
document
s
documents
=
self
.
_split_to_document
s
(
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
processing_rule
=
processing_rule
)
total_segments
+=
len
(
node
s
)
for
node
in
node
s
:
total_segments
+=
len
(
document
s
)
for
document
in
document
s
:
if
len
(
preview_texts
)
<
5
:
preview_texts
.
append
(
node
.
get_text
()
)
preview_texts
.
append
(
document
.
page_content
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
node
.
get_text
()
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
document
.
page_content
)
return
{
"total_segments"
:
total_segments
,
...
...
@@ -230,35 +265,36 @@ class IndexingRunner:
)
.
first
()
if
not
data_source_binding
:
raise
ValueError
(
'Data source binding not found.'
)
reader
=
NotionPageReader
(
integration_token
=
data_source_binding
.
access_token
)
for
page
in
notion_info
[
'pages'
]:
if
page
[
'type'
]
==
'page'
:
page_ids
=
[
page
[
'page_id'
]]
documents
=
reader
.
load_data_as_documents
(
page_ids
=
page_ids
)
elif
page
[
'type'
]
==
'database'
:
documents
=
reader
.
load_data_as_documents
(
database_id
=
page
[
'page_id'
])
else
:
documents
=
[]
loader
=
NotionLoader
(
notion_access_token
=
data_source_binding
.
access_token
,
notion_workspace_id
=
workspace_id
,
notion_obj_id
=
page
[
'page_id'
],
notion_page_type
=
page
[
'type'
]
)
documents
=
loader
.
load
()
processing_rule
=
DatasetProcessRule
(
mode
=
tmp_processing_rule
[
"mode"
],
rules
=
json
.
dumps
(
tmp_processing_rule
[
"rules"
])
)
# get
node parser for splitting
node_parser
=
self
.
_get_node_pars
er
(
processing_rule
)
# get
splitter
splitter
=
self
.
_get_splitt
er
(
processing_rule
)
# split to
node
s
nodes
=
self
.
_split_to_node
s
(
# split to
document
s
documents
=
self
.
_split_to_document
s
(
text_docs
=
documents
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
processing_rule
=
processing_rule
)
total_segments
+=
len
(
node
s
)
for
node
in
node
s
:
total_segments
+=
len
(
document
s
)
for
document
in
document
s
:
if
len
(
preview_texts
)
<
5
:
preview_texts
.
append
(
node
.
get_text
()
)
preview_texts
.
append
(
document
.
page_content
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
node
.
get_text
()
)
tokens
+=
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
document
.
page_content
)
return
{
"total_segments"
:
total_segments
,
...
...
@@ -268,14 +304,14 @@ class IndexingRunner:
"preview"
:
preview_texts
}
def
_load_data
(
self
,
d
ocument
:
Document
)
->
List
[
Document
]:
def
_load_data
(
self
,
d
ataset_document
:
Dataset
Document
)
->
List
[
Document
]:
# load file
if
document
.
data_source_type
not
in
[
"upload_file"
,
"notion_import"
]:
if
d
ataset_d
ocument
.
data_source_type
not
in
[
"upload_file"
,
"notion_import"
]:
return
[]
data_source_info
=
document
.
data_source_info_dict
data_source_info
=
d
ataset_d
ocument
.
data_source_info_dict
text_docs
=
[]
if
document
.
data_source_type
==
'upload_file'
:
if
d
ataset_d
ocument
.
data_source_type
==
'upload_file'
:
if
not
data_source_info
or
'upload_file_id'
not
in
data_source_info
:
raise
ValueError
(
"no upload file found"
)
...
...
@@ -283,47 +319,28 @@ class IndexingRunner:
filter
(
UploadFile
.
id
==
data_source_info
[
'upload_file_id'
])
.
\
one_or_none
()
text_docs
=
self
.
_load_data_from_file
(
file_detail
)
elif
document
.
data_source_type
==
'notion_import'
:
if
not
data_source_info
or
'notion_page_id'
not
in
data_source_info
\
or
'notion_workspace_id'
not
in
data_source_info
:
raise
ValueError
(
"no notion page found"
)
workspace_id
=
data_source_info
[
'notion_workspace_id'
]
page_id
=
data_source_info
[
'notion_page_id'
]
page_type
=
data_source_info
[
'type'
]
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
db
.
and_
(
DataSourceBinding
.
tenant_id
==
document
.
tenant_id
,
DataSourceBinding
.
provider
==
'notion'
,
DataSourceBinding
.
disabled
==
False
,
DataSourceBinding
.
source_info
[
'workspace_id'
]
==
f
'"{workspace_id}"'
)
)
.
first
()
if
not
data_source_binding
:
raise
ValueError
(
'Data source binding not found.'
)
if
page_type
==
'page'
:
# add page last_edited_time to data_source_info
self
.
_get_notion_page_last_edited_time
(
page_id
,
data_source_binding
.
access_token
,
document
)
text_docs
=
self
.
_load_page_data_from_notion
(
page_id
,
data_source_binding
.
access_token
)
elif
page_type
==
'database'
:
# add page last_edited_time to data_source_info
self
.
_get_notion_database_last_edited_time
(
page_id
,
data_source_binding
.
access_token
,
document
)
text_docs
=
self
.
_load_database_data_from_notion
(
page_id
,
data_source_binding
.
access_token
)
text_docs
=
FileExtractor
.
load
(
file_detail
)
elif
dataset_document
.
data_source_type
==
'notion_import'
:
loader
=
NotionLoader
.
from_document
(
dataset_document
)
text_docs
=
loader
.
load
()
# update document status to splitting
self
.
_update_document_index_status
(
document_id
=
document
.
id
,
document_id
=
d
ataset_d
ocument
.
id
,
after_indexing_status
=
"splitting"
,
extra_update_params
=
{
D
ocument
.
word_count
:
sum
([
len
(
text_doc
.
tex
t
)
for
text_doc
in
text_docs
]),
Document
.
parsing_completed_at
:
datetime
.
datetime
.
utcnow
()
D
atasetDocument
.
word_count
:
sum
([
len
(
text_doc
.
page_conten
t
)
for
text_doc
in
text_docs
]),
D
atasetD
ocument
.
parsing_completed_at
:
datetime
.
datetime
.
utcnow
()
}
)
# replace doc id to document model id
text_docs
=
cast
(
List
[
Document
],
text_docs
)
for
text_doc
in
text_docs
:
# remove invalid symbol
text_doc
.
text
=
self
.
filter_string
(
text_doc
.
get_text
())
text_doc
.
doc_id
=
document
.
id
text_doc
.
page_content
=
self
.
filter_string
(
text_doc
.
page_content
)
text_doc
.
metadata
[
'document_id'
]
=
dataset_document
.
id
text_doc
.
metadata
[
'dataset_id'
]
=
dataset_document
.
dataset_id
return
text_docs
...
...
@@ -331,61 +348,7 @@ class IndexingRunner:
pattern
=
re
.
compile
(
'[
\x00
-
\x08\x0B\x0C\x0E
-
\x1F\x7F\x80
-
\xFF
]'
)
return
pattern
.
sub
(
''
,
text
)
def
_load_data_from_file
(
self
,
upload_file
:
UploadFile
)
->
List
[
Document
]:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
suffix
=
Path
(
upload_file
.
key
)
.
suffix
filepath
=
f
"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
self
.
storage
.
download
(
upload_file
.
key
,
filepath
)
file_extractor
=
DEFAULT_FILE_EXTRACTOR
.
copy
()
file_extractor
[
".markdown"
]
=
MarkdownParser
()
file_extractor
[
".md"
]
=
MarkdownParser
()
file_extractor
[
".html"
]
=
HTMLParser
()
file_extractor
[
".htm"
]
=
HTMLParser
()
file_extractor
[
".pdf"
]
=
PDFParser
({
'upload_file'
:
upload_file
})
file_extractor
[
".xlsx"
]
=
XLSXParser
()
loader
=
SimpleDirectoryReader
(
input_files
=
[
filepath
],
file_extractor
=
file_extractor
)
text_docs
=
loader
.
load_data
()
return
text_docs
def
_load_page_data_from_notion
(
self
,
page_id
:
str
,
access_token
:
str
)
->
List
[
Document
]:
page_ids
=
[
page_id
]
reader
=
NotionPageReader
(
integration_token
=
access_token
)
text_docs
=
reader
.
load_data_as_documents
(
page_ids
=
page_ids
)
return
text_docs
def
_load_database_data_from_notion
(
self
,
database_id
:
str
,
access_token
:
str
)
->
List
[
Document
]:
reader
=
NotionPageReader
(
integration_token
=
access_token
)
text_docs
=
reader
.
load_data_as_documents
(
database_id
=
database_id
)
return
text_docs
def
_get_notion_page_last_edited_time
(
self
,
page_id
:
str
,
access_token
:
str
,
document
:
Document
):
reader
=
NotionPageReader
(
integration_token
=
access_token
)
last_edited_time
=
reader
.
get_page_last_edited_time
(
page_id
)
data_source_info
=
document
.
data_source_info_dict
data_source_info
[
'last_edited_time'
]
=
last_edited_time
update_params
=
{
Document
.
data_source_info
:
json
.
dumps
(
data_source_info
)
}
Document
.
query
.
filter_by
(
id
=
document
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
def
_get_notion_database_last_edited_time
(
self
,
page_id
:
str
,
access_token
:
str
,
document
:
Document
):
reader
=
NotionPageReader
(
integration_token
=
access_token
)
last_edited_time
=
reader
.
get_database_last_edited_time
(
page_id
)
data_source_info
=
document
.
data_source_info_dict
data_source_info
[
'last_edited_time'
]
=
last_edited_time
update_params
=
{
Document
.
data_source_info
:
json
.
dumps
(
data_source_info
)
}
Document
.
query
.
filter_by
(
id
=
document
.
id
)
.
update
(
update_params
)
db
.
session
.
commit
()
def
_get_node_parser
(
self
,
processing_rule
:
DatasetProcessRule
)
->
NodeParser
:
def
_get_splitter
(
self
,
processing_rule
:
DatasetProcessRule
)
->
TextSplitter
:
"""
Get the NodeParser object according to the processing rule.
"""
...
...
@@ -414,68 +377,83 @@ class IndexingRunner:
separators
=
[
"
\n\n
"
,
"。"
,
"."
,
" "
,
""
]
)
return
SimpleNodeParser
(
text_splitter
=
character_splitter
,
include_extra_info
=
True
)
return
character_splitter
def
_step_split
(
self
,
text_docs
:
List
[
Document
],
node_parser
:
NodeParser
,
dataset
:
Dataset
,
document
:
Document
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Node
]:
def
_step_split
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitter
,
dataset
:
Dataset
,
dataset_document
:
DatasetDocument
,
processing_rule
:
DatasetProcessRule
)
\
->
List
[
Document
]:
"""
Split the text documents into
node
s and save them to the document segment.
Split the text documents into
document
s and save them to the document segment.
"""
nodes
=
self
.
_split_to_node
s
(
documents
=
self
.
_split_to_document
s
(
text_docs
=
text_docs
,
node_parser
=
node_pars
er
,
splitter
=
splitt
er
,
processing_rule
=
processing_rule
)
# save node to document segment
doc_store
=
DatesetDocumentStore
(
dataset
=
dataset
,
user_id
=
document
.
created_by
,
user_id
=
d
ataset_d
ocument
.
created_by
,
embedding_model_name
=
self
.
embedding_model_name
,
document_id
=
document
.
id
document_id
=
d
ataset_d
ocument
.
id
)
# add document segments
doc_store
.
add_documents
(
node
s
)
doc_store
.
add_documents
(
document
s
)
# update document status to indexing
cur_time
=
datetime
.
datetime
.
utcnow
()
self
.
_update_document_index_status
(
document_id
=
document
.
id
,
document_id
=
d
ataset_d
ocument
.
id
,
after_indexing_status
=
"indexing"
,
extra_update_params
=
{
Document
.
cleaning_completed_at
:
cur_time
,
Document
.
splitting_completed_at
:
cur_time
,
D
atasetD
ocument
.
cleaning_completed_at
:
cur_time
,
D
atasetD
ocument
.
splitting_completed_at
:
cur_time
,
}
)
# update segment status to indexing
self
.
_update_segments_by_document
(
d
ocument_id
=
document
.
id
,
d
ataset_document_id
=
dataset_
document
.
id
,
update_params
=
{
DocumentSegment
.
status
:
"indexing"
,
DocumentSegment
.
indexing_at
:
datetime
.
datetime
.
utcnow
()
}
)
return
node
s
return
document
s
def
_split_to_
nodes
(
self
,
text_docs
:
List
[
Document
],
node_parser
:
NodePars
er
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Node
]:
def
_split_to_
documents
(
self
,
text_docs
:
List
[
Document
],
splitter
:
TextSplitt
er
,
processing_rule
:
DatasetProcessRule
)
->
List
[
Document
]:
"""
Split the text documents into nodes.
"""
all_
node
s
=
[]
all_
document
s
=
[]
for
text_doc
in
text_docs
:
# document clean
document_text
=
self
.
_document_clean
(
text_doc
.
get_text
()
,
processing_rule
)
text_doc
.
tex
t
=
document_text
document_text
=
self
.
_document_clean
(
text_doc
.
page_content
,
processing_rule
)
text_doc
.
page_conten
t
=
document_text
# parse document to nodes
nodes
=
node_parser
.
get_nodes_from_documents
([
text_doc
])
nodes
=
[
node
for
node
in
nodes
if
node
.
text
is
not
None
and
node
.
text
.
strip
()]
all_nodes
.
extend
(
nodes
)
documents
=
splitter
.
split_documents
([
text_doc
])
split_documents
=
[]
for
document
in
documents
:
if
document
.
page_content
is
None
or
not
document
.
page_content
.
strip
():
continue
doc_id
=
str
(
uuid
.
uuid4
())
hash
=
helper
.
generate_text_hash
(
document
.
page_content
)
document
.
metadata
[
'doc_id'
]
=
doc_id
document
.
metadata
[
'doc_hash'
]
=
hash
split_documents
.
append
(
document
)
all_documents
.
extend
(
split_documents
)
return
all_
node
s
return
all_
document
s
def
_document_clean
(
self
,
text
:
str
,
processing_rule
:
DatasetProcessRule
)
->
str
:
"""
...
...
@@ -506,37 +484,38 @@ class IndexingRunner:
return
text
def
_build_index
(
self
,
dataset
:
Dataset
,
d
ocument
:
Document
,
nodes
:
List
[
Node
])
->
None
:
def
_build_index
(
self
,
dataset
:
Dataset
,
d
ataset_document
:
DatasetDocument
,
documents
:
List
[
Document
])
->
None
:
"""
Build the index for the document.
"""
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
keyword_table_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# chunk nodes by chunk size
indexing_start_at
=
time
.
perf_counter
()
tokens
=
0
chunk_size
=
100
for
i
in
range
(
0
,
len
(
node
s
),
chunk_size
):
for
i
in
range
(
0
,
len
(
document
s
),
chunk_size
):
# check document is paused
self
.
_check_document_paused_status
(
document
.
id
)
chunk_
nodes
=
node
s
[
i
:
i
+
chunk_size
]
self
.
_check_document_paused_status
(
d
ataset_d
ocument
.
id
)
chunk_
documents
=
document
s
[
i
:
i
+
chunk_size
]
tokens
+=
sum
(
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
node
.
get_text
())
for
node
in
chunk_nodes
TokenCalculator
.
get_num_tokens
(
self
.
embedding_model_name
,
document
.
page_content
)
for
document
in
chunk_documents
)
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
vector_index
.
add_
nodes
(
chunk_node
s
)
if
vector_index
:
vector_index
.
add_
texts
(
chunk_document
s
)
# save keyword index
keyword_table_index
.
add_
nodes
(
chunk_node
s
)
keyword_table_index
.
add_
texts
(
chunk_document
s
)
node_ids
=
[
node
.
doc_id
for
node
in
chunk_node
s
]
document_ids
=
[
document
.
metadata
[
'doc_id'
]
for
document
in
chunk_document
s
]
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
index_node_id
.
in_
(
node
_ids
),
DocumentSegment
.
document_id
==
d
ataset_d
ocument
.
id
,
DocumentSegment
.
index_node_id
.
in_
(
document
_ids
),
DocumentSegment
.
status
==
"indexing"
)
.
update
({
DocumentSegment
.
status
:
"completed"
,
...
...
@@ -549,12 +528,12 @@ class IndexingRunner:
# update document status to completed
self
.
_update_document_index_status
(
document_id
=
document
.
id
,
document_id
=
d
ataset_d
ocument
.
id
,
after_indexing_status
=
"completed"
,
extra_update_params
=
{
Document
.
tokens
:
tokens
,
Document
.
completed_at
:
datetime
.
datetime
.
utcnow
(),
Document
.
indexing_latency
:
indexing_end_at
-
indexing_start_at
,
D
atasetD
ocument
.
tokens
:
tokens
,
D
atasetD
ocument
.
completed_at
:
datetime
.
datetime
.
utcnow
(),
D
atasetD
ocument
.
indexing_latency
:
indexing_end_at
-
indexing_start_at
,
}
)
...
...
@@ -569,25 +548,25 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count
=
Document
.
query
.
filter_by
(
id
=
document_id
,
is_paused
=
True
)
.
count
()
count
=
D
atasetD
ocument
.
query
.
filter_by
(
id
=
document_id
,
is_paused
=
True
)
.
count
()
if
count
>
0
:
raise
DocumentIsPausedException
()
update_params
=
{
Document
.
indexing_status
:
after_indexing_status
D
atasetD
ocument
.
indexing_status
:
after_indexing_status
}
if
extra_update_params
:
update_params
.
update
(
extra_update_params
)
Document
.
query
.
filter_by
(
id
=
document_id
)
.
update
(
update_params
)
D
atasetD
ocument
.
query
.
filter_by
(
id
=
document_id
)
.
update
(
update_params
)
db
.
session
.
commit
()
def
_update_segments_by_document
(
self
,
document_id
:
str
,
update_params
:
dict
)
->
None
:
def
_update_segments_by_document
(
self
,
d
ataset_d
ocument_id
:
str
,
update_params
:
dict
)
->
None
:
"""
Update the document segment by document id.
"""
DocumentSegment
.
query
.
filter_by
(
document_id
=
document_id
)
.
update
(
update_params
)
DocumentSegment
.
query
.
filter_by
(
document_id
=
d
ataset_d
ocument_id
)
.
update
(
update_params
)
db
.
session
.
commit
()
...
...
api/core/llm/llm_builder.py
View file @
eea011bd
from
typing
import
Union
,
Optional
from
typing
import
Union
,
Optional
,
List
from
langchain.callbacks
import
CallbackManager
from
langchain.llms.fake
import
FakeListLLM
from
langchain.callbacks.base
import
BaseCallbackHandler
from
core.constant
import
llm_constant
from
core.llm.error
import
ProviderTokenNotInitError
...
...
@@ -32,12 +31,11 @@ class LLMBuilder:
"""
@
classmethod
def
to_llm
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
**
kwargs
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
,
FakeListLLM
]:
if
model_name
==
'fake'
:
return
FakeListLLM
(
responses
=
[])
def
to_llm
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
**
kwargs
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
provider
=
cls
.
get_default_provider
(
tenant_id
)
model_credentials
=
cls
.
get_model_credentials
(
tenant_id
,
provider
,
model_name
)
mode
=
cls
.
get_mode_by_model
(
model_name
)
if
mode
==
'chat'
:
if
provider
==
'openai'
:
...
...
@@ -52,16 +50,21 @@ class LLMBuilder:
else
:
raise
ValueError
(
f
"model name {model_name} is not supported."
)
model_credentials
=
cls
.
get_model_credentials
(
tenant_id
,
provider
,
model_name
)
model_kwargs
=
{
'top_p'
:
kwargs
.
get
(
'top_p'
,
1
),
'frequency_penalty'
:
kwargs
.
get
(
'frequency_penalty'
,
0
),
'presence_penalty'
:
kwargs
.
get
(
'presence_penalty'
,
0
),
}
model_extras_kwargs
=
model_kwargs
if
mode
==
'completion'
else
{
'model_kwargs'
:
model_kwargs
}
return
llm_cls
(
model_name
=
model_name
,
temperature
=
kwargs
.
get
(
'temperature'
,
0
),
max_tokens
=
kwargs
.
get
(
'max_tokens'
,
256
),
top_p
=
kwargs
.
get
(
'top_p'
,
1
),
frequency_penalty
=
kwargs
.
get
(
'frequency_penalty'
,
0
),
presence_penalty
=
kwargs
.
get
(
'presence_penalty'
,
0
),
callback_manager
=
kwargs
.
get
(
'callback_manager'
,
None
),
**
model_extras_kwargs
,
callbacks
=
kwargs
.
get
(
'callbacks'
,
None
),
streaming
=
kwargs
.
get
(
'streaming'
,
False
),
# request_timeout=None
**
model_credentials
...
...
@@ -69,7 +72,7 @@ class LLMBuilder:
@
classmethod
def
to_llm_from_model
(
cls
,
tenant_id
:
str
,
model
:
dict
,
streaming
:
bool
=
False
,
callback
_manager
:
Optional
[
CallbackManager
]
=
None
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
callback
s
:
Optional
[
List
[
BaseCallbackHandler
]
]
=
None
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
model_name
=
model
.
get
(
"name"
)
completion_params
=
model
.
get
(
"completion_params"
,
{})
...
...
@@ -82,7 +85,7 @@ class LLMBuilder:
frequency_penalty
=
completion_params
.
get
(
'frequency_penalty'
,
0.1
),
presence_penalty
=
completion_params
.
get
(
'presence_penalty'
,
0.1
),
streaming
=
streaming
,
callback
_manager
=
callback_manager
callback
s
=
callbacks
)
@
classmethod
...
...
api/core/llm/provider/azure_provider.py
View file @
eea011bd
...
...
@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
"""
config
=
self
.
get_provider_api_key
(
model_id
=
model_id
)
config
[
'openai_api_type'
]
=
'azure'
config
[
'deployment_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
if
model_id
==
'text-embedding-ada-002'
:
config
[
'deployment'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
else
:
config
[
'deployment_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
return
config
def
get_provider_name
(
self
):
...
...
api/core/llm/streamable_azure_chat_open_ai.py
View file @
eea011bd
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
AsyncCallbackManagerForLLMRun
,
Callbacks
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.chat_models
import
AzureChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
...
...
@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return
message_tokens
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
ChatResult
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
chat_result
=
super
()
.
_generate
(
messages
,
stop
)
result
=
LLMResult
(
generations
=
[
chat_result
.
generations
],
llm_output
=
chat_result
.
llm_output
)
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
return
chat_result
async
def
_agenerate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
ChatResult
:
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
else
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
chat_result
=
super
()
.
_generate
(
messages
,
stop
)
result
=
LLMResult
(
generations
=
[
chat_result
.
generations
],
llm_output
=
chat_result
.
llm_output
)
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
else
:
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
return
chat_result
@
handle_llm_exceptions
def
generate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
super
()
.
generate
(
messages
,
stop
)
return
super
()
.
generate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
async
def
agenerate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
messages
,
stop
)
return
await
super
()
.
agenerate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
api/core/llm/streamable_azure_open_ai.py
View file @
eea011bd
import
os
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
AzureOpenAI
from
langchain.schema
import
LLMResult
from
typing
import
Optional
,
List
,
Dict
,
Mapping
,
Any
...
...
@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
@
handle_llm_exceptions
def
generate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
super
()
.
generate
(
prompts
,
stop
)
return
super
()
.
generate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
async
def
agenerate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
prompts
,
stop
)
return
await
super
()
.
agenerate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
api/core/llm/streamable_chat_open_ai.py
View file @
eea011bd
import
os
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
BaseMessage
,
LLMResult
from
langchain.chat_models
import
ChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
...
...
@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
return
message_tokens
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
ChatResult
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
chat_result
=
super
()
.
_generate
(
messages
,
stop
)
result
=
LLMResult
(
generations
=
[
chat_result
.
generations
],
llm_output
=
chat_result
.
llm_output
)
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
return
chat_result
async
def
_agenerate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
ChatResult
:
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
else
:
self
.
callback_manager
.
on_llm_start
(
{
"name"
:
self
.
__class__
.
__name__
},
[(
message
.
type
+
": "
+
message
.
content
)
for
message
in
messages
],
verbose
=
self
.
verbose
)
chat_result
=
super
()
.
_generate
(
messages
,
stop
)
result
=
LLMResult
(
generations
=
[
chat_result
.
generations
],
llm_output
=
chat_result
.
llm_output
)
if
self
.
callback_manager
.
is_async
:
await
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
else
:
self
.
callback_manager
.
on_llm_end
(
result
,
verbose
=
self
.
verbose
)
return
chat_result
@
handle_llm_exceptions
def
generate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
super
()
.
generate
(
messages
,
stop
)
return
super
()
.
generate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
async
def
agenerate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
messages
,
stop
)
return
await
super
()
.
agenerate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
api/core/llm/streamable_open_ai.py
View file @
eea011bd
import
os
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
LLMResult
from
typing
import
Optional
,
List
,
Dict
,
Any
,
Mapping
from
langchain
import
OpenAI
...
...
@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}}
@
handle_llm_exceptions
def
generate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
super
()
.
generate
(
prompts
,
stop
)
return
super
()
.
generate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
async
def
agenerate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
prompts
,
stop
)
return
await
super
()
.
agenerate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py
View file @
eea011bd
from
typing
import
Any
,
List
,
Dict
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.schema
import
get_buffer_string
,
BaseMessage
,
BaseLanguageModel
from
langchain.schema
import
get_buffer_string
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
ReadOnlyConversationTokenDBBufferSharedMemory
...
...
api/core/prompt/prompts.py
View file @
eea011bd
from
llama_index
import
QueryKeywordExtractPrompt
CONVERSATION_TITLE_PROMPT
=
(
"Human:{query}
\n
-----
\n
"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.
\n
"
...
...
@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[
\"
question1
\"
,
\"
question2
\"
,
\"
question3
\"
]
\n
"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
=
(
"A question is provided below. Given the question, extract up to {max_keywords} "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords."
"I am not sure which language the following question is in. "
"If the user asked the question in Chinese, please return the keywords in Chinese. "
"If the user asked the question in English, please return the keywords in English.
\n
"
"---------------------
\n
"
"{question}
\n
"
"---------------------
\n
"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'
\n
"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE
=
QueryKeywordExtractPrompt
(
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
)
RULE_CONFIG_GENERATE_TEMPLATE
=
"""Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select
\
the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement.
...
...
api/core/
index/
spiltter/fixed_text_splitter.py
→
api/core/spiltter/fixed_text_splitter.py
View file @
eea011bd
File moved
api/core/tool/dataset_index_tool.py
0 → 100644
View file @
eea011bd
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.tools
import
BaseTool
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.keyword_table_index.keyword_table_index
import
KeywordTableIndex
,
KeywordTableConfig
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
models.dataset
import
Dataset
class
DatasetTool
(
BaseTool
):
"""Tool for querying a Dataset."""
dataset
:
Dataset
k
:
int
=
2
def
_run
(
self
,
tool_input
:
str
)
->
str
:
if
self
.
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
kw_table_index
=
KeywordTableIndex
(
dataset
=
self
.
dataset
,
config
=
KeywordTableConfig
(
max_keywords_per_chunk
=
5
)
)
documents
=
kw_table_index
.
search
(
tool_input
,
search_kwargs
=
{
'k'
:
self
.
k
})
else
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
vector_index
.
search
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
self
.
k
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
self
.
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
documents
=
await
vector_index
.
asearch
(
tool_input
,
search_type
=
'similarity'
,
search_kwargs
=
{
'k'
:
10
}
)
hit_callback
=
DatasetIndexToolCallbackHandler
(
self
.
dataset
.
id
)
hit_callback
.
on_tool_end
(
documents
)
return
str
(
"
\n
"
.
join
([
document
.
page_content
for
document
in
documents
]))
api/core/tool/dataset_tool_builder.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
Optional
from
langchain.callbacks
import
CallbackManager
from
llama_index.langchain_helpers.agents
import
IndexToolConfig
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.index_tool_callback_handler
import
DatasetIndexToolCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.prompt.prompts
import
QUERY_KEYWORD_EXTRACT_TEMPLATE
from
core.tool.llama_index_tool
import
EnhanceLlamaIndexTool
from
models.dataset
import
Dataset
class
DatasetToolBuilder
:
@
classmethod
def
build_dataset_tool
(
cls
,
dataset
:
Dataset
,
response_mode
:
str
=
"no_synthesizer"
,
callback_handler
:
Optional
[
DatasetToolCallbackHandler
]
=
None
):
if
dataset
.
indexing_technique
==
"economy"
:
# use keyword table query
index
=
KeywordTableIndex
(
dataset
=
dataset
)
.
query_index
if
not
index
:
return
None
query_kwargs
=
{
"mode"
:
"default"
,
"response_mode"
:
response_mode
,
"query_keyword_extract_template"
:
QUERY_KEYWORD_EXTRACT_TEMPLATE
,
"max_keywords_per_query"
:
5
,
# If num_chunks_per_query is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"num_chunks_per_query"
:
2
}
else
:
index
=
VectorIndex
(
dataset
=
dataset
)
.
query_index
if
not
index
:
return
None
query_kwargs
=
{
"mode"
:
"default"
,
"response_mode"
:
response_mode
,
# If top_k is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"similarity_top_k"
:
2
}
# fulfill description when it is empty
description
=
dataset
.
description
if
not
description
:
description
=
'useful for when you want to answer queries about the '
+
dataset
.
name
index_tool_config
=
IndexToolConfig
(
index
=
index
,
name
=
f
"dataset-{dataset.id}"
,
description
=
description
,
index_query_kwargs
=
query_kwargs
,
tool_kwargs
=
{
"callback_manager"
:
CallbackManager
([
callback_handler
,
DifyStdOutCallbackHandler
()])
},
# tool_kwargs={"return_direct": True},
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
)
index_callback_handler
=
DatasetIndexToolCallbackHandler
(
dataset_id
=
dataset
.
id
)
return
EnhanceLlamaIndexTool
.
from_tool_config
(
tool_config
=
index_tool_config
,
callback_handler
=
index_callback_handler
)
api/core/tool/llama_index_tool.py
deleted
100644 → 0
View file @
3eb8e66b
from
typing
import
Dict
from
langchain.tools
import
BaseTool
from
llama_index.indices.base
import
BaseGPTIndex
from
llama_index.langchain_helpers.agents
import
IndexToolConfig
from
pydantic
import
Field
from
core.callback_handler.index_tool_callback_handler
import
IndexToolCallbackHandler
class
EnhanceLlamaIndexTool
(
BaseTool
):
"""Tool for querying a LlamaIndex."""
# NOTE: name/description still needs to be set
index
:
BaseGPTIndex
query_kwargs
:
Dict
=
Field
(
default_factory
=
dict
)
return_sources
:
bool
=
False
callback_handler
:
IndexToolCallbackHandler
@
classmethod
def
from_tool_config
(
cls
,
tool_config
:
IndexToolConfig
,
callback_handler
:
IndexToolCallbackHandler
)
->
"EnhanceLlamaIndexTool"
:
"""Create a tool from a tool config."""
return_sources
=
tool_config
.
tool_kwargs
.
pop
(
"return_sources"
,
False
)
return
cls
(
index
=
tool_config
.
index
,
callback_handler
=
callback_handler
,
name
=
tool_config
.
name
,
description
=
tool_config
.
description
,
return_sources
=
return_sources
,
query_kwargs
=
tool_config
.
index_query_kwargs
,
**
tool_config
.
tool_kwargs
,
)
def
_run
(
self
,
tool_input
:
str
)
->
str
:
response
=
self
.
index
.
query
(
tool_input
,
**
self
.
query_kwargs
)
self
.
callback_handler
.
on_tool_end
(
response
)
return
str
(
response
)
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
response
=
await
self
.
index
.
aquery
(
tool_input
,
**
self
.
query_kwargs
)
self
.
callback_handler
.
on_tool_end
(
response
)
return
str
(
response
)
api/core/vector_store/base.py
deleted
100644 → 0
View file @
3eb8e66b
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
llama_index
import
ServiceContext
,
GPTVectorStoreIndex
from
llama_index.data_structs
import
Node
from
llama_index.vector_stores.types
import
VectorStore
class
BaseVectorStoreClient
(
ABC
):
@
abstractmethod
def
get_index
(
self
,
service_context
:
ServiceContext
,
config
:
dict
)
->
GPTVectorStoreIndex
:
raise
NotImplementedError
@
abstractmethod
def
to_index_config
(
self
,
index_id
:
str
)
->
dict
:
raise
NotImplementedError
class
BaseGPTVectorStoreIndex
(
GPTVectorStoreIndex
):
def
delete_node
(
self
,
node_id
:
str
):
self
.
_vector_store
.
delete_node
(
node_id
)
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
return
self
.
_vector_store
.
exists_by_node_id
(
node_id
)
class
EnhanceVectorStore
(
ABC
):
@
abstractmethod
def
delete_node
(
self
,
node_id
:
str
):
pass
@
abstractmethod
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
pass
api/core/vector_store/qdrant_vector_store.py
0 → 100644
View file @
eea011bd
from
typing
import
cast
,
Any
from
langchain.schema
import
Document
from
langchain.vectorstores
import
Qdrant
from
qdrant_client.http.models
import
Filter
,
PointIdsList
,
FilterSelector
from
qdrant_client.local.qdrant_local
import
QdrantLocal
class
QdrantVectorStore
(
Qdrant
):
def
del_texts
(
self
,
filter
:
Filter
):
if
not
filter
:
raise
ValueError
(
'filter must not be empty'
)
self
.
_reload_if_needed
()
self
.
client
.
delete
(
collection_name
=
self
.
collection_name
,
points_selector
=
FilterSelector
(
filter
=
filter
),
)
def
del_text
(
self
,
uuid
:
str
)
->
None
:
self
.
_reload_if_needed
()
self
.
client
.
delete
(
collection_name
=
self
.
collection_name
,
points_selector
=
PointIdsList
(
points
=
[
uuid
],
),
)
def
text_exists
(
self
,
uuid
:
str
)
->
bool
:
self
.
_reload_if_needed
()
response
=
self
.
client
.
retrieve
(
collection_name
=
self
.
collection_name
,
ids
=
[
uuid
]
)
return
len
(
response
)
>
0
def
delete
(
self
):
self
.
_reload_if_needed
()
self
.
client
.
delete_collection
(
collection_name
=
self
.
collection_name
)
@
classmethod
def
_document_from_scored_point
(
cls
,
scored_point
:
Any
,
content_payload_key
:
str
,
metadata_payload_key
:
str
,
)
->
Document
:
if
scored_point
.
payload
.
get
(
'doc_id'
):
return
Document
(
page_content
=
scored_point
.
payload
.
get
(
content_payload_key
),
metadata
=
{
'doc_id'
:
scored_point
.
id
}
)
return
Document
(
page_content
=
scored_point
.
payload
.
get
(
content_payload_key
),
metadata
=
scored_point
.
payload
.
get
(
metadata_payload_key
)
or
{},
)
def
_reload_if_needed
(
self
):
if
isinstance
(
self
.
client
,
QdrantLocal
):
self
.
client
=
cast
(
QdrantLocal
,
self
.
client
)
self
.
client
.
_load
()
api/core/vector_store/qdrant_vector_store_client.py
deleted
100644 → 0
View file @
3eb8e66b
import
os
from
typing
import
cast
,
List
from
llama_index.data_structs
import
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
llama_index.vector_stores.types
import
VectorStoreQuery
,
VectorStoreQueryResult
from
qdrant_client.http.models
import
Payload
,
Filter
import
qdrant_client
from
llama_index
import
ServiceContext
,
GPTVectorStoreIndex
,
GPTQdrantIndex
from
llama_index.data_structs.data_structs_v2
import
QdrantIndexDict
from
llama_index.vector_stores
import
QdrantVectorStore
from
qdrant_client.local.qdrant_local
import
QdrantLocal
from
core.vector_store.base
import
BaseVectorStoreClient
,
BaseGPTVectorStoreIndex
,
EnhanceVectorStore
class
QdrantVectorStoreClient
(
BaseVectorStoreClient
):
def
__init__
(
self
,
url
:
str
,
api_key
:
str
,
root_path
:
str
):
self
.
_client
=
self
.
init_from_config
(
url
,
api_key
,
root_path
)
@
classmethod
def
init_from_config
(
cls
,
url
:
str
,
api_key
:
str
,
root_path
:
str
):
if
url
and
url
.
startswith
(
'path:'
):
path
=
url
.
replace
(
'path:'
,
''
)
if
not
os
.
path
.
isabs
(
path
):
path
=
os
.
path
.
join
(
root_path
,
path
)
return
qdrant_client
.
QdrantClient
(
path
=
path
)
else
:
return
qdrant_client
.
QdrantClient
(
url
=
url
,
api_key
=
api_key
,
)
def
get_index
(
self
,
service_context
:
ServiceContext
,
config
:
dict
)
->
GPTVectorStoreIndex
:
index_struct
=
QdrantIndexDict
()
if
self
.
_client
is
None
:
raise
Exception
(
"Vector client is not initialized."
)
# {"collection_name": "Gpt_index_xxx"}
collection_name
=
config
.
get
(
'collection_name'
)
if
not
collection_name
:
raise
Exception
(
"collection_name cannot be None."
)
return
GPTQdrantEnhanceIndex
(
service_context
=
service_context
,
index_struct
=
index_struct
,
vector_store
=
QdrantEnhanceVectorStore
(
client
=
self
.
_client
,
collection_name
=
collection_name
)
)
def
to_index_config
(
self
,
index_id
:
str
)
->
dict
:
return
{
"collection_name"
:
index_id
}
class
GPTQdrantEnhanceIndex
(
GPTQdrantIndex
,
BaseGPTVectorStoreIndex
):
pass
class
QdrantEnhanceVectorStore
(
QdrantVectorStore
,
EnhanceVectorStore
):
def
delete_node
(
self
,
node_id
:
str
):
"""
Delete node from the index.
:param node_id: node id
"""
from
qdrant_client.http
import
models
as
rest
self
.
_reload_if_needed
()
self
.
_client
.
delete
(
collection_name
=
self
.
_collection_name
,
points_selector
=
rest
.
Filter
(
must
=
[
rest
.
FieldCondition
(
key
=
"id"
,
match
=
rest
.
MatchValue
(
value
=
node_id
)
)
]
),
)
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
"""
Get node from the index by node id.
:param node_id: node id
"""
self
.
_reload_if_needed
()
response
=
self
.
_client
.
retrieve
(
collection_name
=
self
.
_collection_name
,
ids
=
[
node_id
]
)
return
len
(
response
)
>
0
def
query
(
self
,
query
:
VectorStoreQuery
,
)
->
VectorStoreQueryResult
:
"""Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): query
"""
query_embedding
=
cast
(
List
[
float
],
query
.
query_embedding
)
self
.
_reload_if_needed
()
response
=
self
.
_client
.
search
(
collection_name
=
self
.
_collection_name
,
query_vector
=
query_embedding
,
limit
=
cast
(
int
,
query
.
similarity_top_k
),
query_filter
=
cast
(
Filter
,
self
.
_build_query_filter
(
query
)),
with_vectors
=
True
)
nodes
=
[]
similarities
=
[]
ids
=
[]
for
point
in
response
:
payload
=
cast
(
Payload
,
point
.
payload
)
node
=
Node
(
doc_id
=
str
(
point
.
id
),
text
=
payload
.
get
(
"text"
),
embedding
=
point
.
vector
,
extra_info
=
payload
.
get
(
"extra_info"
),
relationships
=
{
DocumentRelationship
.
SOURCE
:
payload
.
get
(
"doc_id"
,
"None"
),
},
)
nodes
.
append
(
node
)
similarities
.
append
(
point
.
score
)
ids
.
append
(
str
(
point
.
id
))
return
VectorStoreQueryResult
(
nodes
=
nodes
,
similarities
=
similarities
,
ids
=
ids
)
def
_reload_if_needed
(
self
):
if
isinstance
(
self
.
_client
.
_client
,
QdrantLocal
):
self
.
_client
.
_client
.
_load
()
api/core/vector_store/vector_store.py
deleted
100644 → 0
View file @
3eb8e66b
from
flask
import
Flask
from
llama_index
import
ServiceContext
,
GPTVectorStoreIndex
from
requests
import
ReadTimeout
from
tenacity
import
retry
,
retry_if_exception_type
,
stop_after_attempt
from
core.vector_store.qdrant_vector_store_client
import
QdrantVectorStoreClient
from
core.vector_store.weaviate_vector_store_client
import
WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES
=
[
'weaviate'
,
'qdrant'
]
class
VectorStore
:
def
__init__
(
self
):
self
.
_vector_store
=
None
self
.
_client
=
None
def
init_app
(
self
,
app
:
Flask
):
if
not
app
.
config
[
'VECTOR_STORE'
]:
return
self
.
_vector_store
=
app
.
config
[
'VECTOR_STORE'
]
if
self
.
_vector_store
not
in
SUPPORTED_VECTOR_STORES
:
raise
ValueError
(
f
"Vector store {self._vector_store} is not supported."
)
if
self
.
_vector_store
==
'weaviate'
:
self
.
_client
=
WeaviateVectorStoreClient
(
endpoint
=
app
.
config
[
'WEAVIATE_ENDPOINT'
],
api_key
=
app
.
config
[
'WEAVIATE_API_KEY'
],
grpc_enabled
=
app
.
config
[
'WEAVIATE_GRPC_ENABLED'
],
batch_size
=
app
.
config
[
'WEAVIATE_BATCH_SIZE'
]
)
elif
self
.
_vector_store
==
'qdrant'
:
self
.
_client
=
QdrantVectorStoreClient
(
url
=
app
.
config
[
'QDRANT_URL'
],
api_key
=
app
.
config
[
'QDRANT_API_KEY'
],
root_path
=
app
.
root_path
)
app
.
extensions
[
'vector_store'
]
=
self
@
retry
(
reraise
=
True
,
retry
=
retry_if_exception_type
(
ReadTimeout
),
stop
=
stop_after_attempt
(
3
))
def
get_index
(
self
,
service_context
:
ServiceContext
,
index_struct
:
dict
)
->
GPTVectorStoreIndex
:
vector_store_config
:
dict
=
index_struct
.
get
(
'vector_store'
)
index
=
self
.
get_client
()
.
get_index
(
service_context
=
service_context
,
config
=
vector_store_config
)
return
index
def
to_index_struct
(
self
,
index_id
:
str
)
->
dict
:
return
{
"type"
:
self
.
_vector_store
,
"vector_store"
:
self
.
get_client
()
.
to_index_config
(
index_id
)
}
def
get_client
(
self
):
if
not
self
.
_client
:
raise
Exception
(
"Vector store client is not initialized."
)
return
self
.
_client
api/core/vector_store/vector_store_index_query.py
deleted
100644 → 0
View file @
3eb8e66b
from
llama_index.indices.query.base
import
IS
from
typing
import
(
Any
,
Dict
,
List
,
Optional
)
from
llama_index.docstore
import
BaseDocumentStore
from
llama_index.indices.postprocessor.node
import
(
BaseNodePostprocessor
,
)
from
llama_index.indices.vector_store
import
GPTVectorStoreIndexQuery
from
llama_index.indices.response.response_builder
import
ResponseMode
from
llama_index.indices.service_context
import
ServiceContext
from
llama_index.optimization.optimizer
import
BaseTokenUsageOptimizer
from
llama_index.prompts.prompts
import
(
QuestionAnswerPrompt
,
RefinePrompt
,
SimpleInputPrompt
,
)
from
core.index.query.synthesizer
import
EnhanceResponseSynthesizer
class
EnhanceGPTVectorStoreIndexQuery
(
GPTVectorStoreIndexQuery
):
@
classmethod
def
from_args
(
cls
,
index_struct
:
IS
,
service_context
:
ServiceContext
,
docstore
:
Optional
[
BaseDocumentStore
]
=
None
,
node_postprocessors
:
Optional
[
List
[
BaseNodePostprocessor
]]
=
None
,
verbose
:
bool
=
False
,
# response synthesizer args
response_mode
:
ResponseMode
=
ResponseMode
.
DEFAULT
,
text_qa_template
:
Optional
[
QuestionAnswerPrompt
]
=
None
,
refine_template
:
Optional
[
RefinePrompt
]
=
None
,
simple_template
:
Optional
[
SimpleInputPrompt
]
=
None
,
response_kwargs
:
Optional
[
Dict
]
=
None
,
use_async
:
bool
=
False
,
streaming
:
bool
=
False
,
optimizer
:
Optional
[
BaseTokenUsageOptimizer
]
=
None
,
# class-specific args
**
kwargs
:
Any
,
)
->
"BaseGPTIndexQuery"
:
response_synthesizer
=
EnhanceResponseSynthesizer
.
from_args
(
service_context
=
service_context
,
text_qa_template
=
text_qa_template
,
refine_template
=
refine_template
,
simple_template
=
simple_template
,
response_mode
=
response_mode
,
response_kwargs
=
response_kwargs
,
use_async
=
use_async
,
streaming
=
streaming
,
optimizer
=
optimizer
,
)
return
cls
(
index_struct
=
index_struct
,
service_context
=
service_context
,
response_synthesizer
=
response_synthesizer
,
docstore
=
docstore
,
node_postprocessors
=
node_postprocessors
,
verbose
=
verbose
,
**
kwargs
,
)
api/core/vector_store/weaviate_vector_store.py
0 → 100644
View file @
eea011bd
from
langchain.vectorstores
import
Weaviate
class
WeaviateVectorStore
(
Weaviate
):
def
del_texts
(
self
,
where_filter
:
dict
):
if
not
where_filter
:
raise
ValueError
(
'where_filter must not be empty'
)
self
.
_client
.
batch
.
delete_objects
(
class_name
=
self
.
_index_name
,
where
=
where_filter
,
output
=
'minimal'
)
def
del_text
(
self
,
uuid
:
str
)
->
None
:
self
.
_client
.
data_object
.
delete
(
uuid
,
class_name
=
self
.
_index_name
)
def
text_exists
(
self
,
uuid
:
str
)
->
bool
:
result
=
self
.
_client
.
query
.
get
(
self
.
_index_name
)
.
with_additional
([
"id"
])
.
with_where
({
"path"
:
[
"doc_id"
],
"operator"
:
"Equal"
,
"valueText"
:
uuid
,
})
.
with_limit
(
1
)
.
do
()
if
"errors"
in
result
:
raise
ValueError
(
f
"Error during query: {result['errors']}"
)
entries
=
result
[
"data"
][
"Get"
][
self
.
_index_name
]
if
len
(
entries
)
==
0
:
return
False
return
True
def
delete
(
self
):
self
.
_client
.
schema
.
delete_class
(
self
.
_index_name
)
api/core/vector_store/weaviate_vector_store_client.py
deleted
100644 → 0
View file @
3eb8e66b
import
json
import
weaviate
from
dataclasses
import
field
from
typing
import
List
,
Any
,
Dict
,
Optional
from
core.vector_store.base
import
BaseVectorStoreClient
,
BaseGPTVectorStoreIndex
,
EnhanceVectorStore
from
llama_index
import
ServiceContext
,
GPTWeaviateIndex
,
GPTVectorStoreIndex
from
llama_index.data_structs.data_structs_v2
import
WeaviateIndexDict
,
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
llama_index.readers.weaviate.client
import
_class_name
,
NODE_SCHEMA
,
_logger
from
llama_index.vector_stores
import
WeaviateVectorStore
from
llama_index.vector_stores.types
import
VectorStoreQuery
,
VectorStoreQueryResult
,
VectorStoreQueryMode
from
llama_index.readers.weaviate.utils
import
(
parse_get_response
,
validate_client
,
)
class
WeaviateVectorStoreClient
(
BaseVectorStoreClient
):
def
__init__
(
self
,
endpoint
:
str
,
api_key
:
str
,
grpc_enabled
:
bool
,
batch_size
:
int
):
self
.
_client
=
self
.
init_from_config
(
endpoint
,
api_key
,
grpc_enabled
,
batch_size
)
def
init_from_config
(
self
,
endpoint
:
str
,
api_key
:
str
,
grpc_enabled
:
bool
,
batch_size
:
int
):
auth_config
=
weaviate
.
auth
.
AuthApiKey
(
api_key
=
api_key
)
weaviate
.
connect
.
connection
.
has_grpc
=
grpc_enabled
client
=
weaviate
.
Client
(
url
=
endpoint
,
auth_client_secret
=
auth_config
,
timeout_config
=
(
5
,
60
),
startup_period
=
None
)
client
.
batch
.
configure
(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size
=
batch_size
,
# dynamically update the `batch_size` based on import speed
dynamic
=
True
,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries
=
3
,
)
return
client
def
get_index
(
self
,
service_context
:
ServiceContext
,
config
:
dict
)
->
GPTVectorStoreIndex
:
index_struct
=
WeaviateIndexDict
()
if
self
.
_client
is
None
:
raise
Exception
(
"Vector client is not initialized."
)
# {"class_prefix": "Gpt_index_xxx"}
class_prefix
=
config
.
get
(
'class_prefix'
)
if
not
class_prefix
:
raise
Exception
(
"class_prefix cannot be None."
)
return
GPTWeaviateEnhanceIndex
(
service_context
=
service_context
,
index_struct
=
index_struct
,
vector_store
=
WeaviateWithSimilaritiesVectorStore
(
weaviate_client
=
self
.
_client
,
class_prefix
=
class_prefix
)
)
def
to_index_config
(
self
,
index_id
:
str
)
->
dict
:
return
{
"class_prefix"
:
index_id
}
class
WeaviateWithSimilaritiesVectorStore
(
WeaviateVectorStore
,
EnhanceVectorStore
):
def
query
(
self
,
query
:
VectorStoreQuery
)
->
VectorStoreQueryResult
:
"""Query index for top k most similar nodes."""
nodes
=
self
.
weaviate_query
(
self
.
_client
,
self
.
_class_prefix
,
query
,
)
nodes
=
nodes
[:
query
.
similarity_top_k
]
node_idxs
=
[
str
(
i
)
for
i
in
range
(
len
(
nodes
))]
similarities
=
[]
for
node
in
nodes
:
similarities
.
append
(
node
.
extra_info
[
'similarity'
])
del
node
.
extra_info
[
'similarity'
]
return
VectorStoreQueryResult
(
nodes
=
nodes
,
ids
=
node_idxs
,
similarities
=
similarities
)
def
weaviate_query
(
self
,
client
:
Any
,
class_prefix
:
str
,
query_spec
:
VectorStoreQuery
,
)
->
List
[
Node
]:
"""Convert to LlamaIndex list."""
validate_client
(
client
)
class_name
=
_class_name
(
class_prefix
)
prop_names
=
[
p
[
"name"
]
for
p
in
NODE_SCHEMA
]
vector
=
query_spec
.
query_embedding
# build query
query
=
client
.
query
.
get
(
class_name
,
prop_names
)
.
with_additional
([
"id"
,
"vector"
,
"certainty"
])
if
query_spec
.
mode
==
VectorStoreQueryMode
.
DEFAULT
:
_logger
.
debug
(
"Using vector search"
)
if
vector
is
not
None
:
query
=
query
.
with_near_vector
(
{
"vector"
:
vector
,
}
)
elif
query_spec
.
mode
==
VectorStoreQueryMode
.
HYBRID
:
_logger
.
debug
(
f
"Using hybrid search with alpha {query_spec.alpha}"
)
query
=
query
.
with_hybrid
(
query
=
query_spec
.
query_str
,
alpha
=
query_spec
.
alpha
,
vector
=
vector
,
)
query
=
query
.
with_limit
(
query_spec
.
similarity_top_k
)
_logger
.
debug
(
f
"Using limit of {query_spec.similarity_top_k}"
)
# execute query
query_result
=
query
.
do
()
# parse results
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
results
=
[
self
.
_to_node
(
entry
)
for
entry
in
entries
]
return
results
def
_to_node
(
self
,
entry
:
Dict
)
->
Node
:
"""Convert to Node."""
extra_info_str
=
entry
[
"extra_info"
]
if
extra_info_str
==
""
:
extra_info
=
None
else
:
extra_info
=
json
.
loads
(
extra_info_str
)
if
'certainty'
in
entry
[
'_additional'
]:
if
extra_info
:
extra_info
[
'similarity'
]
=
entry
[
'_additional'
][
'certainty'
]
else
:
extra_info
=
{
'similarity'
:
entry
[
'_additional'
][
'certainty'
]}
node_info_str
=
entry
[
"node_info"
]
if
node_info_str
==
""
:
node_info
=
None
else
:
node_info
=
json
.
loads
(
node_info_str
)
relationships_str
=
entry
[
"relationships"
]
relationships
:
Dict
[
DocumentRelationship
,
str
]
if
relationships_str
==
""
:
relationships
=
field
(
default_factory
=
dict
)
else
:
relationships
=
{
DocumentRelationship
(
k
):
v
for
k
,
v
in
json
.
loads
(
relationships_str
)
.
items
()
}
return
Node
(
text
=
entry
[
"text"
],
doc_id
=
entry
[
"doc_id"
],
embedding
=
entry
[
"_additional"
][
"vector"
],
extra_info
=
extra_info
,
node_info
=
node_info
,
relationships
=
relationships
,
)
def
delete
(
self
,
doc_id
:
str
,
**
delete_kwargs
:
Any
)
->
None
:
"""Delete a document.
Args:
doc_id (str): document id
"""
delete_document
(
self
.
_client
,
doc_id
,
self
.
_class_prefix
)
def
delete_node
(
self
,
node_id
:
str
):
"""
Delete node from the index.
:param node_id: node id
"""
delete_node
(
self
.
_client
,
node_id
,
self
.
_class_prefix
)
def
exists_by_node_id
(
self
,
node_id
:
str
)
->
bool
:
"""
Get node from the index by node id.
:param node_id: node id
"""
entry
=
get_by_node_id
(
self
.
_client
,
node_id
,
self
.
_class_prefix
)
return
True
if
entry
else
False
class
GPTWeaviateEnhanceIndex
(
GPTWeaviateIndex
,
BaseGPTVectorStoreIndex
):
pass
def
delete_document
(
client
:
Any
,
ref_doc_id
:
str
,
class_prefix
:
str
)
->
None
:
"""Delete entry."""
validate_client
(
client
)
# make sure that each entry
class_name
=
_class_name
(
class_prefix
)
where_filter
=
{
"path"
:
[
"ref_doc_id"
],
"operator"
:
"Equal"
,
"valueString"
:
ref_doc_id
,
}
query
=
(
client
.
query
.
get
(
class_name
)
.
with_additional
([
"id"
])
.
with_where
(
where_filter
)
)
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
for
entry
in
entries
:
client
.
data_object
.
delete
(
entry
[
"_additional"
][
"id"
],
class_name
)
while
len
(
entries
)
>
0
:
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
for
entry
in
entries
:
client
.
data_object
.
delete
(
entry
[
"_additional"
][
"id"
],
class_name
)
def
delete_node
(
client
:
Any
,
node_id
:
str
,
class_prefix
:
str
)
->
None
:
"""Delete entry."""
validate_client
(
client
)
# make sure that each entry
class_name
=
_class_name
(
class_prefix
)
where_filter
=
{
"path"
:
[
"doc_id"
],
"operator"
:
"Equal"
,
"valueString"
:
node_id
,
}
query
=
(
client
.
query
.
get
(
class_name
)
.
with_additional
([
"id"
])
.
with_where
(
where_filter
)
)
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
for
entry
in
entries
:
client
.
data_object
.
delete
(
entry
[
"_additional"
][
"id"
],
class_name
)
def
get_by_node_id
(
client
:
Any
,
node_id
:
str
,
class_prefix
:
str
)
->
Optional
[
Dict
]:
"""Delete entry."""
validate_client
(
client
)
# make sure that each entry
class_name
=
_class_name
(
class_prefix
)
where_filter
=
{
"path"
:
[
"doc_id"
],
"operator"
:
"Equal"
,
"valueString"
:
node_id
,
}
query
=
(
client
.
query
.
get
(
class_name
)
.
with_additional
([
"id"
])
.
with_where
(
where_filter
)
)
query_result
=
query
.
do
()
parsed_result
=
parse_get_response
(
query_result
)
entries
=
parsed_result
[
class_name
]
if
len
(
entries
)
==
0
:
return
None
return
entries
[
0
]
api/extensions/ext_vector_store.py
deleted
100644 → 0
View file @
3eb8e66b
from
core.vector_store.vector_store
import
VectorStore
vector_store
=
VectorStore
()
def
init_app
(
app
):
vector_store
.
init_app
(
app
)
api/libs/helper.py
View file @
eea011bd
...
...
@@ -3,6 +3,7 @@ import re
import
subprocess
import
uuid
from
datetime
import
datetime
from
hashlib
import
sha256
from
zoneinfo
import
available_timezones
import
random
import
string
...
...
@@ -147,3 +148,8 @@ def get_remote_ip(request):
return
request
.
headers
.
getlist
(
"X-Forwarded-For"
)[
0
]
else
:
return
request
.
remote_addr
def
generate_text_hash
(
text
:
str
)
->
str
:
hash_text
=
str
(
text
)
+
'None'
return
sha256
(
hash_text
.
encode
())
.
hexdigest
()
api/models/account.py
View file @
eea011bd
...
...
@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model):
created_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
updated_at
=
db
.
Column
(
db
.
DateTime
,
nullable
=
False
,
server_default
=
db
.
text
(
'CURRENT_TIMESTAMP(0)'
))
_current_tenant
:
db
.
Model
=
None
@
property
def
current_tenant
(
self
):
return
self
.
_current_tenant
...
...
api/models/dataset.py
View file @
eea011bd
...
...
@@ -66,6 +66,23 @@ class Dataset(db.Model):
def
document_count
(
self
):
return
db
.
session
.
query
(
func
.
count
(
Document
.
id
))
.
filter
(
Document
.
dataset_id
==
self
.
id
)
.
scalar
()
@
property
def
available_document_count
(
self
):
return
db
.
session
.
query
(
func
.
count
(
Document
.
id
))
.
filter
(
Document
.
dataset_id
==
self
.
id
,
Document
.
indexing_status
==
'completed'
,
Document
.
enabled
==
True
,
Document
.
archived
==
False
)
.
scalar
()
@
property
def
available_segment_count
(
self
):
return
db
.
session
.
query
(
func
.
count
(
DocumentSegment
.
id
))
.
filter
(
DocumentSegment
.
dataset_id
==
self
.
id
,
DocumentSegment
.
status
==
'completed'
,
DocumentSegment
.
enabled
==
True
)
.
scalar
()
@
property
def
word_count
(
self
):
return
Document
.
query
.
with_entities
(
func
.
coalesce
(
func
.
sum
(
Document
.
word_count
)))
\
...
...
@@ -260,7 +277,7 @@ class Document(db.Model):
@
property
def
dataset
(
self
):
return
Dataset
.
query
.
get
(
self
.
dataset_id
)
return
db
.
session
.
query
(
Dataset
)
.
filter
(
Dataset
.
id
==
self
.
dataset_id
)
.
one_or_none
(
)
@
property
def
segment_count
(
self
):
...
...
@@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model):
@
property
def
keyword_table_dict
(
self
):
return
json
.
loads
(
self
.
keyword_table
)
if
self
.
keyword_table
else
None
class
SetDecoder
(
json
.
JSONDecoder
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
()
.
__init__
(
object_hook
=
self
.
object_hook
,
*
args
,
**
kwargs
)
def
object_hook
(
self
,
dct
):
if
isinstance
(
dct
,
dict
):
for
keyword
,
node_idxs
in
dct
.
items
():
if
isinstance
(
node_idxs
,
list
):
dct
[
keyword
]
=
set
(
node_idxs
)
return
dct
return
json
.
loads
(
self
.
keyword_table
,
cls
=
SetDecoder
)
if
self
.
keyword_table
else
None
class
Embedding
(
db
.
Model
):
...
...
api/requirements.txt
View file @
eea011bd
...
...
@@ -2,6 +2,7 @@ coverage~=7.2.4
beautifulsoup4==4.12.2
flask~=2.3.2
Flask-SQLAlchemy~=3.0.3
SQLAlchemy~=1.4.28
flask-login==0.6.2
flask-migrate~=4.0.4
flask-restful==0.3.9
...
...
@@ -9,8 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10
gunicorn~=20.1.0
gevent~=22.10.2
langchain==0.0.142
llama-index==0.5.27
langchain==0.0.209
openai~=0.27.5
psycopg2-binary~=2.9.6
pycryptodome==3.17
...
...
@@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1
jieba==0.42.1
celery==5.2.7
redis~=4.5.4
pypdf==3.8.1
openpyxl==3.1.2
chardet~=5.1.0
\ No newline at end of file
chardet~=5.1.0
docx2txt==0.8
pypdfium2==4.16.0
\ No newline at end of file
api/services/app_model_config_service.py
View file @
eea011bd
...
...
@@ -4,7 +4,6 @@ import uuid
from
core.constant
import
llm_constant
from
models.account
import
Account
from
services.dataset_service
import
DatasetService
from
services.errors.account
import
NoPermissionError
class
AppModelConfigService
:
...
...
api/services/dataset_service.py
View file @
eea011bd
...
...
@@ -7,7 +7,6 @@ from typing import Optional, List
from
extensions.ext_redis
import
redis_client
from
flask_login
import
current_user
from
core.index.index_builder
import
IndexBuilder
from
events.dataset_event
import
dataset_was_deleted
from
events.document_event
import
document_was_deleted
from
extensions.ext_database
import
db
...
...
@@ -386,8 +385,6 @@ class DocumentService:
dataset
.
indexing_technique
=
document_data
[
"indexing_technique"
]
if
dataset
.
indexing_technique
==
'high_quality'
:
IndexBuilder
.
get_default_service_context
(
dataset
.
tenant_id
)
documents
=
[]
batch
=
time
.
strftime
(
'
%
Y
%
m
%
d
%
H
%
M
%
S'
)
+
str
(
random
.
randint
(
100000
,
999999
))
if
'original_document_id'
in
document_data
and
document_data
[
"original_document_id"
]:
...
...
api/services/hit_testing_service.py
View file @
eea011bd
...
...
@@ -3,47 +3,56 @@ import time
from
typing
import
List
import
numpy
as
np
from
llama_index.data_structs.node_v2
import
NodeWithScore
from
llama_index.indices.query.schema
import
QueryBundle
from
llama_index.indices.vector_store
import
GPTVectorStoreIndexQuery
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
from
sklearn.manifold
import
TSNE
from
core.docstore.empty_docstore
import
EmptyDocumentStore
from
core.index.vector_index
import
VectorIndex
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
extensions.ext_database
import
db
from
models.account
import
Account
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetQuery
from
services.errors.index
import
IndexNotInitializedError
class
HitTestingService
:
@
classmethod
def
retrieve
(
cls
,
dataset
:
Dataset
,
query
:
str
,
account
:
Account
,
limit
:
int
=
10
)
->
dict
:
index
=
VectorIndex
(
dataset
=
dataset
)
.
query_index
if
not
index
:
raise
IndexNotInitializedError
()
index_query
=
GPTVectorStoreIndexQuery
(
index_struct
=
index
.
index_struct
,
service_context
=
index
.
service_context
,
vector_store
=
index
.
query_context
.
get
(
'vector_store'
),
docstore
=
EmptyDocumentStore
(),
response_synthesizer
=
None
,
similarity_top_k
=
limit
)
if
dataset
.
available_document_count
==
0
or
dataset
.
available_document_count
==
0
:
return
{
"query"
:
{
"content"
:
query
,
"tsne_position"
:
{
'x'
:
0
,
'y'
:
0
},
},
"records"
:
[]
}
query_bundle
=
QueryBundle
(
query_str
=
query
,
custom_embedding_strs
=
[
query
],
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
query_bundle
.
embedding
=
index
.
service_context
.
embed_model
.
get_agg_embedding_from_queries
(
query_bundle
.
embedding_strs
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
vector_index
=
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
start
=
time
.
perf_counter
()
nodes
=
index_query
.
retrieve
(
query_bundle
=
query_bundle
)
documents
=
vector_index
.
search
(
query
,
search_type
=
'similarity_score_threshold'
,
search_kwargs
=
{
'k'
:
10
}
)
end
=
time
.
perf_counter
()
logging
.
debug
(
f
"Hit testing retrieve in {end - start:0.4f} seconds"
)
...
...
@@ -58,25 +67,24 @@ class HitTestingService:
db
.
session
.
add
(
dataset_query
)
db
.
session
.
commit
()
return
cls
.
compact_retrieve_response
(
dataset
,
query_bundle
,
node
s
)
return
cls
.
compact_retrieve_response
(
dataset
,
embeddings
,
query
,
document
s
)
@
classmethod
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
query_bundle
:
QueryBundle
,
nodes
:
List
[
NodeWithScore
]):
embeddings
=
[
query_bundle
.
embedding
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
embeddings
:
Embeddings
,
query
:
str
,
documents
:
List
[
Document
]):
text_
embeddings
=
[
embeddings
.
embed_query
(
query
)
]
for
node
in
nodes
:
embeddings
.
append
(
node
.
node
.
embedding
)
text_embeddings
.
extend
(
embeddings
.
embed_documents
([
document
.
page_content
for
document
in
documents
]))
tsne_position_data
=
cls
.
get_tsne_positions_from_embeddings
(
embeddings
)
tsne_position_data
=
cls
.
get_tsne_positions_from_embeddings
(
text_
embeddings
)
query_position
=
tsne_position_data
.
pop
(
0
)
i
=
0
records
=
[]
for
node
in
node
s
:
index_node_id
=
node
.
node
.
doc_id
for
document
in
document
s
:
index_node_id
=
document
.
metadata
[
'doc_id'
]
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset
.
id
,
...
...
@@ -91,7 +99,7 @@ class HitTestingService:
record
=
{
"segment"
:
segment
,
"score"
:
node
.
score
,
"score"
:
document
.
metadata
[
'score'
]
,
"tsne_position"
:
tsne_position_data
[
i
]
}
...
...
@@ -101,7 +109,7 @@ class HitTestingService:
return
{
"query"
:
{
"content"
:
query
_bundle
.
query_str
,
"content"
:
query
,
"tsne_position"
:
query_position
,
},
"records"
:
records
...
...
api/tasks/add_document_to_index_task.py
View file @
eea011bd
...
...
@@ -4,96 +4,81 @@ import time
import
click
from
celery
import
shared_task
from
llama_index.data_structs
import
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
langchain.schema
import
Document
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
,
Document
from
models.dataset
import
DocumentSegment
from
models.dataset
import
Document
as
DatasetDocument
@
shared_task
def
add_document_to_index_task
(
document_id
:
str
):
def
add_document_to_index_task
(
d
ataset_d
ocument_id
:
str
):
"""
Async Add document to index
:param document_id:
Usage: add_document_to_index.delay(document_id)
"""
logging
.
info
(
click
.
style
(
'Start add document to index: {}'
.
format
(
document_id
),
fg
=
'green'
))
logging
.
info
(
click
.
style
(
'Start add document to index: {}'
.
format
(
d
ataset_d
ocument_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
d
ocument
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
)
.
first
()
if
not
document
:
d
ataset_document
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
DatasetDocument
.
id
==
dataset_
document_id
)
.
first
()
if
not
d
ataset_d
ocument
:
raise
NotFound
(
'Document not found'
)
if
document
.
indexing_status
!=
'completed'
:
if
d
ataset_d
ocument
.
indexing_status
!=
'completed'
:
return
indexing_cache_key
=
'document_{}_indexing'
.
format
(
document
.
id
)
indexing_cache_key
=
'document_{}_indexing'
.
format
(
d
ataset_d
ocument
.
id
)
try
:
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
document_id
==
d
ataset_d
ocument
.
id
,
DocumentSegment
.
enabled
==
True
)
\
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
nodes
=
[]
previous_node
=
None
documents
=
[]
for
segment
in
segments
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
document
.
id
}
if
previous_node
:
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_node
.
doc_id
previous_node
.
relationships
[
DocumentRelationship
.
NEXT
]
=
segment
.
index_node_id
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
previous_node
=
node
documents
.
append
(
document
)
nodes
.
append
(
node
)
dataset
=
document
.
dataset
dataset
=
dataset_document
.
dataset
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
vector_index
.
add_nodes
(
nodes
=
nodes
,
duplicate_check
=
True
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
:
index
.
add_texts
(
documents
)
# save keyword index
keyword_table_index
.
add_nodes
(
nodes
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
(
documents
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Document added to index: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
click
.
style
(
'Document added to index: {} latency: {}'
.
format
(
d
ataset_d
ocument
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
as
e
:
logging
.
exception
(
"add document to index failed"
)
document
.
enabled
=
False
document
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
document
.
status
=
'error'
document
.
error
=
str
(
e
)
d
ataset_d
ocument
.
enabled
=
False
d
ataset_d
ocument
.
disabled_at
=
datetime
.
datetime
.
utcnow
()
d
ataset_d
ocument
.
status
=
'error'
d
ataset_d
ocument
.
error
=
str
(
e
)
db
.
session
.
commit
()
finally
:
redis_client
.
delete
(
indexing_cache_key
)
api/tasks/add_segment_to_index_task.py
View file @
eea011bd
...
...
@@ -4,12 +4,10 @@ import time
import
click
from
celery
import
shared_task
from
llama_index.data_structs
import
Node
from
llama_index.data_structs.node_v2
import
DocumentRelationship
from
langchain.schema
import
Document
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
...
...
@@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str):
indexing_cache_key
=
'segment_{}_indexing'
.
format
(
segment
.
id
)
try
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
segment
.
document_id
,
}
previous_segment
=
segment
.
previous_segment
if
previous_segment
:
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_segment
.
index_node_id
next_segment
=
segment
.
next_segment
if
next_segment
:
relationships
[
DocumentRelationship
.
NEXT
]
=
next_segment
.
index_node_id
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
dataset
=
segment
.
dataset
if
not
dataset
:
raise
Exception
(
'Segment has no dataset'
)
logging
.
info
(
click
.
style
(
'Segment {} has no dataset, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
dataset_document
=
segment
.
document
if
not
dataset_document
:
logging
.
info
(
click
.
style
(
'Segment {} has no document, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
if
not
dataset_document
.
enabled
or
dataset_document
.
archived
or
dataset_document
.
indexing_status
!=
'completed'
:
logging
.
info
(
click
.
style
(
'Segment {} document status is invalid, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
# save vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
vector_index
.
add_nodes
(
nodes
=
[
node
],
duplicate_check
=
True
)
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
if
index
:
index
.
add_texts
([
document
],
duplicate_check
=
True
)
# save keyword index
keyword_table_index
.
add_nodes
([
node
])
index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
if
index
:
index
.
add_texts
([
document
])
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment added to index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
api/tasks/clean_dataset_task.py
View file @
eea011bd
...
...
@@ -4,8 +4,7 @@ import time
import
click
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
,
DatasetKeywordTable
,
DatasetQuery
,
DatasetProcessRule
,
\
AppDatasetJoin
...
...
@@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct
=
index_struct
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
documents
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
index_doc_ids
=
[
document
.
id
for
document
in
documents
]
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
for
index_doc_id
in
index_doc_ids
:
try
:
vector_index
.
del_doc
(
index_doc_id
)
except
Exception
:
logging
.
exception
(
"Delete doc index failed when dataset deleted."
)
continue
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
kw_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from
keyword
index
if
index_node_ids
:
# delete from
vector
index
if
vector_index
:
try
:
keyword_table_index
.
del_nodes
(
index_node_ids
)
vector_index
.
delete
(
)
except
Exception
:
logging
.
exception
(
"Delete nodes index failed when dataset deleted."
)
logging
.
exception
(
"Delete doc index failed when dataset deleted."
)
# delete from keyword index
try
:
kw_index
.
delete
()
except
Exception
:
logging
.
exception
(
"Delete nodes index failed when dataset deleted."
)
for
document
in
documents
:
db
.
session
.
delete
(
document
)
...
...
@@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
query
(
DatasetKeywordTable
)
.
filter
(
DatasetKeywordTable
.
dataset_id
==
dataset_id
)
.
delete
()
db
.
session
.
query
(
DatasetProcessRule
)
.
filter
(
DatasetProcessRule
.
dataset_id
==
dataset_id
)
.
delete
()
db
.
session
.
query
(
DatasetQuery
)
.
filter
(
DatasetQuery
.
dataset_id
==
dataset_id
)
.
delete
()
db
.
session
.
query
(
AppDatasetJoin
)
.
filter
(
AppDatasetJoin
.
dataset_id
==
dataset_id
)
.
delete
()
...
...
api/tasks/clean_document_task.py
View file @
eea011bd
...
...
@@ -4,8 +4,7 @@ import time
import
click
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
...
...
@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
db
.
session
.
commit
()
end_at
=
time
.
perf_counter
()
logging
.
info
(
...
...
api/tasks/clean_notion_document_task.py
View file @
eea011bd
...
...
@@ -5,8 +5,7 @@ from typing import List
import
click
from
celery
import
shared_task
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Dataset
,
Document
...
...
@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
for
document_id
in
document_ids
:
document
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
)
.
first
()
db
.
session
.
delete
(
document
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
...
...
api/tasks/deal_dataset_vector_index_task.py
View file @
eea011bd
...
...
@@ -3,10 +3,12 @@ import time
import
click
from
celery
import
shared_task
from
llama_index.data_structs.node_v2
import
DocumentRelationship
,
Node
from
core.index.vector_index
import
VectorIndex
from
langchain.schema
import
Document
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
,
Document
,
Dataset
from
models.dataset
import
DocumentSegment
,
Dataset
from
models.dataset
import
Document
as
DatasetDocument
@
shared_task
...
...
@@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
dataset
=
Dataset
.
query
.
filter_by
(
id
=
dataset_id
)
.
first
()
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
documents
=
Document
.
query
.
filter_by
(
dataset_id
=
dataset_id
)
.
all
()
if
documents
:
vector_index
=
VectorIndex
(
dataset
=
dataset
)
for
document
in
documents
:
# delete from vector index
if
action
==
"remove"
:
vector_index
.
del_doc
(
document
.
id
)
elif
action
==
"add"
:
if
action
==
"remove"
:
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
,
ignore_high_quality_check
=
True
)
index
.
delete
()
elif
action
==
"add"
:
dataset_documents
=
db
.
session
.
query
(
DatasetDocument
)
.
filter
(
DatasetDocument
.
dataset_id
==
dataset_id
,
DatasetDocument
.
indexing_status
==
'completed'
,
DatasetDocument
.
enabled
==
True
,
DatasetDocument
.
archived
==
False
,
)
.
all
()
if
dataset_documents
:
# save vector index
index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
,
ignore_high_quality_check
=
True
)
for
dataset_document
in
dataset_documents
:
# delete from vector index
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
,
DocumentSegment
.
document_id
==
d
ataset_d
ocument
.
id
,
DocumentSegment
.
enabled
==
True
)
.
order_by
(
DocumentSegment
.
position
.
asc
())
.
all
()
nodes
=
[]
previous_node
=
None
documents
=
[]
for
segment
in
segments
:
relationships
=
{
DocumentRelationship
.
SOURCE
:
document
.
id
}
if
previous_node
:
relationships
[
DocumentRelationship
.
PREVIOUS
]
=
previous_node
.
doc_id
previous_node
.
relationships
[
DocumentRelationship
.
NEXT
]
=
segment
.
index_node_id
node
=
Node
(
doc_id
=
segment
.
index_node_id
,
doc_hash
=
segment
.
index_node_hash
,
text
=
segment
.
content
,
extra_info
=
None
,
node_info
=
None
,
relationships
=
relationships
document
=
Document
(
page_content
=
segment
.
content
,
metadata
=
{
"doc_id"
:
segment
.
index_node_id
,
"doc_hash"
:
segment
.
index_node_hash
,
"document_id"
:
segment
.
document_id
,
"dataset_id"
:
segment
.
dataset_id
,
}
)
previous_node
=
node
nodes
.
append
(
node
)
documents
.
append
(
document
)
# save vector index
vector_index
.
add_nodes
(
nodes
=
nodes
,
duplicate_check
=
True
)
index
.
add_texts
(
documents
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
...
...
api/tasks/document_indexing_sync_task.py
View file @
eea011bd
...
...
@@ -6,11 +6,9 @@ import click
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
core.data_source.notion
import
NotionPageReader
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.data_loader.loader.notion
import
NotionLoader
from
core.index.index
import
IndexBuilder
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
from
models.source
import
DataSourceBinding
...
...
@@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
raise
ValueError
(
"no notion page found"
)
workspace_id
=
data_source_info
[
'notion_workspace_id'
]
page_id
=
data_source_info
[
'notion_page_id'
]
page_type
=
data_source_info
[
'type'
]
page_edited_time
=
data_source_info
[
'last_edited_time'
]
data_source_binding
=
DataSourceBinding
.
query
.
filter
(
db
.
and_
(
...
...
@@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
)
.
first
()
if
not
data_source_binding
:
raise
ValueError
(
'Data source binding not found.'
)
reader
=
NotionPageReader
(
integration_token
=
data_source_binding
.
access_token
)
last_edited_time
=
reader
.
get_page_last_edited_time
(
page_id
)
loader
=
NotionLoader
(
notion_access_token
=
data_source_binding
.
access_token
,
notion_workspace_id
=
workspace_id
,
notion_obj_id
=
page_id
,
notion_page_type
=
page_type
)
last_edited_time
=
loader
.
get_notion_last_edited_time
()
# check the page is updated
if
last_edited_time
!=
page_edited_time
:
document
.
indexing_status
=
'parsing'
...
...
@@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_document_id
(
document_id
)
# delete from keyword index
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
...
...
@@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
:
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
try
:
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
([
document
])
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
logging
.
info
(
click
.
style
(
'Document update paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume update document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
:
pass
api/tasks/document_indexing_task.py
View file @
eea011bd
...
...
@@ -7,7 +7,6 @@ from celery import shared_task
from
werkzeug.exceptions
import
NotFound
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
models.dataset
import
Document
...
...
@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
Usage: document_indexing_task.delay(dataset_id, document_id)
"""
documents
=
[]
start_at
=
time
.
perf_counter
()
for
document_id
in
document_ids
:
logging
.
info
(
click
.
style
(
'Start process document: {}'
.
format
(
document_id
),
fg
=
'green'
))
start_at
=
time
.
perf_counter
()
document
=
db
.
session
.
query
(
Document
)
.
filter
(
Document
.
id
==
document_id
,
...
...
@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
(
documents
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
logging
.
info
(
click
.
style
(
'Document paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
logging
.
info
(
click
.
style
(
'Processed dataset: {} latency: {}'
.
format
(
dataset_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
:
pass
api/tasks/document_indexing_update_task.py
View file @
eea011bd
...
...
@@ -6,10 +6,8 @@ import click
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
core.indexing_runner
import
IndexingRunner
,
DocumentIsPausedException
from
core.llm.error
import
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
models.dataset
import
Document
,
Dataset
,
DocumentSegment
...
...
@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
if
not
dataset
:
raise
Exception
(
'Dataset not found'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document_id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
# delete from vector index
vector_index
.
del_nodes
(
index_node_ids
)
if
vector_index
:
vector_index
.
delete_by_ids
(
index_node_ids
)
# delete from keyword index
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
for
segment
in
segments
:
db
.
session
.
delete
(
segment
)
...
...
@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
click
.
style
(
'Cleaned document when document update data source or process rule: {} latency: {}'
.
format
(
document_id
,
end_at
-
start_at
),
fg
=
'green'
))
except
Exception
:
logging
.
exception
(
"Cleaned document when document update data source or process rule failed"
)
try
:
indexing_runner
=
IndexingRunner
()
indexing_runner
.
run
([
document
])
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'update document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
logging
.
info
(
click
.
style
(
'Document update paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
except
ProviderTokenNotInitError
as
e
:
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
.
description
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
Exception
as
e
:
logging
.
exception
(
"consume update document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
:
pass
api/tasks/recover_document_indexing_task.py
View file @
eea011bd
import
datetime
import
logging
import
time
...
...
@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner
.
run_in_indexing_status
(
document
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Processed document: {} latency: {}'
.
format
(
document
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
except
DocumentIsPausedException
:
logging
.
info
(
click
.
style
(
'Document paused, document id: {}'
.
format
(
document
.
id
),
fg
=
'yellow'
))
except
Exception
as
e
:
logging
.
exception
(
"consume document failed"
)
document
.
indexing_status
=
'error'
document
.
error
=
str
(
e
)
document
.
stopped_at
=
datetime
.
datetime
.
utcnow
()
db
.
session
.
commit
()
except
DocumentIsPausedException
as
ex
:
logging
.
info
(
click
.
style
(
str
(
ex
),
fg
=
'yellow'
))
except
Exception
:
pass
api/tasks/remove_document_from_index_task.py
View file @
eea011bd
...
...
@@ -5,8 +5,7 @@ import click
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
,
Document
...
...
@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
if
not
dataset
:
raise
Exception
(
'Document has no dataset'
)
vector_index
=
VectorIndex
(
dataset
=
dataset
)
k
eyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
k
w_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
vector_index
.
del
_doc
(
document
.
id
)
vector_index
.
del
ete_by_document_id
(
document
.
id
)
# delete from keyword index
segments
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
document_id
==
document
.
id
)
.
all
()
index_node_ids
=
[
segment
.
index_node_id
for
segment
in
segments
]
if
index_node_ids
:
k
eyword_table_index
.
del_node
s
(
index_node_ids
)
k
w_index
.
delete_by_id
s
(
index_node_ids
)
end_at
=
time
.
perf_counter
()
logging
.
info
(
...
...
api/tasks/remove_segment_from_index_task.py
View file @
eea011bd
...
...
@@ -5,8 +5,7 @@ import click
from
celery
import
shared_task
from
werkzeug.exceptions
import
NotFound
from
core.index.keyword_table_index
import
KeywordTableIndex
from
core.index.vector_index
import
VectorIndex
from
core.index.index
import
IndexBuilder
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
models.dataset
import
DocumentSegment
...
...
@@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str):
dataset
=
segment
.
dataset
if
not
dataset
:
raise
Exception
(
'Segment has no dataset'
)
logging
.
info
(
click
.
style
(
'Segment {} has no dataset, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
vector_index
=
VectorIndex
(
dataset
=
dataset
)
keyword_table_index
=
KeywordTableIndex
(
dataset
=
dataset
)
dataset_document
=
segment
.
document
if
not
dataset_document
:
logging
.
info
(
click
.
style
(
'Segment {} has no document, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
if
not
dataset_document
.
enabled
or
dataset_document
.
archived
or
dataset_document
.
indexing_status
!=
'completed'
:
logging
.
info
(
click
.
style
(
'Segment {} document status is invalid, pass.'
.
format
(
segment
.
id
),
fg
=
'cyan'
))
return
vector_index
=
IndexBuilder
.
get_index
(
dataset
,
'high_quality'
)
kw_index
=
IndexBuilder
.
get_index
(
dataset
,
'economy'
)
# delete from vector index
if
dataset
.
indexing_technique
==
"high_quality"
:
vector_index
.
del
_node
s
([
segment
.
index_node_id
])
if
vector_index
:
vector_index
.
del
ete_by_id
s
([
segment
.
index_node_id
])
# delete from keyword index
k
eyword_table_index
.
del_node
s
([
segment
.
index_node_id
])
k
w_index
.
delete_by_id
s
([
segment
.
index_node_id
])
end_at
=
time
.
perf_counter
()
logging
.
info
(
click
.
style
(
'Segment removed from index: {} latency: {}'
.
format
(
segment
.
id
,
end_at
-
start_at
),
fg
=
'green'
))
...
...
sdks/python-client/dify_client/client.py
View file @
eea011bd
...
...
@@ -65,8 +65,8 @@ class ChatClient(DifyClient):
return
self
.
_send_request
(
"GET"
,
"/messages"
,
params
=
params
)
def
get_conversations
(
self
,
user
,
fir
st_id
=
None
,
limit
=
None
,
pinned
=
None
):
params
=
{
"user"
:
user
,
"
first_id"
:
fir
st_id
,
"limit"
:
limit
,
"pinned"
:
pinned
}
def
get_conversations
(
self
,
user
,
la
st_id
=
None
,
limit
=
None
,
pinned
=
None
):
params
=
{
"user"
:
user
,
"
last_id"
:
la
st_id
,
"limit"
:
limit
,
"pinned"
:
pinned
}
return
self
.
_send_request
(
"GET"
,
"/conversations"
,
params
=
params
)
def
rename_conversation
(
self
,
conversation_id
,
name
,
user
):
...
...
sdks/python-client/setup.py
View file @
eea011bd
...
...
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
setup
(
name
=
"dify-client"
,
version
=
"0.1.
7
"
,
version
=
"0.1.
8
"
,
author
=
"Dify"
,
author_email
=
"hello@dify.ai"
,
description
=
"A package for interacting with the Dify Service-API"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment