119 lines
4.0 KiB
Python
119 lines
4.0 KiB
Python
import uuid
|
|
from datetime import timedelta
|
|
from typing import Any, override
|
|
|
|
from django.core.handlers.wsgi import WSGIRequest
|
|
from django.test import RequestFactory, TestCase
|
|
|
|
from apps.users.auth.bearer import JWTBearer
|
|
from apps.users.auth.jwt import (
|
|
TokenError,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
create_token_pair,
|
|
decode_access_token,
|
|
decode_refresh_token,
|
|
decode_token,
|
|
)
|
|
from apps.users.models import User, UserRole
|
|
|
|
from .helpers import make_user
|
|
|
|
|
|
class JWTCreateTest(TestCase):
|
|
def test_create_access_token(self) -> None:
|
|
token: str = create_access_token(uuid.uuid4(), "admin")
|
|
self.assertIsInstance(token, str)
|
|
self.assertTrue(len(token) > 0)
|
|
|
|
def test_create_refresh_token(self) -> None:
|
|
token: str = create_refresh_token(uuid.uuid4())
|
|
self.assertIsInstance(token, str)
|
|
|
|
def test_create_token_pair(self) -> None:
|
|
pair: dict[str, str] = create_token_pair(uuid.uuid4(), "viewer")
|
|
self.assertIn("access", pair)
|
|
self.assertIn("refresh", pair)
|
|
|
|
|
|
class JWTDecodeTest(TestCase):
|
|
@override
|
|
def setUp(self) -> None:
|
|
self.uid: uuid.UUID = uuid.uuid4()
|
|
|
|
def test_decode_access_token(self) -> None:
|
|
token: str = create_access_token(self.uid, "experimenter")
|
|
payload: dict[str, Any] = decode_access_token(token)
|
|
self.assertEqual(payload["sub"], str(self.uid))
|
|
self.assertEqual(payload["role"], "experimenter")
|
|
self.assertEqual(payload["type"], "access")
|
|
|
|
def test_decode_refresh_token(self) -> None:
|
|
token: str = create_refresh_token(self.uid)
|
|
payload: dict[str, Any] = decode_refresh_token(token)
|
|
self.assertEqual(payload["sub"], str(self.uid))
|
|
self.assertEqual(payload["type"], "refresh")
|
|
|
|
def test_decode_wrong_type_raises(self) -> None:
|
|
token: str = create_refresh_token(self.uid)
|
|
with self.assertRaises(TokenError):
|
|
decode_access_token(token)
|
|
|
|
def test_decode_expired_token_raises(self) -> None:
|
|
token: str = create_access_token(
|
|
self.uid, "admin", lifetime=timedelta(seconds=-1)
|
|
)
|
|
with self.assertRaises(TokenError):
|
|
decode_access_token(token)
|
|
|
|
def test_decode_invalid_token_raises(self) -> None:
|
|
with self.assertRaises(TokenError):
|
|
decode_token("not.a.jwt")
|
|
|
|
def test_extra_claims(self) -> None:
|
|
token: str = create_access_token(
|
|
self.uid, "admin", extra_claims={"org": "lotty"}
|
|
)
|
|
payload: dict[str, Any] = decode_access_token(token)
|
|
self.assertEqual(payload["org"], "lotty")
|
|
|
|
|
|
class JWTBearerTest(TestCase):
|
|
@override
|
|
def setUp(self) -> None:
|
|
self.bearer = JWTBearer()
|
|
self.user: User = make_user(
|
|
username="bearer_user",
|
|
email="bearer@x.com",
|
|
role=UserRole.ADMIN,
|
|
)
|
|
|
|
def test_valid_token_returns_user(self) -> None:
|
|
token: str = create_access_token(self.user.pk, self.user.role)
|
|
|
|
request: WSGIRequest = RequestFactory().get("/")
|
|
result: User | None = self.bearer.authenticate(request, token)
|
|
self.assertEqual(result, self.user)
|
|
|
|
def test_invalid_token_returns_none(self) -> None:
|
|
|
|
request: WSGIRequest = RequestFactory().get("/")
|
|
result: User | None = self.bearer.authenticate(request, "garbage")
|
|
self.assertIsNone(result)
|
|
|
|
def test_nonexistent_user_returns_none(self) -> None:
|
|
token: str = create_access_token(uuid.uuid4(), "admin")
|
|
|
|
request: WSGIRequest = RequestFactory().get("/")
|
|
result: User | None = self.bearer.authenticate(request, token)
|
|
self.assertIsNone(result)
|
|
|
|
def test_inactive_user_returns_none(self) -> None:
|
|
self.user.is_active = False
|
|
self.user.save()
|
|
token: str = create_access_token(self.user.pk, self.user.role)
|
|
|
|
request: WSGIRequest = RequestFactory().get("/")
|
|
result: User | None = self.bearer.authenticate(request, token)
|
|
self.assertIsNone(result)
|