Unverified Commit cae15013 authored by John Wang's avatar John Wang Committed by GitHub

fix: azure openai deployment list was deprecated suddenly (#611)

parent 52c84da0
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
import logging import logging
from typing import Optional, Union from typing import Optional, Union
import openai
import requests import requests
from core.llm.provider.base import BaseProvider from core.llm.provider.base import BaseProvider
...@@ -11,30 +12,37 @@ from models.provider import ProviderName ...@@ -11,30 +12,37 @@ from models.provider import ProviderName
class AzureProvider(BaseProvider): class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]: def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
credentials = self.get_credentials(model_id) if not credentials else credentials return []
url = "{}/openai/deployments?api-version={}".format(
str(credentials.get('openai_api_base')), def check_embedding_model(self, credentials: Optional[dict] = None):
str(credentials.get('openai_api_version')) credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
) try:
result = openai.Embedding.create(input=['test'],
headers = { engine='text-embedding-ada-0021',
"api-key": str(credentials.get('openai_api_key')), timeout=60,
"content-type": "application/json; charset=utf-8" api_key=str(credentials.get('openai_api_key')),
} api_base=str(credentials.get('openai_api_base')),
api_type='azure',
response = requests.get(url, headers=headers) api_version=str(credentials.get('openai_api_version')))["data"][0][
"embedding"]
if response.status_code == 200: except openai.error.AuthenticationError as e:
result = response.json() raise AzureAuthenticationError(str(e))
return [{ except openai.error.APIConnectionError as e:
'id': deployment['id'], raise AzureRequestFailedError(
'name': '{} ({})'.format(deployment['id'], deployment['model']) 'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
} for deployment in result['data'] if deployment['status'] == 'succeeded'] except openai.error.InvalidRequestError as e:
else: if e.http_status == 404:
if response.status_code == 401: raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
raise AzureAuthenticationError() "deployment name is exists in Azure AI")
else: else:
raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code)) raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
except openai.error.OpenAIError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
if not isinstance(result, list):
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
def get_credentials(self, model_id: Optional[str] = None) -> dict: def get_credentials(self, model_id: Optional[str] = None) -> dict:
""" """
...@@ -94,31 +102,11 @@ class AzureProvider(BaseProvider): ...@@ -94,31 +102,11 @@ class AzureProvider(BaseProvider):
if 'openai_api_version' not in config: if 'openai_api_version' not in config:
config['openai_api_version'] = '2023-03-15-preview' config['openai_api_version'] = '2023-03-15-preview'
models = self.get_models(credentials=config) self.check_embedding_model(credentials=config)
if not models:
raise ValidateFailedError("Please add deployments for "
"'gpt-3.5-turbo', 'text-embedding-ada-002' (required) "
"and 'gpt-4', 'gpt-35-turbo-16k', 'text-davinci-003' (optional).")
fixed_model_ids = [
'gpt-35-turbo',
'text-embedding-ada-002'
]
current_model_ids = [model['id'] for model in models]
missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
fixed_model_id not in current_model_ids]
if missing_model_ids:
raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
except ValidateFailedError as e: except ValidateFailedError as e:
raise e raise e
except AzureAuthenticationError: except AzureAuthenticationError:
raise ValidateFailedError('Validation failed, please check your API Key.') raise ValidateFailedError('Validation failed, please check your API Key.')
except (requests.ConnectionError, requests.RequestException):
raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
except AzureRequestFailedError as ex: except AzureRequestFailedError as ex:
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
except Exception as ex: except Exception as ex:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment