import binascii
import functools
import inspect
import os
import uuid
import sanic_jwt as jwt
from sanic.exceptions import Unauthorized
from tinydb import Query
async def _call_handler(handler, request, *args, **kwargs):
result = handler(request, *args, **kwargs)
if inspect.isawaitable(result):
return await result
else:
return result
[docs]def with_user():
"""Pass an authenticated user as an argument (decorator)."""
def decorator(f):
@functools.wraps(f)
async def decorated(request, *args, **kwargs):
user = request.get('user')
return await _call_handler(f, request, *args, user=user, **kwargs)
return decorated
return decorator
[docs]def authorized():
"""Make sure user is authorized."""
def decorator(f):
@functools.wraps(f)
async def decorated(request, *args, **kwargs):
if request.method != 'OPTIONS':
user = request.get('user')
if user is None:
raise Unauthorized('Not authorized')
return await _call_handler(f, request, *args, **kwargs)
return decorated
return decorator
def _get_secret(homeserver):
"""Get secret key used for JWT."""
secret_table = homeserver.db.table('secret')
if not len(secret_table):
secret_table.insert({'key': binascii.hexlify(os.urandom(24))})
return secret_table.all()[0]['key']
def _jwt_authenticate(homeserver, request):
user = request.json.get('user')
password = request.json.get('password')
device = request.json.get('device') or str(uuid.uuid4())
success = homeserver.user_service.login(user, password)
if success:
return dict(user_id=f'{user}:{device}', user=user)
else:
raise Unauthorized('Invalid credentials')
def _handle_authorization_middleware(homeserver, request):
# pylint: disable=too-many-locals
if request.method == 'OPTIONS':
return
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Basic'):
raw = auth_header[5:].strip()
value = binascii.a2b_base64(raw).decode('utf-8')
user, *password = value.split(':')
password = ':'.join(password)
if homeserver.user_service.login(user, password):
request['user'] = user
else:
raise Unauthorized('Invalid credentials')
elif auth_header and auth_header.startswith('Bearer'):
auth_config = homeserver.app.auth.config
refresh_path = auth_config.url_prefix() + auth_config.path_to_refresh()
# pylint: disable=protected-access
is_valid, status, reason = request.app.auth._check_authentication(request, [], {})
if is_valid:
payload = request.app.auth.extract_payload(request)
request['user'], *_ = payload['user_id'].split(':')
elif request.path == refresh_path:
pass
else:
raise Unauthorized(reason, status_code=status)
elif 'access_token' in request.raw_args:
access_token = request.raw_args['access_token']
try:
# pylint: disable=protected-access
payload = request.app.auth._decode(access_token, verify=True)
request['user'], *_ = payload['user_id'].split(':')
except jwt.exceptions.SanicJWTException as e:
raise Unauthorized(list(e.args))
elif 'login' not in homeserver.config or homeserver.config['login'] is None:
request['user'] = 'user'
[docs]def setup(homeserver, app):
secret = _get_secret(homeserver)
refresh_token_db = homeserver.db.table('refresh_tokens')
RefreshToken = Query()
def authenticate(request):
return _jwt_authenticate(homeserver, request)
def store_refresh_token(user_id, refresh_token, *args, **kwargs):
refresh_token_db.upsert({
'value': refresh_token,
'user_id': user_id,
}, RefreshToken.user_id == user_id)
def retrieve_refresh_token(request, user_id, *args, **kwargs):
token = refresh_token_db.get(RefreshToken.user_id == user_id)
return token['value'] if token else None
def retrieve_user(request, payload, *args, **kwargs):
return payload
# initialize the jwt module
jwt.Initialize(app,
secret=secret,
authenticate=authenticate,
url_prefix='/v1/auth',
refresh_token_enabled=True,
store_refresh_token=store_refresh_token,
retrieve_refresh_token=retrieve_refresh_token,
retrieve_user=retrieve_user)
app.middleware('request')(functools.partial(_handle_authorization_middleware, homeserver))