Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
ced9fc52
Commit
ced9fc52
authored
Jun 20, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: some bugs
parent
85a25148
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
30 additions
and
21 deletions
+30
-21
main_chain_gather_callback_handler.py
...re/callback_handler/main_chain_gather_callback_handler.py
+9
-7
std_out_callback_handler.py
api/core/callback_handler/std_out_callback_handler.py
+2
-2
cached_embedding.py
api/core/embedding/cached_embedding.py
+2
-2
qdrant_vector_index.py
api/core/index/vector_index/qdrant_vector_index.py
+3
-3
vector_index.py
api/core/index/vector_index/vector_index.py
+3
-2
weaviate_vector_index.py
api/core/index/vector_index/weaviate_vector_index.py
+10
-4
azure_provider.py
api/core/llm/provider/azure_provider.py
+1
-1
No files found.
api/core/callback_handler/main_chain_gather_callback_handler.py
View file @
ced9fc52
...
...
@@ -49,8 +49,10 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
)
->
None
:
"""Print out that we are entering a chain."""
if
not
self
.
_current_chain_result
:
chain_type
=
serialized
[
'id'
][
-
1
]
if
chain_type
:
self
.
_current_chain_result
=
ChainResult
(
type
=
serialized
[
'name'
]
,
type
=
chain_type
,
prompt
=
inputs
,
started_at
=
time
.
perf_counter
()
)
...
...
api/core/callback_handler/std_out_callback_handler.py
View file @
ced9fc52
...
...
@@ -50,8 +50,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
self
,
serialized
:
Dict
[
str
,
Any
],
inputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we are entering a chain."""
c
lass_name
=
serialized
[
"name"
]
print_text
(
"
\n
[on_chain_start]
\n
Chain: "
+
c
lass_nam
e
+
"
\n
Inputs: "
+
str
(
inputs
)
+
"
\n
"
,
color
=
'pink'
)
c
hain_type
=
serialized
[
'id'
][
-
1
]
print_text
(
"
\n
[on_chain_start]
\n
Chain: "
+
c
hain_typ
e
+
"
\n
Inputs: "
+
str
(
inputs
)
+
"
\n
"
,
color
=
'pink'
)
def
on_chain_end
(
self
,
outputs
:
Dict
[
str
,
Any
],
**
kwargs
:
Any
)
->
None
:
"""Print out that we finished a chain."""
...
...
api/core/embedding/cached_embedding.py
View file @
ced9fc52
...
...
@@ -22,7 +22,7 @@ class CacheEmbedding(Embeddings):
hash
=
helper
.
generate_text_hash
(
text
)
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
hash
)
.
first
()
if
embedding
:
text_embeddings
.
append
(
embedding
.
embedding
)
text_embeddings
.
append
(
embedding
.
get_embedding
()
)
else
:
embedding_queue_texts
.
append
(
text
)
...
...
@@ -55,7 +55,7 @@ class CacheEmbedding(Embeddings):
hash
=
helper
.
generate_text_hash
(
text
)
embedding
=
db
.
session
.
query
(
Embedding
)
.
filter_by
(
hash
=
hash
)
.
first
()
if
embedding
:
return
embedding
.
embedding
return
embedding
.
get_embedding
()
embedding_results
=
self
.
_embeddings
.
embed_query
(
text
)
...
...
api/core/index/vector_index/qdrant_vector_index.py
View file @
ced9fc52
...
...
@@ -50,7 +50,7 @@ class QdrantVectorIndex(BaseVectorIndex):
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
()
)}
"vector_store"
:
{
"collection_name"
:
self
.
get_index_name
(
self
.
_dataset
.
id
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
...
...
@@ -58,7 +58,7 @@ class QdrantVectorIndex(BaseVectorIndex):
self
.
_vector_store
=
QdrantVectorStore
.
from_documents
(
texts
,
self
.
_embeddings
,
collection_name
=
self
.
get_index_name
(
self
.
_dataset
.
get_id
()
),
collection_name
=
self
.
get_index_name
(
self
.
_dataset
.
id
),
ids
=
uuids
,
**
self
.
_client_config
.
to_qdrant_params
()
)
...
...
@@ -76,7 +76,7 @@ class QdrantVectorIndex(BaseVectorIndex):
return
QdrantVectorStore
(
client
=
client
,
collection_name
=
self
.
get_index_name
(
self
.
_dataset
.
get_id
()
),
collection_name
=
self
.
get_index_name
(
self
.
_dataset
.
id
),
embeddings
=
self
.
_embeddings
)
...
...
api/core/index/vector_index/vector_index.py
View file @
ced9fc52
...
...
@@ -29,12 +29,13 @@ class VectorIndex:
return
WeaviateVectorIndex
(
dataset
=
dataset
,
config
=
WeaviateConfig
(
endpoint
=
config
.
get
(
'WEAVIATE_
URL
'
),
endpoint
=
config
.
get
(
'WEAVIATE_
ENDPOINT
'
),
api_key
=
config
.
get
(
'WEAVIATE_API_KEY'
),
batch_size
=
int
(
config
.
get
(
'WEAVIATE_BATCH_SIZE'
))
),
embeddings
=
embeddings
,
attributes
=
[
'doc_id'
,
'dataset_id'
,
'document_id'
,
'source'
],
# attributes=['doc_id', 'dataset_id', 'document_id', 'source'],
attributes
=
[
'doc_id'
],
)
elif
vector_type
==
"qdrant"
:
from
core.index.vector_index.qdrant_vector_index
import
QdrantVectorIndex
,
QdrantConfig
...
...
api/core/index/vector_index/weaviate_vector_index.py
View file @
ced9fc52
...
...
@@ -4,7 +4,7 @@ import weaviate
from
langchain.embeddings.base
import
Embeddings
from
langchain.schema
import
Document
,
BaseRetriever
from
langchain.vectorstores
import
VectorStore
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
root_validator
from
core.index.base
import
BaseIndex
from
core.index.vector_index.base
import
BaseVectorIndex
...
...
@@ -17,6 +17,12 @@ class WeaviateConfig(BaseModel):
api_key
:
Optional
[
str
]
batch_size
:
int
=
100
@
root_validator
()
def
validate_config
(
cls
,
values
:
dict
)
->
dict
:
if
not
values
[
'endpoint'
]:
raise
ValueError
(
"config WEAVIATE_ENDPOINT is required"
)
return
values
class
WeaviateVectorIndex
(
BaseVectorIndex
):
def
__init__
(
self
,
dataset
:
Dataset
,
config
:
WeaviateConfig
,
embeddings
:
Embeddings
,
attributes
:
list
[
str
]):
...
...
@@ -59,7 +65,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
def
to_index_struct
(
self
)
->
dict
:
return
{
"type"
:
self
.
get_type
(),
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
_dataset
.
get_id
()
)}
"vector_store"
:
{
"class_prefix"
:
self
.
get_index_name
(
self
.
_dataset
.
id
)}
}
def
create
(
self
,
texts
:
list
[
Document
],
**
kwargs
)
->
BaseIndex
:
...
...
@@ -68,7 +74,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
texts
,
self
.
_embeddings
,
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
_dataset
.
get_id
()
),
index_name
=
self
.
get_index_name
(
self
.
_dataset
.
id
),
uuids
=
uuids
,
by_text
=
False
)
...
...
@@ -82,7 +88,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
return
WeaviateVectorStore
(
client
=
self
.
_client
,
index_name
=
self
.
get_index_name
(
self
.
_dataset
.
get_id
()
),
index_name
=
self
.
get_index_name
(
self
.
_dataset
.
id
),
text_key
=
'text'
,
embedding
=
self
.
_embeddings
,
attributes
=
self
.
_attributes
,
...
...
api/core/llm/provider/azure_provider.py
View file @
ced9fc52
...
...
@@ -42,7 +42,7 @@ class AzureProvider(BaseProvider):
"""
config
=
self
.
get_provider_api_key
(
model_id
=
model_id
)
config
[
'openai_api_type'
]
=
'azure'
config
[
'deployment
'
]
=
config
[
'deployment
_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
config
[
'deployment_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
return
config
def
get_provider_name
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment