Unverified Commit 0a21da5b authored by Yeuoly's avatar Yeuoly

feat: dynamic tool parameters

parent 70992609
...@@ -164,6 +164,22 @@ class StableDiffusionTool(BuiltinTool): ...@@ -164,6 +164,22 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e: except Exception as e:
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
def get_sd_models(self) -> List[str]:
"""
get sd models
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
if response.status_code != 200:
return []
else:
return [d['model_name'] for d in response.json()]
except Exception as e:
return []
def img2img(self, base_url: str, lora: str, image_binary: bytes, def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str, prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \ width: int, height: int, steps: int) \
...@@ -269,4 +285,28 @@ class StableDiffusionTool(BuiltinTool): ...@@ -269,4 +285,28 @@ class StableDiffusionTool(BuiltinTool):
) for i in self.list_default_image_variables()]) ) for i in self.list_default_image_variables()])
) )
if self.runtime.credentials:
try:
models = self.get_sd_models()
if len(models) != 0:
parameters.append(
ToolParameter(name='model',
label=I18nObject(en_US='Model', zh_Hans='Model'),
human_description=I18nObject(
en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
required=True,
default=models[0],
options=[ToolParameterOption(
value=i,
label=I18nObject(en_US=i, zh_Hans=i)
) for i in models])
)
except:
pass
return parameters return parameters
...@@ -4,7 +4,7 @@ from typing import List, Tuple ...@@ -4,7 +4,7 @@ from typing import List, Tuple
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import (ApiProviderAuthType, ApiProviderSchemaType, ToolCredentialsOption, from core.tools.entities.tool_entities import (ApiProviderAuthType, ApiProviderSchemaType, ToolCredentialsOption,
ToolProviderCredentials) ToolProviderCredentials, ToolParameter)
from core.tools.entities.user_entities import UserTool, UserToolProvider from core.tools.entities.user_entities import UserTool, UserToolProvider
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
...@@ -69,15 +69,51 @@ class ToolManageService: ...@@ -69,15 +69,51 @@ class ToolManageService:
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools() tools = provider_controller.get_tools()
result = [ tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
UserTool( # check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
result = []
for tool in tools:
# fork tool runtime
tool = tool.fork_tool_runtime(meta={
'credentials': credentials,
'tenant_id': tenant_id,
})
# get tool parameters
parameters = tool.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
# override parameters
current_parameters = parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
user_tool = UserTool(
author=tool.identity.author, author=tool.identity.author,
name=tool.identity.name, name=tool.identity.name,
label=tool.identity.label, label=tool.identity.label,
description=tool.description.human, description=tool.description.human,
parameters=tool.parameters or [] parameters=current_parameters
) for tool in tools )
] result.append(user_tool)
return json.loads( return json.loads(
serialize_base_model_array(result) serialize_base_model_array(result)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment