From dc9d9923e8c2d886ba8cf3c300d6c16dee9180eb Mon Sep 17 00:00:00 2001 From: Albert Sola Date: Thu, 18 Jun 2026 11:59:07 +0100 Subject: [PATCH] feat(auth): add account-scoped authentication provider [MPT-21532] Add AccountScopedAuthentication, an httpx.Auth provider that fetches account-scoped installation tokens and shares them across provider and client instances through a process-wide cache keyed by (secret, account_id). Token refreshes are serialized per account with double-checked locking, so concurrent callers trigger at most one token request; refresh happens proactively before the JWT exp and reactively on 401, and expired cache entries are evicted on write. Extract the shared token-fetch machinery into a new InstallationTokenAuthentication base and reparent ExtensionFrameworkAuthentication onto it with no behavior change. This brings the client to feature parity with the extension SDK's AccountTokenProvider and AccountScopedAsyncHTTPClient. Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/usage.md | 20 ++ mpt_api_client/__init__.py | 2 + mpt_api_client/auth/__init__.py | 9 +- mpt_api_client/auth/account_scoped.py | 171 +++++++++++++ mpt_api_client/auth/base.py | 90 +++++++ mpt_api_client/auth/extension_framework.py | 68 +----- pyproject.toml | 1 + tests/unit/auth/test_account_scoped.py | 267 +++++++++++++++++++++ 8 files changed, 562 insertions(+), 66 deletions(-) create mode 100644 mpt_api_client/auth/account_scoped.py create mode 100644 tests/unit/auth/test_account_scoped.py diff --git a/docs/usage.md b/docs/usage.md index 48edc5bf..8d5fe090 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -48,6 +48,12 @@ implementations are available: proactively once the token nears its JWT `exp` (default leeway 60s) and reactively on `401`. Pass `account_id` to request a token scoped to a specific account (`?account.id=`); use one provider instance per account scope. +- `AccountScopedAuthentication` — an always account-scoped token (`account_id` is required) + backed by a process-wide cache keyed by `(secret, account_id)`. Several provider or client + instances for the same account reuse a single cached token, and refreshes are serialized + per account, so concurrent callers trigger at most one token request. It refreshes + proactively (default leeway 60s) and reactively on `401`. Use this when many clients share + the same account scope or when many requests run concurrently. ## Instantiate The Client @@ -85,6 +91,20 @@ client = MPTClient.from_config( ) ``` +With an account-scoped token shared across clients and concurrent requests: + +```python +from mpt_api_client import MPTClient, AccountScopedAuthentication + +client = MPTClient.from_config( + authentication=AccountScopedAuthentication( + secret="", + account_id="", + ), + base_url="https://api.s1.show/public", +) +``` + `from_config` also accepts a `timeout` argument (HTTP request timeout in seconds, default `60.0`). ## Synchronous Usage Patterns diff --git a/mpt_api_client/__init__.py b/mpt_api_client/__init__.py index 7b1a89dd..bec2e959 100644 --- a/mpt_api_client/__init__.py +++ b/mpt_api_client/__init__.py @@ -1,4 +1,5 @@ from mpt_api_client.auth import ( + AccountScopedAuthentication, Authentication, BearerTokenAuthentication, ExtensionFrameworkAuthentication, @@ -7,6 +8,7 @@ from mpt_api_client.rql import RQLQuery __all__ = [ # noqa: WPS410 + "AccountScopedAuthentication", "AsyncMPTClient", "Authentication", "BearerTokenAuthentication", diff --git a/mpt_api_client/auth/__init__.py b/mpt_api_client/auth/__init__.py index bbe25d04..ebaa22d2 100644 --- a/mpt_api_client/auth/__init__.py +++ b/mpt_api_client/auth/__init__.py @@ -1,8 +1,15 @@ -from mpt_api_client.auth.base import Authentication, BearerTokenAuthentication +from mpt_api_client.auth.account_scoped import AccountScopedAuthentication +from mpt_api_client.auth.base import ( + Authentication, + BearerTokenAuthentication, + InstallationTokenAuthentication, +) from mpt_api_client.auth.extension_framework import ExtensionFrameworkAuthentication __all__ = [ # noqa: WPS410 + "AccountScopedAuthentication", "Authentication", "BearerTokenAuthentication", "ExtensionFrameworkAuthentication", + "InstallationTokenAuthentication", ] diff --git a/mpt_api_client/auth/account_scoped.py b/mpt_api_client/auth/account_scoped.py new file mode 100644 index 00000000..8de36304 --- /dev/null +++ b/mpt_api_client/auth/account_scoped.py @@ -0,0 +1,171 @@ +"""Account-scoped authentication for the MPT integration API. + +This provider fetches account-scoped installation tokens and shares them across instances +through a process-wide cache keyed by ``(secret, account_id)``. Token fetches are serialized +per account, so concurrent callers for the same account trigger at most one token request. +""" + +import asyncio +import datetime as dt +import threading +from collections.abc import AsyncGenerator, Generator +from dataclasses import dataclass +from typing import ClassVar, override + +import httpx + +from mpt_api_client.auth.base import InstallationTokenAuthentication +from mpt_api_client.exceptions import MPTError + +DEFAULT_TOKEN_VALIDITY_LEEWAY_SECONDS = 60 + +CacheKey = tuple[str, str] + + +@dataclass(frozen=True) +class _CachedToken: + """A cached account token together with its decoded expiry.""" + + token: str + expires_at: dt.datetime | None + + +class AccountScopedAuthentication(InstallationTokenAuthentication): # noqa: WPS214 + """Authenticate with an account-scoped token from a shared, concurrency-safe cache. + + Tokens are cached process-wide keyed by ``(secret, account_id)``, so several provider or + client instances for the same account reuse a single token. Refresh is serialized per + account through a lock with double-checked caching: concurrent callers for the same + account trigger at most one token request. Refresh happens proactively once the token is + within ``min_remaining_validity_seconds`` of its JWT ``exp`` claim, with a reactive + refresh on ``401 Unauthorized`` for tokens revoked before they expire. When the fetched + token carries no readable ``exp`` claim, proactive refresh is skipped and only the + reactive ``401`` path applies. + + The token call is delegated to :class:`InstallationsTokenService` (and its async + counterpart) over a dedicated client authenticated with the extension secret; that + client's base URL is supplied by the owning HTTP client through :meth:`configure`. + """ + + _token_cache: ClassVar[dict[CacheKey, _CachedToken]] = {} + _sync_locks: ClassVar[dict[CacheKey, threading.Lock]] = {} + _async_locks: ClassVar[dict[CacheKey, asyncio.Lock]] = {} + + def __init__( + self, + secret: str, + account_id: str, + min_remaining_validity_seconds: int = DEFAULT_TOKEN_VALIDITY_LEEWAY_SECONDS, + ) -> None: + """Initialize the provider. + + Args: + secret: Extension secret used to authenticate token requests. + account_id: Account the requested token is scoped to. + min_remaining_validity_seconds: Proactive refresh leeway before the JWT ``exp``. + """ + super().__init__(secret) + self._account_id = account_id + self._min_remaining_validity_seconds = min_remaining_validity_seconds + + @classmethod + def clear_cache(cls) -> None: + """Clear all cached account tokens and refresh locks.""" + cls._token_cache.clear() + cls._sync_locks.clear() + cls._async_locks.clear() + + @override + def sync_auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + """Attach an account-scoped token, refreshing it proactively and on 401.""" + token = self._token_sync() + request.headers["Authorization"] = f"Bearer {token}" + response = yield request + if response.status_code == httpx.codes.UNAUTHORIZED: + rejected = request.headers["Authorization"].removeprefix("Bearer ") + request.headers["Authorization"] = f"Bearer {self._token_sync(rejected)}" + yield request + + @override + async def async_auth_flow( + self, request: httpx.Request + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + """Attach an account-scoped token, refreshing it proactively and on 401.""" + token = await self._token_async() + request.headers["Authorization"] = f"Bearer {token}" + response = yield request + if response.status_code == httpx.codes.UNAUTHORIZED: + rejected = request.headers["Authorization"].removeprefix("Bearer ") + refreshed = await self._token_async(rejected) + request.headers["Authorization"] = f"Bearer {refreshed}" + yield request + + @property + def _cache_key(self) -> CacheKey: + """Return the shared-cache key for this provider's scope.""" + return self._secret, self._account_id + + def _token_sync(self, rejected: str | None = None) -> str: + """Return a usable token, fetching one under a per-account lock when needed.""" + cached = self._token_cache.get(self._cache_key) + if self._is_usable(cached, rejected): + return cached.token # type: ignore[union-attr] + + lock = self._sync_locks.setdefault(self._cache_key, threading.Lock()) + with lock: + cached = self._token_cache.get(self._cache_key) + if self._is_usable(cached, rejected): + return cached.token # type: ignore[union-attr] + fetched = self._get_sync_service().token(self._account_id) + return self._store(fetched.token) + + async def _token_async(self, rejected: str | None = None) -> str: + """Return a usable token, fetching one under a per-account lock when needed.""" + cached = self._token_cache.get(self._cache_key) + if self._is_usable(cached, rejected): + return cached.token # type: ignore[union-attr] + + lock = self._async_locks.setdefault(self._cache_key, asyncio.Lock()) + async with lock: + cached = self._token_cache.get(self._cache_key) + if self._is_usable(cached, rejected): + return cached.token # type: ignore[union-attr] + fetched = await self._get_async_service().token(self._account_id) + return self._store(fetched.token) + + def _is_usable(self, cached: _CachedToken | None, rejected: str | None) -> bool: + """Return whether the cached token can be reused for the current request. + + A token is unusable when it is missing, when it equals a token the server just + rejected, or when it is within the proactive refresh leeway of its expiry. Tokens + without a readable ``exp`` are reused and rely on the reactive ``401`` path. + """ + if cached is None or cached.token == rejected: + return False + if cached.expires_at is None: + return True + threshold = dt.datetime.now(dt.UTC).timestamp() + self._min_remaining_validity_seconds + return cached.expires_at.timestamp() > threshold + + def _store(self, token: str | None) -> str: + """Cache a freshly fetched token, evicting expired entries, and return it.""" + if not token: + raise MPTError("Installations token endpoint returned an empty token.") + self._token_cache[self._cache_key] = _CachedToken(token, self._read_expiry(token)) + self._evict_expired() + return token + + def _evict_expired(self) -> None: + """Drop cache entries (and their locks) whose tokens have already expired.""" + now = dt.datetime.now(dt.UTC) + expired_keys = [ + key + for key, cached in self._token_cache.items() + if cached.expires_at is not None and cached.expires_at <= now + ] + for key in expired_keys: + self._token_cache.pop(key, None) + self._sync_locks.pop(key, None) + self._async_locks.pop(key, None) diff --git a/mpt_api_client/auth/base.py b/mpt_api_client/auth/base.py index 9abf2b5b..a1122faa 100644 --- a/mpt_api_client/auth/base.py +++ b/mpt_api_client/auth/base.py @@ -4,11 +4,27 @@ the sync and the async HTTP clients. """ +import datetime as dt from collections.abc import Generator from typing import override import httpx +from mpt_api_client.auth.jwt import ( + JWTClaimsError, + JWTFormatError, + decode_unverified_jwt_claims, +) +from mpt_api_client.exceptions import MPTError +from mpt_api_client.http import AsyncHTTPClient, HTTPClient +from mpt_api_client.resources.integration.installations_token import ( + AsyncInstallationsTokenService, + InstallationsTokenService, +) + +DEFAULT_TOKEN_CLIENT_TIMEOUT_SECONDS = 20.0 +DEFAULT_TOKEN_CLIENT_RETRIES = 5 + class Authentication(httpx.Auth): """Base class for MPT API authentication providers.""" @@ -38,3 +54,77 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re """Attach the bearer token to the outgoing request.""" request.headers["Authorization"] = f"Bearer {self._token}" yield request + + +class InstallationTokenAuthentication(Authentication): + """Base for providers backed by the integration installations token endpoint. + + Holds the extension secret, captures the owning client's configuration through + :meth:`configure`, and lazily builds a dedicated token client (authenticated with the + extension secret) that hosts :class:`InstallationsTokenService` and its async counterpart. + Subclasses implement the caching and ``auth_flow`` behavior. + """ + + def __init__(self, secret: str) -> None: + """Initialize the provider. + + Args: + secret: Extension secret used to authenticate token requests. + """ + self._secret = secret + self._base_url: str | None = None + self._timeout: float = DEFAULT_TOKEN_CLIENT_TIMEOUT_SECONDS + self._retries: int = DEFAULT_TOKEN_CLIENT_RETRIES + self._sync_service: InstallationsTokenService | None = None + self._async_service: AsyncInstallationsTokenService | None = None + + @override + def configure(self, *, base_url: str, timeout: float, retries: int) -> None: + """Store the owning client's configuration used to build the token client.""" + self._base_url = base_url + self._timeout = timeout + self._retries = retries + + def _get_sync_service(self) -> InstallationsTokenService: + """Return the cached sync token service, building it on first use.""" + if self._sync_service is None: + token_client = HTTPClient( + authentication=BearerTokenAuthentication(self._secret), + base_url=self._require_base_url(), + timeout=self._timeout, + retries=self._retries, + ) + self._sync_service = InstallationsTokenService(http_client=token_client) + return self._sync_service + + def _get_async_service(self) -> AsyncInstallationsTokenService: + """Return the cached async token service, building it on first use.""" + if self._async_service is None: + token_client = AsyncHTTPClient( + authentication=BearerTokenAuthentication(self._secret), + base_url=self._require_base_url(), + timeout=self._timeout, + retries=self._retries, + ) + self._async_service = AsyncInstallationsTokenService(http_client=token_client) + return self._async_service + + def _require_base_url(self) -> str: + """Return the configured base URL, raising when the provider is unconfigured.""" + if self._base_url is None: + raise MPTError( + f"{type(self).__name__} must be used with an MPT HTTPClient or AsyncHTTPClient; " + "the base URL was not configured.", + ) + return self._base_url + + def _read_expiry(self, token: str) -> dt.datetime | None: + """Read the ``exp`` claim from the token, ignoring tokens without one.""" + try: + claims = decode_unverified_jwt_claims(token) + except (JWTFormatError, JWTClaimsError): + return None + exp = claims.get("exp") + if not isinstance(exp, int): + return None + return dt.datetime.fromtimestamp(exp, tz=dt.UTC) diff --git a/mpt_api_client/auth/extension_framework.py b/mpt_api_client/auth/extension_framework.py index 5d52d484..2bfef5d7 100644 --- a/mpt_api_client/auth/extension_framework.py +++ b/mpt_api_client/auth/extension_framework.py @@ -11,19 +11,13 @@ import httpx -from mpt_api_client.auth.base import Authentication, BearerTokenAuthentication -from mpt_api_client.auth.jwt import JWTClaimsError, JWTFormatError, decode_unverified_jwt_claims +from mpt_api_client.auth.base import InstallationTokenAuthentication from mpt_api_client.exceptions import MPTError -from mpt_api_client.http import AsyncHTTPClient, HTTPClient -from mpt_api_client.resources.integration.installations_token import ( - AsyncInstallationsTokenService, - InstallationsTokenService, -) DEFAULT_TOKEN_VALIDITY_LEEWAY_SECONDS = 60 -class ExtensionFrameworkAuthentication(Authentication): +class ExtensionFrameworkAuthentication(InstallationTokenAuthentication): """Authenticate with a short-lived installation or account-scoped token. The token is fetched through the installations token service using the extension secret @@ -53,23 +47,11 @@ def __init__( account_id: When set, request a token scoped to this account. min_remaining_validity_seconds: Proactive refresh leeway before the JWT ``exp``. """ - self._secret = secret + super().__init__(secret) self._account_id = account_id self._min_remaining_validity_seconds = min_remaining_validity_seconds self._token: str | None = None self._expires_at: dt.datetime | None = None - self._base_url: str | None = None - self._timeout: float = 20.0 - self._retries: int = 5 - self._sync_service: InstallationsTokenService | None = None - self._async_service: AsyncInstallationsTokenService | None = None - - @override - def configure(self, *, base_url: str, timeout: float, retries: int) -> None: - """Store the owning client's configuration used to build the token client.""" - self._base_url = base_url - self._timeout = timeout - self._retries = retries @override def sync_auth_flow( @@ -109,39 +91,6 @@ async def _refresh_async(self) -> None: token = await self._get_async_service().token(self._account_id) self._store(token.token) - def _get_sync_service(self) -> InstallationsTokenService: - """Return the cached sync token service, building it on first use.""" - if self._sync_service is None: - token_client = HTTPClient( - authentication=BearerTokenAuthentication(self._secret), - base_url=self._require_base_url(), - timeout=self._timeout, - retries=self._retries, - ) - self._sync_service = InstallationsTokenService(http_client=token_client) - return self._sync_service - - def _get_async_service(self) -> AsyncInstallationsTokenService: - """Return the cached async token service, building it on first use.""" - if self._async_service is None: - token_client = AsyncHTTPClient( - authentication=BearerTokenAuthentication(self._secret), - base_url=self._require_base_url(), - timeout=self._timeout, - retries=self._retries, - ) - self._async_service = AsyncInstallationsTokenService(http_client=token_client) - return self._async_service - - def _require_base_url(self) -> str: - """Return the configured base URL, raising when the provider is unconfigured.""" - if self._base_url is None: - raise MPTError( - "ExtensionFrameworkAuthentication must be used with an MPT HTTPClient or " - "AsyncHTTPClient; the base URL was not configured.", - ) - return self._base_url - def _store(self, token: str | None) -> None: """Cache a freshly fetched token and its expiry.""" if not token: @@ -149,17 +98,6 @@ def _store(self, token: str | None) -> None: self._token = token self._expires_at = self._read_expiry(token) - def _read_expiry(self, token: str) -> dt.datetime | None: - """Read the ``exp`` claim from the token, ignoring tokens without one.""" - try: - claims = decode_unverified_jwt_claims(token) - except (JWTFormatError, JWTClaimsError): - return None - exp = claims.get("exp") - if not isinstance(exp, int): - return None - return dt.datetime.fromtimestamp(exp, tz=dt.UTC) - def _is_expired(self) -> bool: """Return whether the cached token is within the refresh leeway of expiry.""" if self._expires_at is None: diff --git a/pyproject.toml b/pyproject.toml index d6be511e..a4b67f46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,7 @@ per-file-ignores = [ "tests/unit/resources/integration/*.py: WPS202 WPS210 WPS218 WPS453", "tests/unit/resources/integration/mixins/*.py: WPS453 WPS202", "tests/unit/auth/test_extension_framework.py: AAA01 AAA05 WPS118 WPS202 WPS204 WPS210 WPS218 WPS221 WPS430 WPS432 WPS453", + "tests/unit/auth/test_account_scoped.py: AAA01 AAA05 WPS118 WPS202 WPS204 WPS210 WPS218 WPS221 WPS430 WPS432 WPS437 WPS453", "tests/unit/resources/commerce/*.py: WPS202 WPS204", "tests/unit/resources/program/*.py: WPS202 WPS210 WPS218", "tests/unit/test_mpt_client.py: WPS235", diff --git a/tests/unit/auth/test_account_scoped.py b/tests/unit/auth/test_account_scoped.py new file mode 100644 index 00000000..ad6bacf5 --- /dev/null +++ b/tests/unit/auth/test_account_scoped.py @@ -0,0 +1,267 @@ +import asyncio +import base64 +import datetime as dt +import json + +import httpx +import pytest +import respx + +from mpt_api_client.auth import AccountScopedAuthentication +from mpt_api_client.exceptions import MPTAPIError, MPTError +from mpt_api_client.http import AsyncHTTPClient, HTTPClient +from tests.unit.conftest import API_URL + +SECRET = "extension-secret" +ACCOUNT_ID = "ACC-1" +TOKEN_URL = f"{API_URL}/public/v1/integration/installations/-/token" +ORDERS_URL = f"{API_URL}/orders" + + +def _jwt_with_exp(expires_at: dt.datetime, subject: str = "token") -> str: + def encode(payload: object) -> str: + raw = json.dumps(payload).encode("utf-8") + return base64.urlsafe_b64encode(raw).decode("utf-8").rstrip("=") + + claims = {"exp": int(expires_at.timestamp()), "sub": subject} + return f"{encode({'alg': 'none'})}.{encode(claims)}.signature" + + +pytestmark = pytest.mark.usefixtures("_clear_account_cache") + + +@pytest.fixture +def _clear_account_cache(): + AccountScopedAuthentication.clear_cache() + yield + AccountScopedAuthentication.clear_cache() + + +@respx.mock +def test_account_scoped_applies_token_and_sends_account_id(): + token_route = respx.post(TOKEN_URL).mock( + return_value=httpx.Response(200, json={"token": "account-token"}) + ) + target_route = respx.get(ORDERS_URL).mock(return_value=httpx.Response(200, json={"data": []})) + client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + + client.request("GET", "/orders") # act + + token_request = token_route.calls.last.request + target_request = target_route.calls.last.request + assert token_request.headers["Authorization"] == f"Bearer {SECRET}" + assert token_request.url.params["account.id"] == ACCOUNT_ID + assert target_request.headers["Authorization"] == "Bearer account-token" + + +@respx.mock +def test_shared_cache_across_clients_fetches_token_once(): + token_route = respx.post(TOKEN_URL).mock( + return_value=httpx.Response(200, json={"token": "account-token"}) + ) + respx.get(ORDERS_URL).mock(return_value=httpx.Response(200, json={"data": []})) + first_client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + second_client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + + first_client.request("GET", "/orders") # populates the shared cache + + second_client.request("GET", "/orders") # act: reuses the cached token + + assert token_route.call_count == 1 + + +@respx.mock +def test_different_accounts_fetch_separate_tokens(): + token_route = respx.post(TOKEN_URL).mock( + return_value=httpx.Response(200, json={"token": "account-token"}) + ) + respx.get(ORDERS_URL).mock(return_value=httpx.Response(200, json={"data": []})) + first_client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id="ACC-1"), + ) + second_client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id="ACC-2"), + ) + + first_client.request("GET", "/orders") + + second_client.request("GET", "/orders") # act: different scope, separate token + + assert token_route.call_count == 2 + fetched_accounts = {call.request.url.params["account.id"] for call in token_route.calls} + assert fetched_accounts == {"ACC-1", "ACC-2"} + + +@respx.mock +async def test_serialized_refresh_fetches_token_once_under_concurrency(): + token_route = respx.post(TOKEN_URL).mock( + return_value=httpx.Response(200, json={"token": "account-token"}) + ) + respx.get(ORDERS_URL).mock(return_value=httpx.Response(200, json={"data": []})) + client = AsyncHTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + + await asyncio.gather(*(client.request("GET", "/orders") for _ in range(10))) # act + + assert token_route.call_count == 1 + + +@respx.mock +def test_refreshes_proactively_before_expiry(): + near_expiry = dt.datetime.now(dt.UTC) + dt.timedelta(seconds=10) + far_expiry = dt.datetime.now(dt.UTC) + dt.timedelta(hours=1) + token_route = respx.post(TOKEN_URL).mock( + side_effect=[ + httpx.Response(200, json={"token": _jwt_with_exp(near_expiry, "stale")}), + httpx.Response(200, json={"token": _jwt_with_exp(far_expiry, "fresh")}), + ] + ) + respx.get(ORDERS_URL).mock(return_value=httpx.Response(200, json={"data": []})) + client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + + client.request("GET", "/orders") # caches a near-expiry token + + client.request("GET", "/orders") # act: token within leeway -> proactive refresh + + assert token_route.call_count == 2 + + +@respx.mock +def test_reuses_unexpired_token(): + far_expiry = dt.datetime.now(dt.UTC) + dt.timedelta(hours=1) + token_route = respx.post(TOKEN_URL).mock( + return_value=httpx.Response(200, json={"token": _jwt_with_exp(far_expiry)}) + ) + respx.get(ORDERS_URL).mock(return_value=httpx.Response(200, json={"data": []})) + client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + + client.request("GET", "/orders") # caches a long-lived token + + client.request("GET", "/orders") # act + + assert token_route.call_count == 1 + + +@respx.mock +def test_reactive_refresh_on_unauthorized(): + token_route = respx.post(TOKEN_URL).mock( + side_effect=[ + httpx.Response(200, json={"token": "stale-token"}), + httpx.Response(200, json={"token": "fresh-token"}), + ] + ) + target_route = respx.get(ORDERS_URL).mock( + side_effect=[ + httpx.Response(401, json={"error": "expired"}), + httpx.Response(200, json={"data": []}), + ] + ) + client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + + response = client.request("GET", "/orders") # act + + target_request = target_route.calls.last.request + assert response.status_code == httpx.codes.OK + assert token_route.call_count == 2 + assert target_request.headers["Authorization"] == "Bearer fresh-token" + + +@respx.mock +def test_surfaces_repeated_unauthorized(): + respx.post(TOKEN_URL).mock(return_value=httpx.Response(200, json={"token": "any-token"})) + target_route = respx.get(ORDERS_URL).mock( + return_value=httpx.Response(401, json={"error": "nope"}) + ) + client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + + with pytest.raises(MPTAPIError) as exc_info: # act + client.request("GET", "/orders") + + assert exc_info.value.status_code == httpx.codes.UNAUTHORIZED + assert target_route.call_count == 2 # original + exactly one retry, then surfaced + + +@respx.mock +async def test_reactive_refresh_on_unauthorized_async(): + token_route = respx.post(TOKEN_URL).mock( + side_effect=[ + httpx.Response(200, json={"token": "stale-token"}), + httpx.Response(200, json={"token": "fresh-token"}), + ] + ) + target_route = respx.get(ORDERS_URL).mock( + side_effect=[ + httpx.Response(401, json={"error": "expired"}), + httpx.Response(200, json={"data": []}), + ] + ) + client = AsyncHTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + + response = await client.request("GET", "/orders") # act + + target_request = target_route.calls.last.request + assert response.status_code == httpx.codes.OK + assert token_route.call_count == 2 + assert target_request.headers["Authorization"] == "Bearer fresh-token" + + +@respx.mock +def test_rejects_empty_token(): + respx.post(TOKEN_URL).mock(return_value=httpx.Response(200, json={"token": None})) + client = HTTPClient( + base_url=API_URL, + authentication=AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID), + ) + + with pytest.raises(MPTError): # act + client.request("GET", "/orders") + + +def test_requires_configuration(): + provider = AccountScopedAuthentication(SECRET, account_id=ACCOUNT_ID) + auth_flow = provider.sync_auth_flow(httpx.Request("GET", ORDERS_URL)) + + with pytest.raises(MPTError): # act + next(auth_flow) + + +def test_store_evicts_expired_entries(): + past = dt.datetime.now(dt.UTC) - dt.timedelta(hours=1) + future = dt.datetime.now(dt.UTC) + dt.timedelta(hours=1) + stale_provider = AccountScopedAuthentication(SECRET, account_id="ACC-OLD") + fresh_provider = AccountScopedAuthentication(SECRET, account_id="ACC-NEW") + + stale_provider._store(_jwt_with_exp(past)) # noqa: SLF001 + fresh_provider._store(_jwt_with_exp(future)) # noqa: SLF001 ; act: triggers eviction + + cache = AccountScopedAuthentication._token_cache # noqa: SLF001 + assert (SECRET, "ACC-OLD") not in cache + assert (SECRET, "ACC-NEW") in cache