Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
dd500e3b
Commit
dd500e3b
authored
Jul 15, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: providers list include system token
parent
05493c35
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
26 additions
and
30 deletions
+26
-30
providers.py
api/controllers/console/workspace/providers.py
+2
-1
anthropic_provider.py
api/core/llm/provider/anthropic_provider.py
+2
-2
azure_provider.py
api/core/llm/provider/azure_provider.py
+2
-2
base.py
api/core/llm/provider/base.py
+14
-19
llm_provider_service.py
api/core/llm/provider/llm_provider_service.py
+4
-4
provider_service.py
api/services/provider_service.py
+2
-2
No files found.
api/controllers/console/workspace/providers.py
View file @
dd500e3b
...
...
@@ -51,7 +51,8 @@ class ProviderListApi(Resource):
'quota_used'
:
p
.
quota_used
}
if
p
.
provider_type
==
ProviderType
.
SYSTEM
.
value
else
{}),
'token'
:
ProviderService
.
get_obfuscated_api_key
(
current_user
.
current_tenant
,
ProviderName
(
p
.
provider_name
))
ProviderName
(
p
.
provider_name
),
only_custom
=
True
)
if
p
.
provider_type
==
ProviderType
.
CUSTOM
.
value
else
None
}
for
p
in
providers
]
...
...
api/core/llm/provider/anthropic_provider.py
View file @
dd500e3b
...
...
@@ -32,12 +32,12 @@ class AnthropicProvider(BaseProvider):
def
get_provider_name
(
self
):
return
ProviderName
.
ANTHROPIC
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
)
->
Union
[
str
|
dict
]:
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
,
only_custom
:
bool
=
False
)
->
Union
[
str
|
dict
]:
"""
Returns the provider configs.
"""
try
:
config
=
self
.
get_provider_api_key
()
config
=
self
.
get_provider_api_key
(
only_custom
=
only_custom
)
except
:
config
=
{
'anthropic_api_key'
:
''
...
...
api/core/llm/provider/azure_provider.py
View file @
dd500e3b
...
...
@@ -52,12 +52,12 @@ class AzureProvider(BaseProvider):
def
get_provider_name
(
self
):
return
ProviderName
.
AZURE_OPENAI
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
)
->
Union
[
str
|
dict
]:
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
,
only_custom
:
bool
=
False
)
->
Union
[
str
|
dict
]:
"""
Returns the provider configs.
"""
try
:
config
=
self
.
get_provider_api_key
()
config
=
self
.
get_provider_api_key
(
only_custom
=
only_custom
)
except
:
config
=
{
'openai_api_type'
:
'azure'
,
...
...
api/core/llm/provider/base.py
View file @
dd500e3b
...
...
@@ -14,13 +14,13 @@ class BaseProvider(ABC):
def
__init__
(
self
,
tenant_id
:
str
):
self
.
tenant_id
=
tenant_id
def
get_provider_api_key
(
self
,
model_id
:
Optional
[
str
]
=
None
,
prefer_custom
:
bool
=
Tru
e
)
->
Union
[
str
|
dict
]:
def
get_provider_api_key
(
self
,
model_id
:
Optional
[
str
]
=
None
,
only_custom
:
bool
=
Fals
e
)
->
Union
[
str
|
dict
]:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
"""
provider
=
self
.
get_provider
(
prefer
_custom
)
provider
=
self
.
get_provider
(
only
_custom
)
if
not
provider
:
raise
ProviderTokenNotInitError
(
f
"No valid {llm_constant.models[model_id]} model provider credentials found. "
...
...
@@ -41,19 +41,19 @@ class BaseProvider(ABC):
else
:
return
self
.
get_decrypted_token
(
provider
.
encrypted_config
)
def
get_provider
(
self
,
prefer_custom
:
bool
)
->
Optional
[
Provider
]:
def
get_provider
(
self
,
only_custom
:
bool
=
False
)
->
Optional
[
Provider
]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
return
BaseProvider
.
get_valid_provider
(
self
.
tenant_id
,
self
.
get_provider_name
()
.
value
,
prefer
_custom
)
return
BaseProvider
.
get_valid_provider
(
self
.
tenant_id
,
self
.
get_provider_name
()
.
value
,
only
_custom
)
@
classmethod
def
get_valid_provider
(
cls
,
tenant_id
:
str
,
provider_name
:
str
=
None
,
prefer
_custom
:
bool
=
False
)
->
Optional
[
def
get_valid_provider
(
cls
,
tenant_id
:
str
,
provider_name
:
str
=
None
,
only
_custom
:
bool
=
False
)
->
Optional
[
Provider
]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist
, the preferred provider will be returned based on the prefer_custom flag
.
If both CUSTOM and System providers exist.
"""
query
=
db
.
session
.
query
(
Provider
)
.
filter
(
Provider
.
tenant_id
==
tenant_id
...
...
@@ -62,23 +62,18 @@ class BaseProvider(ABC):
if
provider_name
:
query
=
query
.
filter
(
Provider
.
provider_name
==
provider_name
)
providers
=
query
.
order_by
(
Provider
.
provider_type
.
desc
()
if
prefer_custom
else
Provider
.
provider_type
)
.
all
()
if
only_custom
:
query
=
query
.
filter
(
Provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
)
custom_provider
=
None
system_provider
=
None
providers
=
query
.
order_by
(
Provider
.
provider_type
.
asc
())
.
all
()
for
provider
in
providers
:
if
provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
and
provider
.
is_valid
and
provider
.
encrypted_config
:
custom_provider
=
provider
return
provider
elif
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
and
provider
.
is_valid
:
system_provider
=
provider
return
provider
if
custom_provider
:
return
custom_provider
elif
system_provider
:
return
system_provider
else
:
return
None
return
None
def
get_hosted_credentials
(
self
)
->
Union
[
str
|
dict
]:
raise
ProviderTokenNotInitError
(
...
...
@@ -86,12 +81,12 @@ class BaseProvider(ABC):
f
"Please go to Settings -> Model Provider to complete your provider credentials."
)
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
)
->
Union
[
str
|
dict
]:
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
,
only_custom
:
bool
=
False
)
->
Union
[
str
|
dict
]:
"""
Returns the provider configs.
"""
try
:
config
=
self
.
get_provider_api_key
()
config
=
self
.
get_provider_api_key
(
only_custom
=
only_custom
)
except
:
config
=
''
...
...
api/core/llm/provider/llm_provider_service.py
View file @
dd500e3b
...
...
@@ -31,11 +31,11 @@ class LLMProviderService:
def
get_credentials
(
self
,
model_id
:
Optional
[
str
]
=
None
)
->
dict
:
return
self
.
provider
.
get_credentials
(
model_id
)
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
)
->
Union
[
str
|
dict
]:
return
self
.
provider
.
get_provider_configs
(
obfuscated
)
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
,
only_custom
:
bool
=
False
)
->
Union
[
str
|
dict
]:
return
self
.
provider
.
get_provider_configs
(
obfuscated
=
obfuscated
,
only_custom
=
only_custom
)
def
get_provider_db_record
(
self
,
prefer_custom
:
bool
=
False
)
->
Optional
[
Provider
]:
return
self
.
provider
.
get_provider
(
prefer_custom
)
def
get_provider_db_record
(
self
)
->
Optional
[
Provider
]:
return
self
.
provider
.
get_provider
()
def
config_validate
(
self
,
config
:
Union
[
dict
|
str
]):
"""
...
...
api/services/provider_service.py
View file @
dd500e3b
...
...
@@ -41,9 +41,9 @@ class ProviderService:
db
.
session
.
commit
()
@
staticmethod
def
get_obfuscated_api_key
(
tenant
,
provider_name
:
ProviderName
):
def
get_obfuscated_api_key
(
tenant
,
provider_name
:
ProviderName
,
only_custom
:
bool
=
False
):
llm_provider_service
=
LLMProviderService
(
tenant
.
id
,
provider_name
.
value
)
return
llm_provider_service
.
get_provider_configs
(
obfuscated
=
True
)
return
llm_provider_service
.
get_provider_configs
(
obfuscated
=
True
,
only_custom
=
only_custom
)
@
staticmethod
def
get_token_type
(
tenant
,
provider_name
:
ProviderName
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment