Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
0a21da5b
Unverified
Commit
0a21da5b
authored
Feb 05, 2024
by
Yeuoly
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: dynamic tool parameters
parent
70992609
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
6 deletions
+82
-6
stable_diffusion.py
...rovider/builtin/stablediffusion/tools/stable_diffusion.py
+40
-0
tools_manage_service.py
api/services/tools_manage_service.py
+42
-6
No files found.
api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py
View file @
0a21da5b
...
...
@@ -164,6 +164,22 @@ class StableDiffusionTool(BuiltinTool):
except
Exception
as
e
:
raise
ToolProviderCredentialValidationError
(
f
'Failed to get models, {e}'
)
def
get_sd_models
(
self
)
->
List
[
str
]:
"""
get sd models
"""
try
:
base_url
=
self
.
runtime
.
credentials
.
get
(
'base_url'
,
None
)
if
not
base_url
:
return
[]
response
=
get
(
url
=
f
'{base_url}/sdapi/v1/sd-models'
,
timeout
=
120
)
if
response
.
status_code
!=
200
:
return
[]
else
:
return
[
d
[
'model_name'
]
for
d
in
response
.
json
()]
except
Exception
as
e
:
return
[]
def
img2img
(
self
,
base_url
:
str
,
lora
:
str
,
image_binary
:
bytes
,
prompt
:
str
,
negative_prompt
:
str
,
width
:
int
,
height
:
int
,
steps
:
int
)
\
...
...
@@ -268,5 +284,29 @@ class StableDiffusionTool(BuiltinTool):
label
=
I18nObject
(
en_US
=
i
.
name
,
zh_Hans
=
i
.
name
)
)
for
i
in
self
.
list_default_image_variables
()])
)
if
self
.
runtime
.
credentials
:
try
:
models
=
self
.
get_sd_models
()
if
len
(
models
)
!=
0
:
parameters
.
append
(
ToolParameter
(
name
=
'model'
,
label
=
I18nObject
(
en_US
=
'Model'
,
zh_Hans
=
'Model'
),
human_description
=
I18nObject
(
en_US
=
'Model of Stable Diffusion, you can check the official documentation of Stable Diffusion'
,
zh_Hans
=
'Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档'
,
),
type
=
ToolParameter
.
ToolParameterType
.
SELECT
,
form
=
ToolParameter
.
ToolParameterForm
.
FORM
,
llm_description
=
'Model of Stable Diffusion, you can check the official documentation of Stable Diffusion'
,
required
=
True
,
default
=
models
[
0
],
options
=
[
ToolParameterOption
(
value
=
i
,
label
=
I18nObject
(
en_US
=
i
,
zh_Hans
=
i
)
)
for
i
in
models
])
)
except
:
pass
return
parameters
api/services/tools_manage_service.py
View file @
0a21da5b
...
...
@@ -4,7 +4,7 @@ from typing import List, Tuple
from
core.tools.entities.common_entities
import
I18nObject
from
core.tools.entities.tool_bundle
import
ApiBasedToolBundle
from
core.tools.entities.tool_entities
import
(
ApiProviderAuthType
,
ApiProviderSchemaType
,
ToolCredentialsOption
,
ToolProviderCredentials
)
ToolProviderCredentials
,
ToolParameter
)
from
core.tools.entities.user_entities
import
UserTool
,
UserToolProvider
from
core.tools.errors
import
ToolNotFoundError
,
ToolProviderCredentialValidationError
,
ToolProviderNotFoundError
from
core.tools.provider.api_tool_provider
import
ApiBasedToolProviderController
...
...
@@ -69,15 +69,51 @@ class ToolManageService:
provider_controller
:
ToolProviderController
=
ToolManager
.
get_builtin_provider
(
provider
)
tools
=
provider_controller
.
get_tools
()
result
=
[
UserTool
(
tool_provider_configurations
=
ToolConfiguration
(
tenant_id
=
tenant_id
,
provider_controller
=
provider_controller
)
# check if user has added the provider
builtin_provider
:
BuiltinToolProvider
=
db
.
session
.
query
(
BuiltinToolProvider
)
.
filter
(
BuiltinToolProvider
.
tenant_id
==
tenant_id
,
BuiltinToolProvider
.
provider
==
provider
,
)
.
first
()
credentials
=
{}
if
builtin_provider
is
not
None
:
# get credentials
credentials
=
builtin_provider
.
credentials
credentials
=
tool_provider_configurations
.
decrypt_tool_credentials
(
credentials
)
result
=
[]
for
tool
in
tools
:
# fork tool runtime
tool
=
tool
.
fork_tool_runtime
(
meta
=
{
'credentials'
:
credentials
,
'tenant_id'
:
tenant_id
,
})
# get tool parameters
parameters
=
tool
.
parameters
or
[]
# get tool runtime parameters
runtime_parameters
=
tool
.
get_runtime_parameters
()
# override parameters
current_parameters
=
parameters
.
copy
()
for
runtime_parameter
in
runtime_parameters
:
found
=
False
for
index
,
parameter
in
enumerate
(
current_parameters
):
if
parameter
.
name
==
runtime_parameter
.
name
and
parameter
.
form
==
runtime_parameter
.
form
:
current_parameters
[
index
]
=
runtime_parameter
break
if
not
found
and
runtime_parameter
.
form
==
ToolParameter
.
ToolParameterForm
.
FORM
:
current_parameters
.
append
(
runtime_parameter
)
user_tool
=
UserTool
(
author
=
tool
.
identity
.
author
,
name
=
tool
.
identity
.
name
,
label
=
tool
.
identity
.
label
,
description
=
tool
.
description
.
human
,
parameters
=
tool
.
parameters
or
[]
)
for
tool
in
tools
]
parameters
=
current_parameters
)
result
.
append
(
user_tool
)
return
json
.
loads
(
serialize_base_model_array
(
result
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment