Source code for caspia.homeserver.authorization

import binascii
import functools
import inspect
import os
import uuid

import sanic_jwt as jwt
from sanic.exceptions import Unauthorized
from tinydb import Query


async def _call_handler(handler, request, *args, **kwargs):
    result = handler(request, *args, **kwargs)
    if inspect.isawaitable(result):
        return await result
    else:
        return result


[docs]def with_user(): """Pass an authenticated user as an argument (decorator).""" def decorator(f): @functools.wraps(f) async def decorated(request, *args, **kwargs): user = request.get('user') return await _call_handler(f, request, *args, user=user, **kwargs) return decorated return decorator
[docs]def authorized(): """Make sure user is authorized.""" def decorator(f): @functools.wraps(f) async def decorated(request, *args, **kwargs): if request.method != 'OPTIONS': user = request.get('user') if user is None: raise Unauthorized('Not authorized') return await _call_handler(f, request, *args, **kwargs) return decorated return decorator
def _get_secret(homeserver): """Get secret key used for JWT.""" secret_table = homeserver.db.table('secret') if not len(secret_table): secret_table.insert({'key': binascii.hexlify(os.urandom(24))}) return secret_table.all()[0]['key'] def _jwt_authenticate(homeserver, request): user = request.json.get('user') password = request.json.get('password') device = request.json.get('device') or str(uuid.uuid4()) success = homeserver.user_service.login(user, password) if success: return dict(user_id=f'{user}:{device}', user=user) else: raise Unauthorized('Invalid credentials') def _handle_authorization_middleware(homeserver, request): # pylint: disable=too-many-locals if request.method == 'OPTIONS': return auth_header = request.headers.get('Authorization') if auth_header and auth_header.startswith('Basic'): raw = auth_header[5:].strip() value = binascii.a2b_base64(raw).decode('utf-8') user, *password = value.split(':') password = ':'.join(password) if homeserver.user_service.login(user, password): request['user'] = user else: raise Unauthorized('Invalid credentials') elif auth_header and auth_header.startswith('Bearer'): auth_config = homeserver.app.auth.config refresh_path = auth_config.url_prefix() + auth_config.path_to_refresh() # pylint: disable=protected-access is_valid, status, reason = request.app.auth._check_authentication(request, [], {}) if is_valid: payload = request.app.auth.extract_payload(request) request['user'], *_ = payload['user_id'].split(':') elif request.path == refresh_path: pass else: raise Unauthorized(reason, status_code=status) elif 'access_token' in request.raw_args: access_token = request.raw_args['access_token'] try: # pylint: disable=protected-access payload = request.app.auth._decode(access_token, verify=True) request['user'], *_ = payload['user_id'].split(':') except jwt.exceptions.SanicJWTException as e: raise Unauthorized(list(e.args)) elif 'login' not in homeserver.config or homeserver.config['login'] is None: request['user'] = 'user'
[docs]def setup(homeserver, app): secret = _get_secret(homeserver) refresh_token_db = homeserver.db.table('refresh_tokens') RefreshToken = Query() def authenticate(request): return _jwt_authenticate(homeserver, request) def store_refresh_token(user_id, refresh_token, *args, **kwargs): refresh_token_db.upsert({ 'value': refresh_token, 'user_id': user_id, }, RefreshToken.user_id == user_id) def retrieve_refresh_token(request, user_id, *args, **kwargs): token = refresh_token_db.get(RefreshToken.user_id == user_id) return token['value'] if token else None def retrieve_user(request, payload, *args, **kwargs): return payload # initialize the jwt module jwt.Initialize(app, secret=secret, authenticate=authenticate, url_prefix='/v1/auth', refresh_token_enabled=True, store_refresh_token=store_refresh_token, retrieve_refresh_token=retrieve_refresh_token, retrieve_user=retrieve_user) app.middleware('request')(functools.partial(_handle_authorization_middleware, homeserver))