Unverified Commit 6da5e541 authored by John Wang's avatar John Wang Committed by GitHub

Feat/open azure validate (#163)

parent 1c5f63de
......@@ -157,7 +157,7 @@ class ProviderTokenValidateApi(Resource):
args = parser.parse_args()
# todo: remove this when the provider is supported
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
......
......@@ -78,7 +78,7 @@ class AzureProvider(BaseProvider):
def get_token_type(self):
# TODO: change to dict when implemented
return lambda value: value
return dict
def config_validate(self, config: Union[dict | str]):
"""
......@@ -91,16 +91,34 @@ class AzureProvider(BaseProvider):
if 'openai_api_version' not in config:
config['openai_api_version'] = '2023-03-15-preview'
self.get_models(credentials=config)
models = self.get_models(credentials=config)
if not models:
raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
"'gpt-3.5-turbo', 'text-embedding-ada-002'.")
fixed_model_ids = [
'text-davinci-003',
'gpt-35-turbo',
'text-embedding-ada-002'
]
current_model_ids = [model['id'] for model in models]
missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
fixed_model_id not in current_model_ids]
if missing_model_ids:
raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
except AzureAuthenticationError:
raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Key.')
except requests.ConnectionError:
raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Base Endpoint.')
raise ValidateFailedError('Validation failed, please check your API Key.')
except (requests.ConnectionError, requests.RequestException):
raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
except AzureRequestFailedError as ex:
raise ValidateFailedError('Azure OpenAI Credentials validation failed, error: {}.'.format(str(ex)))
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
except Exception as ex:
logging.exception('Azure OpenAI Credentials validation failed')
raise ex
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
def get_encrypted_token(self, config: Union[dict | str]):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment