Commit 461655a8 authored by John Wang's avatar John Wang

feat: completed sync-anthropic-hosted-providers command

parent 8226c765
...@@ -18,7 +18,8 @@ from models.model import Account ...@@ -18,7 +18,8 @@ from models.model import Account
import secrets import secrets
import base64 import base64
from models.provider import Provider from models.provider import Provider, ProviderName
from services.provider_service import ProviderService
@click.command('reset-password', help='Reset the account password.') @click.command('reset-password', help='Reset the account password.')
...@@ -193,9 +194,40 @@ def recreate_all_dataset_indexes(): ...@@ -193,9 +194,40 @@ def recreate_all_dataset_indexes():
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green')) click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers():
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
page = 1
while True:
try:
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for tenant in tenants:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
count += 1
except Exception as e:
click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes) app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes) app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)
...@@ -50,7 +50,9 @@ DEFAULTS = { ...@@ -50,7 +50,9 @@ DEFAULTS = {
'PDF_PREVIEW': 'True', 'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai' 'DEFAULT_LLM_PROVIDER': 'openai',
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
} }
...@@ -193,6 +195,9 @@ class Config: ...@@ -193,6 +195,9 @@ class Config:
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY') self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
# By default it is False # By default it is False
# You could disable it for compatibility with certain OpenAPI providers # You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
......
...@@ -3,6 +3,7 @@ import base64 ...@@ -3,6 +3,7 @@ import base64
import json import json
import logging import logging
from flask import current_app
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, abort from flask_restful import Resource, reqparse, abort
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
...@@ -203,7 +204,19 @@ class ProviderSystemApi(Resource): ...@@ -203,7 +204,19 @@ class ProviderSystemApi(Resource):
provider_model.is_valid = args['is_enabled'] provider_model.is_valid = args['is_enabled']
db.session.commit() db.session.commit()
elif not provider_model: elif not provider_model:
ProviderService.create_system_provider(tenant, provider, args['is_enabled']) if provider == ProviderName.OPENAI.value:
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
elif provider == ProviderName.ANTHROPIC.value:
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
else:
quota_limit = 0
ProviderService.create_system_provider(
tenant,
provider,
quota_limit,
args['is_enabled']
)
else: else:
abort(403) abort(403)
......
from flask import current_app
from events.tenant_event import tenant_was_updated from events.tenant_event import tenant_was_updated
from models.provider import ProviderName
from services.provider_service import ProviderService from services.provider_service import ProviderService
...@@ -6,4 +9,16 @@ from services.provider_service import ProviderService ...@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs): def handle(sender, **kwargs):
tenant = sender tenant = sender
if tenant.status == 'normal': if tenant.status == 'normal':
ProviderService.create_system_provider(tenant) ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
from flask import current_app
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from models.provider import ProviderName
from services.provider_service import ProviderService from services.provider_service import ProviderService
...@@ -6,4 +9,16 @@ from services.provider_service import ProviderService ...@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs): def handle(sender, **kwargs):
tenant = sender tenant = sender
if tenant.status == 'normal': if tenant.status == 'normal':
ProviderService.create_system_provider(tenant) ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
...@@ -73,7 +73,7 @@ class ProviderService: ...@@ -73,7 +73,7 @@ class ProviderService:
return llm_provider_service.get_encrypted_token(configs) return llm_provider_service.get_encrypted_token(configs)
@staticmethod @staticmethod
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
is_valid: bool = True): is_valid: bool = True):
if current_app.config['EDITION'] != 'CLOUD': if current_app.config['EDITION'] != 'CLOUD':
return return
...@@ -90,7 +90,7 @@ class ProviderService: ...@@ -90,7 +90,7 @@ class ProviderService:
provider_name=provider_name, provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value, provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value, quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=200, quota_limit=quota_limit,
encrypted_config='', encrypted_config='',
is_valid=is_valid, is_valid=is_valid,
) )
......
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Tenant from models.account import Tenant
from models.provider import Provider, ProviderType from models.provider import Provider, ProviderType, ProviderName
class WorkspaceService: class WorkspaceService:
...@@ -33,7 +33,7 @@ class WorkspaceService: ...@@ -33,7 +33,7 @@ class WorkspaceService:
if provider.is_valid and provider.encrypted_config: if provider.is_valid and provider.encrypted_config:
custom_provider = provider custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value: elif provider.provider_type == ProviderType.SYSTEM.value:
if provider.is_valid: if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
system_provider = provider system_provider = provider
if system_provider and not custom_provider: if system_provider and not custom_provider:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment