Unverified Commit 227f9fb7 authored by zxhlyh's avatar zxhlyh Committed by GitHub

Feat/api jwt (#1212)

parent c40ee7e6
...@@ -50,24 +50,6 @@ S3_REGION=your-region ...@@ -50,24 +50,6 @@ S3_REGION=your-region
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Cookie configuration
COOKIE_HTTPONLY=true
COOKIE_SAMESITE=None
COOKIE_SECURE=true
# Session configuration
SESSION_PERMANENT=true
SESSION_USE_SIGNER=true
## support redis, sqlalchemy
SESSION_TYPE=redis
# session redis configuration
SESSION_REDIS_HOST=localhost
SESSION_REDIS_PORT=6379
SESSION_REDIS_PASSWORD=difyai123456
SESSION_REDIS_DB=2
# Vector database configuration, support: weaviate, qdrant # Vector database configuration, support: weaviate, qdrant
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import os import os
from datetime import datetime, timedelta
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey from gevent import monkey
...@@ -12,12 +11,11 @@ import logging ...@@ -12,12 +11,11 @@ import logging
import json import json
import threading import threading
from flask import Flask, request, Response, session from flask import Flask, request, Response
import flask_login
from flask_cors import CORS from flask_cors import CORS
from core.model_providers.providers import hosted from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_login import login_manager from extensions.ext_login import login_manager
...@@ -27,12 +25,10 @@ from models import model, account, dataset, web, task, source, tool ...@@ -27,12 +25,10 @@ from models import model, account, dataset, web, task, source, tool
from events import event_handlers from events import event_handlers
# DO NOT REMOVE ABOVE # DO NOT REMOVE ABOVE
import core
from config import Config, CloudEditionConfig from config import Config, CloudEditionConfig
from commands import register_commands from commands import register_commands
from models.account import TenantAccountJoin, AccountStatus from services.account_service import AccountService
from models.model import Account, EndUser, App from libs.passport import PassportService
from services.account_service import TenantService
import warnings import warnings
warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", ResourceWarning)
...@@ -85,81 +81,33 @@ def initialize_extensions(app): ...@@ -85,81 +81,33 @@ def initialize_extensions(app):
ext_redis.init_app(app) ext_redis.init_app(app)
ext_storage.init_app(app) ext_storage.init_app(app)
ext_celery.init_app(app) ext_celery.init_app(app)
ext_session.init_app(app)
ext_login.init_app(app) ext_login.init_app(app)
ext_mail.init_app(app) ext_mail.init_app(app)
ext_sentry.init_app(app) ext_sentry.init_app(app)
ext_stripe.init_app(app) ext_stripe.init_app(app)
def _create_tenant_for_account(account):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant
return tenant
# Flask-Login configuration # Flask-Login configuration
@login_manager.user_loader @login_manager.request_loader
def load_user(user_id): def load_user_from_request(request_from_flask_login):
"""Load user based on the user_id.""" """Load user based on the request."""
if request.blueprint == 'console': if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
if '.' in user_id: auth_header = request.headers.get('Authorization', '')
tenant_id, account_id = user_id.split('.') if ' ' not in auth_header:
else: raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
account_id = user_id auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
account = db.session.query(Account).filter(Account.id == account_id).first() if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: decoded = PassportService().verify(auth_token)
raise Forbidden('Account is banned or closed.') user_id = decoded.get('user_id')
workspace_id = session.get('workspace_id') return AccountService.load_user(user_id)
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()
if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
current_time = datetime.utcnow()
# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()
# Log in the user with the updated user_id
flask_login.login_user(account, remember=True)
return account
else: else:
return None return None
@login_manager.unauthorized_handler @login_manager.unauthorized_handler
def unauthorized_handler(): def unauthorized_handler():
"""Handle unauthorized requests.""" """Handle unauthorized requests."""
...@@ -216,6 +164,7 @@ if app.config['TESTING']: ...@@ -216,6 +164,7 @@ if app.config['TESTING']:
@app.after_request @app.after_request
def after_request(response): def after_request(response):
"""Add Version headers to the response.""" """Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION']) response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV']) response.headers.add('X-Env', app.config['DEPLOY_ENV'])
return response return response
......
...@@ -10,9 +10,6 @@ from extensions.ext_redis import redis_client ...@@ -10,9 +10,6 @@ from extensions.ext_redis import redis_client
dotenv.load_dotenv() dotenv.load_dotenv()
DEFAULTS = { DEFAULTS = {
'COOKIE_HTTPONLY': 'True',
'COOKIE_SECURE': 'True',
'COOKIE_SAMESITE': 'None',
'DB_USERNAME': 'postgres', 'DB_USERNAME': 'postgres',
'DB_PASSWORD': '', 'DB_PASSWORD': '',
'DB_HOST': 'localhost', 'DB_HOST': 'localhost',
...@@ -22,10 +19,6 @@ DEFAULTS = { ...@@ -22,10 +19,6 @@ DEFAULTS = {
'REDIS_PORT': '6379', 'REDIS_PORT': '6379',
'REDIS_DB': '0', 'REDIS_DB': '0',
'REDIS_USE_SSL': 'False', 'REDIS_USE_SSL': 'False',
'SESSION_REDIS_HOST': 'localhost',
'SESSION_REDIS_PORT': '6379',
'SESSION_REDIS_DB': '2',
'SESSION_REDIS_USE_SSL': 'False',
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize', 'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
'OAUTH_REDIRECT_INDEX_PATH': '/', 'OAUTH_REDIRECT_INDEX_PATH': '/',
'CONSOLE_WEB_URL': 'https://cloud.dify.ai', 'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
...@@ -36,9 +29,6 @@ DEFAULTS = { ...@@ -36,9 +29,6 @@ DEFAULTS = {
'STORAGE_TYPE': 'local', 'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage', 'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai', 'CHECK_UPDATE_URL': 'https://updates.dify.ai',
'SESSION_TYPE': 'sqlalchemy',
'SESSION_PERMANENT': 'True',
'SESSION_USE_SIGNER': 'True',
'DEPLOY_ENV': 'PRODUCTION', 'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30, 'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_POOL_RECYCLE': 3600, 'SQLALCHEMY_POOL_RECYCLE': 3600,
...@@ -115,20 +105,6 @@ class Config: ...@@ -115,20 +105,6 @@ class Config:
# Alternatively you can set it with `SECRET_KEY` environment variable. # Alternatively you can set it with `SECRET_KEY` environment variable.
self.SECRET_KEY = get_env('SECRET_KEY') self.SECRET_KEY = get_env('SECRET_KEY')
# cookie settings
self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.PERMANENT_SESSION_LIFETIME = timedelta(days=7)
# session settings, only support sqlalchemy, redis
self.SESSION_TYPE = get_env('SESSION_TYPE')
self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT')
self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER')
# redis settings # redis settings
self.REDIS_HOST = get_env('REDIS_HOST') self.REDIS_HOST = get_env('REDIS_HOST')
self.REDIS_PORT = get_env('REDIS_PORT') self.REDIS_PORT = get_env('REDIS_PORT')
...@@ -137,14 +113,6 @@ class Config: ...@@ -137,14 +113,6 @@ class Config:
self.REDIS_DB = get_env('REDIS_DB') self.REDIS_DB = get_env('REDIS_DB')
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL') self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
# session redis settings
self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST')
self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT')
self.SESSION_REDIS_USERNAME = get_env('SESSION_REDIS_USERNAME')
self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD')
self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB')
self.SESSION_REDIS_USE_SSL = get_bool_env('SESSION_REDIS_USE_SSL')
# storage settings # storage settings
self.STORAGE_TYPE = get_env('STORAGE_TYPE') self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
......
...@@ -6,7 +6,6 @@ from flask_restful import Resource, reqparse ...@@ -6,7 +6,6 @@ from flask_restful import Resource, reqparse
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from libs.helper import email from libs.helper import email
from libs.password import valid_password from libs.password import valid_password
...@@ -37,12 +36,12 @@ class LoginApi(Resource): ...@@ -37,12 +36,12 @@ class LoginApi(Resource):
except Exception: except Exception:
pass pass
flask_login.login_user(account, remember=args['remember_me'])
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)
# todo: return the user info # todo: return the user info
token = AccountService.get_account_jwt_token(account)
return {'result': 'success'} return {'result': 'success', 'data': token}
class LogoutApi(Resource): class LogoutApi(Resource):
......
...@@ -2,9 +2,8 @@ import logging ...@@ -2,9 +2,8 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
import flask_login
import requests import requests
from flask import request, redirect, current_app, session from flask import request, redirect, current_app
from flask_restful import Resource from flask_restful import Resource
from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth
...@@ -75,12 +74,11 @@ class OAuthCallback(Resource): ...@@ -75,12 +74,11 @@ class OAuthCallback(Resource):
account.initialized_at = datetime.utcnow() account.initialized_at = datetime.utcnow()
db.session.commit() db.session.commit()
# login user
session.clear()
flask_login.login_user(account, remember=True)
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success') token = AccountService.get_account_jwt_token(account)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from functools import wraps from functools import wraps
import flask_login
from flask import request, current_app from flask import request, current_app
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
...@@ -58,9 +57,6 @@ class SetupApi(Resource): ...@@ -58,9 +57,6 @@ class SetupApi(Resource):
) )
setup() setup()
# Login
flask_login.login_user(account)
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)
return {'result': 'success'}, 201 return {'result': 'success'}, 201
......
import os import os
from functools import wraps from functools import wraps
import flask_login
from flask import current_app from flask import current_app
from flask import g from flask import g
from flask import has_request_context from flask import has_request_context
from flask import request from flask import request, session
from flask_login import user_logged_in from flask_login import user_logged_in
from flask_login.config import EXEMPT_METHODS from flask_login.config import EXEMPT_METHODS
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
......
import redis
from redis.connection import SSLConnection, Connection
from flask import request
from flask_session import Session, SqlAlchemySessionInterface, RedisSessionInterface
from flask_session.sessions import total_seconds
from itsdangerous import want_bytes
from extensions.ext_database import db
sess = Session()
def init_app(app):
sqlalchemy_session_interface = CustomSqlAlchemySessionInterface(
app,
db,
app.config.get('SESSION_SQLALCHEMY_TABLE', 'sessions'),
app.config.get('SESSION_KEY_PREFIX', 'session:'),
app.config.get('SESSION_USE_SIGNER', False),
app.config.get('SESSION_PERMANENT', True)
)
session_type = app.config.get('SESSION_TYPE')
if session_type == 'sqlalchemy':
app.session_interface = sqlalchemy_session_interface
elif session_type == 'redis':
connection_class = Connection
if app.config.get('SESSION_REDIS_USE_SSL', False):
connection_class = SSLConnection
sess_redis_client = redis.Redis()
sess_redis_client.connection_pool = redis.ConnectionPool(**{
'host': app.config.get('SESSION_REDIS_HOST', 'localhost'),
'port': app.config.get('SESSION_REDIS_PORT', 6379),
'username': app.config.get('SESSION_REDIS_USERNAME', None),
'password': app.config.get('SESSION_REDIS_PASSWORD', None),
'db': app.config.get('SESSION_REDIS_DB', 2),
'encoding': 'utf-8',
'encoding_errors': 'strict',
'decode_responses': False
}, connection_class=connection_class)
app.extensions['session_redis'] = sess_redis_client
app.session_interface = CustomRedisSessionInterface(
sess_redis_client,
app.config.get('SESSION_KEY_PREFIX', 'session:'),
app.config.get('SESSION_USE_SIGNER', False),
app.config.get('SESSION_PERMANENT', True)
)
class CustomSqlAlchemySessionInterface(SqlAlchemySessionInterface):
def __init__(
self,
app,
db,
table,
key_prefix,
use_signer=False,
permanent=True,
sequence=None,
autodelete=False,
):
if db is None:
from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy(app)
self.db = db
self.key_prefix = key_prefix
self.use_signer = use_signer
self.permanent = permanent
self.autodelete = autodelete
self.sequence = sequence
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
class Session(self.db.Model):
__tablename__ = table
if sequence:
id = self.db.Column( # noqa: A003, VNE003, A001
self.db.Integer, self.db.Sequence(sequence), primary_key=True
)
else:
id = self.db.Column( # noqa: A003, VNE003, A001
self.db.Integer, primary_key=True
)
session_id = self.db.Column(self.db.String(255), unique=True)
data = self.db.Column(self.db.LargeBinary)
expiry = self.db.Column(self.db.DateTime)
def __init__(self, session_id, data, expiry):
self.session_id = session_id
self.data = data
self.expiry = expiry
def __repr__(self):
return f"<Session data {self.data}>"
self.sql_session_model = Session
def save_session(self, *args, **kwargs):
if request.blueprint == 'service_api':
return
elif request.method == 'OPTIONS':
return
elif request.endpoint and request.endpoint == 'health':
return
return super().save_session(*args, **kwargs)
class CustomRedisSessionInterface(RedisSessionInterface):
def save_session(self, app, session, response):
if request.blueprint == 'service_api':
return
elif request.method == 'OPTIONS':
return
elif request.endpoint and request.endpoint == 'health':
return
if not self.should_set_cookie(app, session):
return
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
if not session:
if session.modified:
self.redis.delete(self.key_prefix + session.sid)
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return
# Modification case. There are upsides and downsides to
# emitting a set-cookie header each request. The behavior
# is controlled by the :meth:`should_set_cookie` method
# which performs a quick check to figure out if the cookie
# should be set or not. This is controlled by the
# SESSION_REFRESH_EACH_REQUEST config flag as well as
# the permanent flag on the session itself.
# if not self.should_set_cookie(app, session):
# return
conditional_cookie_kwargs = {}
httponly = self.get_cookie_httponly(app)
secure = self.get_cookie_secure(app)
if self.has_same_site_capability:
conditional_cookie_kwargs["samesite"] = self.get_cookie_samesite(app)
expires = self.get_expiration_time(app, session)
if session.permanent:
value = self.serializer.dumps(dict(session))
if value is not None:
self.redis.setex(
name=self.key_prefix + session.sid,
value=value,
time=total_seconds(app.permanent_session_lifetime),
)
if self.use_signer:
session_id = self._get_signer(app).sign(want_bytes(session.sid)).decode("utf-8")
else:
session_id = session.sid
response.set_cookie(
app.config["SESSION_COOKIE_NAME"],
session_id,
expires=expires,
httponly=httponly,
domain=domain,
path=path,
secure=secure,
**conditional_cookie_kwargs,
)
...@@ -4,11 +4,12 @@ import json ...@@ -4,11 +4,12 @@ import json
import logging import logging
import secrets import secrets
import uuid import uuid
from datetime import datetime from datetime import datetime, timedelta
from hashlib import sha256 from hashlib import sha256
from typing import Optional from typing import Optional
from flask import session from werkzeug.exceptions import Forbidden, Unauthorized
from flask import session, current_app
from sqlalchemy import func from sqlalchemy import func
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
...@@ -19,16 +20,82 @@ from services.errors.account import AccountLoginError, CurrentPasswordIncorrectE ...@@ -19,16 +20,82 @@ from services.errors.account import AccountLoginError, CurrentPasswordIncorrectE
from libs.helper import get_remote_ip from libs.helper import get_remote_ip
from libs.password import compare_password, hash_password from libs.password import compare_password, hash_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from libs.passport import PassportService
from models.account import * from models.account import *
from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task
def _create_tenant_for_account(account):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant
return tenant
class AccountService: class AccountService:
@staticmethod @staticmethod
def load_user(account_id: int) -> Account: def load_user(user_id: str) -> Account:
# todo: used by flask_login # todo: used by flask_login
pass if '.' in user_id:
tenant_id, account_id = user_id.split('.')
else:
account_id = user_id
account = db.session.query(Account).filter(Account.id == account_id).first()
if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()
if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
current_time = datetime.utcnow()
# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()
return account
@staticmethod
def get_account_jwt_token(account):
payload = {
"user_id": account.id,
"exp": datetime.utcnow() + timedelta(days=30),
"iss": current_app.config['EDITION'],
"sub": 'Console API Passport',
}
token = PassportService().issue(payload)
return token
@staticmethod @staticmethod
def authenticate(email: str, password: str) -> Account: def authenticate(email: str, password: str) -> Account:
......
...@@ -49,15 +49,6 @@ services: ...@@ -49,15 +49,6 @@ services:
REDIS_USE_SSL: 'false' REDIS_USE_SSL: 'false'
# use redis db 0 for redis cache # use redis db 0 for redis cache
REDIS_DB: 0 REDIS_DB: 0
# The configurations of session, Supported values are `sqlalchemy`. `redis`
SESSION_TYPE: redis
SESSION_REDIS_HOST: redis
SESSION_REDIS_PORT: 6379
SESSION_REDIS_USERNAME: ''
SESSION_REDIS_PASSWORD: difyai123456
SESSION_REDIS_USE_SSL: 'false'
# use redis db 2 for session store
SESSION_REDIS_DB: 2
# The configurations of celery broker. # The configurations of celery broker.
# Use redis as the broker, and redis db 1 for celery broker. # Use redis as the broker, and redis db 1 for celery broker.
CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1 CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1
...@@ -76,10 +67,6 @@ services: ...@@ -76,10 +67,6 @@ services:
# If you want to enable cross-origin support, # If you want to enable cross-origin support,
# you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`. # you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`.
# #
# For **production** purposes, please set `SameSite=Lax, Secure=true, HttpOnly=true`.
COOKIE_HTTPONLY: 'true'
COOKIE_SAMESITE: 'Lax'
COOKIE_SECURE: 'false'
# The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local` # The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local`
STORAGE_TYPE: local STORAGE_TYPE: local
# The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`. # The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`.
......
'use client' 'use client'
import { SWRConfig } from 'swr' import { SWRConfig } from 'swr'
import { useEffect, useState } from 'react'
import type { ReactNode } from 'react' import type { ReactNode } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
type SwrInitorProps = { type SwrInitorProps = {
children: ReactNode children: ReactNode
...@@ -9,13 +11,32 @@ type SwrInitorProps = { ...@@ -9,13 +11,32 @@ type SwrInitorProps = {
const SwrInitor = ({ const SwrInitor = ({
children, children,
}: SwrInitorProps) => { }: SwrInitorProps) => {
return ( const router = useRouter()
<SWRConfig value={{ const searchParams = useSearchParams()
shouldRetryOnError: false, const consoleToken = searchParams.get('console_token')
}}> const consoleTokenFromLocalStorage = localStorage?.getItem('console_token')
{children} const [init, setInit] = useState(false)
</SWRConfig>
) useEffect(() => {
if (!(consoleToken || consoleTokenFromLocalStorage))
router.replace('/signin')
if (consoleToken) {
localStorage?.setItem('console_token', consoleToken!)
router.replace('/apps', { forceOptimisticNavigation: false })
}
setInit(true)
}, [])
return init
? (
<SWRConfig value={{
shouldRetryOnError: false,
}}>
{children}
</SWRConfig>
)
: null
} }
export default SwrInitor export default SwrInitor
...@@ -8,6 +8,10 @@ import I18n from '@/context/i18n' ...@@ -8,6 +8,10 @@ import I18n from '@/context/i18n'
const Header = () => { const Header = () => {
const { locale, setLocaleOnClient } = useContext(I18n) const { locale, setLocaleOnClient } = useContext(I18n)
if (localStorage?.getItem('console_token'))
localStorage.removeItem('console_token')
return <div className='flex items-center justify-between p-6 w-full'> return <div className='flex items-center justify-between p-6 w-full'>
<div className={style.logo}></div> <div className={style.logo}></div>
<Select <Select
......
...@@ -89,7 +89,7 @@ const NormalForm = () => { ...@@ -89,7 +89,7 @@ const NormalForm = () => {
} }
try { try {
setIsLoading(true) setIsLoading(true)
await login({ const res = await login({
url: '/login', url: '/login',
body: { body: {
email, email,
...@@ -97,7 +97,8 @@ const NormalForm = () => { ...@@ -97,7 +97,8 @@ const NormalForm = () => {
remember_me: true, remember_me: true,
}, },
}) })
router.push('/apps') localStorage.setItem('console_token', res.data)
router.replace('/apps')
} }
finally { finally {
setIsLoading(false) setIsLoading(false)
......
...@@ -179,6 +179,10 @@ const baseFetch = <T>( ...@@ -179,6 +179,10 @@ const baseFetch = <T>(
} }
options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`) options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
} }
else {
const accessToken = localStorage.getItem('console_token') || ''
options.headers.set('Authorization', `Bearer ${accessToken}`)
}
if (deleteContentType) { if (deleteContentType) {
options.headers.delete('Content-Type') options.headers.delete('Content-Type')
...@@ -292,7 +296,9 @@ export const upload = (options: any): Promise<any> => { ...@@ -292,7 +296,9 @@ export const upload = (options: any): Promise<any> => {
const defaultOptions = { const defaultOptions = {
method: 'POST', method: 'POST',
url: `${API_PREFIX}/files/upload`, url: `${API_PREFIX}/files/upload`,
headers: {}, headers: {
Authorization: `Bearer ${localStorage.getItem('console_token') || ''}`,
},
data: {}, data: {},
} }
options = { options = {
......
...@@ -15,8 +15,8 @@ import type { ...@@ -15,8 +15,8 @@ import type {
} from '@/models/app' } from '@/models/app'
import type { BackendModel, ProviderMap } from '@/app/components/header/account-setting/model-page/declarations' import type { BackendModel, ProviderMap } from '@/app/components/header/account-setting/model-page/declarations'
export const login: Fetcher<CommonResponse, { url: string; body: Record<string, any> }> = ({ url, body }) => { export const login: Fetcher<CommonResponse & { data: string }, { url: string; body: Record<string, any> }> = ({ url, body }) => {
return post<CommonResponse>(url, { body }) return post(url, { body }) as Promise<CommonResponse & { data: string }>
} }
export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => { export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment