Unverified Commit f54eb351 authored by Yeuoly's avatar Yeuoly

Merge branch 'feat/support-dynamic-tool-parameters' into feat/ai-ppt-tool

parents f25cec26 dbcf03ef
......@@ -153,7 +153,7 @@ class StableDiffusionTool(BuiltinTool):
if not model:
raise ToolProviderCredentialValidationError('Please input model')
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=10)
if response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to get models')
else:
......@@ -165,6 +165,22 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e:
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
def get_sd_models(self) -> List[str]:
"""
get sd models
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=10)
if response.status_code != 200:
return []
else:
return [d['model_name'] for d in response.json()]
except Exception as e:
return []
def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \
......@@ -270,4 +286,28 @@ class StableDiffusionTool(BuiltinTool):
) for i in self.list_default_image_variables()])
)
if self.runtime.credentials:
try:
models = self.get_sd_models()
if len(models) != 0:
parameters.append(
ToolParameter(name='model',
label=I18nObject(en_US='Model', zh_Hans='Model'),
human_description=I18nObject(
en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
required=True,
default=models[0],
options=[ToolParameterOption(
value=i,
label=I18nObject(en_US=i, zh_Hans=i)
) for i in models])
)
except:
pass
return parameters
......@@ -10,6 +10,7 @@ from core.tools.entities.tool_entities import (
ApiProviderSchemaType,
ToolCredentialsOption,
ToolProviderCredentials,
ToolParameter
)
from core.tools.entities.user_entities import UserTool, UserToolProvider
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
......@@ -73,15 +74,52 @@ class ToolManageService:
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools()
result = [
UserTool(
tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
result = []
for tool in tools:
# fork tool runtime
tool = tool.fork_tool_runtime(meta={
'credentials': credentials,
'tenant_id': tenant_id,
})
# get tool parameters
parameters = tool.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
# override parameters
current_parameters = parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
user_tool = UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
parameters=tool.parameters or []
) for tool in tools
]
parameters=current_parameters
)
result.append(user_tool)
return json.loads(
serialize_base_model_array(result)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment