Unverified Commit f68b05d5 authored by John Wang's avatar John Wang Committed by GitHub

Feat: support azure openai for temporary (#101)

parent 3b3c604e
......@@ -47,6 +47,7 @@ DEFAULTS = {
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai'
}
......@@ -181,6 +182,10 @@ class Config:
# You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
class CloudEditionConfig(Config):
def __init__(self):
......
......@@ -82,29 +82,33 @@ class ProviderTokenApi(Resource):
args = parser.parse_args()
if not args['token']:
raise ValueError('Token is empty')
try:
ProviderService.validate_provider_configs(
if args['token']:
try:
ProviderService.validate_provider_configs(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
token_is_valid = True
except ValidateFailedError:
token_is_valid = False
base64_encrypted_token = ProviderService.get_encrypted_token(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
token_is_valid = True
except ValidateFailedError:
else:
base64_encrypted_token = None
token_is_valid = False
tenant = current_user.current_tenant
base64_encrypted_token = ProviderService.get_encrypted_token(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider,
provider_type=ProviderType.CUSTOM.value).first()
provider_model = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
# Only allow updating token for CUSTOM provider type
if provider_model:
......@@ -117,6 +121,16 @@ class ProviderTokenApi(Resource):
is_valid=token_is_valid)
db.session.add(provider_model)
if provider_model.is_valid:
other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
for other_provider in other_providers:
other_provider.is_valid = False
db.session.commit()
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
......
......@@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(
text: str,
engine: Optional[str] = None,
openai_api_key: Optional[str] = None,
text: str,
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[float]:
"""Get embedding.
......@@ -25,11 +26,12 @@ def get_embedding(
"""
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"]
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]:
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
float]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
......@@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
"embedding"
]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str],
engine: Optional[str] = None,
openai_api_key: Optional[str] = None
list_of_text: List[str],
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[List[float]]:
"""Get embeddings.
......@@ -67,14 +70,14 @@ def get_embeddings(
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
) -> List[List[float]]:
"""Asynchronously get embeddings.
......@@ -90,7 +93,7 @@ async def aget_embeddings(
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
......@@ -98,19 +101,30 @@ async def aget_embeddings(
class OpenAIEmbedding(BaseEmbedding):
def __init__(
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(**kwargs)
new_kwargs = {}
if 'embed_batch_size' in kwargs:
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
if 'tokenizer' in kwargs:
new_kwargs['tokenizer'] = kwargs['tokenizer']
super().__init__(**new_kwargs)
self.mode = OpenAIEmbeddingMode(mode)
self.model = OpenAIEmbeddingModelType(model)
self.deployment_name = deployment_name
self.openai_api_key = openai_api_key
self.openai_api_type = kwargs.get('openai_api_type')
self.openai_api_version = kwargs.get('openai_api_version')
self.openai_api_base = kwargs.get('openai_api_base')
@handle_llm_exceptions
def _get_query_embedding(self, query: str) -> List[float]:
......@@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _QUERY_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _QUERY_MODE_MODEL_DICT[key]
return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key)
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
......@@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
......@@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
......@@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
......@@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
......@@ -33,8 +33,11 @@ class IndexBuilder:
max_chunk_overlap=20
)
provider = LLMBuilder.get_default_provider(tenant_id)
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=tenant_id,
model_provider=provider,
model_name='text-embedding-ada-002'
)
......
......@@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager
from langchain.llms.fake import FakeListLLM
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType
class LLMBuilder:
......@@ -31,16 +36,23 @@ class LLMBuilder:
if model_name == 'fake':
return FakeListLLM(responses=[])
provider = cls.get_default_provider(tenant_id)
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
# llm_cls = StreamableAzureChatOpenAI
llm_cls = StreamableChatOpenAI
if provider == 'openai':
llm_cls = StreamableChatOpenAI
else:
llm_cls = StreamableAzureChatOpenAI
elif mode == 'completion':
llm_cls = StreamableOpenAI
if provider == 'openai':
llm_cls = StreamableOpenAI
else:
llm_cls = StreamableAzureOpenAI
else:
raise ValueError(f"model name {model_name} is not supported.")
model_credentials = cls.get_model_credentials(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
return llm_cls(
model_name=model_name,
......@@ -86,18 +98,31 @@ class LLMBuilder:
raise ValueError(f"model name {model_name} is not supported.")
@classmethod
def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
"""
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found.
"""
if not model_name:
raise Exception('model name not found')
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))
if model_name not in llm_constant.models:
raise Exception('model {} not found'.format(model_name))
model_provider = llm_constant.models[model_name]
# model_provider = llm_constant.models[model_name]
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str) -> str:
provider = BaseProvider.get_valid_provider(tenant_id)
if not provider:
raise ProviderTokenNotInitError()
if provider.provider_type == ProviderType.SYSTEM.value:
provider_name = 'openai'
else:
provider_name = provider.provider_name
return provider_name
......@@ -36,10 +36,9 @@ class AzureProvider(BaseProvider):
"""
Returns the API credentials for Azure OpenAI as a dictionary.
"""
encrypted_config = self.get_provider_api_key(model_id=model_id)
config = json.loads(encrypted_config)
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['deployment_name'] = model_id
config['deployment_name'] = model_id.replace('.', '')
return config
def get_provider_name(self):
......@@ -51,12 +50,11 @@ class AzureProvider(BaseProvider):
"""
try:
config = self.get_provider_api_key()
config = json.loads(config)
except:
config = {
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_base': 'https://foo.microsoft.com/bar',
'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
'openai_api_key': ''
}
......@@ -65,7 +63,7 @@ class AzureProvider(BaseProvider):
config = {
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_base': 'https://foo.microsoft.com/bar',
'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
'openai_api_key': ''
}
......
......@@ -14,7 +14,7 @@ class BaseProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
......@@ -43,23 +43,35 @@ class BaseProvider(ABC):
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
providers = db.session.query(Provider).filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.get_provider_name().value
).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
@classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
)
if provider_name:
query = query.filter(Provider.provider_name == provider_name)
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
custom_provider = None
system_provider = None
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value:
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
system_provider = provider
if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config:
if custom_provider:
return custom_provider
elif system_provider and system_provider.is_valid:
elif system_provider:
return system_provider
else:
return None
......@@ -80,7 +92,7 @@ class BaseProvider(ABC):
try:
config = self.get_provider_api_key()
except:
config = 'THIS-IS-A-MOCK-TOKEN'
config = ''
if obfuscated:
return self.obfuscated_token(config)
......
import requests
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List
from typing import Optional, List, Dict, Any
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableAzureChatOpenAI(AzureChatOpenAI):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
**super()._default_params,
"engine": self.deployment_name,
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages.
......
import os
from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableAzureOpenAI(AzureOpenAI):
openai_api_type: str = "azure"
openai_api_version: str = ""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
return values
@property
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return super().generate(prompts, stop)
@handle_llm_exceptions_async
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return await super().agenerate(prompts, stop)
import os
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import ChatOpenAI
from typing import Optional, List
from typing import Optional, List, Dict, Any
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableChatOpenAI(ChatOpenAI):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
**super()._default_params,
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages.
......
import os
from langchain.schema import LLMResult
from typing import Optional, List
from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableOpenAI(OpenAI):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
return values
@property
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
......
......@@ -20,7 +20,7 @@ const AzureProvider = ({
const [token, setToken] = useState(provider.token as ProviderAzureToken || {})
const handleFocus = () => {
if (token === provider.token) {
token.azure_api_key = ''
token.openai_api_key = ''
setToken({...token})
onTokenChange({...token})
}
......@@ -35,31 +35,17 @@ const AzureProvider = ({
<div className='px-4 py-3'>
<ProviderInput
className='mb-4'
name={t('common.provider.azure.resourceName')}
placeholder={t('common.provider.azure.resourceNamePlaceholder')}
value={token.azure_api_base}
onChange={(v) => handleChange('azure_api_base', v)}
/>
<ProviderInput
className='mb-4'
name={t('common.provider.azure.deploymentId')}
placeholder={t('common.provider.azure.deploymentIdPlaceholder')}
value={token.azure_api_type}
onChange={v => handleChange('azure_api_type', v)}
/>
<ProviderInput
className='mb-4'
name={t('common.provider.azure.apiVersion')}
placeholder={t('common.provider.azure.apiVersionPlaceholder')}
value={token.azure_api_version}
onChange={v => handleChange('azure_api_version', v)}
name={t('common.provider.azure.apiBase')}
placeholder={t('common.provider.azure.apiBasePlaceholder')}
value={token.openai_api_base}
onChange={(v) => handleChange('openai_api_base', v)}
/>
<ProviderValidateTokenInput
className='mb-4'
name={t('common.provider.azure.apiKey')}
placeholder={t('common.provider.azure.apiKeyPlaceholder')}
value={token.azure_api_key}
onChange={v => handleChange('azure_api_key', v)}
value={token.openai_api_key}
onChange={v => handleChange('openai_api_key', v)}
onFocus={handleFocus}
onValidatedStatus={onValidatedStatus}
providerName={provider.provider_name}
......@@ -72,4 +58,4 @@ const AzureProvider = ({
)
}
export default AzureProvider
\ No newline at end of file
export default AzureProvider
......@@ -33,12 +33,12 @@ const ProviderItem = ({
const { notify } = useContext(ToastContext)
const [token, setToken] = useState<ProviderAzureToken | string>(
provider.provider_name === 'azure_openai'
? { azure_api_base: '', azure_api_type: '', azure_api_version: '', azure_api_key: '' }
? { openai_api_base: '', openai_api_key: '' }
: ''
)
const id = `${provider.provider_name}-${provider.provider_type}`
const isOpen = id === activeId
const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.azure_api_key : provider.token
const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.openai_api_key : provider.token
const comingSoon = false
const isValid = provider.is_valid
......@@ -135,4 +135,4 @@ const ProviderItem = ({
)
}
export default ProviderItem
\ No newline at end of file
export default ProviderItem
......@@ -148,12 +148,8 @@ const translation = {
editKey: 'Edit',
invalidApiKey: 'Invalid API key',
azure: {
resourceName: 'Resource Name',
resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.',
deploymentId: 'Deployment ID',
deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
apiVersion: 'API Version',
apiVersionPlaceholder: 'The API version to use for this operation.',
apiBase: 'API Base',
apiBasePlaceholder: 'The API Base URL of your Azure OpenAI Resource.',
apiKey: 'API Key',
apiKeyPlaceholder: 'Enter your API key here',
helpTip: 'Learn Azure OpenAI Service',
......
......@@ -149,14 +149,10 @@ const translation = {
editKey: '编辑',
invalidApiKey: '无效的 API 密钥',
azure: {
resourceName: 'Resource Name',
resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.',
deploymentId: 'Deployment ID',
deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
apiVersion: 'API Version',
apiVersionPlaceholder: 'The API version to use for this operation.',
apiBase: 'API Base',
apiBasePlaceholder: '输入您的 Azure OpenAI API Base 地址',
apiKey: 'API Key',
apiKeyPlaceholder: 'Enter your API key here',
apiKeyPlaceholder: '输入你的 API 密钥',
helpTip: '了解 Azure OpenAI Service',
},
openaiHosted: {
......
......@@ -55,10 +55,8 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l
}
export type ProviderAzureToken = {
azure_api_base: string
azure_api_key: string
azure_api_type: string
azure_api_version: string
openai_api_base: string
openai_api_key: string
}
export type Provider = {
provider_name: string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment