Unverified Commit a8f23ed7 authored by crazywoola's avatar crazywoola Committed by GitHub

Feat/move tenant id into db (#2341)

parent ecf94725
...@@ -8,7 +8,7 @@ from flask import current_app, request ...@@ -8,7 +8,7 @@ from flask import current_app, request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from libs.helper import email from libs.helper import email
from libs.password import valid_password from libs.password import valid_password
from services.account_service import AccountService, TenantService from services.account_service import AccountService
class LoginApi(Resource): class LoginApi(Resource):
...@@ -30,11 +30,6 @@ class LoginApi(Resource): ...@@ -30,11 +30,6 @@ class LoginApi(Resource):
except services.errors.account.AccountLoginError: except services.errors.account.AccountLoginError:
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401 return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
try:
TenantService.switch_tenant(account)
except Exception:
pass
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)
# todo: return the user info # todo: return the user info
...@@ -47,7 +42,6 @@ class LogoutApi(Resource): ...@@ -47,7 +42,6 @@ class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
flask.session.pop('workspace_id', None)
flask_login.logout_user() flask_login.logout_user()
return {'result': 'success'} return {'result': 'success'}
......
"""empty message
Revision ID: 16830a790f0f
Revises: 380c6aa5a70d
Create Date: 2024-02-01 08:21:31.111119
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '16830a790f0f'
down_revision = '380c6aa5a70d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op:
batch_op.add_column(sa.Column('current', sa.Boolean(), server_default=sa.text('false'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op:
batch_op.drop_column('current')
# ### end Alembic commands ###
import enum import enum
import json import json
from math import e
from typing import List from typing import List
from extensions.ext_database import db from extensions.ext_database import db
...@@ -155,6 +154,7 @@ class TenantAccountJoin(db.Model): ...@@ -155,6 +154,7 @@ class TenantAccountJoin(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False) tenant_id = db.Column(UUID, nullable=False)
account_id = db.Column(UUID, nullable=False) account_id = db.Column(UUID, nullable=False)
current = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
role = db.Column(db.String(16), nullable=False, server_default='normal') role = db.Column(db.String(16), nullable=False, server_default='normal')
invited_by = db.Column(UUID, nullable=True) invited_by = db.Column(UUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
......
...@@ -11,7 +11,7 @@ from typing import Any, Dict, Optional ...@@ -11,7 +11,7 @@ from typing import Any, Dict, Optional
from constants.languages import language_timezone_mapping, languages from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from flask import current_app, session from flask import current_app
from libs.helper import get_remote_ip from libs.helper import get_remote_ip
from libs.passport import PassportService from libs.passport import PassportService
from libs.password import compare_password, hash_password from libs.password import compare_password, hash_password
...@@ -23,7 +23,8 @@ from services.errors.account import (AccountAlreadyInTenantError, AccountLoginEr ...@@ -23,7 +23,8 @@ from services.errors.account import (AccountAlreadyInTenantError, AccountLoginEr
NoPermissionError, RoleAlreadyAssignedError, TenantNotFound) NoPermissionError, RoleAlreadyAssignedError, TenantNotFound)
from sqlalchemy import func from sqlalchemy import func
from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task
from werkzeug.exceptions import Forbidden, Unauthorized from werkzeug.exceptions import Forbidden
from sqlalchemy import exc
def _create_tenant_for_account(account) -> Tenant: def _create_tenant_for_account(account) -> Tenant:
...@@ -39,54 +40,33 @@ class AccountService: ...@@ -39,54 +40,33 @@ class AccountService:
@staticmethod @staticmethod
def load_user(user_id: str) -> Account: def load_user(user_id: str) -> Account:
# todo: used by flask_login account = Account.query.filter_by(id=user_id).first()
if '.' in user_id: if not account:
tenant_id, account_id = user_id.split('.') return None
else:
account_id = user_id
account = db.session.query(Account).filter(Account.id == account_id).first()
if account: if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.') raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id') # init owner's tenant
if workspace_id: tenant_owner = TenantAccountJoin.query.filter_by(account_id=account.id, role='owner').first()
tenant_account_join = db.session.query(TenantAccountJoin).filter( if not tenant_owner:
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()
if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account) _create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
current_time = datetime.utcnow() current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
if current_tenant:
account.current_tenant_id = current_tenant.tenant_id
else:
account.current_tenant_id = tenant_owner.tenant_id
tenant_owner.current = True
db.session.commit()
# update last_active_at when last_active_at is more than 10 minutes ago if datetime.utcnow() - account.last_active_at > timedelta(minutes=10):
if current_time - account.last_active_at > timedelta(minutes=10): account.last_active_at = datetime.utcnow()
account.last_active_at = current_time
db.session.commit() db.session.commit()
return account return account
@staticmethod @staticmethod
def get_account_jwt_token(account): def get_account_jwt_token(account):
payload = { payload = {
...@@ -277,18 +257,21 @@ class TenantService: ...@@ -277,18 +257,21 @@ class TenantService:
@staticmethod @staticmethod
def switch_tenant(account: Account, tenant_id: int = None) -> None: def switch_tenant(account: Account, tenant_id: int = None) -> None:
"""Switch the current workspace for the account""" """Switch the current workspace for the account"""
if not tenant_id:
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id).first()
else:
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
# Check if the tenant exists and the account is a member of the tenant tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
if not tenant_account_join: if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
with db.session.begin():
try:
TenantAccountJoin.query.filter_by(account_id=account.id).update({'current': False})
tenant_account_join.current = True
db.session.commit()
# Set the current tenant for the account # Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id account.current_tenant_id = tenant_account_join.tenant_id
session['workspace_id'] = account.current_tenant.id except exc.SQLAlchemyError:
db.session.rollback()
raise
@staticmethod @staticmethod
def get_tenant_members(tenant: Tenant) -> List[Account]: def get_tenant_members(tenant: Tenant) -> List[Account]:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment