Unverified Commit 95733796 authored by Yeuoly's avatar Yeuoly Committed by GitHub

fix: replace os.path.join with yarl (#2690)

parent 552f319b
from os import path
from threading import Lock from threading import Lock
from time import time from time import time
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout from requests.exceptions import ConnectionError, MissingSchema, Timeout
from requests.sessions import Session from requests.sessions import Session
from yarl import URL
class XinferenceModelExtraParameter: class XinferenceModelExtraParameter:
...@@ -55,7 +55,10 @@ class XinferenceHelper: ...@@ -55,7 +55,10 @@ class XinferenceHelper:
get xinference model extra parameter like model_format and model_handle_type get xinference model extra parameter like model_format and model_handle_type
""" """
url = path.join(server_url, 'v1/models', model_uid) if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
raise RuntimeError('model_uid is empty')
url = str(URL(server_url) / 'v1' / 'models' / model_uid)
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session() session = Session()
...@@ -66,7 +69,6 @@ class XinferenceHelper: ...@@ -66,7 +69,6 @@ class XinferenceHelper:
response = session.get(url, timeout=10) response = session.get(url, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e: except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200: if response.status_code != 200:
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
......
...@@ -68,4 +68,5 @@ pydub~=0.25.1 ...@@ -68,4 +68,5 @@ pydub~=0.25.1
gmpy2~=2.1.5 gmpy2~=2.1.5
numexpr~=2.9.0 numexpr~=2.9.0
duckduckgo-search==4.4.3 duckduckgo-search==4.4.3
arxiv==2.1.0 arxiv==2.1.0
\ No newline at end of file yarl~=1.9.4
\ No newline at end of file
...@@ -32,68 +32,70 @@ class MockXinferenceClass(object): ...@@ -32,68 +32,70 @@ class MockXinferenceClass(object):
response = Response() response = Response()
if 'v1/models/' in url: if 'v1/models/' in url:
# get model uid # get model uid
model_uid = url.split('/')[-1] model_uid = url.split('/')[-1] or ''
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']: model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404 response.status_code = 404
response._content = b'{}'
return response return response
# check if url is valid # check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404 response.status_code = 404
response._content = b'{}'
return response return response
if model_uid in ['generate', 'chat']: if model_uid in ['generate', 'chat']:
response.status_code = 200 response.status_code = 200
response._content = b'''{ response._content = b'''{
"model_type": "LLM", "model_type": "LLM",
"address": "127.0.0.1:43877", "address": "127.0.0.1:43877",
"accelerators": [ "accelerators": [
"0", "0",
"1" "1"
], ],
"model_name": "chatglm3-6b", "model_name": "chatglm3-6b",
"model_lang": [ "model_lang": [
"en" "en"
], ],
"model_ability": [ "model_ability": [
"generate", "generate",
"chat" "chat"
], ],
"model_description": "latest chatglm3", "model_description": "latest chatglm3",
"model_format": "pytorch", "model_format": "pytorch",
"model_size_in_billions": 7, "model_size_in_billions": 7,
"quantization": "none", "quantization": "none",
"model_hub": "huggingface", "model_hub": "huggingface",
"revision": null, "revision": null,
"context_length": 2048, "context_length": 2048,
"replica": 1 "replica": 1
}''' }'''
return response return response
elif model_uid == 'embedding': elif model_uid == 'embedding':
response.status_code = 200 response.status_code = 200
response._content = b'''{ response._content = b'''{
"model_type": "embedding", "model_type": "embedding",
"address": "127.0.0.1:43877", "address": "127.0.0.1:43877",
"accelerators": [ "accelerators": [
"0", "0",
"1" "1"
], ],
"model_name": "bge", "model_name": "bge",
"model_lang": [ "model_lang": [
"en" "en"
], ],
"revision": null, "revision": null,
"max_tokens": 512 "max_tokens": 512
}''' }'''
return response return response
elif 'v1/cluster/auth' in url: elif 'v1/cluster/auth' in url:
response.status_code = 200 response.status_code = 200
response._content = b'''{ response._content = b'''{
"auth": true "auth": true
}''' }'''
return response return response
def _check_cluster_authenticated(self): def _check_cluster_authenticated(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment