Unverified Commit cdf7b249 authored by Yeuoly's avatar Yeuoly

refactor: deduct quota

parent 603ce8e5
from flask_restful import Resource
from controllers.console.setup import setup_required
from controllers.inner_api import api
from controllers.inner_api.wraps import inner_api_only
class EnterpriseModelInvokeApi(Resource):
"""Model invoke API for enterprise edition"""
@setup_required
@inner_api_only
def post(self):
pass
api.add_resource(EnterpriseModelInvokeApi, '/model/invoke')
\ No newline at end of file
from core.entities.application_entities import ApplicationGenerateEntity
from core.entities.provider_entities import QuotaUnit
from events.message_event import message_was_created
from extensions.ext_database import db
from models.provider import Provider, ProviderType
from libs.deduct_quota import DeductQuotaManager
from models.model import Message
@message_was_created.connect
def handle(sender, **kwargs):
def handle(sender: Message, **kwargs):
message = sender
application_generate_entity: ApplicationGenerateEntity = kwargs.get('application_generate_entity')
model_config = application_generate_entity.app_orchestration_config_entity.model_config
provider_model_bundle = model_config.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = message.message_tokens + message.answer_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if 'gpt-4' in model_config.model:
used_quota = 20
else:
used_quota = 1
if used_quota is not None:
db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.tenant_id,
Provider.provider_name == model_config.provider,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
db.session.commit()
DeductQuotaManager.deduct_quota(
provider_model_bundle=provider_model_bundle,
model=model_config.model,
message_tokens=message.message_tokens,
answer_tokens=message.answer_tokens
)
\ No newline at end of file
from events.inner_event import model_was_invoked
from core.provider_manager import ProviderManager
from core.model_runtime.entities.model_entities import ModelType
from libs.deduct_quota import DeductQuotaManager
provider_manager = ProviderManager()
@model_was_invoked.connect
def handle(sender, **kwargs):
"""
Invoke model event handler, handle the quota deduction
:param sender: sender
:param kwargs: kwargs
:param tenant_id: tenant id
:param model_config: model config
:param provider: provider
:param model_type: model type
:param model: model
:param message_tokens: message tokens
:param answer_tokens: answer tokens
:return: None
"""
tenant_id = kwargs.get('tenant_id')
if not tenant_id:
return
model_config = kwargs.get('model_config', {})
if not model_config:
return
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id,
model_config.get('provider', ''),
ModelType.value_of(model_config.get('model_type', ''))
)
if not provider_model_bundle:
return
DeductQuotaManager.deduct_quota(
provider_model_bundle=provider_model_bundle,
model=model_config.get('model', ''),
message_tokens=kwargs.get('message_tokens', 0),
answer_tokens=kwargs.get('answer_tokens', 0)
)
\ No newline at end of file
from blinker import signal
# sender: model invoke
model_was_invoked = signal('model-was-invoked')
from core.entities.provider_configuration import ProviderModelBundle
from core.entities.provider_entities import QuotaUnit
from models.provider import Provider, ProviderType
from extensions.ext_database import db
class DeductQuotaManager:
@staticmethod
def deduct_quota(
provider_model_bundle: ProviderModelBundle,
model: str,
message_tokens: int,
answer_tokens: int
):
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = message_tokens + answer_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if 'gpt-4' in model:
used_quota = 20
else:
used_quota = 1
if used_quota is not None:
db.session.query(Provider).filter(
Provider.tenant_id == provider_configuration.tenant_id,
Provider.provider_name == provider_configuration.provider,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
db.session.commit()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment