Unverified Commit 14493697 authored by Yeuoly's avatar Yeuoly

feat: replace string join with yarl

parent 72e936f7
...@@ -2,7 +2,7 @@ import io ...@@ -2,7 +2,7 @@ import io
import json import json
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from copy import deepcopy from copy import deepcopy
from os.path import join from yarl import URL
from typing import Any, Union from typing import Any, Union
from httpx import get, post from httpx import get, post
...@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
# set model # set model
try: try:
url = join(base_url, 'sdapi/v1/options') url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
response = post(url, data=json.dumps({ response = post(url, data=json.dumps({
'sd_model_checkpoint': model 'sd_model_checkpoint': model
})) }))
...@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool): ...@@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool):
if not model: if not model:
raise ToolProviderCredentialValidationError('Please input model') raise ToolProviderCredentialValidationError('Please input model')
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=10) api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
if response.status_code != 200: response = get(url=api_url, timeout=10)
if response.status_code == 404:
# try draw a picture
self._invoke(
user_id='test',
tool_parameters={
'prompt': 'a cat',
'width': 1024,
'height': 1024,
'steps': 1,
'lora': '',
}
)
elif response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to get models') raise ToolProviderCredentialValidationError('Failed to get models')
else: else:
models = [d['model_name'] for d in response.json()] models = [d['model_name'] for d in response.json()]
...@@ -173,7 +186,8 @@ class StableDiffusionTool(BuiltinTool): ...@@ -173,7 +186,8 @@ class StableDiffusionTool(BuiltinTool):
base_url = self.runtime.credentials.get('base_url', None) base_url = self.runtime.credentials.get('base_url', None)
if not base_url: if not base_url:
return [] return []
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=10) api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
response = get(url=api_url, timeout=10)
if response.status_code != 200: if response.status_code != 200:
return [] return []
else: else:
...@@ -208,7 +222,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -208,7 +222,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['prompt'] = prompt draw_options['prompt'] = prompt
try: try:
url = join(base_url, 'sdapi/v1/img2img') url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
response = post(url, data=json.dumps(draw_options), timeout=120) response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200: if response.status_code != 200:
return self.create_text_message('Failed to generate image') return self.create_text_message('Failed to generate image')
...@@ -241,7 +255,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -241,7 +255,7 @@ class StableDiffusionTool(BuiltinTool):
draw_options['negative_prompt'] = negative_prompt draw_options['negative_prompt'] = negative_prompt
try: try:
url = join(base_url, 'sdapi/v1/txt2img') url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
response = post(url, data=json.dumps(draw_options), timeout=120) response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200: if response.status_code != 200:
return self.create_text_message('Failed to generate image') return self.create_text_message('Failed to generate image')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment