Unverified Commit a5d21f3b authored by Matri's avatar Matri Committed by GitHub

fix: shortening invite url (#1100)

Co-authored-by: 's avatarMatriQi <matri@aifi.io>
parent 7ba068c3
...@@ -16,26 +16,25 @@ from services.account_service import RegisterService ...@@ -16,26 +16,25 @@ from services.account_service import RegisterService
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args') parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
parser.add_argument('email', type=email, required=True, nullable=False, location='args') parser.add_argument('email', type=email, required=False, nullable=True, location='args')
parser.add_argument('token', type=str, required=True, nullable=False, location='args') parser.add_argument('token', type=str, required=True, nullable=False, location='args')
args = parser.parse_args() args = parser.parse_args()
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token']) workspaceId = args['workspace_id']
reg_email = args['email']
token = args['token']
tenant = db.session.query(Tenant).filter( invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
Tenant.id == args['workspace_id'],
Tenant.status == 'normal'
).first()
return {'is_valid': account is not None, 'workspace_name': tenant.name} return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
class ActivateApi(Resource): class ActivateApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json') parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
parser.add_argument('email', type=email, required=True, nullable=False, location='json') parser.add_argument('email', type=email, required=False, nullable=True, location='json')
parser.add_argument('token', type=str, required=True, nullable=False, location='json') parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json') parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json') parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
...@@ -44,12 +43,13 @@ class ActivateApi(Resource): ...@@ -44,12 +43,13 @@ class ActivateApi(Resource):
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json') parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token']) invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
if account is None: if invitation is None:
raise AlreadyActivateError() raise AlreadyActivateError()
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token']) RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
account = invitation['account']
account.name = args['name'] account.name = args['name']
# generate password salt # generate password salt
......
...@@ -72,7 +72,7 @@ class MemberInviteEmailApi(Resource): ...@@ -72,7 +72,7 @@ class MemberInviteEmailApi(Resource):
invitation_results.append({ invitation_results.append({
'status': 'success', 'status': 'success',
'email': invitee_email, 'email': invitee_email,
'url': f'{console_web_url}/activate?workspace_id={current_user.current_tenant_id}&email={invitee_email}&token={token}' 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
}) })
account = marshal(account, account_fields) account = marshal(account, account_fields)
account['role'] = role account['role'] = role
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import base64 import base64
import json
import logging import logging
import secrets import secrets
import uuid import uuid
...@@ -346,6 +347,10 @@ class TenantService: ...@@ -346,6 +347,10 @@ class TenantService:
class RegisterService: class RegisterService:
@classmethod
def _get_invitation_token_key(cls, token: str) -> str:
return f'member_invite:token:{token}'
@classmethod @classmethod
def register(cls, email, name, password: str = None, open_id: str = None, provider: str = None) -> Account: def register(cls, email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
db.session.begin_nested() db.session.begin_nested()
...@@ -401,7 +406,7 @@ class RegisterService: ...@@ -401,7 +406,7 @@ class RegisterService:
# send email # send email
send_invite_member_mail_task.delay( send_invite_member_mail_task.delay(
to=email, to=email,
token=cls.generate_invite_token(tenant, account), token=token,
inviter_name=inviter.name if inviter else 'Dify', inviter_name=inviter.name if inviter else 'Dify',
workspace_id=tenant.id, workspace_id=tenant.id,
workspace_name=tenant.name, workspace_name=tenant.name,
...@@ -412,21 +417,35 @@ class RegisterService: ...@@ -412,21 +417,35 @@ class RegisterService:
@classmethod @classmethod
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str: def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
token = str(uuid.uuid4()) token = str(uuid.uuid4())
email_hash = sha256(account.email.encode()).hexdigest() invitation_data = {
cache_key = 'member_invite_token:{}, {}:{}'.format(str(tenant.id), email_hash, token) 'account_id': account.id,
redis_client.setex(cache_key, 3600, str(account.id)) 'email': account.email,
'workspace_id': tenant.id,
}
redis_client.setex(
cls._get_invitation_token_key(token),
3600,
json.dumps(invitation_data)
)
return token return token
@classmethod @classmethod
def revoke_token(cls, workspace_id: str, email: str, token: str): def revoke_token(cls, workspace_id: str, email: str, token: str):
email_hash = sha256(email.encode()).hexdigest() if workspace_id and email:
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) email_hash = sha256(email.encode()).hexdigest()
redis_client.delete(cache_key) cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
redis_client.delete(cache_key)
else:
redis_client.delete(cls._get_invitation_token_key(token))
@classmethod @classmethod
def get_account_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[Account]: def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[Account]:
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
tenant = db.session.query(Tenant).filter( tenant = db.session.query(Tenant).filter(
Tenant.id == workspace_id, Tenant.id == invitation_data['workspace_id'],
Tenant.status == 'normal' Tenant.status == 'normal'
).first() ).first()
...@@ -435,30 +454,43 @@ class RegisterService: ...@@ -435,30 +454,43 @@ class RegisterService:
tenant_account = db.session.query(Account, TenantAccountJoin.role).join( tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id TenantAccountJoin, Account.id == TenantAccountJoin.account_id
).filter(Account.email == email, TenantAccountJoin.tenant_id == tenant.id).first() ).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first()
if not tenant_account: if not tenant_account:
return None return None
account_id = cls._get_account_id_by_invite_token(workspace_id, email, token)
if not account_id:
return None
account = tenant_account[0] account = tenant_account[0]
if not account: if not account:
return None return None
if account_id != str(account.id): if invitation_data['account_id'] != str(account.id):
return None return None
return account return {
'account': account,
'data': invitation_data,
'tenant': tenant,
}
@classmethod @classmethod
def _get_account_id_by_invite_token(cls, workspace_id: str, email: str, token: str) -> Optional[str]: def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[str]:
email_hash = sha256(email.encode()).hexdigest() if workspace_id is not None and email is not None:
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) email_hash = sha256(email.encode()).hexdigest()
account_id = redis_client.get(cache_key) cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}'
if not account_id: account_id = redis_client.get(cache_key)
return None
if not account_id:
return None
return {
'account_id': account_id.decode('utf-8'),
'email': email,
'workspace_id': workspace_id,
}
else:
data = redis_client.get(cls._get_invitation_token_key(token))
if not data:
return None
return account_id.decode('utf-8') invitation = json.loads(data)
return invitation
...@@ -31,8 +31,8 @@ const ActivateForm = () => { ...@@ -31,8 +31,8 @@ const ActivateForm = () => {
const checkParams = { const checkParams = {
url: '/activate/check', url: '/activate/check',
params: { params: {
workspace_id: workspaceID, ...workspaceID && { workspace_id: workspaceID },
email, ...email && { email },
token, token,
}, },
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment