import uuid from datetime import timedelta from typing import Any, override from django.core.handlers.wsgi import WSGIRequest from django.test import RequestFactory, TestCase from apps.users.auth.bearer import JWTBearer from apps.users.auth.jwt import ( TokenError, create_access_token, create_refresh_token, create_token_pair, decode_access_token, decode_refresh_token, decode_token, ) from apps.users.models import User, UserRole from .helpers import make_user class JWTCreateTest(TestCase): def test_create_access_token(self) -> None: token: str = create_access_token(uuid.uuid4(), "admin") self.assertIsInstance(token, str) self.assertTrue(len(token) > 0) def test_create_refresh_token(self) -> None: token: str = create_refresh_token(uuid.uuid4()) self.assertIsInstance(token, str) def test_create_token_pair(self) -> None: pair: dict[str, str] = create_token_pair(uuid.uuid4(), "viewer") self.assertIn("access", pair) self.assertIn("refresh", pair) class JWTDecodeTest(TestCase): @override def setUp(self) -> None: self.uid: uuid.UUID = uuid.uuid4() def test_decode_access_token(self) -> None: token: str = create_access_token(self.uid, "experimenter") payload: dict[str, Any] = decode_access_token(token) self.assertEqual(payload["sub"], str(self.uid)) self.assertEqual(payload["role"], "experimenter") self.assertEqual(payload["type"], "access") def test_decode_refresh_token(self) -> None: token: str = create_refresh_token(self.uid) payload: dict[str, Any] = decode_refresh_token(token) self.assertEqual(payload["sub"], str(self.uid)) self.assertEqual(payload["type"], "refresh") def test_decode_wrong_type_raises(self) -> None: token: str = create_refresh_token(self.uid) with self.assertRaises(TokenError): decode_access_token(token) def test_decode_expired_token_raises(self) -> None: token: str = create_access_token( self.uid, "admin", lifetime=timedelta(seconds=-1) ) with self.assertRaises(TokenError): decode_access_token(token) def test_decode_invalid_token_raises(self) -> None: with self.assertRaises(TokenError): decode_token("not.a.jwt") def test_extra_claims(self) -> None: token: str = create_access_token( self.uid, "admin", extra_claims={"org": "lotty"} ) payload: dict[str, Any] = decode_access_token(token) self.assertEqual(payload["org"], "lotty") class JWTBearerTest(TestCase): @override def setUp(self) -> None: self.bearer = JWTBearer() self.user: User = make_user( username="bearer_user", email="bearer@x.com", role=UserRole.ADMIN, ) def test_valid_token_returns_user(self) -> None: token: str = create_access_token(self.user.pk, self.user.role) request: WSGIRequest = RequestFactory().get("/") result: User | None = self.bearer.authenticate(request, token) self.assertEqual(result, self.user) def test_invalid_token_returns_none(self) -> None: request: WSGIRequest = RequestFactory().get("/") result: User | None = self.bearer.authenticate(request, "garbage") self.assertIsNone(result) def test_nonexistent_user_returns_none(self) -> None: token: str = create_access_token(uuid.uuid4(), "admin") request: WSGIRequest = RequestFactory().get("/") result: User | None = self.bearer.authenticate(request, token) self.assertIsNone(result) def test_inactive_user_returns_none(self) -> None: self.user.is_active = False self.user.save() token: str = create_access_token(self.user.pk, self.user.role) request: WSGIRequest = RequestFactory().get("/") result: User | None = self.bearer.authenticate(request, token) self.assertIsNone(result)