Unverified Commit f68b05d5 authored by John Wang's avatar John Wang Committed by GitHub

Feat: support azure openai for temporary (#101)

parent 3b3c604e
...@@ -47,6 +47,7 @@ DEFAULTS = { ...@@ -47,6 +47,7 @@ DEFAULTS = {
'PDF_PREVIEW': 'True', 'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai'
} }
...@@ -181,6 +182,10 @@ class Config: ...@@ -181,6 +182,10 @@ class Config:
# You could disable it for compatibility with certain OpenAPI providers # You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
class CloudEditionConfig(Config): class CloudEditionConfig(Config):
def __init__(self): def __init__(self):
......
...@@ -82,29 +82,33 @@ class ProviderTokenApi(Resource): ...@@ -82,29 +82,33 @@ class ProviderTokenApi(Resource):
args = parser.parse_args() args = parser.parse_args()
if not args['token']: if args['token']:
raise ValueError('Token is empty') try:
ProviderService.validate_provider_configs(
try: tenant=current_user.current_tenant,
ProviderService.validate_provider_configs( provider_name=ProviderName(provider),
configs=args['token']
)
token_is_valid = True
except ValidateFailedError:
token_is_valid = False
base64_encrypted_token = ProviderService.get_encrypted_token(
tenant=current_user.current_tenant, tenant=current_user.current_tenant,
provider_name=ProviderName(provider), provider_name=ProviderName(provider),
configs=args['token'] configs=args['token']
) )
token_is_valid = True else:
except ValidateFailedError: base64_encrypted_token = None
token_is_valid = False token_is_valid = False
tenant = current_user.current_tenant tenant = current_user.current_tenant
base64_encrypted_token = ProviderService.get_encrypted_token( provider_model = db.session.query(Provider).filter(
tenant=current_user.current_tenant, Provider.tenant_id == tenant.id,
provider_name=ProviderName(provider), Provider.provider_name == provider,
configs=args['token'] Provider.provider_type == ProviderType.CUSTOM.value
) ).first()
provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider,
provider_type=ProviderType.CUSTOM.value).first()
# Only allow updating token for CUSTOM provider type # Only allow updating token for CUSTOM provider type
if provider_model: if provider_model:
...@@ -117,6 +121,16 @@ class ProviderTokenApi(Resource): ...@@ -117,6 +121,16 @@ class ProviderTokenApi(Resource):
is_valid=token_is_valid) is_valid=token_is_valid)
db.session.add(provider_model) db.session.add(provider_model)
if provider_model.is_valid:
other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
for other_provider in other_providers:
other_provider.is_valid = False
db.session.commit() db.session.commit()
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
......
...@@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except ...@@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding( def get_embedding(
text: str, text: str,
engine: Optional[str] = None, engine: Optional[str] = None,
openai_api_key: Optional[str] = None, api_key: Optional[str] = None,
**kwargs
) -> List[float]: ) -> List[float]:
"""Get embedding. """Get embedding.
...@@ -25,11 +26,12 @@ def get_embedding( ...@@ -25,11 +26,12 @@ def get_embedding(
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"] return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]: async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
float]:
"""Asynchronously get embedding. """Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils: NOTE: Copied from OpenAI's embedding utils:
...@@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key ...@@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
text = text.replace("\n", " ") text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][ return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
"embedding" "embedding"
] ]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings( def get_embeddings(
list_of_text: List[str], list_of_text: List[str],
engine: Optional[str] = None, engine: Optional[str] = None,
openai_api_key: Optional[str] = None api_key: Optional[str] = None,
**kwargs
) -> List[List[float]]: ) -> List[List[float]]:
"""Get embeddings. """Get embeddings.
...@@ -67,14 +70,14 @@ def get_embeddings( ...@@ -67,14 +70,14 @@ def get_embeddings(
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text] list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data] return [d["embedding"] for d in data]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings( async def aget_embeddings(
list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
) -> List[List[float]]: ) -> List[List[float]]:
"""Asynchronously get embeddings. """Asynchronously get embeddings.
...@@ -90,7 +93,7 @@ async def aget_embeddings( ...@@ -90,7 +93,7 @@ async def aget_embeddings(
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text] list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data] return [d["embedding"] for d in data]
...@@ -98,19 +101,30 @@ async def aget_embeddings( ...@@ -98,19 +101,30 @@ async def aget_embeddings(
class OpenAIEmbedding(BaseEmbedding): class OpenAIEmbedding(BaseEmbedding):
def __init__( def __init__(
self, self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None, deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None, openai_api_key: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Init params.""" """Init params."""
super().__init__(**kwargs) new_kwargs = {}
if 'embed_batch_size' in kwargs:
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
if 'tokenizer' in kwargs:
new_kwargs['tokenizer'] = kwargs['tokenizer']
super().__init__(**new_kwargs)
self.mode = OpenAIEmbeddingMode(mode) self.mode = OpenAIEmbeddingMode(mode)
self.model = OpenAIEmbeddingModelType(model) self.model = OpenAIEmbeddingModelType(model)
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.openai_api_key = openai_api_key self.openai_api_key = openai_api_key
self.openai_api_type = kwargs.get('openai_api_type')
self.openai_api_version = kwargs.get('openai_api_version')
self.openai_api_base = kwargs.get('openai_api_base')
@handle_llm_exceptions @handle_llm_exceptions
def _get_query_embedding(self, query: str) -> List[float]: def _get_query_embedding(self, query: str) -> List[float]:
...@@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding): ...@@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _QUERY_MODE_MODEL_DICT: if key not in _QUERY_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _QUERY_MODE_MODEL_DICT[key] engine = _QUERY_MODE_MODEL_DICT[key]
return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key) return get_embedding(query, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embedding(self, text: str) -> List[float]: def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding.""" """Get text embedding."""
...@@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding): ...@@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _TEXT_MODE_MODEL_DICT: if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key] engine = _TEXT_MODE_MODEL_DICT[key]
return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key) return get_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
async def _aget_text_embedding(self, text: str) -> List[float]: async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding.""" """Asynchronously get text embedding."""
...@@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding): ...@@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _TEXT_MODE_MODEL_DICT: if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key] engine = _TEXT_MODE_MODEL_DICT[key]
return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key) return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings. """Get text embeddings.
...@@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding): ...@@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _TEXT_MODE_MODEL_DICT: if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key] engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings return embeddings
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
...@@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding): ...@@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding):
if key not in _TEXT_MODE_MODEL_DICT: if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key] engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings return embeddings
...@@ -33,8 +33,11 @@ class IndexBuilder: ...@@ -33,8 +33,11 @@ class IndexBuilder:
max_chunk_overlap=20 max_chunk_overlap=20
) )
provider = LLMBuilder.get_default_provider(tenant_id)
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
model_provider=provider,
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )
......
...@@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager ...@@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager
from langchain.llms.fake import FakeListLLM from langchain.llms.fake import FakeListLLM
from core.constant import llm_constant from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType
class LLMBuilder: class LLMBuilder:
...@@ -31,16 +36,23 @@ class LLMBuilder: ...@@ -31,16 +36,23 @@ class LLMBuilder:
if model_name == 'fake': if model_name == 'fake':
return FakeListLLM(responses=[]) return FakeListLLM(responses=[])
provider = cls.get_default_provider(tenant_id)
mode = cls.get_mode_by_model(model_name) mode = cls.get_mode_by_model(model_name)
if mode == 'chat': if mode == 'chat':
# llm_cls = StreamableAzureChatOpenAI if provider == 'openai':
llm_cls = StreamableChatOpenAI llm_cls = StreamableChatOpenAI
else:
llm_cls = StreamableAzureChatOpenAI
elif mode == 'completion': elif mode == 'completion':
llm_cls = StreamableOpenAI if provider == 'openai':
llm_cls = StreamableOpenAI
else:
llm_cls = StreamableAzureOpenAI
else: else:
raise ValueError(f"model name {model_name} is not supported.") raise ValueError(f"model name {model_name} is not supported.")
model_credentials = cls.get_model_credentials(tenant_id, model_name) model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
return llm_cls( return llm_cls(
model_name=model_name, model_name=model_name,
...@@ -86,18 +98,31 @@ class LLMBuilder: ...@@ -86,18 +98,31 @@ class LLMBuilder:
raise ValueError(f"model name {model_name} is not supported.") raise ValueError(f"model name {model_name} is not supported.")
@classmethod @classmethod
def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict: def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
""" """
Returns the API credentials for the given tenant_id and model_name, based on the model's provider. Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found. Raises an exception if the model_name is not found or if the provider is not found.
""" """
if not model_name: if not model_name:
raise Exception('model name not found') raise Exception('model name not found')
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))
if model_name not in llm_constant.models: # model_provider = llm_constant.models[model_name]
raise Exception('model {} not found'.format(model_name))
model_provider = llm_constant.models[model_name]
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name) return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str) -> str:
provider = BaseProvider.get_valid_provider(tenant_id)
if not provider:
raise ProviderTokenNotInitError()
if provider.provider_type == ProviderType.SYSTEM.value:
provider_name = 'openai'
else:
provider_name = provider.provider_name
return provider_name
...@@ -36,10 +36,9 @@ class AzureProvider(BaseProvider): ...@@ -36,10 +36,9 @@ class AzureProvider(BaseProvider):
""" """
Returns the API credentials for Azure OpenAI as a dictionary. Returns the API credentials for Azure OpenAI as a dictionary.
""" """
encrypted_config = self.get_provider_api_key(model_id=model_id) config = self.get_provider_api_key(model_id=model_id)
config = json.loads(encrypted_config)
config['openai_api_type'] = 'azure' config['openai_api_type'] = 'azure'
config['deployment_name'] = model_id config['deployment_name'] = model_id.replace('.', '')
return config return config
def get_provider_name(self): def get_provider_name(self):
...@@ -51,12 +50,11 @@ class AzureProvider(BaseProvider): ...@@ -51,12 +50,11 @@ class AzureProvider(BaseProvider):
""" """
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key()
config = json.loads(config)
except: except:
config = { config = {
'openai_api_type': 'azure', 'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview', 'openai_api_version': '2023-03-15-preview',
'openai_api_base': 'https://foo.microsoft.com/bar', 'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
'openai_api_key': '' 'openai_api_key': ''
} }
...@@ -65,7 +63,7 @@ class AzureProvider(BaseProvider): ...@@ -65,7 +63,7 @@ class AzureProvider(BaseProvider):
config = { config = {
'openai_api_type': 'azure', 'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview', 'openai_api_version': '2023-03-15-preview',
'openai_api_base': 'https://foo.microsoft.com/bar', 'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
'openai_api_key': '' 'openai_api_key': ''
} }
......
...@@ -14,7 +14,7 @@ class BaseProvider(ABC): ...@@ -14,7 +14,7 @@ class BaseProvider(ABC):
def __init__(self, tenant_id: str): def __init__(self, tenant_id: str):
self.tenant_id = tenant_id self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str: def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
""" """
Returns the decrypted API key for the given tenant_id and provider_name. Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
...@@ -43,23 +43,35 @@ class BaseProvider(ABC): ...@@ -43,23 +43,35 @@ class BaseProvider(ABC):
Returns the Provider instance for the given tenant_id and provider_name. Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
""" """
providers = db.session.query(Provider).filter( return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.get_provider_name().value @classmethod
).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
)
if provider_name:
query = query.filter(Provider.provider_name == provider_name)
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
custom_provider = None custom_provider = None
system_provider = None system_provider = None
for provider in providers: for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value: if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
custom_provider = provider custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value: elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
system_provider = provider system_provider = provider
if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config: if custom_provider:
return custom_provider return custom_provider
elif system_provider and system_provider.is_valid: elif system_provider:
return system_provider return system_provider
else: else:
return None return None
...@@ -80,7 +92,7 @@ class BaseProvider(ABC): ...@@ -80,7 +92,7 @@ class BaseProvider(ABC):
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key()
except: except:
config = 'THIS-IS-A-MOCK-TOKEN' config = ''
if obfuscated: if obfuscated:
return self.obfuscated_token(config) return self.obfuscated_token(config)
......
import requests
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List from typing import Optional, List, Dict, Any
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableAzureChatOpenAI(AzureChatOpenAI): class StreamableAzureChatOpenAI(AzureChatOpenAI):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
**super()._default_params,
"engine": self.deployment_name,
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}
def get_messages_tokens(self, messages: List[BaseMessage]) -> int: def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages. """Get the number of tokens in a list of messages.
......
import os
from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableAzureOpenAI(AzureOpenAI):
openai_api_type: str = "azure"
openai_api_version: str = ""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
return values
@property
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return super().generate(prompts, stop)
@handle_llm_exceptions_async
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return await super().agenerate(prompts, stop)
import os
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from typing import Optional, List from typing import Optional, List, Dict, Any
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableChatOpenAI(ChatOpenAI): class StreamableChatOpenAI(ChatOpenAI):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
**super()._default_params,
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}
def get_messages_tokens(self, messages: List[BaseMessage]) -> int: def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages. """Get the number of tokens in a list of messages.
......
import os
from langchain.schema import LLMResult from langchain.schema import LLMResult
from typing import Optional, List from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI from langchain import OpenAI
from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableOpenAI(OpenAI): class StreamableOpenAI(OpenAI):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
return values
@property
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions @handle_llm_exceptions
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
......
...@@ -20,7 +20,7 @@ const AzureProvider = ({ ...@@ -20,7 +20,7 @@ const AzureProvider = ({
const [token, setToken] = useState(provider.token as ProviderAzureToken || {}) const [token, setToken] = useState(provider.token as ProviderAzureToken || {})
const handleFocus = () => { const handleFocus = () => {
if (token === provider.token) { if (token === provider.token) {
token.azure_api_key = '' token.openai_api_key = ''
setToken({...token}) setToken({...token})
onTokenChange({...token}) onTokenChange({...token})
} }
...@@ -35,31 +35,17 @@ const AzureProvider = ({ ...@@ -35,31 +35,17 @@ const AzureProvider = ({
<div className='px-4 py-3'> <div className='px-4 py-3'>
<ProviderInput <ProviderInput
className='mb-4' className='mb-4'
name={t('common.provider.azure.resourceName')} name={t('common.provider.azure.apiBase')}
placeholder={t('common.provider.azure.resourceNamePlaceholder')} placeholder={t('common.provider.azure.apiBasePlaceholder')}
value={token.azure_api_base} value={token.openai_api_base}
onChange={(v) => handleChange('azure_api_base', v)} onChange={(v) => handleChange('openai_api_base', v)}
/>
<ProviderInput
className='mb-4'
name={t('common.provider.azure.deploymentId')}
placeholder={t('common.provider.azure.deploymentIdPlaceholder')}
value={token.azure_api_type}
onChange={v => handleChange('azure_api_type', v)}
/>
<ProviderInput
className='mb-4'
name={t('common.provider.azure.apiVersion')}
placeholder={t('common.provider.azure.apiVersionPlaceholder')}
value={token.azure_api_version}
onChange={v => handleChange('azure_api_version', v)}
/> />
<ProviderValidateTokenInput <ProviderValidateTokenInput
className='mb-4' className='mb-4'
name={t('common.provider.azure.apiKey')} name={t('common.provider.azure.apiKey')}
placeholder={t('common.provider.azure.apiKeyPlaceholder')} placeholder={t('common.provider.azure.apiKeyPlaceholder')}
value={token.azure_api_key} value={token.openai_api_key}
onChange={v => handleChange('azure_api_key', v)} onChange={v => handleChange('openai_api_key', v)}
onFocus={handleFocus} onFocus={handleFocus}
onValidatedStatus={onValidatedStatus} onValidatedStatus={onValidatedStatus}
providerName={provider.provider_name} providerName={provider.provider_name}
...@@ -72,4 +58,4 @@ const AzureProvider = ({ ...@@ -72,4 +58,4 @@ const AzureProvider = ({
) )
} }
export default AzureProvider export default AzureProvider
\ No newline at end of file
...@@ -33,12 +33,12 @@ const ProviderItem = ({ ...@@ -33,12 +33,12 @@ const ProviderItem = ({
const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
const [token, setToken] = useState<ProviderAzureToken | string>( const [token, setToken] = useState<ProviderAzureToken | string>(
provider.provider_name === 'azure_openai' provider.provider_name === 'azure_openai'
? { azure_api_base: '', azure_api_type: '', azure_api_version: '', azure_api_key: '' } ? { openai_api_base: '', openai_api_key: '' }
: '' : ''
) )
const id = `${provider.provider_name}-${provider.provider_type}` const id = `${provider.provider_name}-${provider.provider_type}`
const isOpen = id === activeId const isOpen = id === activeId
const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.azure_api_key : provider.token const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.openai_api_key : provider.token
const comingSoon = false const comingSoon = false
const isValid = provider.is_valid const isValid = provider.is_valid
...@@ -135,4 +135,4 @@ const ProviderItem = ({ ...@@ -135,4 +135,4 @@ const ProviderItem = ({
) )
} }
export default ProviderItem export default ProviderItem
\ No newline at end of file
...@@ -148,12 +148,8 @@ const translation = { ...@@ -148,12 +148,8 @@ const translation = {
editKey: 'Edit', editKey: 'Edit',
invalidApiKey: 'Invalid API key', invalidApiKey: 'Invalid API key',
azure: { azure: {
resourceName: 'Resource Name', apiBase: 'API Base',
resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', apiBasePlaceholder: 'The API Base URL of your Azure OpenAI Resource.',
deploymentId: 'Deployment ID',
deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
apiVersion: 'API Version',
apiVersionPlaceholder: 'The API version to use for this operation.',
apiKey: 'API Key', apiKey: 'API Key',
apiKeyPlaceholder: 'Enter your API key here', apiKeyPlaceholder: 'Enter your API key here',
helpTip: 'Learn Azure OpenAI Service', helpTip: 'Learn Azure OpenAI Service',
......
...@@ -149,14 +149,10 @@ const translation = { ...@@ -149,14 +149,10 @@ const translation = {
editKey: '编辑', editKey: '编辑',
invalidApiKey: '无效的 API 密钥', invalidApiKey: '无效的 API 密钥',
azure: { azure: {
resourceName: 'Resource Name', apiBase: 'API Base',
resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', apiBasePlaceholder: '输入您的 Azure OpenAI API Base 地址',
deploymentId: 'Deployment ID',
deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
apiVersion: 'API Version',
apiVersionPlaceholder: 'The API version to use for this operation.',
apiKey: 'API Key', apiKey: 'API Key',
apiKeyPlaceholder: 'Enter your API key here', apiKeyPlaceholder: '输入你的 API 密钥',
helpTip: '了解 Azure OpenAI Service', helpTip: '了解 Azure OpenAI Service',
}, },
openaiHosted: { openaiHosted: {
......
...@@ -55,10 +55,8 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l ...@@ -55,10 +55,8 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l
} }
export type ProviderAzureToken = { export type ProviderAzureToken = {
azure_api_base: string openai_api_base: string
azure_api_key: string openai_api_key: string
azure_api_type: string
azure_api_version: string
} }
export type Provider = { export type Provider = {
provider_name: string provider_name: string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment