Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
3fa5204b
Unverified
Commit
3fa5204b
authored
Jan 04, 2024
by
takatost
Committed by
GitHub
Jan 04, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: optimize performance (#1928)
parent
5a756ca9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
215 additions
and
75 deletions
+215
-75
provider_configuration.py
api/core/entities/provider_configuration.py
+33
-0
model_provider_cache.py
api/core/helper/model_provider_cache.py
+51
-0
model_provider_factory.py
...e/model_runtime/model_providers/model_provider_factory.py
+4
-0
provider_manager.py
api/core/provider_manager.py
+126
-74
model_provider_service.py
api/services/model_provider_service.py
+1
-1
No files found.
api/core/entities/provider_configuration.py
View file @
3fa5204b
...
...
@@ -10,6 +10,7 @@ from pydantic import BaseModel
from
core.entities.model_entities
import
ModelWithProviderEntity
,
ModelStatus
,
SimpleModelProviderEntity
from
core.entities.provider_entities
import
SystemConfiguration
,
CustomConfiguration
,
SystemConfigurationStatus
from
core.helper
import
encrypter
from
core.helper.model_provider_cache
import
ProviderCredentialsCache
,
ProviderCredentialsCacheType
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_runtime.entities.provider_entities
import
ProviderEntity
,
CredentialFormSchema
,
FormType
from
core.model_runtime.model_providers
import
model_provider_factory
...
...
@@ -171,6 +172,14 @@ class ProviderConfiguration(BaseModel):
db
.
session
.
add
(
provider_record
)
db
.
session
.
commit
()
provider_model_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
provider_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
PROVIDER
)
provider_model_credentials_cache
.
delete
()
self
.
switch_preferred_provider_type
(
ProviderType
.
CUSTOM
)
def
delete_custom_credentials
(
self
)
->
None
:
...
...
@@ -193,6 +202,14 @@ class ProviderConfiguration(BaseModel):
db
.
session
.
delete
(
provider_record
)
db
.
session
.
commit
()
provider_model_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
provider_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
PROVIDER
)
provider_model_credentials_cache
.
delete
()
def
get_custom_model_credentials
(
self
,
model_type
:
ModelType
,
model
:
str
,
obfuscated
:
bool
=
False
)
\
->
Optional
[
dict
]:
"""
...
...
@@ -314,6 +331,14 @@ class ProviderConfiguration(BaseModel):
db
.
session
.
add
(
provider_model_record
)
db
.
session
.
commit
()
provider_model_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
provider_model_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
MODEL
)
provider_model_credentials_cache
.
delete
()
def
delete_custom_model_credentials
(
self
,
model_type
:
ModelType
,
model
:
str
)
->
None
:
"""
Delete custom model credentials.
...
...
@@ -335,6 +360,14 @@ class ProviderConfiguration(BaseModel):
db
.
session
.
delete
(
provider_model_record
)
db
.
session
.
commit
()
provider_model_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
provider_model_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
MODEL
)
provider_model_credentials_cache
.
delete
()
def
get_provider_instance
(
self
)
->
ModelProvider
:
"""
Get provider instance.
...
...
api/core/helper/model_provider_cache.py
0 → 100644
View file @
3fa5204b
import
json
from
enum
import
Enum
from
json
import
JSONDecodeError
from
typing
import
Optional
from
extensions.ext_redis
import
redis_client
class
ProviderCredentialsCacheType
(
Enum
):
PROVIDER
=
"provider"
MODEL
=
"provider_model"
class
ProviderCredentialsCache
:
def
__init__
(
self
,
tenant_id
:
str
,
identity_id
:
str
,
cache_type
:
ProviderCredentialsCacheType
):
self
.
cache_key
=
f
"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def
get
(
self
)
->
Optional
[
dict
]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials
=
redis_client
.
get
(
self
.
cache_key
)
if
cached_provider_credentials
:
try
:
cached_provider_credentials
=
cached_provider_credentials
.
decode
(
'utf-8'
)
cached_provider_credentials
=
json
.
loads
(
cached_provider_credentials
)
except
JSONDecodeError
:
return
None
return
cached_provider_credentials
else
:
return
None
def
set
(
self
,
credentials
:
dict
)
->
None
:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client
.
setex
(
self
.
cache_key
,
3600
,
json
.
dumps
(
credentials
))
def
delete
(
self
)
->
None
:
"""
Delete cached model provider credentials.
:return:
"""
redis_client
.
delete
(
self
.
cache_key
)
api/core/model_runtime/model_providers/model_provider_factory.py
View file @
3fa5204b
...
...
@@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel):
class
ModelProviderFactory
:
model_provider_extensions
:
dict
[
str
,
ModelProviderExtension
]
=
None
def
__init__
(
self
)
->
None
:
# for cache in memory
self
.
get_providers
()
def
get_providers
(
self
)
->
list
[
ProviderEntity
]:
"""
Get all providers
...
...
api/core/provider_manager.py
View file @
3fa5204b
...
...
@@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide
from
core.entities.provider_entities
import
CustomConfiguration
,
CustomProviderConfiguration
,
CustomModelConfiguration
,
\
SystemConfiguration
,
QuotaConfiguration
from
core.helper
import
encrypter
from
core.helper.model_provider_cache
import
ProviderCredentialsCache
,
ProviderCredentialsCacheType
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_runtime.entities.provider_entities
import
ProviderEntity
,
CredentialFormSchema
,
FormType
from
core.model_runtime.model_providers
import
model_provider_factory
...
...
@@ -79,9 +80,6 @@ class ProviderManager:
# Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict
=
self
.
_get_all_preferred_model_providers
(
tenant_id
)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
provider_configurations
=
ProviderConfigurations
(
tenant_id
=
tenant_id
)
...
...
@@ -100,19 +98,17 @@ class ProviderManager:
# Convert to custom configuration
custom_configuration
=
self
.
_to_custom_configuration
(
tenant_id
,
provider_entity
,
provider_records
,
provider_model_records
,
decoding_rsa_key
,
decoding_cipher_rsa
provider_model_records
)
# Convert to system configuration
system_configuration
=
self
.
_to_system_configuration
(
tenant_id
,
provider_entity
,
provider_records
,
decoding_rsa_key
,
decoding_cipher_rsa
provider_records
)
# Get preferred provider type
...
...
@@ -413,19 +409,17 @@ class ProviderManager:
return
provider_name_to_provider_records_dict
def
_to_custom_configuration
(
self
,
tenant_id
:
str
,
provider_entity
:
ProviderEntity
,
provider_records
:
list
[
Provider
],
provider_model_records
:
list
[
ProviderModel
],
decoding_rsa_key
,
decoding_cipher_rsa
)
->
CustomConfiguration
:
provider_model_records
:
list
[
ProviderModel
])
->
CustomConfiguration
:
"""
Convert to custom configuration.
:param tenant_id: workspace id
:param provider_entity: provider entity
:param provider_records: provider records
:param provider_model_records: provider model records
:param decoding_rsa_key: decoding rsa key
:param decoding_cipher_rsa: decoding cipher rsa
:return:
"""
# Get provider credential secret variables
...
...
@@ -448,6 +442,16 @@ class ProviderManager:
# Get custom provider credentials
custom_provider_configuration
=
None
if
custom_provider_record
:
provider_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
tenant_id
,
identity_id
=
custom_provider_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
PROVIDER
)
# Get cached provider credentials
cached_provider_credentials
=
provider_credentials_cache
.
get
()
if
not
cached_provider_credentials
:
try
:
# fix origin data
if
(
custom_provider_record
.
encrypted_config
...
...
@@ -460,6 +464,9 @@ class ProviderManager:
except
JSONDecodeError
:
provider_credentials
=
{}
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
for
variable
in
provider_credential_secret_variables
:
if
variable
in
provider_credentials
:
try
:
...
...
@@ -471,6 +478,13 @@ class ProviderManager:
except
ValueError
:
pass
# cache provider credentials
provider_credentials_cache
.
set
(
credentials
=
provider_credentials
)
else
:
provider_credentials
=
cached_provider_credentials
custom_provider_configuration
=
CustomProviderConfiguration
(
credentials
=
provider_credentials
)
...
...
@@ -487,11 +501,24 @@ class ProviderManager:
if
not
provider_model_record
.
encrypted_config
:
continue
provider_model_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
tenant_id
,
identity_id
=
provider_model_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
MODEL
)
# Get cached provider model credentials
cached_provider_model_credentials
=
provider_model_credentials_cache
.
get
()
if
not
cached_provider_model_credentials
:
try
:
provider_model_credentials
=
json
.
loads
(
provider_model_record
.
encrypted_config
)
except
JSONDecodeError
:
continue
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
for
variable
in
model_credential_secret_variables
:
if
variable
in
provider_model_credentials
:
try
:
...
...
@@ -503,6 +530,13 @@ class ProviderManager:
except
ValueError
:
pass
# cache provider model credentials
provider_model_credentials_cache
.
set
(
credentials
=
provider_model_credentials
)
else
:
provider_model_credentials
=
cached_provider_model_credentials
custom_model_configurations
.
append
(
CustomModelConfiguration
(
model
=
provider_model_record
.
model_name
,
...
...
@@ -517,17 +551,15 @@ class ProviderManager:
)
def
_to_system_configuration
(
self
,
tenant_id
:
str
,
provider_entity
:
ProviderEntity
,
provider_records
:
list
[
Provider
],
decoding_rsa_key
,
decoding_cipher_rsa
)
->
SystemConfiguration
:
provider_records
:
list
[
Provider
])
->
SystemConfiguration
:
"""
Convert to system configuration.
:param tenant_id: workspace id
:param provider_entity: provider entity
:param provider_records: provider records
:param decoding_rsa_key: decoding rsa key
:param decoding_cipher_rsa: decoding cipher rsa
:return:
"""
# Get hosting configuration
...
...
@@ -580,6 +612,16 @@ class ProviderManager:
provider_record
=
quota_type_to_provider_records_dict
.
get
(
current_quota_type
)
if
provider_record
:
provider_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
tenant_id
,
identity_id
=
provider_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
PROVIDER
)
# Get cached provider credentials
cached_provider_credentials
=
provider_credentials_cache
.
get
()
if
not
cached_provider_credentials
:
try
:
provider_credentials
=
json
.
loads
(
provider_record
.
encrypted_config
)
except
JSONDecodeError
:
...
...
@@ -591,6 +633,9 @@ class ProviderManager:
if
provider_entity
.
provider_credential_schema
else
[]
)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
for
variable
in
provider_credential_secret_variables
:
if
variable
in
provider_credentials
:
try
:
...
...
@@ -603,6 +648,13 @@ class ProviderManager:
pass
current_using_credentials
=
provider_credentials
# cache provider credentials
provider_credentials_cache
.
set
(
credentials
=
current_using_credentials
)
else
:
current_using_credentials
=
cached_provider_credentials
else
:
current_using_credentials
=
{}
...
...
api/services/model_provider_service.py
View file @
3fa5204b
...
...
@@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple
import
requests
from
flask
import
current_app
from
core.entities.model_entities
import
Model
WithProviderEntity
,
ModelStatus
,
DefaultModelEntity
from
core.entities.model_entities
import
Model
Status
from
core.model_runtime.entities.model_entities
import
ModelType
,
ParameterRule
from
core.model_runtime.model_providers
import
model_provider_factory
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment