Unverified Commit 5b24d712 authored by Charlie.Wei's avatar Charlie.Wei Committed by GitHub

Azure openai init (#1929)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: 's avatarcrazywoola <427733928@qq.com>
Co-authored-by: 's avatarcrazywoola <100913391+crazywoola@users.noreply.github.com>
parent b8592ad4
import datetime import datetime
import json import json
import logging import logging
import time
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional, List, Dict, Tuple, Iterator from typing import Optional, List, Dict, Tuple, Iterator
...@@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S ...@@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType, FetchFrom
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
ConfigurateMethod
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.__base.model_provider import ModelProvider
...@@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr ...@@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
original_provider_configurate_methods = {}
class ProviderConfiguration(BaseModel): class ProviderConfiguration(BaseModel):
""" """
...@@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel): ...@@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel):
system_configuration: SystemConfiguration system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration custom_configuration: CustomConfiguration
def __init__(self, **data):
super().__init__(**data)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (any([len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations])
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
""" """
Get current credentials. Get current credentials.
...@@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel): ...@@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel):
if provider_record: if provider_record:
try: try:
original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {} original_credentials = json.loads(
provider_record.encrypted_config) if provider_record.encrypted_config else {}
except JSONDecodeError: except JSONDecodeError:
original_credentials = {} original_credentials = {}
...@@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel): ...@@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel):
if provider_model_record: if provider_model_record:
try: try:
original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} original_credentials = json.loads(
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
except JSONDecodeError: except JSONDecodeError:
original_credentials = {} original_credentials = {}
...@@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel): ...@@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel):
] ]
) )
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
should_use_custom_model = False
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
should_use_custom_model = True
for quota_configuration in self.system_configuration.quota_configurations: for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type: if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue continue
restrict_llms = quota_configuration.restrict_llms restrict_models = quota_configuration.restrict_models
if not restrict_llms: if len(restrict_models) == 0:
break break
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials['base_model_name'] = restrict_model.base_model_name
try:
custom_model_schema = (
provider_instance.get_model_instance(restrict_model.model_type)
.get_customizable_model_schema_from_credentials(
restrict_model.model,
copy_credentials
)
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
continue
if not custom_model_schema:
continue
if custom_model_schema.model_type not in model_types:
continue
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
)
# if llm name not in restricted llm list, remove it # if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for m in provider_models: for m in provider_models:
if m.model_type == ModelType.LLM and m.model not in restrict_llms: if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
m.status = ModelStatus.NO_PERMISSION m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid: elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED m.status = ModelStatus.QUOTA_EXCEEDED
return provider_models return provider_models
def _get_custom_provider_models(self, def _get_custom_provider_models(self,
......
...@@ -21,6 +21,12 @@ class SystemConfigurationStatus(Enum): ...@@ -21,6 +21,12 @@ class SystemConfigurationStatus(Enum):
UNSUPPORTED = 'unsupported' UNSUPPORTED = 'unsupported'
class RestrictModel(BaseModel):
model: str
base_model_name: Optional[str] = None
model_type: ModelType
class QuotaConfiguration(BaseModel): class QuotaConfiguration(BaseModel):
""" """
Model class for provider quota configuration. Model class for provider quota configuration.
...@@ -30,7 +36,7 @@ class QuotaConfiguration(BaseModel): ...@@ -30,7 +36,7 @@ class QuotaConfiguration(BaseModel):
quota_limit: int quota_limit: int
quota_used: int quota_used: int
is_valid: bool is_valid: bool
restrict_llms: list[str] = [] restrict_models: list[RestrictModel] = []
class SystemConfiguration(BaseModel): class SystemConfiguration(BaseModel):
......
...@@ -4,13 +4,14 @@ from typing import Optional ...@@ -4,13 +4,14 @@ from typing import Optional
from flask import Flask from flask import Flask
from pydantic import BaseModel from pydantic import BaseModel
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType from models.provider import ProviderQuotaType
class HostingQuota(BaseModel): class HostingQuota(BaseModel):
quota_type: ProviderQuotaType quota_type: ProviderQuotaType
restrict_llms: list[str] = [] restrict_models: list[RestrictModel] = []
class TrialHostingQuota(HostingQuota): class TrialHostingQuota(HostingQuota):
...@@ -47,10 +48,9 @@ class HostingConfiguration: ...@@ -47,10 +48,9 @@ class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {} provider_map: dict[str, HostingProvider] = {}
moderation_config: HostedModerationConfig = None moderation_config: HostedModerationConfig = None
def init_app(self, app: Flask): def init_app(self, app: Flask) -> None:
if app.config.get('EDITION') != 'CLOUD':
return
self.provider_map["azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai() self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic() self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax() self.provider_map["minimax"] = self.init_minimax()
...@@ -59,6 +59,47 @@ class HostingConfiguration: ...@@ -59,6 +59,47 @@ class HostingConfiguration:
self.moderation_config = self.init_moderation_config() self.moderation_config = self.init_moderation_config()
def init_azure_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true':
credentials = {
"openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"),
"openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"),
"base_model_name": "gpt-35-turbo"
}
quotas = []
hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
]
)
quotas.append(trial_quota)
return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_openai(self) -> HostingProvider: def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.TIMES quota_unit = QuotaUnit.TIMES
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true': if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
...@@ -77,12 +118,12 @@ class HostingConfiguration: ...@@ -77,12 +118,12 @@ class HostingConfiguration:
if hosted_quota_limit != -1 or hosted_quota_limit > 0: if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota( trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit, quota_limit=hosted_quota_limit,
restrict_llms=[ restrict_models=[
"gpt-3.5-turbo", RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
"gpt-3.5-turbo-1106", RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
"gpt-3.5-turbo-instruct", RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
"gpt-3.5-turbo-16k", RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
"text-davinci-003" RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
] ]
) )
quotas.append(trial_quota) quotas.append(trial_quota)
......
...@@ -144,7 +144,7 @@ class ModelInstance: ...@@ -144,7 +144,7 @@ class ModelInstance:
user=user user=user
) )
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \
-> str: -> str:
""" """
Invoke large language model Invoke large language model
...@@ -161,7 +161,8 @@ class ModelInstance: ...@@ -161,7 +161,8 @@ class ModelInstance:
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
file=file, file=file,
user=user user=user,
**params
) )
......
...@@ -32,7 +32,7 @@ class ModelType(Enum): ...@@ -32,7 +32,7 @@ class ModelType(Enum):
return cls.TEXT_EMBEDDING return cls.TEXT_EMBEDDING
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
return cls.RERANK return cls.RERANK
elif origin_model_type == cls.SPEECH2TEXT.value: elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
return cls.SPEECH2TEXT return cls.SPEECH2TEXT
elif origin_model_type == cls.MODERATION.value: elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION return cls.MODERATION
......
...@@ -2,7 +2,7 @@ from pydantic import BaseModel ...@@ -2,7 +2,7 @@ from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \ from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \
DefaultParameterName, PriceConfig DefaultParameterName, PriceConfig, ModelPropertyKey
from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
...@@ -502,8 +502,8 @@ EMBEDDING_BASE_MODELS = [ ...@@ -502,8 +502,8 @@ EMBEDDING_BASE_MODELS = [
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model_properties={ model_properties={
'context_size': 8097, ModelPropertyKey.CONTEXT_SIZE: 8097,
'max_chunks': 32, ModelPropertyKey.MAX_CHUNKS: 32,
}, },
pricing=PriceConfig( pricing=PriceConfig(
input=0.0001, input=0.0001,
......
...@@ -30,7 +30,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -30,7 +30,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
stream: bool = True, user: Optional[str] = None) \ stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model # chat model
...@@ -59,7 +59,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -59,7 +59,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get( model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
ModelPropertyKey.MODE) ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value: if model_mode == LLMMode.CHAT.value:
...@@ -79,7 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -79,7 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
if 'base_model_name' not in credentials: if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required') raise CredentialsValidateFailedError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
if not ai_model_entity: if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
...@@ -109,8 +109,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -109,8 +109,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
return ai_model_entity.entity return ai_model_entity.entity if ai_model_entity else None
def _generate(self, model: str, credentials: dict, def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
......
...@@ -12,7 +12,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC ...@@ -12,7 +12,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
ConfigurateMethod
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider from extensions import ext_hosting_provider
from extensions.ext_database import db from extensions.ext_database import db
...@@ -607,7 +608,7 @@ class ProviderManager: ...@@ -607,7 +608,7 @@ class ProviderManager:
quota_used=provider_record.quota_used, quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit, quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1,
restrict_llms=provider_quota.restrict_llms restrict_models=provider_quota.restrict_models
) )
quota_configurations.append(quota_configuration) quota_configurations.append(quota_configuration)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment