Unverified Commit 14493697 authored by Yeuoly's avatar Yeuoly

feat: replace string join with yarl

parent 72e936f7
......@@ -2,7 +2,7 @@ import io
import json
from base64 import b64decode, b64encode
from copy import deepcopy
from os.path import join
from yarl import URL
from typing import Any, Union
from httpx import get, post
......@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
# set model
try:
url = join(base_url, 'sdapi/v1/options')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
response = post(url, data=json.dumps({
'sd_model_checkpoint': model
}))
......@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool):
if not model:
raise ToolProviderCredentialValidationError('Please input model')
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=10)
if response.status_code != 200:
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
if response.status_code == 404:
# try draw a picture
self._invoke(
user_id='test',
tool_parameters={
'prompt': 'a cat',
'width': 1024,
'height': 1024,
'steps': 1,
'lora': '',
}
)
elif response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to get models')
else:
models = [d['model_name'] for d in response.json()]
......@@ -173,7 +186,8 @@ class StableDiffusionTool(BuiltinTool):
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=10)
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
if response.status_code != 200:
return []
else:
......@@ -208,7 +222,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['prompt'] = prompt
try:
url = join(base_url, 'sdapi/v1/img2img')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
......@@ -241,7 +255,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['negative_prompt'] = negative_prompt
try:
url = join(base_url, 'sdapi/v1/txt2img')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment