Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
0debd75b
Unverified
Commit
0debd75b
authored
Jan 30, 2024
by
Yeuoly
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update AzureDALL-E and ZhipuaiVision providers
parent
ee0e2dc1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
96 additions
and
17 deletions
+96
-17
_model_providers.yaml
api/core/tools/provider/_model_providers.yaml
+4
-4
azuredalle.yaml
api/core/tools/provider/builtin/azuredalle/azuredalle.yaml
+3
-3
tool_manager.py
api/core/tools/tool_manager.py
+89
-10
No files found.
api/core/tools/provider/_model_providers.yaml
View file @
0debd75b
...
@@ -10,8 +10,8 @@ providers:
...
@@ -10,8 +10,8 @@ providers:
zh_Hans
:
Gemini Pro 图像理解
zh_Hans
:
Gemini Pro 图像理解
-
provider
:
zhipuai
-
provider
:
zhipuai
alias
:
alias
:
en_US
:
Zhipuai
Model
en_US
:
Zhipuai
Vision
zh_Hans
:
智谱AI
模型能力
zh_Hans
:
智谱AI
图像理解
models
:
models
:
-
name
:
glm-4v
-
name
:
glm-4v
alias
:
alias
:
...
@@ -19,8 +19,8 @@ providers:
...
@@ -19,8 +19,8 @@ providers:
zh_Hans
:
GLM-4V 图像理解
zh_Hans
:
GLM-4V 图像理解
-
provider
:
openai
-
provider
:
openai
alias
:
alias
:
en_US
:
OpenAI
Model
en_US
:
OpenAI
Vision
zh_Hans
:
OpenAI
模型能力
zh_Hans
:
OpenAI
图像理解
models
:
models
:
-
name
:
gpt-4-vision-preview
-
name
:
gpt-4-vision-preview
alias
:
alias
:
...
...
api/core/tools/provider/builtin/azuredalle/azuredalle.yaml
View file @
0debd75b
...
@@ -2,9 +2,9 @@ identity:
...
@@ -2,9 +2,9 @@ identity:
author
:
Leslie
author
:
Leslie
name
:
azuredalle
name
:
azuredalle
label
:
label
:
en_US
:
A
ZURE
DALL-E
en_US
:
A
zure
DALL-E
zh_Hans
:
A
ZURE
DALL-E 绘画
zh_Hans
:
A
zure
DALL-E 绘画
pt_BR
:
A
ZURE
DALL-E
pt_BR
:
A
zure
DALL-E
description
:
description
:
en_US
:
AZURE DALL-E art
en_US
:
AZURE DALL-E art
zh_Hans
:
AZURE DALL-E 绘画
zh_Hans
:
AZURE DALL-E 绘画
...
...
api/core/tools/tool_manager.py
View file @
0debd75b
...
@@ -3,6 +3,7 @@ from os import listdir, path
...
@@ -3,6 +3,7 @@ from os import listdir, path
from
core.tools.entities.tool_entities
import
ToolInvokeMessage
,
ApiProviderAuthType
,
ToolProviderCredentials
from
core.tools.entities.tool_entities
import
ToolInvokeMessage
,
ApiProviderAuthType
,
ToolProviderCredentials
from
core.tools.provider.tool_provider
import
ToolProviderController
from
core.tools.provider.tool_provider
import
ToolProviderController
from
core.tools.provider.model_tool_provider
import
ModelToolProviderController
from
core.tools.tool.builtin_tool
import
BuiltinTool
from
core.tools.tool.builtin_tool
import
BuiltinTool
from
core.tools.tool.api_tool
import
ApiTool
from
core.tools.tool.api_tool
import
ApiTool
from
core.tools.provider.builtin_tool_provider
import
BuiltinToolProviderController
from
core.tools.provider.builtin_tool_provider
import
BuiltinToolProviderController
...
@@ -15,6 +16,7 @@ from core.tools.entities.user_entities import UserToolProvider
...
@@ -15,6 +16,7 @@ from core.tools.entities.user_entities import UserToolProvider
from
core.tools.utils.configration
import
ToolConfiguration
from
core.tools.utils.configration
import
ToolConfiguration
from
core.tools.utils.encoder
import
serialize_base_model_dict
from
core.tools.utils.encoder
import
serialize_base_model_dict
from
core.tools.provider.builtin._positions
import
BuiltinToolProviderSort
from
core.tools.provider.builtin._positions
import
BuiltinToolProviderSort
from
core.provider_manager
import
ProviderManager
from
core.model_runtime.entities.message_entities
import
PromptMessage
from
core.model_runtime.entities.message_entities
import
PromptMessage
from
core.callback_handler.agent_tool_callback_handler
import
DifyAgentCallbackHandler
from
core.callback_handler.agent_tool_callback_handler
import
DifyAgentCallbackHandler
...
@@ -139,7 +141,7 @@ class ToolManager:
...
@@ -139,7 +141,7 @@ class ToolManager:
raise
ToolProviderNotFoundError
(
f
'provider type {provider_type} not found'
)
raise
ToolProviderNotFoundError
(
f
'provider type {provider_type} not found'
)
@
staticmethod
@
staticmethod
def
get_tool_runtime
(
provider_type
:
str
,
provider_name
:
str
,
tool_name
:
str
,
t
ane
nt_id
,
def
get_tool_runtime
(
provider_type
:
str
,
provider_name
:
str
,
tool_name
:
str
,
t
ena
nt_id
,
agent_callback
:
DifyAgentCallbackHandler
=
None
)
\
agent_callback
:
DifyAgentCallbackHandler
=
None
)
\
->
Union
[
BuiltinTool
,
ApiTool
]:
->
Union
[
BuiltinTool
,
ApiTool
]:
"""
"""
...
@@ -158,13 +160,13 @@ class ToolManager:
...
@@ -158,13 +160,13 @@ class ToolManager:
provider_controller
=
ToolManager
.
get_builtin_provider
(
provider_name
)
provider_controller
=
ToolManager
.
get_builtin_provider
(
provider_name
)
if
not
provider_controller
.
need_credentials
:
if
not
provider_controller
.
need_credentials
:
return
builtin_tool
.
fork_tool_runtime
(
meta
=
{
return
builtin_tool
.
fork_tool_runtime
(
meta
=
{
'tenant_id'
:
t
ane
nt_id
,
'tenant_id'
:
t
ena
nt_id
,
'credentials'
:
{},
'credentials'
:
{},
},
agent_callback
=
agent_callback
)
},
agent_callback
=
agent_callback
)
# get credentials
# get credentials
builtin_provider
:
BuiltinToolProvider
=
db
.
session
.
query
(
BuiltinToolProvider
)
.
filter
(
builtin_provider
:
BuiltinToolProvider
=
db
.
session
.
query
(
BuiltinToolProvider
)
.
filter
(
BuiltinToolProvider
.
tenant_id
==
t
ane
nt_id
,
BuiltinToolProvider
.
tenant_id
==
t
ena
nt_id
,
BuiltinToolProvider
.
provider
==
provider_name
,
BuiltinToolProvider
.
provider
==
provider_name
,
)
.
first
()
)
.
first
()
...
@@ -174,30 +176,43 @@ class ToolManager:
...
@@ -174,30 +176,43 @@ class ToolManager:
# decrypt the credentials
# decrypt the credentials
credentials
=
builtin_provider
.
credentials
credentials
=
builtin_provider
.
credentials
controller
=
ToolManager
.
get_builtin_provider
(
provider_name
)
controller
=
ToolManager
.
get_builtin_provider
(
provider_name
)
tool_configuration
=
ToolConfiguration
(
tenant_id
=
t
ane
nt_id
,
provider_controller
=
controller
)
tool_configuration
=
ToolConfiguration
(
tenant_id
=
t
ena
nt_id
,
provider_controller
=
controller
)
decrypted_credentials
=
tool_configuration
.
decrypt_tool_credentials
(
credentials
)
decrypted_credentials
=
tool_configuration
.
decrypt_tool_credentials
(
credentials
)
return
builtin_tool
.
fork_tool_runtime
(
meta
=
{
return
builtin_tool
.
fork_tool_runtime
(
meta
=
{
'tenant_id'
:
t
ane
nt_id
,
'tenant_id'
:
t
ena
nt_id
,
'credentials'
:
decrypted_credentials
,
'credentials'
:
decrypted_credentials
,
'runtime_parameters'
:
{}
'runtime_parameters'
:
{}
},
agent_callback
=
agent_callback
)
},
agent_callback
=
agent_callback
)
elif
provider_type
==
'api'
:
elif
provider_type
==
'api'
:
if
t
ane
nt_id
is
None
:
if
t
ena
nt_id
is
None
:
raise
ValueError
(
'tanent id is required for api provider'
)
raise
ValueError
(
'tanent id is required for api provider'
)
api_provider
,
credentials
=
ToolManager
.
get_api_provider_controller
(
t
ane
nt_id
,
provider_name
)
api_provider
,
credentials
=
ToolManager
.
get_api_provider_controller
(
t
ena
nt_id
,
provider_name
)
# decrypt the credentials
# decrypt the credentials
tool_configuration
=
ToolConfiguration
(
tenant_id
=
t
ane
nt_id
,
provider_controller
=
api_provider
)
tool_configuration
=
ToolConfiguration
(
tenant_id
=
t
ena
nt_id
,
provider_controller
=
api_provider
)
decrypted_credentials
=
tool_configuration
.
decrypt_tool_credentials
(
credentials
)
decrypted_credentials
=
tool_configuration
.
decrypt_tool_credentials
(
credentials
)
return
api_provider
.
get_tool
(
tool_name
)
.
fork_tool_runtime
(
meta
=
{
return
api_provider
.
get_tool
(
tool_name
)
.
fork_tool_runtime
(
meta
=
{
'tenant_id'
:
t
ane
nt_id
,
'tenant_id'
:
t
ena
nt_id
,
'credentials'
:
decrypted_credentials
,
'credentials'
:
decrypted_credentials
,
})
})
elif
provider_type
==
'model'
:
if
tenant_id
is
None
:
raise
ValueError
(
'tanent id is required for model provider'
)
# get model provider
model_provider
=
ToolManager
.
get_model_provider
(
tenant_id
,
provider_name
)
# get tool
model_tool
=
model_provider
.
get_tool
(
tool_name
)
return
model_tool
.
fork_tool_runtime
(
meta
=
{
'tenant_id'
:
tenant_id
,
'credentials'
:
model_tool
.
_model_instance
.
credentials
})
elif
provider_type
==
'app'
:
elif
provider_type
==
'app'
:
raise
NotImplementedError
(
'app provider not implemented'
)
raise
NotImplementedError
(
'app provider not implemented'
)
else
:
else
:
...
@@ -270,6 +285,44 @@ class ToolManager:
...
@@ -270,6 +285,44 @@ class ToolManager:
return
builtin_providers
return
builtin_providers
@
staticmethod
def
list_model_providers
(
tenant_id
:
str
=
None
)
->
List
[
ModelToolProviderController
]:
"""
list all the model providers
:return: the list of the model providers
"""
tenant_id
=
tenant_id
or
'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get configurations
provider_manager
=
ProviderManager
()
configurations
=
provider_manager
.
get_configurations
(
tenant_id
)
.
values
()
# get model providers
model_providers
:
List
[
ModelToolProviderController
]
=
[]
for
configuration
in
configurations
:
if
not
ModelToolProviderController
.
is_configuration_valid
(
configuration
):
continue
model_providers
.
append
(
ModelToolProviderController
.
from_db
(
configuration
))
return
model_providers
@
staticmethod
def
get_model_provider
(
tenant_id
:
str
,
provider_name
:
str
)
->
ModelToolProviderController
:
"""
get the model provider
:param provider_name: the name of the provider
:return: the provider
"""
# get configurations
provider_manager
=
ProviderManager
()
configurations
=
provider_manager
.
get_configurations
(
tenant_id
)
configuration
=
configurations
.
get
(
provider_name
)
if
configuration
is
None
:
raise
ToolProviderNotFoundError
(
f
'model provider {provider_name} not found'
)
return
ModelToolProviderController
.
from_db
(
configuration
)
@
staticmethod
@
staticmethod
def
get_tool_label
(
tool_name
:
str
)
->
Union
[
I18nObject
,
None
]:
def
get_tool_label
(
tool_name
:
str
)
->
Union
[
I18nObject
,
None
]:
"""
"""
...
@@ -336,6 +389,9 @@ class ToolManager:
...
@@ -336,6 +389,9 @@ class ToolManager:
# add provider into providers
# add provider into providers
credentials
=
db_builtin_provider
.
credentials
credentials
=
db_builtin_provider
.
credentials
provider_name
=
db_builtin_provider
.
provider
provider_name
=
db_builtin_provider
.
provider
if
provider_name
not
in
result_providers
:
continue
result_providers
[
provider_name
]
.
is_team_authorization
=
True
result_providers
[
provider_name
]
.
is_team_authorization
=
True
# package builtin tool provider controller
# package builtin tool provider controller
...
@@ -349,6 +405,28 @@ class ToolManager:
...
@@ -349,6 +405,28 @@ class ToolManager:
result_providers
[
provider_name
]
.
team_credentials
=
masked_credentials
result_providers
[
provider_name
]
.
team_credentials
=
masked_credentials
# get model tool providers
model_providers
=
ToolManager
.
list_model_providers
(
tenant_id
=
tenant_id
)
# append model providers
for
provider
in
model_providers
:
result_providers
[
f
'model_provider.{provider.identity.name}'
]
=
UserToolProvider
(
id
=
provider
.
identity
.
name
,
author
=
provider
.
identity
.
author
,
name
=
provider
.
identity
.
name
,
description
=
I18nObject
(
en_US
=
provider
.
identity
.
description
.
en_US
,
zh_Hans
=
provider
.
identity
.
description
.
zh_Hans
,
),
icon
=
provider
.
identity
.
icon
,
label
=
I18nObject
(
en_US
=
provider
.
identity
.
label
.
en_US
,
zh_Hans
=
provider
.
identity
.
label
.
zh_Hans
,
),
type
=
UserToolProvider
.
ProviderType
.
MODEL
,
team_credentials
=
{},
is_team_authorization
=
provider
.
is_active
,
)
# get db api providers
# get db api providers
db_api_providers
:
List
[
ApiToolProvider
]
=
db
.
session
.
query
(
ApiToolProvider
)
.
\
db_api_providers
:
List
[
ApiToolProvider
]
=
db
.
session
.
query
(
ApiToolProvider
)
.
\
filter
(
ApiToolProvider
.
tenant_id
==
tenant_id
)
.
all
()
filter
(
ApiToolProvider
.
tenant_id
==
tenant_id
)
.
all
()
...
@@ -468,4 +546,5 @@ class ToolManager:
...
@@ -468,4 +546,5 @@ class ToolManager:
'description'
:
provider
.
description
,
'description'
:
provider
.
description
,
'credentials'
:
masked_credentials
,
'credentials'
:
masked_credentials
,
'privacy_policy'
:
provider
.
privacy_policy
'privacy_policy'
:
provider
.
privacy_policy
}))
}))
\ No newline at end of file
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment