Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
0d82aa8f
Commit
0d82aa8f
authored
Jun 19, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: use callbacks instead of callback manager
parent
9e9d15ec
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
144 additions
and
233 deletions
+144
-233
__init__.py
api/core/__init__.py
+0
-2
agent_builder.py
api/core/agent/agent_builder.py
+8
-11
index_tool_callback_handler.py
api/core/callback_handler/index_tool_callback_handler.py
+0
-2
chain_builder.py
api/core/chain/chain_builder.py
+2
-4
llm_router_chain.py
api/core/chain/llm_router_chain.py
+2
-2
main_chain_builder.py
api/core/chain/main_chain_builder.py
+4
-7
multi_dataset_router_chain.py
api/core/chain/multi_dataset_router_chain.py
+3
-6
completion.py
api/core/completion.py
+11
-12
empty_docstore.py
api/core/docstore/empty_docstore.py
+0
-51
base.py
api/core/index/vector_index/base.py
+68
-3
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+3
-45
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+3
-45
llm_builder.py
api/core/llm/llm_builder.py
+5
-5
llama_index_tool.py
api/core/tool/llama_index_tool.py
+0
-2
hit_testing_service.py
api/services/hit_testing_service.py
+35
-36
No files found.
api/core/__init__.py
View file @
0d82aa8f
...
...
@@ -3,7 +3,6 @@ from typing import Optional
import
langchain
from
flask
import
Flask
from
langchain
import
set_handler
from
langchain.prompts.base
import
DEFAULT_FORMATTER_MAPPING
from
pydantic
import
BaseModel
...
...
@@ -28,7 +27,6 @@ def init_app(app: Flask):
if
os
.
environ
.
get
(
"DEBUG"
)
and
os
.
environ
.
get
(
"DEBUG"
)
.
lower
()
==
'true'
:
langchain
.
verbose
=
True
set_handler
(
DifyStdOutCallbackHandler
())
if
app
.
config
.
get
(
"OPENAI_API_KEY"
):
hosted_llm_credentials
.
openai
=
HostedOpenAICredential
(
api_key
=
app
.
config
.
get
(
"OPENAI_API_KEY"
))
api/core/agent/agent_builder.py
View file @
0d82aa8f
...
...
@@ -2,7 +2,7 @@ from typing import Optional
from
langchain
import
LLMChain
from
langchain.agents
import
ZeroShotAgent
,
AgentExecutor
,
ConversationalAgent
from
langchain.callbacks
import
CallbackManager
from
langchain.callbacks
.manager
import
CallbackManager
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
...
...
@@ -16,23 +16,20 @@ class AgentBuilder:
def
to_agent_chain
(
cls
,
tenant_id
:
str
,
tools
,
memory
:
Optional
[
BaseChatMemory
],
dataset_tool_callback_handler
:
DatasetToolCallbackHandler
,
agent_loop_gather_callback_handler
:
AgentLoopGatherCallbackHandler
):
llm_callback_manager
=
CallbackManager
([
agent_loop_gather_callback_handler
,
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
agent_loop_gather_callback_handler
.
model_name
,
temperature
=
0
,
max_tokens
=
1024
,
callback
_manager
=
llm_callback_manager
callback
s
=
[
agent_loop_gather_callback_handler
,
DifyStdOutCallbackHandler
()]
)
tool_callback_manager
=
CallbackManager
([
agent_loop_gather_callback_handler
,
dataset_tool_callback_handler
,
DifyStdOutCallbackHandler
()
])
for
tool
in
tools
:
tool
.
callback_manager
=
tool_callback_manager
tool
.
callbacks
=
[
agent_loop_gather_callback_handler
,
dataset_tool_callback_handler
,
DifyStdOutCallbackHandler
()
]
prompt
=
cls
.
build_agent_prompt_template
(
tools
=
tools
,
...
...
@@ -54,7 +51,7 @@ class AgentBuilder:
tools
=
tools
,
agent
=
agent
,
memory
=
memory
,
callback
_manager
=
agent_callback_manager
,
callback
s
=
agent_callback_manager
,
max_iterations
=
6
,
early_stopping_method
=
"generate"
,
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
...
...
api/core/callback_handler/index_tool_callback_handler.py
View file @
0d82aa8f
from
llama_index
import
Response
from
extensions.ext_database
import
db
from
models.dataset
import
DocumentSegment
...
...
api/core/chain/chain_builder.py
View file @
0d82aa8f
from
typing
import
Optional
from
langchain.callbacks
import
CallbackManager
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.tool_chain
import
ToolChain
...
...
@@ -14,7 +12,7 @@ class ChainBuilder:
tool
=
tool
,
input_key
=
kwargs
.
get
(
'input_key'
,
'input'
),
output_key
=
kwargs
.
get
(
'output_key'
,
'tool_output'
),
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
@
classmethod
...
...
@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words
=
sensitive_words
.
split
(
","
),
canned_response
=
tool_config
.
get
(
"canned_response"
,
''
),
output_key
=
"sensitive_word_avoidance_output"
,
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
,
callback
s
=
[
DifyStdOutCallbackHandler
()]
,
**
kwargs
)
...
...
api/core/chain/llm_router_chain.py
View file @
0d82aa8f
"""Base classes for LLM-powered router chains."""
from
__future__
import
annotations
import
json
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
,
NamedTuple
from
langchain.base_language
import
BaseLanguageModel
from
langchain.chains.base
import
Chain
from
pydantic
import
root_validator
from
langchain.chains
import
LLMChain
from
langchain.prompts
import
BasePromptTemplate
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
,
BaseLanguageModel
from
langchain.schema
import
BaseOutputParser
,
OutputParserException
from
libs.json_in_md_parser
import
parse_and_check_json_markdown
...
...
api/core/chain/main_chain_builder.py
View file @
0d82aa8f
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
cast
from
langchain.callbacks
import
SharedCallbackManager
,
CallbackManager
from
langchain.chains
import
SequentialChain
from
langchain.chains.base
import
Chain
from
langchain.memory.chat_memory
import
BaseChatMemory
from
core.callback_handler.agent_loop_gather_callback_handler
import
AgentLoopGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.chain_builder
import
ChainBuilder
...
...
@@ -42,9 +40,8 @@ class MainChainBuilder:
return
None
for
chain
in
chains
:
# do not add handler into singleton callback manager
if
not
isinstance
(
chain
.
callback_manager
,
SharedCallbackManager
):
chain
.
callback_manager
.
add_handler
(
chain_callback_handler
)
chain
=
cast
(
Chain
,
chain
)
chain
.
callbacks
.
append
(
chain_callback_handler
)
# build main chain
overall_chain
=
SequentialChain
(
...
...
@@ -93,7 +90,7 @@ class MainChainBuilder:
tenant_id
=
tenant_id
,
datasets
=
datasets
,
conversation_message_task
=
conversation_message_task
,
callback
_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
chains
.
append
(
multi_dataset_router_chain
)
...
...
api/core/chain/multi_dataset_router_chain.py
View file @
0d82aa8f
from
typing
import
Mapping
,
List
,
Dict
,
Any
,
Optional
from
typing
import
Mapping
,
List
,
Dict
,
Any
from
langchain
import
LLMChain
,
PromptTemplate
,
ConversationChain
from
langchain.callbacks
import
CallbackManager
from
langchain
import
PromptTemplate
from
langchain.chains.base
import
Chain
from
langchain.schema
import
BaseLanguageModel
from
pydantic
import
Extra
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
...
...
@@ -82,13 +80,12 @@ class MultiDatasetRouterChain(Chain):
**
kwargs
:
Any
,
):
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager
=
CallbackManager
([
DifyStdOutCallbackHandler
()])
llm
=
LLMBuilder
.
to_llm
(
tenant_id
=
tenant_id
,
model_name
=
'gpt-3.5-turbo'
,
temperature
=
0
,
max_tokens
=
1024
,
callback
_manager
=
llm_callback_manager
callback
s
=
[
DifyStdOutCallbackHandler
()]
)
destinations
=
[
"{}: {}"
.
format
(
d
.
id
,
d
.
description
.
replace
(
'
\n
'
,
' '
)
if
d
.
description
...
...
api/core/completion.py
View file @
0d82aa8f
import
logging
from
typing
import
Optional
,
List
,
Union
,
Tuple
from
langchain.callbacks
import
CallbackManager
from
langchain.base_language
import
BaseLanguageModel
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.chat_models.base
import
BaseChatModel
from
langchain.llms
import
BaseLLM
from
langchain.schema
import
BaseMessage
,
BaseLanguageModel
,
HumanMessage
from
langchain.schema
import
BaseMessage
,
HumanMessage
from
requests.exceptions
import
ChunkedEncodingError
from
core.constant
import
llm_constant
from
core.callback_handler.llm_callback_handler
import
LLMCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
\
DifyStdOutCallbackHandler
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
,
PubHandler
from
core.conversation_message_task
import
ConversationMessageTask
,
ConversationTaskStoppedException
from
core.llm.error
import
LLMBadRequestError
from
core.llm.llm_builder
import
LLMBuilder
from
core.chain.main_chain_builder
import
MainChainBuilder
...
...
@@ -115,7 +116,7 @@ class Completion:
memory
=
memory
)
final_llm
.
callback
_manager
=
cls
.
get_llm_callback_manager
(
final_llm
,
streaming
,
conversation_message_task
)
final_llm
.
callback
s
=
cls
.
get_llm_callbacks
(
final_llm
,
streaming
,
conversation_message_task
)
cls
.
recale_llm_max_tokens
(
final_llm
=
final_llm
,
...
...
@@ -247,16 +248,14 @@ And answer according to the language of the user's question.
return
messages
,
[
'
\n
Human:'
]
@
classmethod
def
get_llm_callback
_manager
(
cls
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
streaming
:
bool
,
conversation_message_task
:
ConversationMessageTask
)
->
CallbackManager
:
def
get_llm_callback
s
(
cls
,
llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
],
streaming
:
bool
,
conversation_message_task
:
ConversationMessageTask
)
->
List
[
BaseCallbackHandler
]
:
llm_callback_handler
=
LLMCallbackHandler
(
llm
,
conversation_message_task
)
if
streaming
:
callback_handlers
=
[
llm_callback_handler
,
DifyStreamingStdOutCallbackHandler
()]
return
[
llm_callback_handler
,
DifyStreamingStdOutCallbackHandler
()]
else
:
callback_handlers
=
[
llm_callback_handler
,
DifyStdOutCallbackHandler
()]
return
CallbackManager
(
callback_handlers
)
return
[
llm_callback_handler
,
DifyStdOutCallbackHandler
()]
@
classmethod
def
get_history_messages_from_memory
(
cls
,
memory
:
ReadOnlyConversationTokenDBBufferSharedMemory
,
...
...
@@ -360,7 +359,7 @@ And answer according to the language of the user's question.
streaming
=
streaming
)
llm
.
callback
_manager
=
cls
.
get_llm_callback_manager
(
llm
,
streaming
,
conversation_message_task
)
llm
.
callback
s
=
cls
.
get_llm_callbacks
(
llm
,
streaming
,
conversation_message_task
)
cls
.
recale_llm_max_tokens
(
final_llm
=
llm
,
...
...
api/core/docstore/empty_docstore.py
deleted
100644 → 0
View file @
9e9d15ec
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
from
llama_index.docstore.types
import
BaseDocumentStore
from
llama_index.schema
import
BaseDocument
class
EmptyDocumentStore
(
BaseDocumentStore
):
@
classmethod
def
from_dict
(
cls
,
config_dict
:
Dict
[
str
,
Any
])
->
"EmptyDocumentStore"
:
return
cls
()
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""Serialize to dict."""
return
{}
@
property
def
docs
(
self
)
->
Dict
[
str
,
BaseDocument
]:
return
{}
def
add_documents
(
self
,
docs
:
Sequence
[
BaseDocument
],
allow_update
:
bool
=
True
)
->
None
:
pass
def
document_exists
(
self
,
doc_id
:
str
)
->
bool
:
"""Check if document exists."""
return
False
def
get_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
Optional
[
BaseDocument
]:
return
None
def
delete_document
(
self
,
doc_id
:
str
,
raise_error
:
bool
=
True
)
->
None
:
pass
def
set_document_hash
(
self
,
doc_id
:
str
,
doc_hash
:
str
)
->
None
:
"""Set the hash for a given doc_id."""
pass
def
get_document_hash
(
self
,
doc_id
:
str
)
->
Optional
[
str
]:
"""Get the stored hash for a document, if it exists."""
return
None
def
update_docstore
(
self
,
other
:
"BaseDocumentStore"
)
->
None
:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self
.
add_documents
(
list
(
other
.
docs
.
values
()))
api/core/index/vector_index/base.py
View file @
0d82aa8f
from
abc
import
abstractmethod
from
typing
import
List
,
Any
,
Tuple
from
typing
import
List
,
Any
,
Tuple
,
cast
from
langchain.schema
import
Document
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
index.base
import
BaseIndex
from
core.
index.base
import
BaseIndex
class
BaseVectorIndex
(
BaseIndex
):
...
...
@@ -22,3 +22,68 @@ class BaseVectorIndex(BaseIndex):
@
abstractmethod
def
_get_vector_store
(
self
)
->
VectorStore
:
raise
NotImplementedError
@
abstractmethod
def
_get_vector_store_class
(
self
)
->
type
:
raise
NotImplementedError
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
if
search_type
==
'similarity_score_threshold'
:
score_threshold
=
search_kwargs
.
get
(
"score_threshold"
)
if
(
score_threshold
is
None
)
or
(
not
isinstance
(
score_threshold
,
float
)):
search_kwargs
[
'score_threshold'
]
=
.0
docs_with_similarity
=
vector_store
.
similarity_search_with_relevance_scores
(
query
,
**
search_kwargs
)
docs
=
[]
for
doc
,
similarity
in
docs_with_similarity
:
doc
.
metadata
[
'score'
]
=
similarity
docs
.
append
(
doc
)
return
docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
self
.
_get_vector_store_class
(),
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
\ No newline at end of file
api/core/index/vector_index/qdrant_vector_index.py
View file @
0d82aa8f
...
...
@@ -80,54 +80,12 @@ class QdrantVectorIndex(BaseVectorIndex):
embeddings
=
self
.
_embeddings
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
def
_get_vector_store_class
(
self
)
->
type
:
return
QdrantVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
QdrantVectorStore
,
vector_store
)
vector_store
=
cast
(
self
.
_get_vector_store_class
()
,
vector_store
)
from
qdrant_client.http
import
models
...
...
api/core/index/vector_index/weaviate_vector_index.py
View file @
0d82aa8f
...
...
@@ -89,54 +89,12 @@ class WeaviateVectorIndex(BaseVectorIndex):
by_text
=
False
)
def
get_retriever
(
self
,
**
kwargs
:
Any
)
->
BaseRetriever
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
return
vector_store
.
as_retriever
(
**
kwargs
)
def
search
(
self
,
query
:
str
,
**
kwargs
:
Any
)
->
List
[
Document
]:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
search_type
=
kwargs
.
get
(
'search_type'
)
if
kwargs
.
get
(
'search_type'
)
else
'similarity'
search_kwargs
=
kwargs
.
get
(
'search_kwargs'
)
if
kwargs
.
get
(
'search_kwargs'
)
else
{}
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return
vector_store
.
as_retriever
(
search_type
=
search_type
,
search_kwargs
=
search_kwargs
)
.
get_relevant_documents
(
query
)
def
add_texts
(
self
,
texts
:
list
[
Document
]):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
texts
=
self
.
_filter_duplicate_texts
(
texts
)
uuids
=
self
.
_get_uuids
(
texts
)
vector_store
.
add_documents
(
texts
,
uuids
=
uuids
)
def
text_exists
(
self
,
id
:
str
)
->
bool
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
return
vector_store
.
text_exists
(
id
)
def
delete_by_ids
(
self
,
ids
:
list
[
str
])
->
None
:
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
for
node_id
in
ids
:
vector_store
.
del_text
(
node_id
)
def
_get_vector_store_class
(
self
)
->
type
:
return
WeaviateVectorStore
def
delete_by_document_id
(
self
,
document_id
:
str
):
vector_store
=
self
.
_get_vector_store
()
vector_store
=
cast
(
WeaviateVectorStore
,
vector_store
)
vector_store
=
cast
(
self
.
_get_vector_store_class
()
,
vector_store
)
vector_store
.
del_texts
({
"operator"
:
"Equal"
,
...
...
api/core/llm/llm_builder.py
View file @
0d82aa8f
from
typing
import
Union
,
Optional
from
typing
import
Union
,
Optional
,
List
from
langchain.callbacks
import
CallbackManag
er
from
langchain.callbacks
.base
import
BaseCallbackHandl
er
from
langchain.llms.fake
import
FakeListLLM
from
core.constant
import
llm_constant
...
...
@@ -61,7 +61,7 @@ class LLMBuilder:
top_p
=
kwargs
.
get
(
'top_p'
,
1
),
frequency_penalty
=
kwargs
.
get
(
'frequency_penalty'
,
0
),
presence_penalty
=
kwargs
.
get
(
'presence_penalty'
,
0
),
callback
_manager
=
kwargs
.
get
(
'callback_manager
'
,
None
),
callback
s
=
kwargs
.
get
(
'callbacks
'
,
None
),
streaming
=
kwargs
.
get
(
'streaming'
,
False
),
# request_timeout=None
**
model_credentials
...
...
@@ -69,7 +69,7 @@ class LLMBuilder:
@
classmethod
def
to_llm_from_model
(
cls
,
tenant_id
:
str
,
model
:
dict
,
streaming
:
bool
=
False
,
callback
_manager
:
Optional
[
CallbackManager
]
=
None
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
callback
s
:
Optional
[
List
[
BaseCallbackHandler
]
]
=
None
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
model_name
=
model
.
get
(
"name"
)
completion_params
=
model
.
get
(
"completion_params"
,
{})
...
...
@@ -82,7 +82,7 @@ class LLMBuilder:
frequency_penalty
=
completion_params
.
get
(
'frequency_penalty'
,
0.1
),
presence_penalty
=
completion_params
.
get
(
'presence_penalty'
,
0.1
),
streaming
=
streaming
,
callback
_manager
=
callback_manager
callback
s
=
callbacks
)
@
classmethod
...
...
api/core/tool/llama_index_tool.py
View file @
0d82aa8f
from
typing
import
Dict
from
langchain.tools
import
BaseTool
from
llama_index.indices.base
import
BaseGPTIndex
from
llama_index.langchain_helpers.agents
import
IndexToolConfig
from
pydantic
import
Field
from
core.callback_handler.index_tool_callback_handler
import
IndexToolCallbackHandler
...
...
api/services/hit_testing_service.py
View file @
0d82aa8f
...
...
@@ -3,47 +3,47 @@ import time
from
typing
import
List
import
numpy
as
np
from
llama_index.data_structs.node_v2
import
NodeWithScore
from
llama_index.indices.query.schema
import
QueryBundle
from
llama_index.indices.vector_store
import
GPTVectorStoreIndexQuery
from
flask
import
current_app
from
langchain.embeddings
import
OpenAIEmbeddings
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
from
sklearn.manifold
import
TSNE
from
core.docstore.empty_docstore
import
EmptyDocumentStore
from
core.index.vector_index
import
VectorIndex
from
core.embedding.cached_embedding
import
CacheEmbedding
from
core.index.vector_index.vector_index
import
VectorIndex
from
core.llm.llm_builder
import
LLMBuilder
from
extensions.ext_database
import
db
from
models.account
import
Account
from
models.dataset
import
Dataset
,
DocumentSegment
,
DatasetQuery
from
services.errors.index
import
IndexNotInitializedError
class
HitTestingService
:
@
classmethod
def
retrieve
(
cls
,
dataset
:
Dataset
,
query
:
str
,
account
:
Account
,
limit
:
int
=
10
)
->
dict
:
index
=
VectorIndex
(
dataset
=
dataset
)
.
query_index
if
not
index
:
raise
IndexNotInitializedError
()
index_query
=
GPTVectorStoreIndexQuery
(
index_struct
=
index
.
index_struct
,
service_context
=
index
.
service_context
,
vector_store
=
index
.
query_context
.
get
(
'vector_store'
),
docstore
=
EmptyDocumentStore
(),
response_synthesizer
=
None
,
similarity_top_k
=
limit
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_name
=
'text-embedding-ada-002'
)
query_bundle
=
QueryBundle
(
query_str
=
query
,
custom_embedding_strs
=
[
query
],
)
embeddings
=
CacheEmbedding
(
OpenAIEmbeddings
(
**
model_credentials
))
query_bundle
.
embedding
=
index
.
service_context
.
embed_model
.
get_agg_embedding_from_queries
(
query_bundle
.
embedding_strs
vector_index
=
VectorIndex
(
dataset
=
dataset
,
config
=
current_app
.
config
,
embeddings
=
embeddings
)
start
=
time
.
perf_counter
()
nodes
=
index_query
.
retrieve
(
query_bundle
=
query_bundle
)
documents
=
vector_index
.
search
(
query
,
search_type
=
'similarity_score_threshold'
,
search_kwargs
=
{
'k'
:
10
}
)
end
=
time
.
perf_counter
()
logging
.
debug
(
f
"Hit testing retrieve in {end - start:0.4f} seconds"
)
...
...
@@ -58,25 +58,24 @@ class HitTestingService:
db
.
session
.
add
(
dataset_query
)
db
.
session
.
commit
()
return
cls
.
compact_retrieve_response
(
dataset
,
query_bundle
,
node
s
)
return
cls
.
compact_retrieve_response
(
dataset
,
embeddings
,
query
,
document
s
)
@
classmethod
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
query_bundle
:
QueryBundle
,
nodes
:
List
[
NodeWithScore
]):
embeddings
=
[
query_bundle
.
embedding
def
compact_retrieve_response
(
cls
,
dataset
:
Dataset
,
embeddings
:
Embeddings
,
query
:
str
,
documents
:
List
[
Document
]):
text_
embeddings
=
[
embeddings
.
embed_query
(
query
)
]
for
node
in
nodes
:
embeddings
.
append
(
node
.
node
.
embedding
)
text_embeddings
.
extend
(
embeddings
.
embed_documents
([
document
.
page_content
for
document
in
documents
]))
tsne_position_data
=
cls
.
get_tsne_positions_from_embeddings
(
embeddings
)
tsne_position_data
=
cls
.
get_tsne_positions_from_embeddings
(
text_
embeddings
)
query_position
=
tsne_position_data
.
pop
(
0
)
i
=
0
records
=
[]
for
node
in
node
s
:
index_node_id
=
node
.
node
.
doc_id
for
document
in
document
s
:
index_node_id
=
document
.
metadata
[
'doc_id'
]
segment
=
db
.
session
.
query
(
DocumentSegment
)
.
filter
(
DocumentSegment
.
dataset_id
==
dataset
.
id
,
...
...
@@ -91,7 +90,7 @@ class HitTestingService:
record
=
{
"segment"
:
segment
,
"score"
:
node
.
score
,
"score"
:
document
.
metadata
[
'score'
]
,
"tsne_position"
:
tsne_position_data
[
i
]
}
...
...
@@ -101,7 +100,7 @@ class HitTestingService:
return
{
"query"
:
{
"content"
:
query
_bundle
.
query_str
,
"content"
:
query
,
"tsne_position"
:
query_position
,
},
"records"
:
records
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment