Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
5010706d
Unverified
Commit
5010706d
authored
Feb 05, 2024
by
Yeuoly
Committed by
GitHub
Feb 05, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: tool credentials cache and introduce _position.yaml (#2386)
parent
6278ff0f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
25 deletions
+95
-25
tool_provider_cache.py
api/core/helper/tool_provider_cache.py
+49
-0
_position.yaml
api/core/tools/provider/_position.yaml
+15
-0
_positions.py
api/core/tools/provider/builtin/_positions.py
+17
-19
configuration.py
api/core/tools/utils/configuration.py
+14
-6
No files found.
api/core/helper/tool_provider_cache.py
0 → 100644
View file @
5010706d
import
json
from
enum
import
Enum
from
json
import
JSONDecodeError
from
typing
import
Optional
from
extensions.ext_redis
import
redis_client
class
ToolProviderCredentialsCacheType
(
Enum
):
PROVIDER
=
"tool_provider"
class
ToolProviderCredentialsCache
:
def
__init__
(
self
,
tenant_id
:
str
,
identity_id
:
str
,
cache_type
:
ToolProviderCredentialsCacheType
):
self
.
cache_key
=
f
"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def
get
(
self
)
->
Optional
[
dict
]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials
=
redis_client
.
get
(
self
.
cache_key
)
if
cached_provider_credentials
:
try
:
cached_provider_credentials
=
cached_provider_credentials
.
decode
(
'utf-8'
)
cached_provider_credentials
=
json
.
loads
(
cached_provider_credentials
)
except
JSONDecodeError
:
return
None
return
cached_provider_credentials
else
:
return
None
def
set
(
self
,
credentials
:
dict
)
->
None
:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client
.
setex
(
self
.
cache_key
,
86400
,
json
.
dumps
(
credentials
))
def
delete
(
self
)
->
None
:
"""
Delete cached model provider credentials.
:return:
"""
redis_client
.
delete
(
self
.
cache_key
)
\ No newline at end of file
api/core/tools/provider/_position.yaml
0 → 100644
View file @
5010706d
-
google
-
bing
-
wikipedia
-
dalle
-
azuredalle
-
webscraper
-
wolframalpha
-
github
-
chart
-
time
-
yahoo
-
stablediffusion
-
vectorizer
-
youtube
-
gaode
api/core/tools/provider/builtin/_positions.py
View file @
5010706d
from
typing
import
List
from
core.tools.entities.user_entities
import
UserToolProvider
from
core.tools.entities.user_entities
import
UserToolProvider
from
core.tools.entities.tool_entities
import
ToolProviderType
from
typing
import
List
from
yaml
import
load
,
FullLoader
position
=
{
import
os.path
'google'
:
1
,
'bing'
:
2
,
'wikipedia'
:
2
,
'dalle'
:
3
,
'webscraper'
:
4
,
'wolframalpha'
:
5
,
'chart'
:
6
,
'time'
:
7
,
'yahoo'
:
8
,
'stablediffusion'
:
9
,
'vectorizer'
:
10
,
'youtube'
:
11
,
'github'
:
12
,
'gaode'
:
13
}
position
=
{}
class
BuiltinToolProviderSort
:
class
BuiltinToolProviderSort
:
@
staticmethod
@
staticmethod
def
sort
(
providers
:
List
[
UserToolProvider
])
->
List
[
UserToolProvider
]:
def
sort
(
providers
:
List
[
UserToolProvider
])
->
List
[
UserToolProvider
]:
global
position
if
not
position
:
tmp_position
=
{}
file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'_position.yaml'
)
with
open
(
file_path
,
'r'
)
as
f
:
for
pos
,
val
in
enumerate
(
load
(
f
,
Loader
=
FullLoader
)):
tmp_position
[
val
]
=
pos
position
=
tmp_position
def
sort_compare
(
provider
:
UserToolProvider
)
->
int
:
def
sort_compare
(
provider
:
UserToolProvider
)
->
int
:
# if provider.type == UserToolProvider.ProviderType.MODEL:
# return position.get(f'model_provider.{provider.name}', 10000)
return
position
.
get
(
provider
.
name
,
10000
)
return
position
.
get
(
provider
.
name
,
10000
)
sorted_providers
=
sorted
(
providers
,
key
=
sort_compare
)
sorted_providers
=
sorted
(
providers
,
key
=
sort_compare
)
return
sorted_providers
return
sorted_providers
\ No newline at end of file
api/core/tools/utils/configuration.py
View file @
5010706d
from
typing
import
Any
,
Dict
from
typing
import
Dict
,
Any
from
pydantic
import
BaseModel
from
core.helper
import
encrypter
from
core.tools.entities.tool_entities
import
ToolProviderCredentials
from
core.tools.entities.tool_entities
import
ToolProviderCredentials
from
core.tools.provider.tool_provider
import
ToolProviderController
from
core.tools.provider.tool_provider
import
ToolProviderController
from
pydantic
import
BaseModel
from
core.helper
import
encrypter
from
core.helper.tool_provider_cache
import
ToolProviderCredentialsCacheType
,
ToolProviderCredentialsCache
class
ToolConfiguration
(
BaseModel
):
class
ToolConfiguration
(
BaseModel
):
tenant_id
:
str
tenant_id
:
str
...
@@ -63,8 +63,15 @@ class ToolConfiguration(BaseModel):
...
@@ -63,8 +63,15 @@ class ToolConfiguration(BaseModel):
return a deep copy of credentials with decrypted values
return a deep copy of credentials with decrypted values
"""
"""
cache
=
ToolProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
f
'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}'
,
cache_type
=
ToolProviderCredentialsCacheType
.
PROVIDER
)
cached_credentials
=
cache
.
get
()
if
cached_credentials
:
return
cached_credentials
credentials
=
self
.
_deep_copy
(
credentials
)
credentials
=
self
.
_deep_copy
(
credentials
)
# get fields need to be decrypted
# get fields need to be decrypted
fields
=
self
.
provider_controller
.
get_credentials_schema
()
fields
=
self
.
provider_controller
.
get_credentials_schema
()
for
field_name
,
field
in
fields
.
items
():
for
field_name
,
field
in
fields
.
items
():
...
@@ -74,5 +81,6 @@ class ToolConfiguration(BaseModel):
...
@@ -74,5 +81,6 @@ class ToolConfiguration(BaseModel):
credentials
[
field_name
]
=
encrypter
.
decrypt_token
(
self
.
tenant_id
,
credentials
[
field_name
])
credentials
[
field_name
]
=
encrypter
.
decrypt_token
(
self
.
tenant_id
,
credentials
[
field_name
])
except
:
except
:
pass
pass
cache
.
set
(
credentials
)
return
credentials
return
credentials
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment