Unverified Commit 27e67848 authored by Yeuoly's avatar Yeuoly Committed by GitHub

Feat: AIPPT & DynamicToolParamter (#2725)

parent 70525653
......@@ -9,6 +9,7 @@
- azuredalle
- stablediffusion
- webscraper
- aippt
- youtube
- wolframalpha
- maths
......
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class AIPPTProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
identity:
author: Dify
name: aippt
label:
en_US: AIPPT
zh_Hans: AIPPT
description:
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
icon: icon.png
credentials_for_provider:
aippt_access_key:
type: secret-input
required: true
label:
en_US: AIPPT API key
zh_Hans: AIPPT API key
pt_BR: AIPPT API key
help:
en_US: Please input your AIPPT API key
zh_Hans: 请输入你的 AIPPT API key
pt_BR: Please input your AIPPT API key
placeholder:
en_US: Please input your AIPPT API key
zh_Hans: 请输入你的 AIPPT API key
pt_BR: Please input your AIPPT API key
url: https://www.aippt.cn
aippt_secret_key:
type: secret-input
required: true
label:
en_US: AIPPT Secret key
zh_Hans: AIPPT Secret key
pt_BR: AIPPT Secret key
help:
en_US: Please input your AIPPT Secret key
zh_Hans: 请输入你的 AIPPT Secret key
pt_BR: Please input your AIPPT Secret key
placeholder:
en_US: Please input your AIPPT Secret key
zh_Hans: 请输入你的 AIPPT Secret key
pt_BR: Please input your AIPPT Secret key
from base64 import b64encode
from hashlib import sha1
from hmac import new as hmac_new
from json import loads as json_loads
from threading import Lock
from time import sleep, time
from typing import Any
from httpx import get, post
from requests import get as requests_get
from yarl import URL
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from core.tools.tool.builtin_tool import BuiltinTool
class AIPPTGenerateTool(BuiltinTool):
"""
A tool for generating a ppt
"""
_api_base_url = URL('https://co.aippt.cn/api')
_api_token_cache = {}
_api_token_cache_lock = Lock()
_task = {}
_task_type_map = {
'auto': 1,
'markdown': 7,
}
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Invokes the AIPPT generate tool with the given user ID and tool parameters.
Args:
user_id (str): The ID of the user invoking the tool.
tool_parameters (dict[str, Any]): The parameters for the tool
Returns:
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
"""
title = tool_parameters.get('title', '')
if not title:
return self.create_text_message('Please provide a title for the ppt')
model = tool_parameters.get('model', 'aippt')
if not model:
return self.create_text_message('Please provide a model for the ppt')
outline = tool_parameters.get('outline', '')
# create task
task_id = self._create_task(
type=self._task_type_map['auto' if not outline else 'markdown'],
title=title,
content=outline,
user_id=user_id
)
# get suit
color = tool_parameters.get('color')
style = tool_parameters.get('style')
if color == '__default__':
color_id = ''
else:
color_id = int(color.split('-')[1])
if style == '__default__':
style_id = ''
else:
style_id = int(style.split('-')[1])
suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
# generate outline
if not outline:
self._generate_outline(
task_id=task_id,
model=model,
user_id=user_id
)
# generate content
self._generate_content(
task_id=task_id,
model=model,
user_id=user_id
)
# generate ppt
_, ppt_url = self._generate_ppt(
task_id=task_id,
suit_id=suit_id,
user_id=user_id
)
return self.create_text_message('''the ppt has been created successfully,'''
f'''the ppt url is {ppt_url}'''
'''please give the ppt url to user and direct user to download it.''')
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
"""
Create a task
:param type: the task type
:param title: the task title
:param content: the task content
:return: the task ID
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = post(
str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'),
headers=headers,
files={
'type': ('', str(type)),
'title': ('', title),
'content': ('', content)
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to create task: {response.get("msg")}')
return response.get('data', {}).get('id')
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \
self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline'
api_url %= {'task_id': task_id}
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = requests_get(
url=api_url,
headers=headers,
stream=True,
timeout=(10, 60)
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
outline = ''
for chunk in response.iter_lines(delimiter=b'\n\n'):
if not chunk:
continue
event = ''
lines = chunk.decode('utf-8').split('\n')
for line in lines:
if line.startswith('event:'):
event = line[6:]
elif line.startswith('data:'):
data = line[5:]
if event == 'message':
try:
data = json_loads(data)
outline += data.get('content', '')
except Exception as e:
pass
elif event == 'close':
break
elif event == 'error' or event == 'filter':
raise Exception(f'Failed to generate outline: {data}')
return outline
def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \
self._api_base_url / 'ai' / 'chat' / 'wx' / 'content'
api_url %= {'task_id': task_id}
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = requests_get(
url=api_url,
headers=headers,
stream=True,
timeout=(10, 60)
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
if model == 'aippt':
content = ''
for chunk in response.iter_lines(delimiter=b'\n\n'):
if not chunk:
continue
event = ''
lines = chunk.decode('utf-8').split('\n')
for line in lines:
if line.startswith('event:'):
event = line[6:]
elif line.startswith('data:'):
data = line[5:]
if event == 'message':
try:
data = json_loads(data)
content += data.get('content', '')
except Exception as e:
pass
elif event == 'close':
break
elif event == 'error' or event == 'filter':
raise Exception(f'Failed to generate content: {data}')
return content
elif model == 'wenxin':
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to generate content: {response.get("msg")}')
return response.get('data', '')
return ''
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
"""
Generate a ppt
:param task_id: the task ID
:param suit_id: the suit ID
:return: the cover url of the ppt and the ppt url
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = post(
str(self._api_base_url / 'design' / 'v2' / 'save'),
headers=headers,
data={
'task_id': task_id,
'template_id': suit_id
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
id = response.get('data', {}).get('id')
cover_url = response.get('data', {}).get('cover_url')
response = post(
str(self._api_base_url / 'download' / 'export' / 'file'),
headers=headers,
data={
'id': id,
'format': 'ppt',
'files_to_zip': False,
'edit': True
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
export_code = response.get('data')
if not export_code:
raise Exception('Failed to generate ppt, the export code is empty')
current_iteration = 0
while current_iteration < 50:
# get ppt url
response = post(
str(self._api_base_url / 'download' / 'export' / 'file' / 'result'),
headers=headers,
data={
'task_key': export_code
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
if response.get('msg') == '导出中':
current_iteration += 1
sleep(2)
continue
ppt_url = response.get('data', [])
if len(ppt_url) == 0:
raise Exception('Failed to generate ppt, the ppt url is empty')
return cover_url, ppt_url[0]
raise Exception('Failed to generate ppt, the export is timeout')
@classmethod
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
"""
Get API token
:param credentials: the credentials
:return: the API token
"""
access_key = credentials['aippt_access_key']
secret_key = credentials['aippt_secret_key']
cache_key = f'{access_key}#@#{user_id}'
with cls._api_token_cache_lock:
# clear expired tokens
now = time()
for key in list(cls._api_token_cache.keys()):
if cls._api_token_cache[key]['expire'] < now:
del cls._api_token_cache[key]
if cache_key in cls._api_token_cache:
return cls._api_token_cache[cache_key]['token']
# get token
headers = {
'x-api-key': access_key,
'x-timestamp': str(int(now)),
'x-signature': cls._calculate_sign(access_key, secret_key, int(now))
}
param = {
'uid': user_id,
'channel': ''
}
response = get(
str(cls._api_base_url / 'grant' / 'token'),
params=param,
headers=headers
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
token = response.get('data', {}).get('token')
expire = response.get('data', {}).get('time_expire')
with cls._api_token_cache_lock:
cls._api_token_cache[cache_key] = {
'token': token,
'expire': now + expire
}
return token
@classmethod
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
return b64encode(
hmac_new(
key=secret_key.encode('utf-8'),
msg=f'GET@/api/grant/token/@{timestamp}'.encode(),
digestmod=sha1
).digest()
).decode('utf-8')
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
"""
Get styles
:param credentials: the credentials
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id)
}
response = get(
str(self._api_base_url / 'template_component' / 'suit' / 'select'),
headers=headers
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
colors = [{
'id': f'id-{item.get("id")}',
'name': item.get('name'),
'en_name': item.get('en_name', item.get('name')),
} for item in response.get('data', {}).get('colour') or []]
styles = [{
'id': f'id-{item.get("id")}',
'name': item.get('title'),
} for item in response.get('data', {}).get('suit_style') or []]
return colors, styles
def _get_suit(self, style_id: int, colour_id: int) -> int:
"""
Get suit
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__')
}
response = get(
str(self._api_base_url / 'template_component' / 'suit' / 'search'),
headers=headers,
params={
'style_id': style_id,
'colour_id': colour_id,
'page': 1,
'page_size': 1
}
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
response = response.json()
if response.get('code') != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
if len(response.get('data', {}).get('list') or []) > 0:
return response.get('data', {}).get('list')[0].get('id')
raise Exception('Failed to get suit, the suit does not exist, please check the style and color')
def get_runtime_parameters(self) -> list[ToolParameter]:
"""
Get runtime parameters
Override this method to add runtime parameters to the tool.
"""
try:
colors, styles = self.get_styles(user_id='__dify_system__')
except Exception as e:
colors, styles = [
{'id': -1, 'name': '__default__'}
], [
{'id': -1, 'name': '__default__'}
]
return [
ToolParameter(
name='color',
label=I18nObject(zh_Hans='颜色', en_US='Color'),
human_description=I18nObject(zh_Hans='颜色', en_US='Color'),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=False,
default=colors[0]['id'],
options=[
ToolParameterOption(
value=color['id'],
label=I18nObject(zh_Hans=color['name'], en_US=color['en_name'])
) for color in colors
]
),
ToolParameter(
name='style',
label=I18nObject(zh_Hans='风格', en_US='Style'),
human_description=I18nObject(zh_Hans='风格', en_US='Style'),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=False,
default=styles[0]['id'],
options=[
ToolParameterOption(
value=style['id'],
label=I18nObject(zh_Hans=style['name'], en_US=style['name'])
) for style in styles
]
),
]
\ No newline at end of file
identity:
name: aippt
author: Dify
label:
en_US: AIPPT
zh_Hans: AIPPT
description:
human:
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you.
parameters:
- name: title
type: string
required: true
label:
en_US: Title
zh_Hans: 标题
human_description:
en_US: The title of the PPT.
zh_Hans: PPT的标题。
llm_description: The title of the PPT, which will be used to generate the PPT outline.
form: llm
- name: outline
type: string
required: false
label:
en_US: Outline
zh_Hans: 大纲
human_description:
en_US: The outline of the PPT
zh_Hans: PPT的大纲
llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have.
form: llm
- name: llm
type: select
required: true
label:
en_US: LLM model
zh_Hans: 生成大纲的LLM
options:
- value: aippt
label:
en_US: AIPPT default model
zh_Hans: AIPPT默认模型
- value: wenxin
label:
en_US: Wenxin ErnieBot
zh_Hans: 文心一言
default: aippt
human_description:
en_US: The LLM model used for generating PPT outline.
zh_Hans: 用于生成PPT大纲的LLM模型。
form: form
......@@ -2,11 +2,11 @@ import io
import json
from base64 import b64decode, b64encode
from copy import deepcopy
from os.path import join
from typing import Any, Union
from httpx import get, post
from PIL import Image
from yarl import URL
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
......@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
# set model
try:
url = join(base_url, 'sdapi/v1/options')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
response = post(url, data=json.dumps({
'sd_model_checkpoint': model
}))
......@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool):
if not model:
raise ToolProviderCredentialValidationError('Please input model')
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
if response.status_code != 200:
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
if response.status_code == 404:
# try draw a picture
self._invoke(
user_id='test',
tool_parameters={
'prompt': 'a cat',
'width': 1024,
'height': 1024,
'steps': 1,
'lora': '',
}
)
elif response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to get models')
else:
models = [d['model_name'] for d in response.json()]
......@@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e:
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
def get_sd_models(self) -> list[str]:
"""
get sd models
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
if response.status_code != 200:
return []
else:
return [d['model_name'] for d in response.json()]
except Exception as e:
return []
def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \
......@@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['prompt'] = prompt
try:
url = join(base_url, 'sdapi/v1/img2img')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
......@@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['negative_prompt'] = negative_prompt
try:
url = join(base_url, 'sdapi/v1/txt2img')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
......@@ -270,4 +300,28 @@ class StableDiffusionTool(BuiltinTool):
) for i in self.list_default_image_variables()])
)
if self.runtime.credentials:
try:
models = self.get_sd_models()
if len(models) != 0:
parameters.append(
ToolParameter(name='model',
label=I18nObject(en_US='Model', zh_Hans='Model'),
human_description=I18nObject(
en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
required=True,
default=models[0],
options=[ToolParameterOption(
value=i,
label=I18nObject(en_US=i, zh_Hans=i)
) for i in models])
)
except:
pass
return parameters
......@@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ApiProviderSchemaType,
ToolCredentialsOption,
ToolParameter,
ToolProviderCredentials,
)
from core.tools.entities.user_entities import UserTool, UserToolProvider
......@@ -73,15 +74,52 @@ class ToolManageService:
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools()
result = [
UserTool(
tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
result = []
for tool in tools:
# fork tool runtime
tool = tool.fork_tool_runtime(meta={
'credentials': credentials,
'tenant_id': tenant_id,
})
# get tool parameters
parameters = tool.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
# override parameters
current_parameters = parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
user_tool = UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
parameters=tool.parameters or []
) for tool in tools
]
parameters=current_parameters
)
result.append(user_tool)
return json.loads(
serialize_base_model_array(result)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment