Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def refresh_credentials(self):
return
except AccessTokenNotFoundError:
logging.warning("Failed to refresh token. Kicking off a full authorization flow.")
KeyringStore.delete(self._endpoint)

self._creds = self._auth_client.get_creds_from_remote()
KeyringStore.store(self._creds)
Expand Down Expand Up @@ -325,7 +324,6 @@ def refresh_credentials(self):
return
except (AuthenticationError, AuthenticationPending):
logging.warning("Failed to refresh token. Kicking off a full authorization flow.")
KeyringStore.delete(self._endpoint)

"""Fall back to device flow"""
resp = token_client.get_device_code(
Expand All @@ -341,20 +339,16 @@ def refresh_credentials(self):
full_uri = f"{resp.verification_uri}?user_code={resp.user_code}"
text = f"To Authenticate, navigate in a browser to the following URL: {click.style(full_uri, fg='blue', underline=True)}"
click.secho(text)
try:
token, refresh_token, expires_in = token_client.poll_token_endpoint(
resp,
self._token_endpoint,
client_id=self._client_id,
audience=self._audience,
scopes=self._scopes,
http_proxy_url=self._http_proxy_url,
verify=self._verify,
)
self._creds = Credentials(
access_token=token, refresh_token=refresh_token, expires_in=expires_in, for_endpoint=self._endpoint
)
KeyringStore.store(self._creds)
except Exception:
KeyringStore.delete(self._endpoint)
raise
token, refresh_token, expires_in = token_client.poll_token_endpoint(
resp,
self._token_endpoint,
client_id=self._client_id,
audience=self._audience,
scopes=self._scopes,
http_proxy_url=self._http_proxy_url,
verify=self._verify,
)
self._creds = Credentials(
access_token=token, refresh_token=refresh_token, expires_in=expires_in, for_endpoint=self._endpoint
)
KeyringStore.store(self._creds)
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def refresh_credentials(self):
return
except AccessTokenNotFoundError:
logging.warning("Failed to refresh token. Kicking off a full authorization flow.")
KeyringStore.delete(self._endpoint)

self._creds = self._auth_client.get_creds_from_remote()
KeyringStore.store(self._creds)
Expand Down
26 changes: 25 additions & 1 deletion tests/flytekit/unit/clients/auth/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PKCEAuthenticator,
StaticClientConfigStore,
)
from flytekit.clients.auth.exceptions import AuthenticationError, AccessTokenNotFoundError
from flytekit.clients.auth.exceptions import AccessTokenNotFoundError, AuthenticationError
from flytekit.clients.auth.keyring import Credentials
from flytekit.clients.auth.token_client import DeviceCodeResponse

Expand Down Expand Up @@ -51,6 +51,27 @@ def test_pkce_authenticator(mock_refresh: MagicMock, mock_get_creds: MagicMock,
mock_refresh.assert_called()


@patch("flytekit.clients.auth.authenticator.KeyringStore")
@patch("flytekit.clients.auth.auth_client.AuthorizationClient.get_creds_from_remote")
@patch("flytekit.clients.auth.auth_client.AuthorizationClient.refresh_access_token")
def test_pkce_refresh_failure_keeps_cached_credentials(
mock_refresh: MagicMock, mock_get_creds: MagicMock, mock_keyring: MagicMock
):
# A cached token exists, but refreshing it fails.
mock_keyring.retrieve.return_value = Credentials("access", "refresh", ENDPOINT)
mock_refresh.side_effect = AccessTokenNotFoundError("expired")

authn = PKCEAuthenticator(ENDPOINT, static_cfg_store, verify=False)
authn.refresh_credentials()

# A failed refresh must not delete the cached credentials. The authenticator falls back to a
# full authorization flow and overwrites them via store(), so the delete was both redundant and
# harmful -- a transient failure would force an unnecessary re-login.
mock_keyring.delete.assert_not_called()
mock_get_creds.assert_called_once()
mock_keyring.store.assert_called()


@patch("subprocess.run")
def test_command_authenticator(mock_subprocess: MagicMock):
with pytest.raises(AuthenticationError):
Expand Down Expand Up @@ -130,6 +151,9 @@ def test_device_flow_authenticator(mock_refresh: MagicMock, poll_mock: MagicMock
device_mock.assert_called()
mock_refresh.assert_called()

# The failed refresh must not have deleted the cached credentials.
mock_keyring.delete.assert_not_called()

@patch("flytekit.clients.auth.authenticator.KeyringStore")
@patch("flytekit.clients.auth.token_client.get_device_code")
@patch("flytekit.clients.auth.token_client.poll_token_endpoint")
Expand Down
Loading