33from unittest .mock import patch
44
55import jwt
6+ import pytest
67from django .test import TestCase
7- from jwt import PyJWS , algorithms
8+ from jwt import PyJWS
9+ from jwt import __version__ as jwt_version
10+ from jwt import algorithms
811
9- from rest_framework_simplejwt .backends import TokenBackend
12+ from rest_framework_simplejwt .backends import JWK_CLIENT_AVAILABLE , TokenBackend
1013from rest_framework_simplejwt .exceptions import TokenBackendError
1114from rest_framework_simplejwt .utils import aware_utcnow , datetime_to_epoch , make_utc
1215from tests .keys import (
2831
2932LEEWAY = 100
3033
34+ IS_OLD_JWT = jwt_version == "1.7.1"
35+
3136
3237class TestTokenBackend (TestCase ):
3338 def setUp (self ):
@@ -159,7 +164,7 @@ def test_decode_with_expiry(self):
159164 def test_decode_with_invalid_sig (self ):
160165 self .payload ["exp" ] = aware_utcnow () - timedelta (seconds = 1 )
161166 for backend in self .backends :
162- with self .subTest ("Test decode with invalid sig for f {backend.algorithm}" ):
167+ with self .subTest (f "Test decode with invalid sig for { backend .algorithm } " ):
163168 payload = self .payload .copy ()
164169 payload ["exp" ] = aware_utcnow () + timedelta (days = 1 )
165170 token_1 = jwt .encode (
@@ -170,6 +175,10 @@ def test_decode_with_invalid_sig(self):
170175 payload , backend .signing_key , algorithm = backend .algorithm
171176 )
172177
178+ if IS_OLD_JWT :
179+ token_1 = token_1 .decode ("utf-8" )
180+ token_2 = token_2 .decode ("utf-8" )
181+
173182 token_2_payload = token_2 .rsplit ("." , 1 )[0 ]
174183 token_1_sig = token_1 .rsplit ("." , 1 )[- 1 ]
175184 invalid_token = token_2_payload + "." + token_1_sig
@@ -189,8 +198,12 @@ def test_decode_with_invalid_sig_no_verify(self):
189198 token_2 = jwt .encode (
190199 payload , backend .signing_key , algorithm = backend .algorithm
191200 )
192- # Payload copied
193- payload ["exp" ] = datetime_to_epoch (payload ["exp" ])
201+ if IS_OLD_JWT :
202+ token_1 = token_1 .decode ("utf-8" )
203+ token_2 = token_2 .decode ("utf-8" )
204+ else :
205+ # Payload copied
206+ payload ["exp" ] = datetime_to_epoch (payload ["exp" ])
194207
195208 token_2_payload = token_2 .rsplit ("." , 1 )[0 ]
196209 token_1_sig = token_1 .rsplit ("." , 1 )[- 1 ]
@@ -210,9 +223,13 @@ def test_decode_success(self):
210223 token = jwt .encode (
211224 self .payload , backend .signing_key , algorithm = backend .algorithm
212225 )
213- # Payload copied
214- payload = self .payload .copy ()
215- payload ["exp" ] = datetime_to_epoch (self .payload ["exp" ])
226+ if IS_OLD_JWT :
227+ token = token .decode ("utf-8" )
228+ payload = self .payload
229+ else :
230+ # Payload copied
231+ payload = self .payload .copy ()
232+ payload ["exp" ] = datetime_to_epoch (self .payload ["exp" ])
216233
217234 self .assertEqual (backend .decode (token ), payload )
218235
@@ -223,11 +240,18 @@ def test_decode_aud_iss_success(self):
223240 self .payload ["iss" ] = ISSUER
224241
225242 token = jwt .encode (self .payload , PRIVATE_KEY , algorithm = "RS256" )
226- # Payload copied
227- self .payload ["exp" ] = datetime_to_epoch (self .payload ["exp" ])
243+ if IS_OLD_JWT :
244+ token = token .decode ("utf-8" )
245+ else :
246+ # Payload copied
247+ self .payload ["exp" ] = datetime_to_epoch (self .payload ["exp" ])
228248
229249 self .assertEqual (self .aud_iss_token_backend .decode (token ), self .payload )
230250
251+ @pytest .mark .skipif (
252+ not JWK_CLIENT_AVAILABLE ,
253+ reason = "PyJWT 1.7.1 doesn't have JWK client" ,
254+ )
231255 def test_decode_rsa_aud_iss_jwk_success (self ):
232256 self .payload ["exp" ] = aware_utcnow () + timedelta (days = 1 )
233257 self .payload ["foo" ] = "baz"
@@ -261,6 +285,8 @@ def test_decode_rsa_aud_iss_jwk_success(self):
261285
262286 def test_decode_when_algorithm_not_available (self ):
263287 token = jwt .encode (self .payload , PRIVATE_KEY , algorithm = "RS256" )
288+ if IS_OLD_JWT :
289+ token = token .decode ("utf-8" )
264290
265291 pyjwt_without_rsa = PyJWS ()
266292 pyjwt_without_rsa .unregister_algorithm ("RS256" )
@@ -276,6 +302,8 @@ def _decode(jwt, key, algorithms, options, audience, issuer, leeway):
276302
277303 def test_decode_when_token_algorithm_does_not_match (self ):
278304 token = jwt .encode (self .payload , PRIVATE_KEY , algorithm = "RS256" )
305+ if IS_OLD_JWT :
306+ token = token .decode ("utf-8" )
279307
280308 with self .assertRaisesRegex (TokenBackendError , "Invalid algorithm specified" ):
281309 self .hmac_token_backend .decode (token )
0 commit comments