Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
3fa5204b
Unverified
Commit
3fa5204b
authored
Jan 04, 2024
by
takatost
Committed by
GitHub
Jan 04, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: optimize performance (#1928)
parent
5a756ca9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
215 additions
and
75 deletions
+215
-75
provider_configuration.py
api/core/entities/provider_configuration.py
+33
-0
model_provider_cache.py
api/core/helper/model_provider_cache.py
+51
-0
model_provider_factory.py
...e/model_runtime/model_providers/model_provider_factory.py
+4
-0
provider_manager.py
api/core/provider_manager.py
+126
-74
model_provider_service.py
api/services/model_provider_service.py
+1
-1
No files found.
api/core/entities/provider_configuration.py
View file @
3fa5204b
...
@@ -10,6 +10,7 @@ from pydantic import BaseModel
...
@@ -10,6 +10,7 @@ from pydantic import BaseModel
from
core.entities.model_entities
import
ModelWithProviderEntity
,
ModelStatus
,
SimpleModelProviderEntity
from
core.entities.model_entities
import
ModelWithProviderEntity
,
ModelStatus
,
SimpleModelProviderEntity
from
core.entities.provider_entities
import
SystemConfiguration
,
CustomConfiguration
,
SystemConfigurationStatus
from
core.entities.provider_entities
import
SystemConfiguration
,
CustomConfiguration
,
SystemConfigurationStatus
from
core.helper
import
encrypter
from
core.helper
import
encrypter
from
core.helper.model_provider_cache
import
ProviderCredentialsCache
,
ProviderCredentialsCacheType
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_runtime.entities.provider_entities
import
ProviderEntity
,
CredentialFormSchema
,
FormType
from
core.model_runtime.entities.provider_entities
import
ProviderEntity
,
CredentialFormSchema
,
FormType
from
core.model_runtime.model_providers
import
model_provider_factory
from
core.model_runtime.model_providers
import
model_provider_factory
...
@@ -171,6 +172,14 @@ class ProviderConfiguration(BaseModel):
...
@@ -171,6 +172,14 @@ class ProviderConfiguration(BaseModel):
db
.
session
.
add
(
provider_record
)
db
.
session
.
add
(
provider_record
)
db
.
session
.
commit
()
db
.
session
.
commit
()
provider_model_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
provider_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
PROVIDER
)
provider_model_credentials_cache
.
delete
()
self
.
switch_preferred_provider_type
(
ProviderType
.
CUSTOM
)
self
.
switch_preferred_provider_type
(
ProviderType
.
CUSTOM
)
def
delete_custom_credentials
(
self
)
->
None
:
def
delete_custom_credentials
(
self
)
->
None
:
...
@@ -193,6 +202,14 @@ class ProviderConfiguration(BaseModel):
...
@@ -193,6 +202,14 @@ class ProviderConfiguration(BaseModel):
db
.
session
.
delete
(
provider_record
)
db
.
session
.
delete
(
provider_record
)
db
.
session
.
commit
()
db
.
session
.
commit
()
provider_model_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
provider_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
PROVIDER
)
provider_model_credentials_cache
.
delete
()
def
get_custom_model_credentials
(
self
,
model_type
:
ModelType
,
model
:
str
,
obfuscated
:
bool
=
False
)
\
def
get_custom_model_credentials
(
self
,
model_type
:
ModelType
,
model
:
str
,
obfuscated
:
bool
=
False
)
\
->
Optional
[
dict
]:
->
Optional
[
dict
]:
"""
"""
...
@@ -314,6 +331,14 @@ class ProviderConfiguration(BaseModel):
...
@@ -314,6 +331,14 @@ class ProviderConfiguration(BaseModel):
db
.
session
.
add
(
provider_model_record
)
db
.
session
.
add
(
provider_model_record
)
db
.
session
.
commit
()
db
.
session
.
commit
()
provider_model_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
provider_model_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
MODEL
)
provider_model_credentials_cache
.
delete
()
def
delete_custom_model_credentials
(
self
,
model_type
:
ModelType
,
model
:
str
)
->
None
:
def
delete_custom_model_credentials
(
self
,
model_type
:
ModelType
,
model
:
str
)
->
None
:
"""
"""
Delete custom model credentials.
Delete custom model credentials.
...
@@ -335,6 +360,14 @@ class ProviderConfiguration(BaseModel):
...
@@ -335,6 +360,14 @@ class ProviderConfiguration(BaseModel):
db
.
session
.
delete
(
provider_model_record
)
db
.
session
.
delete
(
provider_model_record
)
db
.
session
.
commit
()
db
.
session
.
commit
()
provider_model_credentials_cache
=
ProviderCredentialsCache
(
tenant_id
=
self
.
tenant_id
,
identity_id
=
provider_model_record
.
id
,
cache_type
=
ProviderCredentialsCacheType
.
MODEL
)
provider_model_credentials_cache
.
delete
()
def
get_provider_instance
(
self
)
->
ModelProvider
:
def
get_provider_instance
(
self
)
->
ModelProvider
:
"""
"""
Get provider instance.
Get provider instance.
...
...
api/core/helper/model_provider_cache.py
0 → 100644
View file @
3fa5204b
import
json
from
enum
import
Enum
from
json
import
JSONDecodeError
from
typing
import
Optional
from
extensions.ext_redis
import
redis_client
class
ProviderCredentialsCacheType
(
Enum
):
PROVIDER
=
"provider"
MODEL
=
"provider_model"
class
ProviderCredentialsCache
:
def
__init__
(
self
,
tenant_id
:
str
,
identity_id
:
str
,
cache_type
:
ProviderCredentialsCacheType
):
self
.
cache_key
=
f
"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def
get
(
self
)
->
Optional
[
dict
]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials
=
redis_client
.
get
(
self
.
cache_key
)
if
cached_provider_credentials
:
try
:
cached_provider_credentials
=
cached_provider_credentials
.
decode
(
'utf-8'
)
cached_provider_credentials
=
json
.
loads
(
cached_provider_credentials
)
except
JSONDecodeError
:
return
None
return
cached_provider_credentials
else
:
return
None
def
set
(
self
,
credentials
:
dict
)
->
None
:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client
.
setex
(
self
.
cache_key
,
3600
,
json
.
dumps
(
credentials
))
def
delete
(
self
)
->
None
:
"""
Delete cached model provider credentials.
:return:
"""
redis_client
.
delete
(
self
.
cache_key
)
api/core/model_runtime/model_providers/model_provider_factory.py
View file @
3fa5204b
...
@@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel):
...
@@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel):
class
ModelProviderFactory
:
class
ModelProviderFactory
:
model_provider_extensions
:
dict
[
str
,
ModelProviderExtension
]
=
None
model_provider_extensions
:
dict
[
str
,
ModelProviderExtension
]
=
None
def
__init__
(
self
)
->
None
:
# for cache in memory
self
.
get_providers
()
def
get_providers
(
self
)
->
list
[
ProviderEntity
]:
def
get_providers
(
self
)
->
list
[
ProviderEntity
]:
"""
"""
Get all providers
Get all providers
...
...
api/core/provider_manager.py
View file @
3fa5204b
...
@@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide
...
@@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide
from
core.entities.provider_entities
import
CustomConfiguration
,
CustomProviderConfiguration
,
CustomModelConfiguration
,
\
from
core.entities.provider_entities
import
CustomConfiguration
,
CustomProviderConfiguration
,
CustomModelConfiguration
,
\
SystemConfiguration
,
QuotaConfiguration
SystemConfiguration
,
QuotaConfiguration
from
core.helper
import
encrypter
from
core.helper
import
encrypter
from
core.helper.model_provider_cache
import
ProviderCredentialsCache
,
ProviderCredentialsCacheType
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_runtime.entities.model_entities
import
ModelType
from
core.model_runtime.entities.provider_entities
import
ProviderEntity
,
CredentialFormSchema
,
FormType
from
core.model_runtime.entities.provider_entities
import
ProviderEntity
,
CredentialFormSchema
,
FormType
from
core.model_runtime.model_providers
import
model_provider_factory
from
core.model_runtime.model_providers
import
model_provider_factory
...
@@ -79,9 +80,6 @@ class ProviderManager:
...
@@ -79,9 +80,6 @@ class ProviderManager:
# Get All preferred provider types of the workspace
# Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict
=
self
.
_get_all_preferred_model_providers
(
tenant_id
)
provider_name_to_preferred_model_provider_records_dict
=
self
.
_get_all_preferred_model_providers
(
tenant_id
)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
provider_configurations
=
ProviderConfigurations
(
provider_configurations
=
ProviderConfigurations
(
tenant_id
=
tenant_id
tenant_id
=
tenant_id
)
)
...
@@ -100,19 +98,17 @@ class ProviderManager:
...
@@ -100,19 +98,17 @@ class ProviderManager:
# Convert to custom configuration
# Convert to custom configuration
custom_configuration
=
self
.
_to_custom_configuration
(
custom_configuration
=
self
.
_to_custom_configuration
(
tenant_id
,
provider_entity
,
provider_entity
,
provider_records
,
provider_records
,
provider_model_records
,
provider_model_records
decoding_rsa_key
,
decoding_cipher_rsa
)
)
# Convert to system configuration
# Convert to system configuration
system_configuration
=
self
.
_to_system_configuration
(
system_configuration
=
self
.
_to_system_configuration
(
tenant_id
,
provider_entity
,
provider_entity
,
provider_records
,
provider_records
decoding_rsa_key
,
decoding_cipher_rsa
)
)
# Get preferred provider type
# Get preferred provider type
...
@@ -413,19 +409,17 @@ class ProviderManager:
...
@@ -413,19 +409,17 @@ class ProviderManager:
return
provider_name_to_provider_records_dict
return
provider_name_to_provider_records_dict
def
_to_custom_configuration
(
self
,
def
_to_custom_configuration
(
self
,
tenant_id
:
str
,
provider_entity
:
ProviderEntity
,
provider_entity
:
ProviderEntity
,
provider_records
:
list
[
Provider
],
provider_records
:
list
[
Provider
],
provider_model_records
:
list
[
ProviderModel
],
provider_model_records
:
list
[
ProviderModel
])
->
CustomConfiguration
:
decoding_rsa_key
,
decoding_cipher_rsa
)
->
CustomConfiguration
:
"""
"""
Convert to custom configuration.
Convert to custom configuration.
:param tenant_id: workspace id
:param provider_entity: provider entity
:param provider_entity: provider entity
:param provider_records: provider records
:param provider_records: provider records
:param provider_model_records: provider model records
:param provider_model_records: provider model records
:param decoding_rsa_key: decoding rsa key
:param decoding_cipher_rsa: decoding cipher rsa
:return:
:return:
"""
"""
# Get provider credential secret variables
# Get provider credential secret variables
...
@@ -448,28 +442,48 @@ class ProviderManager:
...
@@ -448,28 +442,48 @@ class ProviderManager:
# Get custom provider credentials
# Get custom provider credentials
custom_provider_configuration
=
None
custom_provider_configuration
=
None
if
custom_provider_record
:
if
custom_provider_record
:
try
:
provider_credentials_cache
=
ProviderCredentialsCache
(
# fix origin data
tenant_id
=
tenant_id
,
if
(
custom_provider_record
.
encrypted_config
identity_id
=
custom_provider_record
.
id
,
and
not
custom_provider_record
.
encrypted_config
.
startswith
(
"{"
)):
cache_type
=
ProviderCredentialsCacheType
.
PROVIDER
provider_credentials
=
{
)
"openai_api_key"
:
custom_provider_record
.
encrypted_config
}
else
:
provider_credentials
=
json
.
loads
(
custom_provider_record
.
encrypted_config
)
except
JSONDecodeError
:
provider_credentials
=
{}
for
variable
in
provider_credential_secret_variables
:
# Get cached provider credentials
if
variable
in
provider_credentials
:
cached_provider_credentials
=
provider_credentials_cache
.
get
()
try
:
provider_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
if
not
cached_provider_credentials
:
provider_credentials
.
get
(
variable
),
try
:
decoding_rsa_key
,
# fix origin data
decoding_cipher_rsa
if
(
custom_provider_record
.
encrypted_config
)
and
not
custom_provider_record
.
encrypted_config
.
startswith
(
"{"
)):
except
ValueError
:
provider_credentials
=
{
pass
"openai_api_key"
:
custom_provider_record
.
encrypted_config
}
else
:
provider_credentials
=
json
.
loads
(
custom_provider_record
.
encrypted_config
)
except
JSONDecodeError
:
provider_credentials
=
{}
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
for
variable
in
provider_credential_secret_variables
:
if
variable
in
provider_credentials
:
try
:
provider_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_credentials
.
get
(
variable
),
decoding_rsa_key
,
decoding_cipher_rsa
)
except
ValueError
:
pass
# cache provider credentials
provider_credentials_cache
.
set
(
credentials
=
provider_credentials
)
else
:
provider_credentials
=
cached_provider_credentials
custom_provider_configuration
=
CustomProviderConfiguration
(
custom_provider_configuration
=
CustomProviderConfiguration
(
credentials
=
provider_credentials
credentials
=
provider_credentials
...
@@ -487,21 +501,41 @@ class ProviderManager:
...
@@ -487,21 +501,41 @@ class ProviderManager:
if
not
provider_model_record
.
encrypted_config
:
if
not
provider_model_record
.
encrypted_config
:
continue
continue
try
:
provider_model_credentials_cache
=
ProviderCredentialsCache
(
provider_model_credentials
=
json
.
loads
(
provider_model_record
.
encrypted_config
)
tenant_id
=
tenant_id
,
except
JSONDecodeError
:
identity_id
=
provider_model_record
.
id
,
continue
cache_type
=
ProviderCredentialsCacheType
.
MODEL
)
for
variable
in
model_credential_secret_variables
:
# Get cached provider model credentials
if
variable
in
provider_model_credentials
:
cached_provider_model_credentials
=
provider_model_credentials_cache
.
get
()
try
:
provider_model_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
if
not
cached_provider_model_credentials
:
provider_model_credentials
.
get
(
variable
),
try
:
decoding_rsa_key
,
provider_model_credentials
=
json
.
loads
(
provider_model_record
.
encrypted_config
)
decoding_cipher_rsa
except
JSONDecodeError
:
)
continue
except
ValueError
:
pass
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
for
variable
in
model_credential_secret_variables
:
if
variable
in
provider_model_credentials
:
try
:
provider_model_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_model_credentials
.
get
(
variable
),
decoding_rsa_key
,
decoding_cipher_rsa
)
except
ValueError
:
pass
# cache provider model credentials
provider_model_credentials_cache
.
set
(
credentials
=
provider_model_credentials
)
else
:
provider_model_credentials
=
cached_provider_model_credentials
custom_model_configurations
.
append
(
custom_model_configurations
.
append
(
CustomModelConfiguration
(
CustomModelConfiguration
(
...
@@ -517,17 +551,15 @@ class ProviderManager:
...
@@ -517,17 +551,15 @@ class ProviderManager:
)
)
def
_to_system_configuration
(
self
,
def
_to_system_configuration
(
self
,
tenant_id
:
str
,
provider_entity
:
ProviderEntity
,
provider_entity
:
ProviderEntity
,
provider_records
:
list
[
Provider
],
provider_records
:
list
[
Provider
])
->
SystemConfiguration
:
decoding_rsa_key
,
decoding_cipher_rsa
)
->
SystemConfiguration
:
"""
"""
Convert to system configuration.
Convert to system configuration.
:param tenant_id: workspace id
:param provider_entity: provider entity
:param provider_entity: provider entity
:param provider_records: provider records
:param provider_records: provider records
:param decoding_rsa_key: decoding rsa key
:param decoding_cipher_rsa: decoding cipher rsa
:return:
:return:
"""
"""
# Get hosting configuration
# Get hosting configuration
...
@@ -580,29 +612,49 @@ class ProviderManager:
...
@@ -580,29 +612,49 @@ class ProviderManager:
provider_record
=
quota_type_to_provider_records_dict
.
get
(
current_quota_type
)
provider_record
=
quota_type_to_provider_records_dict
.
get
(
current_quota_type
)
if
provider_record
:
if
provider_record
:
try
:
provider_credentials_cache
=
ProviderCredentialsCache
(
provider_credentials
=
json
.
loads
(
provider_record
.
encrypted_config
)
tenant_id
=
tenant_id
,
except
JSONDecodeError
:
identity_id
=
provider_record
.
id
,
provider_credentials
=
{}
cache_type
=
ProviderCredentialsCacheType
.
PROVIDER
# Get provider credential secret variables
provider_credential_secret_variables
=
self
.
_extract_secret_variables
(
provider_entity
.
provider_credential_schema
.
credential_form_schemas
if
provider_entity
.
provider_credential_schema
else
[]
)
)
for
variable
in
provider_credential_secret_variables
:
# Get cached provider credentials
if
variable
in
provider_credentials
:
cached_provider_credentials
=
provider_credentials_cache
.
get
()
try
:
provider_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_credentials
.
get
(
variable
),
decoding_rsa_key
,
decoding_cipher_rsa
)
except
ValueError
:
pass
current_using_credentials
=
provider_credentials
if
not
cached_provider_credentials
:
try
:
provider_credentials
=
json
.
loads
(
provider_record
.
encrypted_config
)
except
JSONDecodeError
:
provider_credentials
=
{}
# Get provider credential secret variables
provider_credential_secret_variables
=
self
.
_extract_secret_variables
(
provider_entity
.
provider_credential_schema
.
credential_form_schemas
if
provider_entity
.
provider_credential_schema
else
[]
)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
for
variable
in
provider_credential_secret_variables
:
if
variable
in
provider_credentials
:
try
:
provider_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_credentials
.
get
(
variable
),
decoding_rsa_key
,
decoding_cipher_rsa
)
except
ValueError
:
pass
current_using_credentials
=
provider_credentials
# cache provider credentials
provider_credentials_cache
.
set
(
credentials
=
current_using_credentials
)
else
:
current_using_credentials
=
cached_provider_credentials
else
:
else
:
current_using_credentials
=
{}
current_using_credentials
=
{}
...
...
api/services/model_provider_service.py
View file @
3fa5204b
...
@@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple
...
@@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple
import
requests
import
requests
from
flask
import
current_app
from
flask
import
current_app
from
core.entities.model_entities
import
Model
WithProviderEntity
,
ModelStatus
,
DefaultModelEntity
from
core.entities.model_entities
import
Model
Status
from
core.model_runtime.entities.model_entities
import
ModelType
,
ParameterRule
from
core.model_runtime.entities.model_entities
import
ModelType
,
ParameterRule
from
core.model_runtime.model_providers
import
model_provider_factory
from
core.model_runtime.model_providers
import
model_provider_factory
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment