diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index bc609e42e..15b5cfbce 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -239,8 +239,8 @@ def get_token_expiry_date(self) -> AirbyteDateTime: def _has_access_token_been_initialized(self) -> bool: return self._access_token is not None - def set_token_expiry_date(self, value: Union[str, int]) -> None: - self._token_expiry_date = self._parse_token_expiration_date(value) + def set_token_expiry_date(self, value: AirbyteDateTime) -> None: + self._token_expiry_date = value def get_assertion_name(self) -> str: return self.assertion_name diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index b0afeca6e..108055f1d 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None: headers = self.get_refresh_request_headers() return headers if headers else None - def refresh_access_token(self) -> Tuple[str, Union[str, int]]: + def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: """ Returns the refresh token and its expiration datetime @@ -148,6 +148,14 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]: # PRIVATE METHODS # ---------------- + def _default_token_expiry_date(self) -> AirbyteDateTime: + """ + Returns the default token expiry date + """ + # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration + default_token_expiry_duration_hours = 1 # 1 hour + return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) + def _wrap_refresh_token_exception( self, exception: requests.exceptions.RequestException ) -> bool: @@ -257,14 +265,10 @@ def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: """ - Return the expiration datetime of the refresh token + Parse a string or integer token expiration date into a datetime object :return: expiration datetime """ - if not value and not self.token_has_expired(): - # No expiry token was provided but the previous one is not expired so it's fine - return self.get_token_expiry_date() - if self.token_expiry_is_time_of_expiration: if not self.token_expiry_date_format: raise ValueError( @@ -308,17 +312,30 @@ def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: """ return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) - def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any: + def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: """ Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. + If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. + Args: response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. Returns: - str: The extracted token_expiry_date. + The extracted token_expiry_date or None if not found. """ - return self._find_and_get_value_from_response(response_data, self.get_expires_in_name()) + expires_in = self._find_and_get_value_from_response( + response_data, self.get_expires_in_name() + ) + if expires_in is not None: + return self._parse_token_expiration_date(expires_in) + + # expires_in is None + existing_expiry_date = self.get_token_expiry_date() + if existing_expiry_date and not self.token_has_expired(): + return existing_expiry_date + + return self._default_token_expiry_date() def _find_and_get_value_from_response( self, @@ -344,7 +361,7 @@ def _find_and_get_value_from_response( """ if current_depth > max_depth: # this is needed to avoid an inf loop, possible with a very deep nesting observed. - message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response." + message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." raise ResponseKeysMaxRecurtionReached( internal_message=message, message=message, failure_type=FailureType.config_error ) @@ -441,7 +458,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime: """Expiration date of the access token""" @abstractmethod - def set_token_expiry_date(self, value: Union[str, int]) -> None: + def set_token_expiry_date(self, value: AirbyteDateTime) -> None: """Setter for access token expiration date""" @abstractmethod diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 2ff2f60e9..0ca6f6b3a 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -120,8 +120,8 @@ def get_grant_type(self) -> str: def get_token_expiry_date(self) -> AirbyteDateTime: return self._token_expiry_date - def set_token_expiry_date(self, value: Union[str, int]) -> None: - self._token_expiry_date = self._parse_token_expiration_date(value) + def set_token_expiry_date(self, value: AirbyteDateTime) -> None: + self._token_expiry_date = value @property def token_expiry_is_time_of_expiration(self) -> bool: @@ -316,26 +316,6 @@ def token_has_expired(self) -> bool: """Returns True if the token is expired""" return ab_datetime_now() > self.get_token_expiry_date() - @staticmethod - def get_new_token_expiry_date( - access_token_expires_in: str, - token_expiry_date_format: str | None = None, - ) -> AirbyteDateTime: - """ - Calculate the new token expiry date based on the provided expiration duration or format. - - Args: - access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format. - token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None. - - Returns: - AirbyteDateTime: The calculated expiry date of the access token. - """ - if token_expiry_date_format: - return ab_datetime_parse(access_token_expires_in) - else: - return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in)) - def get_access_token(self) -> str: """Retrieve new access and refresh token if the access token has expired. The new refresh token is persisted with the set_refresh_token function @@ -346,16 +326,13 @@ def get_access_token(self) -> str: new_access_token, access_token_expires_in, new_refresh_token = ( self.refresh_access_token() ) - new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date( - access_token_expires_in, self._token_expiry_date_format - ) self.access_token = new_access_token self.set_refresh_token(new_refresh_token) - self.set_token_expiry_date(new_token_expiry_date) + self.set_token_expiry_date(access_token_expires_in) self._emit_control_message() return self.access_token - def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override] + def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override] """ Refreshes the access token by making a handled request and extracting the necessary token information. diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index c54b9982f..077aa4573 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -203,6 +203,7 @@ def test_error_on_refresh_token_grant_without_refresh_token(self): grant_type="refresh_token", ) + @freezegun.freeze_time("2022-01-01") def test_refresh_access_token(self, mocker): oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", @@ -225,13 +226,15 @@ def test_refresh_access_token(self, mocker): resp, "json", return_value={"access_token": "access_token", "expires_in": 1000} ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) - token = oauth.refresh_access_token() + access_token, token_expiry_date = oauth.refresh_access_token() - assert ("access_token", 1000) == token + assert access_token == "access_token" + assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000) filtered = filter_secrets("access_token") assert filtered == "****" + @freezegun.freeze_time("2022-01-01") def test_refresh_access_token_when_headers_provided(self, mocker): expected_headers = { "Authorization": "Bearer some_access_token", @@ -256,9 +259,10 @@ def test_refresh_access_token_when_headers_provided(self, mocker): mocked_request = mocker.patch.object( requests, "request", side_effect=mock_request, autospec=True ) - token = oauth.refresh_access_token() + access_token, token_expiry_date = oauth.refresh_access_token() - assert ("access_token", 1000) == token + assert access_token == "access_token" + assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000) assert mocked_request.call_args.kwargs["headers"] == expected_headers @@ -314,6 +318,7 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp( assert isinstance(oauth._token_expiry_date, AirbyteDateTime) assert oauth.get_token_expiry_date() == ab_datetime_parse(expected_date) + @freezegun.freeze_time("2022-01-01") def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_fetch_access_token( self, ) -> None: @@ -335,12 +340,65 @@ def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_ url="https://refresh_endpoint.com/", body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token", ), - HttpResponse(body=json.dumps({"access_token": "new_access_token"})), + HttpResponse( + body=json.dumps({"access_token": "new_access_token", "expires_in": 1000}) + ), ) oauth.get_access_token() assert oauth.access_token == "new_access_token" - assert oauth._token_expiry_date == expiry_date + assert oauth._token_expiry_date == ab_datetime_now() + timedelta(seconds=1000) + + @freezegun.freeze_time("2022-01-01") + @pytest.mark.parametrize( + "initial_expiry_date_delta, expected_new_expiry_date_delta, expected_access_token", + [ + (timedelta(days=1), timedelta(days=1), "some_access_token"), + (timedelta(days=-1), timedelta(hours=1), "new_access_token"), + (None, timedelta(hours=1), "new_access_token"), + ], + ids=[ + "initial_expiry_date_in_future", + "initial_expiry_date_in_past", + "no_initial_expiry_date", + ], + ) + def test_no_expiry_date_provided_by_auth_server( + self, + initial_expiry_date_delta, + expected_new_expiry_date_delta, + expected_access_token, + ) -> None: + initial_expiry_date = ( + ab_datetime_now().add(initial_expiry_date_delta).isoformat() + if initial_expiry_date_delta + else None + ) + expected_new_expiry_date = ab_datetime_now().add(expected_new_expiry_date_delta) + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="https://refresh_endpoint.com/", + client_id="some_client_id", + client_secret="some_client_secret", + token_expiry_date=initial_expiry_date, + access_token_value="some_access_token", + refresh_token="some_refresh_token", + config={}, + parameters={}, + grant_type="client", + ) + + with HttpMocker() as http_mocker: + http_mocker.post( + HttpRequest( + url="https://refresh_endpoint.com/", + body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token", + ), + HttpResponse(body=json.dumps({"access_token": "new_access_token"})), + ) + oauth.get_access_token() + + assert oauth.access_token == expected_access_token + assert oauth._token_expiry_date == expected_new_expiry_date @pytest.mark.parametrize( "expires_in_response, token_expiry_date_format", @@ -443,6 +501,7 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next assert "access_token" == token assert oauth.get_token_expiry_date() == ab_datetime_parse(next_day) + @freezegun.freeze_time("2022-01-01") def test_profile_assertion(self, mocker): with HttpMocker() as http_mocker: jwt = JwtAuthenticator( @@ -477,7 +536,7 @@ def test_profile_assertion(self, mocker): token = oauth.refresh_access_token() - assert ("access_token", 1000) == token + assert ("access_token", ab_datetime_now().add(timedelta(seconds=1000))) == token filtered = filter_secrets("access_token") assert filtered == "****" diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index d756931c8..dbfc0ac86 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -4,7 +4,7 @@ import json import logging -from datetime import timedelta, timezone +from datetime import timedelta from typing import Optional, Union from unittest.mock import Mock @@ -83,6 +83,7 @@ def test_multiple_token_authenticator(): assert {"Authorization": "Bearer token1"} == header3 +@freezegun.freeze_time("2022-01-01") class TestOauth2Authenticator: """ Test class for OAuth2Authenticator. @@ -104,8 +105,9 @@ def test_get_auth_header_fresh(self, mocker): refresh_token=TestOauth2Authenticator.refresh_token, ) + expires_in = ab_datetime_now().add(timedelta(seconds=1000)) mocker.patch.object( - Oauth2Authenticator, "refresh_access_token", return_value=("access_token", 1000) + Oauth2Authenticator, "refresh_access_token", return_value=("access_token", expires_in) ) header = oauth.get_auth_header() assert {"Authorization": "Bearer access_token"} == header @@ -121,15 +123,15 @@ def test_get_auth_header_expired(self, mocker): refresh_token=TestOauth2Authenticator.refresh_token, ) - expire_immediately = 0 + already_expired = ab_datetime_now() - timedelta(seconds=100) mocker.patch.object( Oauth2Authenticator, "refresh_access_token", - return_value=("access_token_1", expire_immediately), + return_value=("access_token_1", already_expired), ) oauth.get_auth_header() # Set the first expired token. - valid_100_secs = 100 + valid_100_secs = ab_datetime_now() + timedelta(seconds=100) mocker.patch.object( Oauth2Authenticator, "refresh_access_token", @@ -251,6 +253,20 @@ def test_refresh_access_token(self, mocker): }, ) + oauth_with_expired_token = Oauth2Authenticator( + token_refresh_endpoint="https://refresh_endpoint.com", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + scopes=["scope1", "scope2"], + token_expiry_date=ab_datetime_now() - timedelta(days=3), + refresh_request_body={ + "custom_field": "in_outbound_request", + "another_field": "exists_in_body", + "scopes": ["no_override"], + }, + ) + resp.status_code = 200 mocker.patch.object( resp, "json", return_value={"access_token": "access_token", "expires_in": 1000} @@ -258,8 +274,9 @@ def test_refresh_access_token(self, mocker): mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, int) - assert ("access_token", 1000) == (token, expires_in) + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=1000)) + assert token == "access_token" # Test with expires_in as str(int) mocker.patch.object( @@ -267,8 +284,9 @@ def test_refresh_access_token(self, mocker): ) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, str) - assert ("access_token", "2000") == (token, expires_in) + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=2000)) + assert token == "access_token" # Test with expires_in as datetime(str) mocker.patch.object( @@ -276,10 +294,30 @@ def test_refresh_access_token(self, mocker): "json", return_value={"access_token": "access_token", "expires_in": "2022-04-24T00:00:00Z"}, ) + # This should raise a ValueError because the token_expiry_is_time_of_expiration is False by default + with pytest.raises(ValueError): + token, expires_in = oauth.refresh_access_token() + + # Test with no expires_in + mocker.patch.object( + resp, + "json", + return_value={"access_token": "access_token"}, + ) + + # Since the initialized token is not expired (now + 3 days), we don't expect the expiration date to be updated token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, str) - assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in) + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(days=3)) + assert token == "access_token" + + # Since the initialized token is expired (now - 3 days), we expect the expiration date to be updated to the default value (now + 1 hour) + token, expires_in = oauth_with_expired_token.refresh_access_token() + + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(hours=1)) + assert token == "access_token" # Test with nested access_token and expires_in as str(int) mocker.patch.object( @@ -289,8 +327,9 @@ def test_refresh_access_token(self, mocker): ) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, str) - assert ("access_token_nested", "2001") == (token, expires_in) + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=2001)) + assert token == "access_token_nested" # Test with multiple nested levels access_token and expires_in as str(int) mocker.patch.object( @@ -317,8 +356,9 @@ def test_refresh_access_token(self, mocker): ) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, str) - assert ("access_token_deeply_nested", "2002") == (token, expires_in) + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=2002)) + assert token == "access_token_deeply_nested" # Test with max nested levels access_token and expires_in as str(int) mocker.patch.object( @@ -348,7 +388,7 @@ def test_refresh_access_token(self, mocker): ) with pytest.raises(ResponseKeysMaxRecurtionReached) as exc_info: oauth.refresh_access_token() - error_message = "The maximum level of recursion is reached. Couldn't find the speficied `access_token` in the response." + error_message = "The maximum level of recursion is reached. Couldn't find the specified `access_token` in the response." assert exc_info.value.internal_message == error_message assert exc_info.value.message == error_message assert exc_info.value.failure_type == FailureType.config_error @@ -377,8 +417,9 @@ def test_refresh_access_token_when_headers_provided(self, mocker): ) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, int) - assert ("access_token", 1000) == (token, expires_in) + assert isinstance(expires_in, AirbyteDateTime) + assert expires_in == ab_datetime_now().add(timedelta(seconds=1000)) + assert token == "access_token" assert mocked_request.call_args.kwargs["headers"] == expected_headers @@ -393,10 +434,18 @@ def test_refresh_access_token_when_headers_provided(self, mocker): "YYYY-MM-DDTHH:mm:ss.SSSSSSZ", AirbyteDateTime(year=2022, month=2, day=12), ), + (None, None, AirbyteDateTime(year=2022, month=1, day=1, hour=1)), + (None, "YYYY-MM-DD", AirbyteDateTime(year=2022, month=1, day=1, hour=1)), + ], + ids=[ + "seconds", + "string_of_seconds", + "simple_date", + "simple_datetime", + "default_behavior", + "default_behavior_with_format", ], - ids=["seconds", "string_of_seconds", "simple_date", "simple_datetime"], ) - @freezegun.freeze_time("2022-01-01") def test_parse_refresh_token_lifespan( self, mocker, @@ -427,14 +476,11 @@ def test_parse_refresh_token_lifespan( return_value={"access_token": "access_token", "expires_in": expires_in_response}, ) mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) - token, expire_in = oauth.refresh_access_token() - expires_datetime = oauth._parse_token_expiration_date(expire_in) + token, expires_datetime = oauth.refresh_access_token() assert isinstance(expires_datetime, AirbyteDateTime) - assert ("access_token", expected_token_expiry_date) == ( - token, - expires_datetime, - ) + assert expires_datetime == expected_token_expiry_date + assert token == "access_token" @pytest.mark.usefixtures("mock_sleep") @pytest.mark.parametrize("error_code", (429, 500, 502, 504)) @@ -454,8 +500,9 @@ def test_refresh_access_token_retry(self, error_code, requests_mock): ], ) token, expires_in = oauth.refresh_access_token() - assert isinstance(expires_in, int) - assert (token, expires_in) == ("token", 10) + assert isinstance(expires_in, AirbyteDateTime) + assert token == "token" + assert expires_in == ab_datetime_now().add(timedelta(seconds=10)) assert requests_mock.call_count == 3 def test_auth_call_method(self, mocker): @@ -466,8 +513,9 @@ def test_auth_call_method(self, mocker): refresh_token=TestOauth2Authenticator.refresh_token, ) + expires_in = ab_datetime_now().add(timedelta(seconds=1000)) mocker.patch.object( - Oauth2Authenticator, "refresh_access_token", return_value=("access_token", 1000) + Oauth2Authenticator, "refresh_access_token", return_value=("access_token", expires_in) ) prepared_request = requests.PreparedRequest() prepared_request.headers = {} @@ -531,6 +579,7 @@ def test_refresh_access_token_wrapped( assert exc_info.value.failure_type == FailureType.config_error +@freezegun.freeze_time("2022-12-31") class TestSingleUseRefreshTokenOauth2Authenticator: @pytest.fixture def connector_config(self): @@ -551,7 +600,7 @@ def invalid_connector_config(self): def test_init(self, connector_config): authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, - token_refresh_endpoint="foobar", + token_refresh_endpoint="https://refresh_endpoint.com", client_id=connector_config["credentials"]["client_id"], client_secret=connector_config["credentials"]["client_secret"], ) @@ -561,7 +610,6 @@ def test_init(self, connector_config): connector_config["credentials"]["token_expiry_date"] ) - @freezegun.freeze_time("2022-12-31") @pytest.mark.parametrize( "test_name, expires_in_value, expiry_date_format, expected_expiry_date", [ @@ -582,14 +630,26 @@ def test_given_no_message_repository_get_access_token( ): authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, - token_refresh_endpoint="foobar", + token_refresh_endpoint="https://refresh_endpoint.com", client_id=connector_config["credentials"]["client_id"], client_secret=connector_config["credentials"]["client_secret"], token_expiry_date_format=expiry_date_format, + token_expiry_is_time_of_expiration=bool(expiry_date_format), ) - authenticator.refresh_access_token = mocker.Mock( - return_value=("new_access_token", expires_in_value, "new_refresh_token") + + # Mock the response from the refresh token endpoint + resp.status_code = 200 + mocker.patch.object( + resp, + "json", + return_value={ + authenticator.get_access_token_name(): "new_access_token", + authenticator.get_expires_in_name(): expires_in_value, + authenticator.get_refresh_token_name(): "new_refresh_token", + }, ) + mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) + authenticator.token_has_expired = mocker.Mock(return_value=True) access_token = authenticator.get_access_token() captured = capsys.readouterr() @@ -614,15 +674,26 @@ def test_given_message_repository_when_get_access_token_then_emit_message( message_repository = Mock() authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, - token_refresh_endpoint="foobar", + token_refresh_endpoint="https://refresh_endpoint.com", client_id=connector_config["credentials"]["client_id"], client_secret=connector_config["credentials"]["client_secret"], + token_expiry_is_time_of_expiration=True, token_expiry_date_format="YYYY-MM-DD", message_repository=message_repository, ) - authenticator.refresh_access_token = mocker.Mock( - return_value=("new_access_token", "2023-04-04", "new_refresh_token") + # Mock the response from the refresh token endpoint + resp.status_code = 200 + mocker.patch.object( + resp, + "json", + return_value={ + authenticator.get_access_token_name(): "new_access_token", + authenticator.get_expires_in_name(): "2023-04-04", + authenticator.get_refresh_token_name(): "new_refresh_token", + }, ) + mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) + authenticator.token_has_expired = mocker.Mock(return_value=True) authenticator.get_access_token() @@ -683,21 +754,27 @@ def test_given_message_repository_when_get_access_token_then_log_request( def test_refresh_access_token(self, mocker, connector_config): authenticator = SingleUseRefreshTokenOauth2Authenticator( connector_config, - token_refresh_endpoint="foobar", + token_refresh_endpoint="https://refresh_endpoint.com", client_id=connector_config["credentials"]["client_id"], client_secret=connector_config["credentials"]["client_secret"], ) - authenticator._make_handled_request = mocker.Mock( + # Mock the response from the refresh token endpoint + resp.status_code = 200 + mocker.patch.object( + resp, + "json", return_value={ authenticator.get_access_token_name(): "new_access_token", authenticator.get_expires_in_name(): "42", authenticator.get_refresh_token_name(): "new_refresh_token", - } + }, ) + mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) + assert authenticator.refresh_access_token() == ( "new_access_token", - "42", + ab_datetime_now().add(timedelta(seconds=42)), "new_refresh_token", )