diff --git a/wfe2/wfe.go b/wfe2/wfe.go index b2d0a8bba0e..9e3b79d359c 100644 --- a/wfe2/wfe.go +++ b/wfe2/wfe.go @@ -424,7 +424,7 @@ func (wfe *WebFrontEndImpl) Handler(stats prometheus.Registerer) http.Handler { // Endpoint for draft-aaron-ari if features.Enabled(features.ServeRenewalInfo) { - wfe.HandleFunc(m, renewalInfoPath, wfe.RenewalInfo, "GET") + wfe.HandleFunc(m, renewalInfoPath, wfe.RenewalInfo, "GET", "POST") } // Non-ACME endpoints @@ -2257,6 +2257,11 @@ func (wfe *WebFrontEndImpl) RenewalInfo(ctx context.Context, logEvent *web.Reque return } + if request.Method == http.MethodPost { + wfe.UpdateRenewal(ctx, logEvent, response, request) + return + } + if len(request.URL.Path) == 0 { wfe.sendError(response, logEvent, probs.NotFound("Must specify a request path"), nil) return @@ -2266,30 +2271,26 @@ func (wfe *WebFrontEndImpl) RenewalInfo(ctx context.Context, logEvent *web.Reque // the base64url-encoded DER CertID sequence. der, err := base64.RawURLEncoding.DecodeString(request.URL.Path) if err != nil { - wfe.sendError(response, logEvent, probs.Malformed("Path was not base64url-encoded"), nil) + wfe.sendError(response, logEvent, probs.Malformed("Path was not base64url-encoded or had padding"), err) return } var id certID rest, err := asn1.Unmarshal(der, &id) if err != nil || len(rest) != 0 { - wfe.sendError(response, logEvent, probs.Malformed("Path was not a DER-encoded CertID sequence"), nil) + wfe.sendError(response, logEvent, probs.Malformed("Path was not a DER-encoded CertID sequence"), err) return } // Verify that the hash algorithm is SHA-256, so people don't use SHA-1 here. if !id.HashAlgorithm.Algorithm.Equal(asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1}) { - wfe.sendError(response, logEvent, probs.Malformed("Request used hash algorithm other than SHA-256"), nil) + wfe.sendError(response, logEvent, probs.Malformed("Request used hash algorithm other than SHA-256"), err) return } // We can do all of our processing based just on the serial, because Boulder // does not re-use the same serial across multiple issuers. serial := core.SerialToString(id.SerialNumber) - if !core.ValidSerial(serial) { - wfe.sendError(response, logEvent, probs.NotFound("Certificate not found"), nil) - return - } logEvent.Extra["RequestedSerial"] = serial setDefaultRetryAfterHeader := func(response http.ResponseWriter) { @@ -2356,6 +2357,72 @@ func (wfe *WebFrontEndImpl) RenewalInfo(ctx context.Context, logEvent *web.Reque time.Unix(0, cert.Expires).UTC())) } +// UpdateRenewal is used by the client to inform the server that they have +// replaced the certificate in question, so it can be safely revoked. All +// requests must be authenticated to the account which ordered the cert. +func (wfe *WebFrontEndImpl) UpdateRenewal(ctx context.Context, logEvent *web.RequestEvent, response http.ResponseWriter, request *http.Request) { + if !features.Enabled(features.ServeRenewalInfo) { + wfe.sendError(response, logEvent, probs.NotFound("Feature not enabled"), nil) + return + } + + body, _, acct, prob := wfe.validPOSTForAccount(request, ctx, logEvent) + addRequesterHeader(response, logEvent.Requester) + if prob != nil { + // validPOSTForAccount handles its own setting of logEvent.Errors + wfe.sendError(response, logEvent, prob, nil) + return + } + + var updateRenewalRequest struct { + CertID string `json:"certID"` + Replaced bool `json:"replaced"` + } + err := json.Unmarshal(body, &updateRenewalRequest) + if err != nil { + wfe.sendError(response, logEvent, probs.Malformed("Unable to unmarshal RenewalInfo POST request body"), err) + return + } + + der, err := base64.RawURLEncoding.DecodeString(updateRenewalRequest.CertID) + if err != nil { + wfe.sendError(response, logEvent, probs.Malformed("certID was not base64url-encoded or contained padding"), err) + return + } + + var id certID + rest, err := asn1.Unmarshal(der, &id) + if err != nil || len(rest) != 0 { + wfe.sendError(response, logEvent, probs.Malformed("certID was not a DER-encoded CertID ASN.1 sequence"), err) + return + } + + if !id.HashAlgorithm.Algorithm.Equal(asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1}) { + wfe.sendError(response, logEvent, probs.Malformed("Decoded CertID used a hashAlgorithm other than SHA-256"), err) + return + } + + // We can do all of our processing based just on the serial, because Boulder + // does not re-use the same serial across multiple issuers. + serial := core.SerialToString(id.SerialNumber) + logEvent.Extra["RequestedSerial"] = serial + + metadata, err := wfe.sa.GetSerialMetadata(ctx, &sapb.Serial{Serial: serial}) + if err != nil { + wfe.sendError(response, logEvent, probs.NotFound("Certificate not found"), err) + return + } + + if acct.ID != metadata.RegistrationID { + wfe.sendError(response, logEvent, probs.Unauthorized("Account ID doesn't match ID for certificate"), err) + return + } + + // TODO(#6732): Write the replaced status to persistent storage. + + response.WriteHeader(http.StatusOK) +} + func extractRequesterIP(req *http.Request) (net.IP, error) { ip := net.ParseIP(req.Header.Get("X-Real-IP")) if ip != nil { diff --git a/wfe2/wfe_test.go b/wfe2/wfe_test.go index 3bb9fd67fa4..b958e89cce7 100644 --- a/wfe2/wfe_test.go +++ b/wfe2/wfe_test.go @@ -3459,12 +3459,14 @@ func TestARI(t *testing.T) { msa := newMockSAWithCert(t, wfe.sa) wfe.sa = msa + err := features.Set(map[string]bool{"ServeRenewalInfo": true}) + test.AssertNotError(t, err, "setting feature flag") + defer features.Reset() + makeGet := func(path, endpoint string) (*http.Request, *web.RequestEvent) { return &http.Request{URL: &url.URL{Path: path}, Method: "GET"}, &web.RequestEvent{Endpoint: endpoint, Extra: map[string]interface{}{}} } - _ = features.Set(map[string]bool{"ServeRenewalInfo": true}) - defer features.Reset() // Load the certificate and its issuer. cert, err := core.LoadCert("../test/hierarchy/ee-r3.cert.pem") @@ -3586,12 +3588,14 @@ func TestIncidentARI(t *testing.T) { expectSerialString := core.SerialToString(big.NewInt(12345)) wfe.sa = newMockSAWithIncident(wfe.sa, []string{expectSerialString}) + err := features.Set(map[string]bool{"ServeRenewalInfo": true}) + test.AssertNotError(t, err, "setting feature flag") + defer features.Reset() + makeGet := func(path, endpoint string) (*http.Request, *web.RequestEvent) { return &http.Request{URL: &url.URL{Path: path}, Method: "GET"}, &web.RequestEvent{Endpoint: endpoint, Extra: map[string]interface{}{}} } - _ = features.Set(map[string]bool{"ServeRenewalInfo": true}) - defer features.Reset() idBytes, err := asn1.Marshal(certID{ pkix.AlgorithmIdentifier{ // SHA256 @@ -3620,6 +3624,157 @@ func TestIncidentARI(t *testing.T) { test.AssertEquals(t, ri.SuggestedWindow.End.Before(wfe.clk.Now()), true) } +type mockSAWithSerialMetadata struct { + sapb.StorageAuthorityReadOnlyClient + serial string + regID int64 +} + +// GetSerialMetadata returns fake metadata if it recognizes the given serial. +func (sa *mockSAWithSerialMetadata) GetSerialMetadata(_ context.Context, req *sapb.Serial, _ ...grpc.CallOption) (*sapb.SerialMetadata, error) { + if req.Serial != sa.serial { + return nil, berrors.NotFoundError("metadata for certificate with serial %q not found", req.Serial) + } + + return &sapb.SerialMetadata{ + Serial: sa.serial, + RegistrationID: sa.regID, + }, nil +} + +// TestUpdateARI tests that requests for real certs issued to the correct regID +// are accepted, while all others result in errors. +func TestUpdateARI(t *testing.T) { + wfe, _, signer := setupWFE(t) + + err := features.Set(map[string]bool{"ServeRenewalInfo": true}) + test.AssertNotError(t, err, "setting feature flag") + defer features.Reset() + + makePost := func(regID int64, body string) *http.Request { + signedURL := fmt.Sprintf("http://localhost%s", renewalInfoPath) + _, _, jwsBody := signer.byKeyID(regID, nil, signedURL, body) + return makePostRequestWithPath(renewalInfoPath, jwsBody) + } + + type jsonReq struct { + CertID string `json:"certID"` + Replaced bool `json:"replaced"` + } + + // Load a cert, its issuer, and use OCSP to compute issuer name/key hashes. + cert, err := core.LoadCert("../test/hierarchy/ee-r3.cert.pem") + test.AssertNotError(t, err, "failed to load test certificate") + issuer, err := core.LoadCert("../test/hierarchy/int-r3.cert.pem") + test.AssertNotError(t, err, "failed to load test issuer") + ocspReqBytes, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256}) + test.AssertNotError(t, err, "failed to create ocsp request") + ocspReq, err := ocsp.ParseRequest(ocspReqBytes) + test.AssertNotError(t, err, "failed to parse ocsp request") + + // Set up the mock SA. + msa := mockSAWithSerialMetadata{wfe.sa, core.SerialToString(cert.SerialNumber), 1} + wfe.sa = &msa + + // An empty POST should result in an error. + req := makePost(1, "") + responseWriter := httptest.NewRecorder() + wfe.UpdateRenewal(ctx, newRequestEvent(), responseWriter, req) + test.AssertEquals(t, responseWriter.Code, http.StatusBadRequest) + + // Non-certID base64 should result in an error. + req = makePost(1, "aGVsbG8gd29ybGQK") // $ echo "hello world" | base64 + responseWriter = httptest.NewRecorder() + wfe.UpdateRenewal(ctx, newRequestEvent(), responseWriter, req) + test.AssertEquals(t, responseWriter.Code, http.StatusBadRequest) + + // Non-sha256 hash algorithm should result in an error. + idBytes, err := asn1.Marshal(certID{ + pkix.AlgorithmIdentifier{ // definitely not SHA256 + Algorithm: asn1.ObjectIdentifier{1, 2, 3, 4, 5}, + Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, + }, + ocspReq.IssuerNameHash, + ocspReq.IssuerKeyHash, + cert.SerialNumber, + }) + test.AssertNotError(t, err, "failed to marshal certID") + body, err := json.Marshal(jsonReq{ + CertID: base64.RawURLEncoding.EncodeToString(idBytes), + Replaced: true, + }) + test.AssertNotError(t, err, "failed to marshal request body") + req = makePost(1, string(body)) + responseWriter = httptest.NewRecorder() + wfe.UpdateRenewal(ctx, newRequestEvent(), responseWriter, req) + test.AssertEquals(t, responseWriter.Code, http.StatusBadRequest) + + // Unrecognized serial should result in an error. + idBytes, err = asn1.Marshal(certID{ + pkix.AlgorithmIdentifier{ // SHA256 + Algorithm: asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1}, + Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, + }, + ocspReq.IssuerNameHash, + ocspReq.IssuerKeyHash, + big.NewInt(12345), + }) + test.AssertNotError(t, err, "failed to marshal certID") + body, err = json.Marshal(jsonReq{ + CertID: base64.RawURLEncoding.EncodeToString(idBytes), + Replaced: true, + }) + test.AssertNotError(t, err, "failed to marshal request body") + req = makePost(1, string(body)) + responseWriter = httptest.NewRecorder() + wfe.UpdateRenewal(ctx, newRequestEvent(), responseWriter, req) + test.AssertEquals(t, responseWriter.Code, http.StatusNotFound) + + // Recognized serial but owned by the wrong account should result in an error. + msa.regID = 2 + idBytes, err = asn1.Marshal(certID{ + pkix.AlgorithmIdentifier{ // SHA256 + Algorithm: asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1}, + Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, + }, + ocspReq.IssuerNameHash, + ocspReq.IssuerKeyHash, + cert.SerialNumber, + }) + test.AssertNotError(t, err, "failed to marshal certID") + body, err = json.Marshal(jsonReq{ + CertID: base64.RawURLEncoding.EncodeToString(idBytes), + Replaced: true, + }) + test.AssertNotError(t, err, "failed to marshal request body") + req = makePost(1, string(body)) + responseWriter = httptest.NewRecorder() + wfe.UpdateRenewal(ctx, newRequestEvent(), responseWriter, req) + test.AssertEquals(t, responseWriter.Code, http.StatusForbidden) + + // Recognized serial and owned by the right account should work. + msa.regID = 1 + idBytes, err = asn1.Marshal(certID{ + pkix.AlgorithmIdentifier{ // SHA256 + Algorithm: asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1}, + Parameters: asn1.RawValue{Tag: 5 /* ASN.1 NULL */}, + }, + ocspReq.IssuerNameHash, + ocspReq.IssuerKeyHash, + cert.SerialNumber, + }) + test.AssertNotError(t, err, "failed to marshal certID") + body, err = json.Marshal(jsonReq{ + CertID: base64.RawURLEncoding.EncodeToString(idBytes), + Replaced: true, + }) + test.AssertNotError(t, err, "failed to marshal request body") + req = makePost(1, string(body)) + responseWriter = httptest.NewRecorder() + wfe.UpdateRenewal(ctx, newRequestEvent(), responseWriter, req) + test.AssertEquals(t, responseWriter.Code, http.StatusOK) +} + func TestOldTLSInbound(t *testing.T) { wfe, _, _ := setupWFE(t) req := &http.Request{