From ccf65aae23db62de22ed81ece82fdc44ac333f84 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Sat, 20 Jun 2026 16:20:54 -0700 Subject: [PATCH] fix(auth): don't delete cached credentials when token refresh fails The authenticators call KeyringStore.delete(endpoint) when a token refresh fails, then immediately fall back to a full authorization flow that calls KeyringStore.store(...). store() overwrites each key in place, so the delete is redundant -- and harmful: a transient refresh failure wipes the last-good cached tokens (including a still-valid refresh_token), forcing an unnecessary browser re-login. Remove the premature delete() from the PKCE and device-code refresh paths and from the identity-aware-proxy authenticator. At the device-flow poll site the surrounding try/except existed only to delete-and-re-raise, so it is dropped. Signed-off-by: 1fanwang <1fannnw@gmail.com> --- flytekit/clients/auth/authenticator.py | 32 ++++++++----------- .../identity_aware_proxy/cli.py | 1 - .../unit/clients/auth/test_authenticator.py | 26 ++++++++++++++- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index eeacee07d0..d3ee5f0905 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -157,7 +157,6 @@ def refresh_credentials(self): return except AccessTokenNotFoundError: logging.warning("Failed to refresh token. Kicking off a full authorization flow.") - KeyringStore.delete(self._endpoint) self._creds = self._auth_client.get_creds_from_remote() KeyringStore.store(self._creds) @@ -325,7 +324,6 @@ def refresh_credentials(self): return except (AuthenticationError, AuthenticationPending): logging.warning("Failed to refresh token. Kicking off a full authorization flow.") - KeyringStore.delete(self._endpoint) """Fall back to device flow""" resp = token_client.get_device_code( @@ -341,20 +339,16 @@ def refresh_credentials(self): full_uri = f"{resp.verification_uri}?user_code={resp.user_code}" text = f"To Authenticate, navigate in a browser to the following URL: {click.style(full_uri, fg='blue', underline=True)}" click.secho(text) - try: - token, refresh_token, expires_in = token_client.poll_token_endpoint( - resp, - self._token_endpoint, - client_id=self._client_id, - audience=self._audience, - scopes=self._scopes, - http_proxy_url=self._http_proxy_url, - verify=self._verify, - ) - self._creds = Credentials( - access_token=token, refresh_token=refresh_token, expires_in=expires_in, for_endpoint=self._endpoint - ) - KeyringStore.store(self._creds) - except Exception: - KeyringStore.delete(self._endpoint) - raise + token, refresh_token, expires_in = token_client.poll_token_endpoint( + resp, + self._token_endpoint, + client_id=self._client_id, + audience=self._audience, + scopes=self._scopes, + http_proxy_url=self._http_proxy_url, + verify=self._verify, + ) + self._creds = Credentials( + access_token=token, refresh_token=refresh_token, expires_in=expires_in, for_endpoint=self._endpoint + ) + KeyringStore.store(self._creds) diff --git a/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py b/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py index d1fcbdef21..69394dd2b6 100644 --- a/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py +++ b/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py @@ -93,7 +93,6 @@ def refresh_credentials(self): return except AccessTokenNotFoundError: logging.warning("Failed to refresh token. Kicking off a full authorization flow.") - KeyringStore.delete(self._endpoint) self._creds = self._auth_client.get_creds_from_remote() KeyringStore.store(self._creds) diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 46954d9286..edb84b52db 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -12,7 +12,7 @@ PKCEAuthenticator, StaticClientConfigStore, ) -from flytekit.clients.auth.exceptions import AuthenticationError, AccessTokenNotFoundError +from flytekit.clients.auth.exceptions import AccessTokenNotFoundError, AuthenticationError from flytekit.clients.auth.keyring import Credentials from flytekit.clients.auth.token_client import DeviceCodeResponse @@ -51,6 +51,27 @@ def test_pkce_authenticator(mock_refresh: MagicMock, mock_get_creds: MagicMock, mock_refresh.assert_called() +@patch("flytekit.clients.auth.authenticator.KeyringStore") +@patch("flytekit.clients.auth.auth_client.AuthorizationClient.get_creds_from_remote") +@patch("flytekit.clients.auth.auth_client.AuthorizationClient.refresh_access_token") +def test_pkce_refresh_failure_keeps_cached_credentials( + mock_refresh: MagicMock, mock_get_creds: MagicMock, mock_keyring: MagicMock +): + # A cached token exists, but refreshing it fails. + mock_keyring.retrieve.return_value = Credentials("access", "refresh", ENDPOINT) + mock_refresh.side_effect = AccessTokenNotFoundError("expired") + + authn = PKCEAuthenticator(ENDPOINT, static_cfg_store, verify=False) + authn.refresh_credentials() + + # A failed refresh must not delete the cached credentials. The authenticator falls back to a + # full authorization flow and overwrites them via store(), so the delete was both redundant and + # harmful -- a transient failure would force an unnecessary re-login. + mock_keyring.delete.assert_not_called() + mock_get_creds.assert_called_once() + mock_keyring.store.assert_called() + + @patch("subprocess.run") def test_command_authenticator(mock_subprocess: MagicMock): with pytest.raises(AuthenticationError): @@ -130,6 +151,9 @@ def test_device_flow_authenticator(mock_refresh: MagicMock, poll_mock: MagicMock device_mock.assert_called() mock_refresh.assert_called() + # The failed refresh must not have deleted the cached credentials. + mock_keyring.delete.assert_not_called() + @patch("flytekit.clients.auth.authenticator.KeyringStore") @patch("flytekit.clients.auth.token_client.get_device_code") @patch("flytekit.clients.auth.token_client.poll_token_endpoint")