Unverified Commit 5b24d712 authored by Charlie.Wei's avatar Charlie.Wei Committed by GitHub

Azure openai init (#1929)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: 's avatarcrazywoola <427733928@qq.com>
Co-authored-by: 's avatarcrazywoola <100913391+crazywoola@users.noreply.github.com>
parent b8592ad4
import datetime
import json
import logging
import time
from json import JSONDecodeError
from typing import Optional, List, Dict, Tuple, Iterator
......@@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
from core.model_runtime.entities.model_entities import ModelType, FetchFrom
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
ConfigurateMethod
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
......@@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr
logger = logging.getLogger(__name__)
original_provider_configurate_methods = {}
class ProviderConfiguration(BaseModel):
"""
......@@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel):
system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration
def __init__(self, **data):
super().__init__(**data)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (any([len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations])
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
"""
Get current credentials.
......@@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel):
if provider_record:
try:
original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
original_credentials = json.loads(
provider_record.encrypted_config) if provider_record.encrypted_config else {}
except JSONDecodeError:
original_credentials = {}
......@@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel):
if provider_model_record:
try:
original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
original_credentials = json.loads(
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
except JSONDecodeError:
original_credentials = {}
......@@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel):
]
)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
should_use_custom_model = False
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
should_use_custom_model = True
for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue
restrict_llms = quota_configuration.restrict_llms
if not restrict_llms:
restrict_models = quota_configuration.restrict_models
if len(restrict_models) == 0:
break
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials['base_model_name'] = restrict_model.base_model_name
try:
custom_model_schema = (
provider_instance.get_model_instance(restrict_model.model_type)
.get_customizable_model_schema_from_credentials(
restrict_model.model,
copy_credentials
)
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
continue
if not custom_model_schema:
continue
if custom_model_schema.model_type not in model_types:
continue
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
)
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for m in provider_models:
if m.model_type == ModelType.LLM and m.model not in restrict_llms:
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
def _get_custom_provider_models(self,
......
......@@ -21,6 +21,12 @@ class SystemConfigurationStatus(Enum):
UNSUPPORTED = 'unsupported'
class RestrictModel(BaseModel):
model: str
base_model_name: Optional[str] = None
model_type: ModelType
class QuotaConfiguration(BaseModel):
"""
Model class for provider quota configuration.
......@@ -30,7 +36,7 @@ class QuotaConfiguration(BaseModel):
quota_limit: int
quota_used: int
is_valid: bool
restrict_llms: list[str] = []
restrict_models: list[RestrictModel] = []
class SystemConfiguration(BaseModel):
......
......@@ -4,13 +4,14 @@ from typing import Optional
from flask import Flask
from pydantic import BaseModel
from core.entities.provider_entities import QuotaUnit
from core.entities.provider_entities import QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
class HostingQuota(BaseModel):
quota_type: ProviderQuotaType
restrict_llms: list[str] = []
restrict_models: list[RestrictModel] = []
class TrialHostingQuota(HostingQuota):
......@@ -47,10 +48,9 @@ class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
moderation_config: HostedModerationConfig = None
def init_app(self, app: Flask):
if app.config.get('EDITION') != 'CLOUD':
return
def init_app(self, app: Flask) -> None:
self.provider_map["azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax()
......@@ -59,6 +59,47 @@ class HostingConfiguration:
self.moderation_config = self.init_moderation_config()
def init_azure_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true':
credentials = {
"openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"),
"openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"),
"base_model_name": "gpt-35-turbo"
}
quotas = []
hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
]
)
quotas.append(trial_quota)
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
......@@ -77,12 +118,12 @@ class HostingConfiguration:
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_llms=[
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-16k",
"text-davinci-003"
restrict_models=[
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
)
quotas.append(trial_quota)
......
......@@ -144,7 +144,7 @@ class ModelInstance:
user=user
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \
-> str:
"""
Invoke large language model
......@@ -161,7 +161,8 @@ class ModelInstance:
model=self.model,
credentials=self.credentials,
file=file,
user=user
user=user,
**params
)
......
......@@ -32,7 +32,7 @@ class ModelType(Enum):
return cls.TEXT_EMBEDDING
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
return cls.RERANK
elif origin_model_type == cls.SPEECH2TEXT.value:
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
return cls.SPEECH2TEXT
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION
......
......@@ -2,7 +2,7 @@ from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \
DefaultParameterName, PriceConfig
DefaultParameterName, PriceConfig, ModelPropertyKey
from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
......@@ -502,8 +502,8 @@ EMBEDDING_BASE_MODELS = [
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
'context_size': 8097,
'max_chunks': 32,
ModelPropertyKey.CONTEXT_SIZE: 8097,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.0001,
......
......@@ -30,7 +30,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
......@@ -59,7 +59,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get(
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value:
......@@ -79,7 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
......@@ -109,8 +109,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
return ai_model_entity.entity
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
return ai_model_entity.entity if ai_model_entity else None
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
......
......@@ -12,7 +12,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
ConfigurateMethod
from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider
from extensions.ext_database import db
......@@ -607,7 +608,7 @@ class ProviderManager:
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1,
restrict_llms=provider_quota.restrict_llms
restrict_models=provider_quota.restrict_models
)
quota_configurations.append(quota_configuration)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment