Unverified Commit 27e67848 authored by Yeuoly's avatar Yeuoly Committed by GitHub

Feat: AIPPT & DynamicToolParamter (#2725)

parent 70525653
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
- azuredalle - azuredalle
- stablediffusion - stablediffusion
- webscraper - webscraper
- aippt
- youtube - youtube
- wolframalpha - wolframalpha
- maths - maths
......
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class AIPPTProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
identity:
author: Dify
name: aippt
label:
en_US: AIPPT
zh_Hans: AIPPT
description:
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
icon: icon.png
credentials_for_provider:
aippt_access_key:
type: secret-input
required: true
label:
en_US: AIPPT API key
zh_Hans: AIPPT API key
pt_BR: AIPPT API key
help:
en_US: Please input your AIPPT API key
zh_Hans: 请输入你的 AIPPT API key
pt_BR: Please input your AIPPT API key
placeholder:
en_US: Please input your AIPPT API key
zh_Hans: 请输入你的 AIPPT API key
pt_BR: Please input your AIPPT API key
url: https://www.aippt.cn
aippt_secret_key:
type: secret-input
required: true
label:
en_US: AIPPT Secret key
zh_Hans: AIPPT Secret key
pt_BR: AIPPT Secret key
help:
en_US: Please input your AIPPT Secret key
zh_Hans: 请输入你的 AIPPT Secret key
pt_BR: Please input your AIPPT Secret key
placeholder:
en_US: Please input your AIPPT Secret key
zh_Hans: 请输入你的 AIPPT Secret key
pt_BR: Please input your AIPPT Secret key
This diff is collapsed.
identity:
name: aippt
author: Dify
label:
en_US: AIPPT
zh_Hans: AIPPT
description:
human:
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you.
parameters:
- name: title
type: string
required: true
label:
en_US: Title
zh_Hans: 标题
human_description:
en_US: The title of the PPT.
zh_Hans: PPT的标题。
llm_description: The title of the PPT, which will be used to generate the PPT outline.
form: llm
- name: outline
type: string
required: false
label:
en_US: Outline
zh_Hans: 大纲
human_description:
en_US: The outline of the PPT
zh_Hans: PPT的大纲
llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have.
form: llm
- name: llm
type: select
required: true
label:
en_US: LLM model
zh_Hans: 生成大纲的LLM
options:
- value: aippt
label:
en_US: AIPPT default model
zh_Hans: AIPPT默认模型
- value: wenxin
label:
en_US: Wenxin ErnieBot
zh_Hans: 文心一言
default: aippt
human_description:
en_US: The LLM model used for generating PPT outline.
zh_Hans: 用于生成PPT大纲的LLM模型。
form: form
...@@ -2,11 +2,11 @@ import io ...@@ -2,11 +2,11 @@ import io
import json import json
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from copy import deepcopy from copy import deepcopy
from os.path import join
from typing import Any, Union from typing import Any, Union
from httpx import get, post from httpx import get, post
from PIL import Image from PIL import Image
from yarl import URL
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
...@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
# set model # set model
try: try:
url = join(base_url, 'sdapi/v1/options') url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
response = post(url, data=json.dumps({ response = post(url, data=json.dumps({
'sd_model_checkpoint': model 'sd_model_checkpoint': model
})) }))
...@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool): ...@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool):
if not model: if not model:
raise ToolProviderCredentialValidationError('Please input model') raise ToolProviderCredentialValidationError('Please input model')
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120) api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
if response.status_code != 200: response = get(url=api_url, timeout=10)
if response.status_code == 404:
# try draw a picture
self._invoke(
user_id='test',
tool_parameters={
'prompt': 'a cat',
'width': 1024,
'height': 1024,
'steps': 1,
'lora': '',
}
)
elif response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to get models') raise ToolProviderCredentialValidationError('Failed to get models')
else: else:
models = [d['model_name'] for d in response.json()] models = [d['model_name'] for d in response.json()]
...@@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool): ...@@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e: except Exception as e:
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
def get_sd_models(self) -> list[str]:
"""
get sd models
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
if response.status_code != 200:
return []
else:
return [d['model_name'] for d in response.json()]
except Exception as e:
return []
def img2img(self, base_url: str, lora: str, image_binary: bytes, def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str, prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \ width: int, height: int, steps: int) \
...@@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['prompt'] = prompt draw_options['prompt'] = prompt
try: try:
url = join(base_url, 'sdapi/v1/img2img') url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
response = post(url, data=json.dumps(draw_options), timeout=120) response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200: if response.status_code != 200:
return self.create_text_message('Failed to generate image') return self.create_text_message('Failed to generate image')
...@@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['negative_prompt'] = negative_prompt draw_options['negative_prompt'] = negative_prompt
try: try:
url = join(base_url, 'sdapi/v1/txt2img') url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
response = post(url, data=json.dumps(draw_options), timeout=120) response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200: if response.status_code != 200:
return self.create_text_message('Failed to generate image') return self.create_text_message('Failed to generate image')
...@@ -270,4 +300,28 @@ class StableDiffusionTool(BuiltinTool): ...@@ -270,4 +300,28 @@ class StableDiffusionTool(BuiltinTool):
) for i in self.list_default_image_variables()]) ) for i in self.list_default_image_variables()])
) )
if self.runtime.credentials:
try:
models = self.get_sd_models()
if len(models) != 0:
parameters.append(
ToolParameter(name='model',
label=I18nObject(en_US='Model', zh_Hans='Model'),
human_description=I18nObject(
en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
required=True,
default=models[0],
options=[ToolParameterOption(
value=i,
label=I18nObject(en_US=i, zh_Hans=i)
) for i in models])
)
except:
pass
return parameters return parameters
...@@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import ( ...@@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ApiProviderSchemaType, ApiProviderSchemaType,
ToolCredentialsOption, ToolCredentialsOption,
ToolParameter,
ToolProviderCredentials, ToolProviderCredentials,
) )
from core.tools.entities.user_entities import UserTool, UserToolProvider from core.tools.entities.user_entities import UserTool, UserToolProvider
...@@ -73,15 +74,52 @@ class ToolManageService: ...@@ -73,15 +74,52 @@ class ToolManageService:
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools() tools = provider_controller.get_tools()
result = [ tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
UserTool( # check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
result = []
for tool in tools:
# fork tool runtime
tool = tool.fork_tool_runtime(meta={
'credentials': credentials,
'tenant_id': tenant_id,
})
# get tool parameters
parameters = tool.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
# override parameters
current_parameters = parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
user_tool = UserTool(
author=tool.identity.author, author=tool.identity.author,
name=tool.identity.name, name=tool.identity.name,
label=tool.identity.label, label=tool.identity.label,
description=tool.description.human, description=tool.description.human,
parameters=tool.parameters or [] parameters=current_parameters
) for tool in tools )
] result.append(user_tool)
return json.loads( return json.loads(
serialize_base_model_array(result) serialize_base_model_array(result)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment