Unverified Commit 6da5e541 authored by John Wang's avatar John Wang Committed by GitHub

Feat/open azure validate (#163)

parent 1c5f63de
...@@ -157,7 +157,7 @@ class ProviderTokenValidateApi(Resource): ...@@ -157,7 +157,7 @@ class ProviderTokenValidateApi(Resource):
args = parser.parse_args() args = parser.parse_args()
# todo: remove this when the provider is supported # todo: remove this when the provider is supported
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]: ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
......
...@@ -78,7 +78,7 @@ class AzureProvider(BaseProvider): ...@@ -78,7 +78,7 @@ class AzureProvider(BaseProvider):
def get_token_type(self): def get_token_type(self):
# TODO: change to dict when implemented # TODO: change to dict when implemented
return lambda value: value return dict
def config_validate(self, config: Union[dict | str]): def config_validate(self, config: Union[dict | str]):
""" """
...@@ -91,16 +91,34 @@ class AzureProvider(BaseProvider): ...@@ -91,16 +91,34 @@ class AzureProvider(BaseProvider):
if 'openai_api_version' not in config: if 'openai_api_version' not in config:
config['openai_api_version'] = '2023-03-15-preview' config['openai_api_version'] = '2023-03-15-preview'
self.get_models(credentials=config) models = self.get_models(credentials=config)
if not models:
raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
"'gpt-3.5-turbo', 'text-embedding-ada-002'.")
fixed_model_ids = [
'text-davinci-003',
'gpt-35-turbo',
'text-embedding-ada-002'
]
current_model_ids = [model['id'] for model in models]
missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
fixed_model_id not in current_model_ids]
if missing_model_ids:
raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
except AzureAuthenticationError: except AzureAuthenticationError:
raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Key.') raise ValidateFailedError('Validation failed, please check your API Key.')
except requests.ConnectionError: except (requests.ConnectionError, requests.RequestException):
raise ValidateFailedError('Azure OpenAI Credentials validation failed, please check your API Base Endpoint.') raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
except AzureRequestFailedError as ex: except AzureRequestFailedError as ex:
raise ValidateFailedError('Azure OpenAI Credentials validation failed, error: {}.'.format(str(ex))) raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
except Exception as ex: except Exception as ex:
logging.exception('Azure OpenAI Credentials validation failed') logging.exception('Azure OpenAI Credentials validation failed')
raise ex raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
def get_encrypted_token(self, config: Union[dict | str]): def get_encrypted_token(self, config: Union[dict | str]):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment