Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
0d82aa8f
Commit
0d82aa8f
authored
Jun 19, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: use callbacks instead of callback manager
parent
9e9d15ec
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
144 additions
and
233 deletions
+144
-233
__init__.py
api/core/__init__.py
+0
-2
agent_builder.py
api/core/agent/agent_builder.py
+8
-11
index_tool_callback_handler.py
api/core/callback_handler/index_tool_callback_handler.py
+0
-2
chain_builder.py
api/core/chain/chain_builder.py
+2
-4
llm_router_chain.py
api/core/chain/llm_router_chain.py
+2
-2
main_chain_builder.py
api/core/chain/main_chain_builder.py
+4
-7
multi_dataset_router_chain.py
api/core/chain/multi_dataset_router_chain.py
+3
-6
completion.py
api/core/completion.py
+11
-12
empty_docstore.py
api/core/docstore/empty_docstore.py
+0
-51
base.py
api/core/index/vector_index/base.py
+68
-3
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+3
-45
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+3
-45
llm_builder.py
api/core/llm/llm_builder.py
+5
-5
llama_index_tool.py
api/core/tool/llama_index_tool.py
+0
-2
hit_testing_service.py
api/services/hit_testing_service.py
+35
-36
No files found.
api/core/__init__.py
View file @
0d82aa8f
...
@@ -3,7 +3,6 @@ from typing import Optional
...
@@ -3,7 +3,6 @@ from typing import Optional
import
langchain
import
langchain
from
flask
import
Flask
from
flask
import
Flask
from
langchain
import
set_handler
from
langchain.prompts.base
import
DEFAULT_FORMATTER_MAPPING
from
langchain.prompts.base
import
DEFAULT_FORMATTER_MAPPING
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
...
@@ -28,7 +27,6 @@ def init_app(app: Flask):
...
@@ -28,7 +27,6 @@ def init_app(app: Flask):
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
langchain
.
verbose
=
True
langchain
.
verbose
=
True
set_handler
(
DifyStdOutCallbackHandler
())
if
app
.
config
.
get
(
"OPENAI_API_KEY"
):
if
app
.
config
.
get
(
"OPENAI_API_KEY"
):
hosted_llm_credentials
.
openai
=
HostedOpenAICredential
(
api_key
=
app
.
config
.
get
(
"OPENAI_API_KEY"
))
hosted_llm_credentials
.
openai
=
HostedOpenAICredential
(
api_key
=
app
.
config
.
get
(
"OPENAI_API_KEY"
))
api/core/agent/agent_builder.py
View file @
0d82aa8f
...
@@ -2,7 +2,7 @@ from typing import Optional
...
@@ -2,7 +2,7 @@ from typing import Optional
from
langchain
import
LLMChain
from
langchain
import
LLMChain
from
langchain.agents
import
ZeroShotAgent
,
AgentExecutor
,
ConversationalAgent
from
langchain.agents
import
ZeroShotAgent
,
AgentExecutor
,
ConversationalAgent
from
langchain.callbacks
import
CallbackManager
from
langchain.callbacks
.manager
import
CallbackManager
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
...
@@ -16,23 +16,20 @@ class AgentBuilder:
...
@@ -16,23 +16,20 @@ class AgentBuilder:
def
to_agent_chain
(
cls
,
tenant_id
:
str
,
tools
,
memory
:
Optional
[
BaseChatMemory
],
def
to_agent_chain
(
cls
,
tenant_id
:
str
,
tools
,
memory
:
Optional
[
BaseChatMemory
],
dataset_tool_callback_handler
:
DatasetToolCallbackHandler
,
dataset_tool_callback_handler
:
DatasetToolCallbackHandler
,
agent_loop_gather_callback_handler
:
AgentLoopGatherCallbackHandler
):
agent_loop_gather_callback_handler
:
AgentLoopGatherCallbackHandler
):
llm_callback_manager
=
CallbackManager
([
agent_loop_gather_callback_handler
,
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
model_name
=
agent_loop_gather_callback_handler
.
model_name
,
model_name
=
agent_loop_gather_callback_handler
.
model_name
,
temperature
=
0
,
temperature
=
0
,
max_tokens
=
1024
,
max_tokens
=
1024
,
callback
_manager
=
llm_callback_manager
callback
s
=
[
agent_loop_gather_callback_handler
,
DifyStdOutCallbackHandler
()]
)
)
tool_callback_manager
=
CallbackManager
([
agent_loop_gather_callback_handler
,
dataset_tool_callback_handler
,
DifyStdOutCallbackHandler
()
])
for
tool
in
tools
:
for
tool
in
tools
:
tool
.
callback_manager
=
tool_callback_manager
tool
.
callbacks
=
[
agent_loop_gather_callback_handler
,
dataset_tool_callback_handler
,
DifyStdOutCallbackHandler
()
]
prompt
=
cls
.
build_agent_prompt_template
(
prompt
=
cls
.
build_agent_prompt_template
(
tools
=
tools
,
tools
=
tools
,
...
@@ -54,7 +51,7 @@ class AgentBuilder:
...
@@ -54,7 +51,7 @@ class AgentBuilder:
tools
=
tools
,
tools
=
tools
,
agent
=
agent
,
agent
=
agent
,
memory
=
memory
,
memory
=
memory
,
callback
_manager
=
agent_callback_manager
,
callback
s
=
agent_callback_manager
,
max_iterations
=
6
,
max_iterations
=
6
,
early_stopping_method
=
"generate"
,
early_stopping_method
=
"generate"
,
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
...
...
api/core/callback_handler/index_tool_callback_handler.py
View file @
0d82aa8f
from
llama_index
import
Response
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
from
models.dataset
import
DocumentSegment
...
...
api/core/chain/chain_builder.py
View file @
0d82aa8f
from
typing
import
Optional
from
typing
import
Optional
from
langchain.callbacks
import
CallbackManager
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.tool_chain
import
ToolChain
from
core.chain.tool_chain
import
ToolChain
...
@@ -14,7 +12,7 @@ class ChainBuilder:
...
@@ -14,7 +12,7 @@ class ChainBuilder:
tool
=
tool
,
tool
=
tool
,
input_key
=
kwargs
.
get
(
'input_key'
,
'input'
),
input_key
=
kwargs
.
get
(
'input_key'
,
'input'
),
output_key
=
kwargs
.
get
(
'output_key'
,
'tool_output'
),
output_key
=
kwargs
.
get
(
'output_key'
,
'tool_output'
),
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
)
@
classmethod
@
classmethod
...
@@ -27,7 +25,7 @@ class ChainBuilder:
...
@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words
=
sensitive_words
.
split
(
","
),
sensitive_words
=
sensitive_words
.
split
(
","
),
canned_response
=
tool_config
.
get
(
"canned_response"
,
''
),
canned_response
=
tool_config
.
get
(
"canned_response"
,
''
),
output_key
=
"sensitive_word_avoidance_output"
,
output_key
=
"sensitive_word_avoidance_output"
,
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
,
callback
s
=
[
DifyStdOutCallbackHandler
()]
,
**
kwargs
**
kwargs
)
)
...
...
api/core/chain/llm_router_chain.py
View file @
0d82aa8f
"""Base classes for LLM-powered router chains."""
"""Base classes for LLM-powered router chains."""
from
__future__
import
annotations
from
__future__
import
annotations
import
json
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
langchain.base_language
import
BaseLanguageModel
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
pydantic
import
root_validator
from
pydantic
import
root_validator
from
langchain.chains
import
LLMChain
from
langchain.chains
import
LLMChain
from
langchain.prompts
import
BasePromptTemplate
from
langchain.prompts
import
BasePromptTemplate
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
,
BaseLanguageModel
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
from
libs.json_in_md_parser
import
parse_and_check_json_markdown
from
libs.json_in_md_parser
import
parse_and_check_json_markdown
...
...
api/core/chain/main_chain_builder.py
View file @
0d82aa8f
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
cast
from
langchain.callbacks
import
SharedCallbackManager
,
CallbackManager
from
langchain.chains
import
SequentialChain
from
langchain.chains
import
SequentialChain
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.chain_builder
import
ChainBuilder
from
core.chain.chain_builder
import
ChainBuilder
...
@@ -42,9 +40,8 @@ class MainChainBuilder:
...
@@ -42,9 +40,8 @@ class MainChainBuilder:
return
None
return
None
for
chain
in
chains
:
for
chain
in
chains
:
# do not add handler into singleton callback manager
chain
=
cast
(
Chain
,
chain
)
if
not
isinstance
(
chain
.
callback_manager
,
SharedCallbackManager
):
chain
.
callbacks
.
append
(
chain_callback_handler
)
chain
.
callback_manager
.
add_handler
(
chain_callback_handler
)
# build main chain
# build main chain
overall_chain
=
SequentialChain
(
overall_chain
=
SequentialChain
(
...
@@ -93,7 +90,7 @@ class MainChainBuilder:
...
@@ -93,7 +90,7 @@ class MainChainBuilder:
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
datasets
=
datasets
,
datasets
=
datasets
,
conversation_message_task
=
conversation_message_task
,
conversation_message_task
=
conversation_message_task
,
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
)
chains
.
append
(
multi_dataset_router_chain
)
chains
.
append
(
multi_dataset_router_chain
)
...
...
api/core/chain/multi_dataset_router_chain.py
View file @
0d82aa8f
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
typing
import
Mapping
,
List
,
Dict
,
Any
from
langchain
import
LLMChain
,
PromptTemplate
,
ConversationChain
from
langchain
import
PromptTemplate
from
langchain.callbacks
import
CallbackManager
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
langchain.schema
import
BaseLanguageModel
from
pydantic
import
Extra
from
pydantic
import
Extra
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
...
@@ -82,13 +80,12 @@ class MultiDatasetRouterChain(Chain):
...
@@ -82,13 +80,12 @@ class MultiDatasetRouterChain(Chain):
**
kwargs
:
Any
,
**
kwargs
:
Any
,
):
):
"""Convenience constructor for instantiating from destination prompts."""
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
model_name
=
'gpt-3.5-turbo'
,
temperature
=
0
,
temperature
=
0
,
max_tokens
=
1024
,
max_tokens
=
1024
,
callback
_manager
=
llm_callback_manager
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
)
destinations
=
[
"{}: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
destinations
=
[
"{}: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
...
...
api/core/completion.py
View file @
0d82aa8f
import
logging
import
logging
from
typing
import
Optional
,
List
,
Union
,
Tuple
from
typing
import
Optional
,
List
,
Union
,
Tuple
from
langchain.callbacks
import
CallbackManager
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.llms
import
BaseLLM
from
langchain.llms
import
BaseLLM
from
langchain.schema
import
BaseMessage
,
BaseLanguageModel
,
HumanMessage
from
langchain.schema
import
BaseMessage
,
HumanMessage
from
requests.exceptions
import
ChunkedEncodingError
from
requests.exceptions
import
ChunkedEncodingError
from
core.constant
import
llm_constant
from
core.constant
import
llm_constant
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
DifyStdOutCallbackHandler
DifyStdOutCallbackHandler
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
,
PubHandler
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.llm.error
import
LLMBadRequestError
from
core.llm.error
import
LLMBadRequestError
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.llm_builder
import
LLMBuilder
from
core.chain.main_chain_builder
import
MainChainBuilder
from
core.chain.main_chain_builder
import
MainChainBuilder
...
@@ -115,7 +116,7 @@ class Completion:
...
@@ -115,7 +116,7 @@ class Completion:
memory
=
memory
memory
=
memory
)
)
final_llm
.
callback
_manager
=
cls
.
get_llm_callback_manager
(
final_llm
,
streaming
,
conversation_message_task
)
final_llm
.
callback
s
=
cls
.
get_llm_callbacks
(
final_llm
,
streaming
,
conversation_message_task
)
cls
.
recale_llm_max_tokens
(
cls
.
recale_llm_max_tokens
(
final_llm
=
final_llm
,
final_llm
=
final_llm
,
...
@@ -247,16 +248,14 @@ And answer according to the language of the user's question.
...
@@ -247,16 +248,14 @@ And answer according to the language of the user's question.
return
messages
,
[
'
\n
Human:'
]
return
messages
,
[
'
\n
Human:'
]
@
classmethod
@
classmethod
def
get_llm_callback
_manager
(
cls
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
def
get_llm_callback
s
(
cls
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
streaming
:
bool
,
streaming
:
bool
,
conversation_message_task
:
ConversationMessageTask
)
->
CallbackManager
:
conversation_message_task
:
ConversationMessageTask
)
->
List
[
BaseCallbackHandler
]
:
llm_callback_handler
=
LLMCallbackHandler
(
llm
,
conversation_message_task
)
llm_callback_handler
=
LLMCallbackHandler
(
llm
,
conversation_message_task
)
if
streaming
:
if
streaming
:
callback_handlers
=
[
llm_callback_handler
,
DifyStreamingStdOutCallbackHandler
()]
return
[
llm_callback_handler
,
DifyStreamingStdOutCallbackHandler
()]
else
:
else
:
callback_handlers
=
[
llm_callback_handler
,
DifyStdOutCallbackHandler
()]
return
[
llm_callback_handler
,
DifyStdOutCallbackHandler
()]
return
CallbackManager
(
callback_handlers
)
@
classmethod
@
classmethod
def
get_history_messages_from_memory
(
cls
,
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
,
def
get_history_messages_from_memory
(
cls
,
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
,
...
@@ -360,7 +359,7 @@ And answer according to the language of the user's question.
...
@@ -360,7 +359,7 @@ And answer according to the language of the user's question.
streaming
=
streaming
streaming
=
streaming
)
)
llm
.
callback
_manager
=
cls
.
get_llm_callback_manager
(
llm
,
streaming
,
conversation_message_task
)
llm
.
callback
s
=
cls
.
get_llm_callbacks
(
llm
,
streaming
,
conversation_message_task
)
cls
.
recale_llm_max_tokens
(
cls
.
recale_llm_max_tokens
(
final_llm
=
llm
,
final_llm
=
llm
,
...
...
api/core/docstore/empty_docstore.py
deleted
100644 → 0
View file @
9e9d15ec
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
from
llama_index.docstore.types
import
BaseDocumentStore
from
llama_index.schema
import
BaseDocument
class
EmptyDocumentStore
(
BaseDocumentStore
):
@
classmethod
def
from_dict
(
cls
,
config_dict
:
Dict
[
str
,
Any
])
->
"EmptyDocumentStore"
:
return
cls
()
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""Serialize to dict."""
return
{}
@
property
def
docs
(
self
)
->
Dict
[
str
,
BaseDocument
]:
return
{}
def
add_documents
(
self
,
docs
:
Sequence
[
BaseDocument
],
allow_update
:
bool
=
True
)
->
None
:
pass
def
document_exists
(
self
,
doc_id
:
str
)
->
bool
:
"""Check if document exists."""
return
False
def
get_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
Optional
[
BaseDocument
]:
return
None
def
delete_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
None
:
pass
def
set_document_hash
(
self
,
doc_id
:
str
,
doc_hash
:
str
)
->
None
:
"""Set the hash for a given doc_id."""
pass
def
get_document_hash
(
self
,
doc_id
:
str
)
->
Optional
[
str
]:
"""Get the stored hash for a document, if it exists."""
return
None
def
update_docstore
(
self
,
other
:
"BaseDocumentStore"
)
->
None
:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self
.
add_documents
(
list
(
other
.
docs
.
values
()))
api/core/index/vector_index/base.py
View file @
0d82aa8f
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
List
,
Any
,
Tuple
from
typing
import
List
,
Any
,
Tuple
,
cast
from
langchain.schema
import
Document
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
langchain.vectorstores
import
VectorStore
from
index.base
import
BaseIndex
from
core.
index.base
import
BaseIndex
class
BaseVectorIndex
(
BaseIndex
):
class
BaseVectorIndex
(
BaseIndex
):
...
@@ -22,3 +22,68 @@ class BaseVectorIndex(BaseIndex):
...
@@ -22,3 +22,68 @@ class BaseVectorIndex(BaseIndex):
@
abstractmethod
@
abstractmethod
def
_get_vector_store
(
self
)
->
VectorStore
:
def
_get_vector_store
(
self
)
->
VectorStore
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
_get_vector_store_class
(
self
)
->
type
:
raise
NotImplementedError
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
if
search_type
==
'similarity_score_threshold'
:
score_threshold
=
search_kwargs
.
get
(
"score_threshold"
)
if
(
score_threshold
is
None
)
or
(
not
isinstance
(
score_threshold
,
float
)):
search_kwargs
[
'score_threshold'
]
=
.0
docs_with_similarity
=
vector_store
.
similarity_search_with_relevance_scores
(
query
,
**
search_kwargs
)
docs
=
[]
for
doc
,
similarity
in
docs_with_similarity
:
doc
.
metadata
[
'score'
]
=
similarity
docs
.
append
(
doc
)
return
docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
\ No newline at end of file
api/core/index/vector_index/qdrant_vector_index.py
View file @
0d82aa8f
...
@@ -80,54 +80,12 @@ class QdrantVectorIndex(BaseVectorIndex):
...
@@ -80,54 +80,12 @@ class QdrantVectorIndex(BaseVectorIndex):
embeddings
=
self
.
_embeddings
embeddings
=
self
.
_embeddings
)
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
def
_get_vector_store_class
(
self
)
->
type
:
vector_store
=
self
.
_get_vector_store
()
return
QdrantVectorStore
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
def
delete_by_document_id
(
self
,
document_id
:
str
):
def
delete_by_document_id
(
self
,
document_id
:
str
):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
vector_store
=
cast
(
self
.
_get_vector_store_class
()
,
vector_store
)
from
qdrant_client.http
import
models
from
qdrant_client.http
import
models
...
...
api/core/index/vector_index/weaviate_vector_index.py
View file @
0d82aa8f
...
@@ -89,54 +89,12 @@ class WeaviateVectorIndex(BaseVectorIndex):
...
@@ -89,54 +89,12 @@ class WeaviateVectorIndex(BaseVectorIndex):
by_text
=
False
by_text
=
False
)
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
def
_get_vector_store_class
(
self
)
->
type
:
vector_store
=
self
.
_get_vector_store
()
return
WeaviateVectorStore
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
def
delete_by_document_id
(
self
,
document_id
:
str
):
def
delete_by_document_id
(
self
,
document_id
:
str
):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
vector_store
=
cast
(
self
.
_get_vector_store_class
()
,
vector_store
)
vector_store
.
del_texts
({
vector_store
.
del_texts
({
"operator"
:
"Equal"
,
"operator"
:
"Equal"
,
...
...
api/core/llm/llm_builder.py
View file @
0d82aa8f
from
typing
import
Union
,
Optional
from
typing
import
Union
,
Optional
,
List
from
langchain.callbacks
import
CallbackManag
er
from
langchain.callbacks
.base
import
BaseCallbackHandl
er
from
langchain.llms.fake
import
FakeListLLM
from
langchain.llms.fake
import
FakeListLLM
from
core.constant
import
llm_constant
from
core.constant
import
llm_constant
...
@@ -61,7 +61,7 @@ class LLMBuilder:
...
@@ -61,7 +61,7 @@ class LLMBuilder:
top_p
=
kwargs
.
get
(
'top_p'
,
1
),
top_p
=
kwargs
.
get
(
'top_p'
,
1
),
frequency_penalty
=
kwargs
.
get
(
'frequency_penalty'
,
0
),
frequency_penalty
=
kwargs
.
get
(
'frequency_penalty'
,
0
),
presence_penalty
=
kwargs
.
get
(
'presence_penalty'
,
0
),
presence_penalty
=
kwargs
.
get
(
'presence_penalty'
,
0
),
callback
_manager
=
kwargs
.
get
(
'callback_manager
'
,
None
),
callback
s
=
kwargs
.
get
(
'callbacks
'
,
None
),
streaming
=
kwargs
.
get
(
'streaming'
,
False
),
streaming
=
kwargs
.
get
(
'streaming'
,
False
),
# request_timeout=None
# request_timeout=None
**
model_credentials
**
model_credentials
...
@@ -69,7 +69,7 @@ class LLMBuilder:
...
@@ -69,7 +69,7 @@ class LLMBuilder:
@
classmethod
@
classmethod
def
to_llm_from_model
(
cls
,
tenant_id
:
str
,
model
:
dict
,
streaming
:
bool
=
False
,
def
to_llm_from_model
(
cls
,
tenant_id
:
str
,
model
:
dict
,
streaming
:
bool
=
False
,
callback
_manager
:
Optional
[
CallbackManager
]
=
None
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
callback
s
:
Optional
[
List
[
BaseCallbackHandler
]
]
=
None
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
model_name
=
model
.
get
(
"name"
)
model_name
=
model
.
get
(
"name"
)
completion_params
=
model
.
get
(
"completion_params"
,
{})
completion_params
=
model
.
get
(
"completion_params"
,
{})
...
@@ -82,7 +82,7 @@ class LLMBuilder:
...
@@ -82,7 +82,7 @@ class LLMBuilder:
frequency_penalty
=
completion_params
.
get
(
'frequency_penalty'
,
0.1
),
frequency_penalty
=
completion_params
.
get
(
'frequency_penalty'
,
0.1
),
presence_penalty
=
completion_params
.
get
(
'presence_penalty'
,
0.1
),
presence_penalty
=
completion_params
.
get
(
'presence_penalty'
,
0.1
),
streaming
=
streaming
,
streaming
=
streaming
,
callback
_manager
=
callback_manager
callback
s
=
callbacks
)
)
@
classmethod
@
classmethod
...
...
api/core/tool/llama_index_tool.py
View file @
0d82aa8f
from
typing
import
Dict
from
typing
import
Dict
from
langchain.tools
import
BaseTool
from
langchain.tools
import
BaseTool
from
llama_index.indices.base
import
BaseGPTIndex
from
llama_index.langchain_helpers.agents
import
IndexToolConfig
from
pydantic
import
Field
from
pydantic
import
Field
from
core.callback_handler.index_tool_callback_handler
import
IndexToolCallbackHandler
from
core.callback_handler.index_tool_callback_handler
import
IndexToolCallbackHandler
...
...
api/services/hit_testing_service.py
View file @
0d82aa8f
...
@@ -3,47 +3,47 @@ import time
...
@@ -3,47 +3,47 @@ import time
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
from
llama_index.data_structs.node_v2
import
NodeWithScore
from
flask
import
current_app
from
llama_index.indices.query.schema
import
QueryBundle
from
langchain.embeddings
import
OpenAIEmbeddings
from
llama_index.indices.vector_store
import
GPTVectorStoreIndexQuery
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
from
sklearn.manifold
import
TSNE
from
sklearn.manifold
import
TSNE
from
core.docstore.empty_docstore
import
EmptyDocumentStore
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.vector_index
import
VectorIndex
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.account
import
Account
from
models.account
import
Account
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetQuery
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetQuery
from
services.errors.index
import
IndexNotInitializedError
class
HitTestingService
:
class
HitTestingService
:
@
classmethod
@
classmethod
def
retrieve
(
cls
,
dataset
:
Dataset
,
query
:
str
,
account
:
Account
,
limit
:
int
=
10
)
->
dict
:
def
retrieve
(
cls
,
dataset
:
Dataset
,
query
:
str
,
account
:
Account
,
limit
:
int
=
10
)
->
dict
:
index
=
VectorIndex
(
dataset
=
dataset
)
.
query_index
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
if
not
index
:
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
raise
IndexNotInitializedError
()
model_name
=
'text-embedding-ada-002'
index_query
=
GPTVectorStoreIndexQuery
(
index_struct
=
index
.
index_struct
,
service_context
=
index
.
service_context
,
vector_store
=
index
.
query_context
.
get
(
'vector_store'
),
docstore
=
EmptyDocumentStore
(),
response_synthesizer
=
None
,
similarity_top_k
=
limit
)
)
query_bundle
=
QueryBundle
(
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
query_str
=
query
,
**
model_credentials
custom_embedding_strs
=
[
query
],
))
)
query_bundle
.
embedding
=
index
.
service_context
.
embed_model
.
get_agg_embedding_from_queries
(
vector_index
=
VectorIndex
(
query_bundle
.
embedding_strs
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
)
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
nodes
=
index_query
.
retrieve
(
query_bundle
=
query_bundle
)
documents
=
vector_index
.
search
(
query
,
search_type
=
'similarity_score_threshold'
,
search_kwargs
=
{
'k'
:
10
}
)
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
logging
.
debug
(
f
"Hit testing retrieve in {end - start:0.4f} seconds"
)
logging
.
debug
(
f
"Hit testing retrieve in {end - start:0.4f} seconds"
)
...
@@ -58,25 +58,24 @@ class HitTestingService:
...
@@ -58,25 +58,24 @@ class HitTestingService:
db
.
session
.
add
(
dataset_query
)
db
.
session
.
add
(
dataset_query
)
db
.
session
.
commit
()
db
.
session
.
commit
()
return
cls
.
compact_retrieve_response
(
dataset
,
query_bundle
,
node
s
)
return
cls
.
compact_retrieve_response
(
dataset
,
embeddings
,
query
,
document
s
)
@
classmethod
@
classmethod
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
query_bundle
:
QueryBundle
,
nodes
:
List
[
NodeWithScore
]):
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
embeddings
:
Embeddings
,
query
:
str
,
documents
:
List
[
Document
]):
embeddings
=
[
text_
embeddings
=
[
query_bundle
.
embedding
embeddings
.
embed_query
(
query
)
]
]
for
node
in
nodes
:
text_embeddings
.
extend
(
embeddings
.
embed_documents
([
document
.
page_content
for
document
in
documents
]))
embeddings
.
append
(
node
.
node
.
embedding
)
tsne_position_data
=
cls
.
get_tsne_positions_from_embeddings
(
embeddings
)
tsne_position_data
=
cls
.
get_tsne_positions_from_embeddings
(
text_
embeddings
)
query_position
=
tsne_position_data
.
pop
(
0
)
query_position
=
tsne_position_data
.
pop
(
0
)
i
=
0
i
=
0
records
=
[]
records
=
[]
for
node
in
node
s
:
for
document
in
document
s
:
index_node_id
=
node
.
node
.
doc_id
index_node_id
=
document
.
metadata
[
'doc_id'
]
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset
.
id
,
DocumentSegment
.
dataset_id
==
dataset
.
id
,
...
@@ -91,7 +90,7 @@ class HitTestingService:
...
@@ -91,7 +90,7 @@ class HitTestingService:
record
=
{
record
=
{
"segment"
:
segment
,
"segment"
:
segment
,
"score"
:
node
.
score
,
"score"
:
document
.
metadata
[
'score'
]
,
"tsne_position"
:
tsne_position_data
[
i
]
"tsne_position"
:
tsne_position_data
[
i
]
}
}
...
@@ -101,7 +100,7 @@ class HitTestingService:
...
@@ -101,7 +100,7 @@ class HitTestingService:
return
{
return
{
"query"
:
{
"query"
:
{
"content"
:
query
_bundle
.
query_str
,
"content"
:
query
,
"tsne_position"
:
query_position
,
"tsne_position"
:
query_position
,
},
},
"records"
:
records
"records"
:
records
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment