Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ All versions prior to 0.9.0 are untracked.

* Added support for ed25519 keys.
[#1377](https://github.com/sigstore/sigstore-python/pull/1377)
* api: `IdentityToken` now supports `client_id` for audience claim validation.
[#1402](https://github.com/sigstore/sigstore-python/pull/1402)


### Fixed

Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ extra:
- icon: fontawesome/brands/slack
link: https://sigstore.slack.com
- icon: fontawesome/brands/x-twitter
link: https://twitter.com/projectsigstore
link: https://twitter.com/projectsigstore
6 changes: 3 additions & 3 deletions sigstore/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def _sign_common(
# 3) Interactive OAuth flow
identity: IdentityToken | None
if args.identity_token:
identity = IdentityToken(args.identity_token)
identity = IdentityToken(args.identity_token, args.oidc_client_id)
else:
identity = _get_identity(args, trust_config)

Expand Down Expand Up @@ -1181,11 +1181,11 @@ def _get_identity(
) -> Optional[IdentityToken]:
token = None
if not args.oidc_disable_ambient_providers:
token = detect_credential()
token = detect_credential(args.oidc_client_id)

# Happy path: we've detected an ambient credential, so we can return early.
if token:
return IdentityToken(token)
return IdentityToken(token, args.oidc_client_id)

if args.oidc_issuer is not None:
issuer = Issuer(args.oidc_issuer)
Expand Down
16 changes: 9 additions & 7 deletions sigstore/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
"https://oauth2.sigstage.dev/auth": "email",
"https://token.actions.githubusercontent.com": "sub",
}
_DEFAULT_AUDIENCE = "sigstore"

_DEFAULT_CLIENT_ID = "sigstore"


class _OpenIDConfiguration(BaseModel):
Expand All @@ -66,7 +67,7 @@ class IdentityToken:
a sensible subject, issuer, and audience for Sigstore purposes.
"""

def __init__(self, raw_token: str) -> None:
def __init__(self, raw_token: str, client_id: str) -> None:
"""
Create a new `IdentityToken` from the given OIDC token.
"""
Expand All @@ -90,7 +91,7 @@ def __init__(self, raw_token: str) -> None:
# See: https://openid.net/specs/openid-connect-basic-1_0.html#IDToken
"require": ["aud", "sub", "iat", "exp", "iss"],
},
audience=_DEFAULT_AUDIENCE,
audience=client_id,
# NOTE: This leeway shouldn't be strictly necessary, but is
# included to preempt any (small) skew between the host
# and the originating IdP.
Expand Down Expand Up @@ -270,7 +271,7 @@ def __init__(self, base_url: str) -> None:

def identity_token( # nosec: B107
self,
client_id: str = "sigstore",
client_id: str = _DEFAULT_CLIENT_ID,
client_secret: str = "",
force_oob: bool = False,
) -> IdentityToken:
Expand Down Expand Up @@ -350,7 +351,7 @@ def identity_token( # nosec: B107
if token_error is not None:
raise IdentityError(f"Error response from token endpoint: {token_error}")

return IdentityToken(token_json["access_token"])
return IdentityToken(token_json["access_token"], client_id)


class IdentityError(Error):
Expand Down Expand Up @@ -402,9 +403,10 @@ def diagnostics(self) -> str:
"""


def detect_credential() -> Optional[str]:
def detect_credential(client_id: str = _DEFAULT_CLIENT_ID) -> Optional[str]:
"""Calls `id.detect_credential`, but wraps exceptions with our own exception type."""

try:
return cast(Optional[str], id.detect_credential(_DEFAULT_AUDIENCE))
return cast(Optional[str], id.detect_credential(client_id))
except id.IdentityError as exc:
IdentityError.raise_from_id(exc)
6 changes: 3 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
detect_credential,
)

from sigstore.oidc import _DEFAULT_AUDIENCE

_ASSETS = (Path(__file__).parent / "assets").resolve()
assert _ASSETS.is_dir()

TEST_CLIENT_ID = "sigstore"


@pytest.fixture
def asset():
Expand All @@ -44,7 +44,7 @@ def _has_oidc_id():
return True

try:
token = detect_credential(_DEFAULT_AUDIENCE)
token = detect_credential(TEST_CLIENT_ID)
if token is None:
return False
except GitHubOidcPermissionCredentialError:
Expand Down
6 changes: 3 additions & 3 deletions test/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from sigstore._internal.trust import ClientTrustConfig
from sigstore._utils import sha256_digest
from sigstore.models import Bundle
from sigstore.oidc import _DEFAULT_AUDIENCE, IdentityToken
from sigstore.oidc import IdentityToken
from sigstore.sign import SigningContext
from sigstore.verify.verifier import Verifier

Expand Down Expand Up @@ -209,7 +209,7 @@ def ctx_cls():
token = os.getenv(f"SIGSTORE_IDENTITY_TOKEN_{env}")
if not token:
# If the variable is not defined, try getting an ambient token.
token = detect_credential(_DEFAULT_AUDIENCE)
token = detect_credential()

return ctx_cls, IdentityToken(token)

Expand All @@ -230,7 +230,7 @@ def signer():
token = os.getenv("SIGSTORE_IDENTITY_TOKEN_staging")
if not token:
# If the variable is not defined, try getting an ambient token.
token = detect_credential(_DEFAULT_AUDIENCE)
token = detect_credential()

return signer, verifier, IdentityToken(token)

Expand Down
28 changes: 15 additions & 13 deletions test/unit/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

from sigstore import oidc

TEST_CLIENT_ID = "sigstore"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this (and the uses below) can be removed now that we have a proper kwarg default in place, right?



class TestIdentityToken:
def test_invalid_jwt(self):
with pytest.raises(
oidc.IdentityError, match="Identity token is malformed or missing claims"
):
oidc.IdentityToken("invalid jwt")
oidc.IdentityToken("invalid jwt", TEST_CLIENT_ID)

def test_missing_iss(self, dummy_jwt):
now = int(datetime.datetime.now().timestamp())
Expand All @@ -41,7 +43,7 @@ def test_missing_iss(self, dummy_jwt):
with pytest.raises(
oidc.IdentityError, match="Identity token is malformed or missing claims"
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

def test_missing_aud(self, dummy_jwt):
now = int(datetime.datetime.now().timestamp())
Expand All @@ -58,7 +60,7 @@ def test_missing_aud(self, dummy_jwt):
with pytest.raises(
oidc.IdentityError, match="Identity token is malformed or missing claims"
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

@pytest.mark.parametrize("aud", (None, "not-sigstore"))
def test_invalid_aud(self, dummy_jwt, aud):
Expand All @@ -77,7 +79,7 @@ def test_invalid_aud(self, dummy_jwt, aud):
with pytest.raises(
oidc.IdentityError, match="Identity token is malformed or missing claims"
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

def test_missing_iat(self, dummy_jwt):
now = int(datetime.datetime.now().timestamp())
Expand All @@ -94,7 +96,7 @@ def test_missing_iat(self, dummy_jwt):
with pytest.raises(
oidc.IdentityError, match="Identity token is malformed or missing claims"
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

@pytest.mark.parametrize("iat", (None, "not-an-int"))
def test_invalid_iat(self, dummy_jwt, iat):
Expand All @@ -113,7 +115,7 @@ def test_invalid_iat(self, dummy_jwt, iat):
with pytest.raises(
oidc.IdentityError, match="Identity token is malformed or missing claims"
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

def test_missing_nbf_ok(self, dummy_jwt):
now = int(datetime.datetime.now().timestamp())
Expand All @@ -127,7 +129,7 @@ def test_missing_nbf_ok(self, dummy_jwt):
}
)

assert oidc.IdentityToken(jwt) is not None
assert oidc.IdentityToken(jwt, TEST_CLIENT_ID) is not None

def test_invalid_nbf(self, dummy_jwt):
now = int(datetime.datetime.now().timestamp())
Expand All @@ -146,7 +148,7 @@ def test_invalid_nbf(self, dummy_jwt):
oidc.IdentityError,
match="Identity token is not within its validity period",
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

def test_missing_exp(self, dummy_jwt):
now = int(datetime.datetime.now().timestamp())
Expand All @@ -163,7 +165,7 @@ def test_missing_exp(self, dummy_jwt):
with pytest.raises(
oidc.IdentityError, match="Identity token is malformed or missing claims"
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

def test_invalid_exp(self, dummy_jwt):
now = int(datetime.datetime.now().timestamp())
Expand All @@ -182,7 +184,7 @@ def test_invalid_exp(self, dummy_jwt):
with pytest.raises(
oidc.IdentityError, match="Identity token is malformed or missing claims"
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

@pytest.mark.parametrize(
"iss", [k for k, v in oidc._KNOWN_OIDC_ISSUERS.items() if v != "sub"]
Expand All @@ -204,7 +206,7 @@ def test_missing_identity_claim(self, dummy_jwt, iss):
oidc.IdentityError,
match=r"Identity token is missing the required '.+' claim",
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

@pytest.mark.parametrize("fed", ("notadict", {"connector_id": 123}))
def test_invalid_federated_claims(self, dummy_jwt, fed):
Expand All @@ -226,7 +228,7 @@ def test_invalid_federated_claims(self, dummy_jwt, fed):
oidc.IdentityError,
match="unexpected claim type: federated_claims.*",
):
oidc.IdentityToken(jwt)
oidc.IdentityToken(jwt, TEST_CLIENT_ID)

@pytest.mark.parametrize(
("iss", "identity_claim", "identity_value", "fed_iss"),
Expand Down Expand Up @@ -263,7 +265,7 @@ def test_ok(self, dummy_jwt, iss, identity_claim, identity_value, fed_iss):
}
)

identity = oidc.IdentityToken(jwt)
identity = oidc.IdentityToken(jwt, TEST_CLIENT_ID)
assert identity.in_validity_period()
assert identity.identity == identity_value
assert identity.issuer == iss
Expand Down