In Part 4 of this series, we looked at authentication and authorizarion. In this part, we are going to go in depth on how you can revoke your tokens so that they may no longer access your endpoints.
Before we embark on this, let’s first see how to create refresh tokens
for
our endpoints.
Refresh Tokens
Suppose your access token gets stolen by an attacker. He may use it to access your protected endpoints in your stead. To combat this, we incorporate refresh tokens in our app. This method works by granting access tokens a much shorter lifespan than refresh tokens. Then, when the evil attacker lays hold of your precious access token, the damage he may do will be limited somewhat. He’ll only have a very small window to afflict you with any significant heartache from the mess he may cause.
When the access token expires, as it will frequently do, the refresh token is used to generate a brand new access token to be used for any subsequent access to your enpoints. Until it expires, which will require you to generate a new access token using the long lived refresh token and so on…
Let’s modify our auth.py
module to include the following code.
"""
app.api.v1.auth
~~~~~~~~~~~~~~
Authentication views
"""
from flask_jwt_extended import create_access_token, create_refresh_token
from flask_restful import Resource, reqparse
from app.models import User
from app import db
from .common.utils import valid_email, valid_password
from .common.errors import raise_error
parser = reqparse.RequestParser()
parser.add_argument('email', type=str)
parser.add_argument('password', type=str)
class SignUP(Resource):
def post(self):
args = parser.parse_args()
email = args.get('email') or ''
password = args.get('password') or ''
# validate input data
if not valid_email(email):
return raise_error(400, "Invalid email format")
if not valid_password(password):
return raise_error(400, "Invalid password. Should be at least 5 "
"characters long and include a number and a special "
"character")
user = User.query.filter_by(email=email).first()
if user is not None:
return raise_error(400, "User already exists")
#: set username to be same as email if it's not provided
user = User(email=email, username=email)
user.set_password(password)
db.session.add(user)
db.session.commit()
###################################################
###################################################
#: Create both access and refresh tokens
access_token = create_access_token(identity=email)
refresh_token = create_refresh_token(identity=email)
###################################################
###################################################
data = {}
data['access_token'] = access_token
data['refresh_token'] = refresh_token
data['user'] = user.serialize
response = {
"status": 201,
"data": [data]
}
return response, 201
class SignIn(Resource):
def post(self):
args = parser.parse_args()
email = args.get('email', None)
password = args.get('password', None)
if email is None:
return raise_error(400, "Missing 'email' in body")
if password is None:
return raise_error(400, "Missing 'password' in body")
user = User.query.filter_by(email=email).first()
if user is None or not user.check_password(password):
return raise_error(401, "Bad email or password")
#########################################################
#########################################################
# Create our JWTs
access_token = create_access_token(identity=email)
refresh_token = create_refresh_token(identity=email)
#########################################################
#########################################################
data = {}
data['access_token'] = access_token
data['refresh_token'] = refresh_token
data['user'] = user.serialize
response = {
"status": 200,
"data": [data]
}
return response
We have imported the create_refresh_token
function from flask_jwt_extended
which operates in a similar manner to its cousin create_access_token
, only
that the refresh token has a much longer expiration period. Both tokens are
returned in our response. The access token will be used for as long as it is
valid after which the refresh token will be used to generate a new access
token.
A new endpoint for generating a new access token using the refresh token needs to be created.
Add the following code to the app.api.v1.__init__.py
file.
# app.api.v1.__init__
# ..previous code
from .auth import RefreshToken
api.add_resource(
RefreshToken,
'/auth/refresh',
)
The RefreshToken
resource is defined as shown in the following piece of
code.
"""
app.api.v1.auth
~~~~~~~~~~~~~~
Authentication views
"""
# ...previous code
from flask_jwt_extended import jwt_refresh_token_required, get_jwt_identity
# ... previous code
class RefreshToken(Resource):
"""
Creates a new access token
"""
@jwt_refresh_token_required
def post(self):
"""
Returns a new access token
"""
current_user = get_jwt_identity()
new_token = create_access_token(identity=current_user)
add_token_to_database(new_token, current_app.config['JWT_IDENTITY_CLAIM'])
return {
'status': 200,
'data': [{'access_token': new_token
}]
}
The jwt_refresh_token_required
decorator is used to insure that only a valid
refresh token can be used to access this endpoint. The get_jwt_identity
function is used to get the identity of the refresh token used to access this
endpoint. This identity is then used to create a new access token.
The following is the test for our refresh token endpoint.
"""
tests.v1.test_auth
~~~~~~~~~~~~~~~~~~
Tests for authentication
"""
# ... prev code
from .util import make_token_header
# ... previous code
def test_token_refresh(client, auth):
auth.signup()
refresh_token = auth.refresh_token
access_token = auth.access_token
bad_token = auth.refresh_token + '@'
refresh_token_header = make_token_header(refresh_token)
access_token_header = make_token_header(access_token)
bad_token_header = make_token_header(bad_token)
response = client.post('/auth/refresh', headers=refresh_token_header)
assert response.status_code == 200
data = json.loads(response.data.decode('utf-8'))['data'][0]
assert 'access_token' in data
# Return bad authorization header error for all other cases
response = client.post('/auth/refresh', headers=access_token_header)
assert response.status_code == 422
response = client.post('/auth/refresh', headers=bad_token_header)
assert response.status_code == 422
Remember to update the AuthActions
class in the conftest
module to
return a refresh token also.
Run pytest
and see to it that all tests pass.
Now to the more interesting part of this section, revoking tokens.
Blacklist and Token Revoking
As I explained at the start of this section, revoking a token simply means making it useless. These isn’t much to how this is accomplished.
The concept is simple.
First we’d need a storage location for our tokens. This can either be a traditional database or an in memory storage such as redis. Both have their pros and cons. I will mostly focus on the database.
This storage location is what is referred to as the blacklist. It is used to store user tokens and any metadata about the tokens. Each token accessing an endpoint is compared against the blacklist to check whether it’s revoked or not.
This comparison check should be done each time a token is used to access an endpoint.
If a given token is revoked, access is denied. It is, as it were, in the dreaded blacklist.
And that’s the crux of the entire process. Pretty simple, right?
Flask-JWT-Extended has a host of features that make implementing this process as painless as possible.This is accomplished using … you got it, decorators.
Let’s flesh out these ideas by writing code.
As usual, our very first step will be to write tests for the features we are going to implement.