Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
296bf443
Unverified
Commit
296bf443
authored
Jan 05, 2024
by
takatost
Committed by
GitHub
Jan 05, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: reuse decoding_rsa_key & decoding_cipher_rsa & optimize construct (#1937)
parent
af7be9bd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
16 deletions
+60
-16
provider_configuration.py
api/core/entities/provider_configuration.py
+21
-3
provider_manager.py
api/core/provider_manager.py
+15
-9
model_provider_service.py
api/services/model_provider_service.py
+24
-4
No files found.
api/core/entities/provider_configuration.py
View file @
296bf443
...
@@ -520,7 +520,13 @@ class ProviderConfiguration(BaseModel):
...
@@ -520,7 +520,13 @@ class ProviderConfiguration(BaseModel):
provider_models
.
extend
(
provider_models
.
extend
(
[
[
ModelWithProviderEntity
(
ModelWithProviderEntity
(
**
m
.
dict
(),
model
=
m
.
model
,
label
=
m
.
label
,
model_type
=
m
.
model_type
,
features
=
m
.
features
,
fetch_from
=
m
.
fetch_from
,
model_properties
=
m
.
model_properties
,
deprecated
=
m
.
deprecated
,
provider
=
SimpleModelProviderEntity
(
self
.
provider
),
provider
=
SimpleModelProviderEntity
(
self
.
provider
),
status
=
ModelStatus
.
ACTIVE
status
=
ModelStatus
.
ACTIVE
)
)
...
@@ -569,7 +575,13 @@ class ProviderConfiguration(BaseModel):
...
@@ -569,7 +575,13 @@ class ProviderConfiguration(BaseModel):
for
m
in
models
:
for
m
in
models
:
provider_models
.
append
(
provider_models
.
append
(
ModelWithProviderEntity
(
ModelWithProviderEntity
(
**
m
.
dict
(),
model
=
m
.
model
,
label
=
m
.
label
,
model_type
=
m
.
model_type
,
features
=
m
.
features
,
fetch_from
=
m
.
fetch_from
,
model_properties
=
m
.
model_properties
,
deprecated
=
m
.
deprecated
,
provider
=
SimpleModelProviderEntity
(
self
.
provider
),
provider
=
SimpleModelProviderEntity
(
self
.
provider
),
status
=
ModelStatus
.
ACTIVE
if
credentials
else
ModelStatus
.
NO_CONFIGURE
status
=
ModelStatus
.
ACTIVE
if
credentials
else
ModelStatus
.
NO_CONFIGURE
)
)
...
@@ -597,7 +609,13 @@ class ProviderConfiguration(BaseModel):
...
@@ -597,7 +609,13 @@ class ProviderConfiguration(BaseModel):
provider_models
.
append
(
provider_models
.
append
(
ModelWithProviderEntity
(
ModelWithProviderEntity
(
**
custom_model_schema
.
dict
(),
model
=
custom_model_schema
.
model
,
label
=
custom_model_schema
.
label
,
model_type
=
custom_model_schema
.
model_type
,
features
=
custom_model_schema
.
features
,
fetch_from
=
custom_model_schema
.
fetch_from
,
model_properties
=
custom_model_schema
.
model_properties
,
deprecated
=
custom_model_schema
.
deprecated
,
provider
=
SimpleModelProviderEntity
(
self
.
provider
),
provider
=
SimpleModelProviderEntity
(
self
.
provider
),
status
=
ModelStatus
.
ACTIVE
status
=
ModelStatus
.
ACTIVE
)
)
...
...
api/core/provider_manager.py
View file @
296bf443
...
@@ -24,6 +24,9 @@ class ProviderManager:
...
@@ -24,6 +24,9 @@ class ProviderManager:
"""
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
"""
def
__init__
(
self
)
->
None
:
self
.
decoding_rsa_key
=
None
self
.
decoding_cipher_rsa
=
None
def
get_configurations
(
self
,
tenant_id
:
str
)
->
ProviderConfigurations
:
def
get_configurations
(
self
,
tenant_id
:
str
)
->
ProviderConfigurations
:
"""
"""
...
@@ -472,15 +475,16 @@ class ProviderManager:
...
@@ -472,15 +475,16 @@ class ProviderManager:
provider_credentials
=
{}
provider_credentials
=
{}
# Get decoding rsa key and cipher for decrypting credentials
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
if
self
.
decoding_rsa_key
is
None
or
self
.
decoding_cipher_rsa
is
None
:
self
.
decoding_rsa_key
,
self
.
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
for
variable
in
provider_credential_secret_variables
:
for
variable
in
provider_credential_secret_variables
:
if
variable
in
provider_credentials
:
if
variable
in
provider_credentials
:
try
:
try
:
provider_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_credentials
.
get
(
variable
),
provider_credentials
.
get
(
variable
),
decoding_rsa_key
,
self
.
decoding_rsa_key
,
decoding_cipher_rsa
self
.
decoding_cipher_rsa
)
)
except
ValueError
:
except
ValueError
:
pass
pass
...
@@ -524,15 +528,16 @@ class ProviderManager:
...
@@ -524,15 +528,16 @@ class ProviderManager:
continue
continue
# Get decoding rsa key and cipher for decrypting credentials
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
if
self
.
decoding_rsa_key
is
None
or
self
.
decoding_cipher_rsa
is
None
:
self
.
decoding_rsa_key
,
self
.
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
for
variable
in
model_credential_secret_variables
:
for
variable
in
model_credential_secret_variables
:
if
variable
in
provider_model_credentials
:
if
variable
in
provider_model_credentials
:
try
:
try
:
provider_model_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_model_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_model_credentials
.
get
(
variable
),
provider_model_credentials
.
get
(
variable
),
decoding_rsa_key
,
self
.
decoding_rsa_key
,
decoding_cipher_rsa
self
.
decoding_cipher_rsa
)
)
except
ValueError
:
except
ValueError
:
pass
pass
...
@@ -641,15 +646,16 @@ class ProviderManager:
...
@@ -641,15 +646,16 @@ class ProviderManager:
)
)
# Get decoding rsa key and cipher for decrypting credentials
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key
,
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
if
self
.
decoding_rsa_key
is
None
or
self
.
decoding_cipher_rsa
is
None
:
self
.
decoding_rsa_key
,
self
.
decoding_cipher_rsa
=
encrypter
.
get_decrypt_decoding
(
tenant_id
)
for
variable
in
provider_credential_secret_variables
:
for
variable
in
provider_credential_secret_variables
:
if
variable
in
provider_credentials
:
if
variable
in
provider_credentials
:
try
:
try
:
provider_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_credentials
[
variable
]
=
encrypter
.
decrypt_token_with_decoding
(
provider_credentials
.
get
(
variable
),
provider_credentials
.
get
(
variable
),
decoding_rsa_key
,
self
.
decoding_rsa_key
,
decoding_cipher_rsa
self
.
decoding_cipher_rsa
)
)
except
ValueError
:
except
ValueError
:
pass
pass
...
...
api/services/model_provider_service.py
View file @
296bf443
...
@@ -14,7 +14,7 @@ from core.provider_manager import ProviderManager
...
@@ -14,7 +14,7 @@ from core.provider_manager import ProviderManager
from
models.provider
import
ProviderType
from
models.provider
import
ProviderType
from
services.entities.model_provider_entities
import
ProviderResponse
,
CustomConfigurationResponse
,
\
from
services.entities.model_provider_entities
import
ProviderResponse
,
CustomConfigurationResponse
,
\
SystemConfigurationResponse
,
CustomConfigurationStatus
,
ProviderWithModelsResponse
,
ModelResponse
,
\
SystemConfigurationResponse
,
CustomConfigurationStatus
,
ProviderWithModelsResponse
,
ModelResponse
,
\
DefaultModelResponse
,
ModelWithProviderEntityResponse
DefaultModelResponse
,
ModelWithProviderEntityResponse
,
SimpleProviderEntityResponse
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -45,7 +45,17 @@ class ModelProviderService:
...
@@ -45,7 +45,17 @@ class ModelProviderService:
continue
continue
provider_response
=
ProviderResponse
(
provider_response
=
ProviderResponse
(
**
provider_configuration
.
provider
.
dict
(),
provider
=
provider_configuration
.
provider
.
provider
,
label
=
provider_configuration
.
provider
.
label
,
description
=
provider_configuration
.
provider
.
description
,
icon_small
=
provider_configuration
.
provider
.
icon_small
,
icon_large
=
provider_configuration
.
provider
.
icon_large
,
background
=
provider_configuration
.
provider
.
background
,
help
=
provider_configuration
.
provider
.
help
,
supported_model_types
=
provider_configuration
.
provider
.
supported_model_types
,
configurate_methods
=
provider_configuration
.
provider
.
configurate_methods
,
provider_credential_schema
=
provider_configuration
.
provider
.
provider_credential_schema
,
model_credential_schema
=
provider_configuration
.
provider
.
model_credential_schema
,
preferred_provider_type
=
provider_configuration
.
preferred_provider_type
,
preferred_provider_type
=
provider_configuration
.
preferred_provider_type
,
custom_configuration
=
CustomConfigurationResponse
(
custom_configuration
=
CustomConfigurationResponse
(
status
=
CustomConfigurationStatus
.
ACTIVE
status
=
CustomConfigurationStatus
.
ACTIVE
...
@@ -53,7 +63,9 @@ class ModelProviderService:
...
@@ -53,7 +63,9 @@ class ModelProviderService:
else
CustomConfigurationStatus
.
NO_CONFIGURE
else
CustomConfigurationStatus
.
NO_CONFIGURE
),
),
system_configuration
=
SystemConfigurationResponse
(
system_configuration
=
SystemConfigurationResponse
(
**
provider_configuration
.
system_configuration
.
dict
()
enabled
=
provider_configuration
.
system_configuration
.
enabled
,
current_quota_type
=
provider_configuration
.
system_configuration
.
current_quota_type
,
quota_configurations
=
provider_configuration
.
system_configuration
.
quota_configurations
)
)
)
)
...
@@ -369,7 +381,15 @@ class ModelProviderService:
...
@@ -369,7 +381,15 @@ class ModelProviderService:
)
)
return
DefaultModelResponse
(
return
DefaultModelResponse
(
**
result
.
dict
()
model
=
result
.
model
,
model_type
=
result
.
model_type
,
provider
=
SimpleProviderEntityResponse
(
provider
=
result
.
provider
.
provider
,
label
=
result
.
provider
.
label
,
icon_small
=
result
.
provider
.
icon_small
,
icon_large
=
result
.
provider
.
icon_large
,
supported_model_types
=
result
.
provider
.
supported_model_types
)
)
if
result
else
None
)
if
result
else
None
def
update_default_model_of_model_type
(
self
,
tenant_id
:
str
,
model_type
:
str
,
provider
:
str
,
model
:
str
)
->
None
:
def
update_default_model_of_model_type
(
self
,
tenant_id
:
str
,
model_type
:
str
,
provider
:
str
,
model
:
str
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment