Skip to content

Commit 99e6660

Browse files
committed
validate generic headers
Signed-off-by: qmuntal <qmuntaldiaz@microsoft.com>
1 parent 877c58e commit 99e6660

File tree

5 files changed

+232
-191
lines changed

5 files changed

+232
-191
lines changed

example_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func ExampleSignMessage() {
1818
// create a signature holder
1919
sigHolder := cose.NewSignature()
2020
sigHolder.Headers.Protected.SetAlgorithm(cose.AlgorithmES512)
21-
sigHolder.Headers.Unprotected[cose.HeaderLabelKeyID] = 1
21+
sigHolder.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1")
2222

2323
// create message to be signed
2424
msgToSign := cose.NewSignMessage()
@@ -84,7 +84,7 @@ func ExampleSign1Message() {
8484
msgToSign := cose.NewSign1Message()
8585
msgToSign.Payload = []byte("hello world")
8686
msgToSign.Headers.Protected.SetAlgorithm(cose.AlgorithmES512)
87-
msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = 1
87+
msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1")
8888

8989
// create a signer
9090
privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
@@ -157,7 +157,7 @@ func ExampleSign1() {
157157
cose.HeaderLabelAlgorithm: cose.AlgorithmES512,
158158
},
159159
Unprotected: cose.UnprotectedHeader{
160-
cose.HeaderLabelKeyID: 1,
160+
cose.HeaderLabelKeyID: []byte("1"),
161161
},
162162
}
163163
sig, err := cose.Sign1(rand.Reader, signer, headers, []byte("hello world"), nil)

headers.go

Lines changed: 111 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,8 @@ func (h ProtectedHeader) MarshalCBOR() ([]byte, error) {
3838
if len(h) == 0 {
3939
encoded = []byte{}
4040
} else {
41-
err := validateHeaderLabel(h)
41+
err := validateHeaderParameters(h, true)
4242
if err != nil {
43-
return nil, err
44-
}
45-
if err = h.ensureCritical(); err != nil {
46-
return nil, err
47-
}
48-
if err = ensureHeaderIV(h); err != nil {
4943
return nil, fmt.Errorf("protected header: %w", err)
5044
}
5145
encoded, err = encMode.Marshal(map[interface{}]interface{}(h))
@@ -85,10 +79,7 @@ func (h *ProtectedHeader) UnmarshalCBOR(data []byte) error {
8579
return err
8680
}
8781
candidate := ProtectedHeader(header)
88-
if err := candidate.ensureCritical(); err != nil {
89-
return err
90-
}
91-
if err := ensureHeaderIV(candidate); err != nil {
82+
if err := validateHeaderParameters(candidate, true); err != nil {
9283
return fmt.Errorf("protected header: %w", err)
9384
}
9485

@@ -140,29 +131,28 @@ func (h ProtectedHeader) Critical() ([]interface{}, error) {
140131
if !ok {
141132
return nil, nil
142133
}
143-
criticalLabels, ok := value.([]interface{})
144-
if !ok {
145-
return nil, errors.New("invalid crit header")
146-
}
147-
// if present, the array MUST have at least one value in it.
148-
if len(criticalLabels) == 0 {
149-
return nil, errors.New("empty crit header")
134+
err := ensureCritical(value, h)
135+
if err != nil {
136+
return nil, err
150137
}
151-
return criticalLabels, nil
138+
return value.([]interface{}), nil
152139
}
153140

154141
// ensureCritical ensures all critical headers are present in the protected bucket.
155-
func (h ProtectedHeader) ensureCritical() error {
156-
labels, err := h.Critical()
157-
if err != nil {
158-
return err
142+
func ensureCritical(value interface{}, headers map[interface{}]interface{}) error {
143+
labels, ok := value.([]interface{})
144+
if !ok {
145+
return errors.New("invalid crit header")
146+
}
147+
// if present, the array MUST have at least one value in it.
148+
if len(labels) == 0 {
149+
return errors.New("empty crit header")
159150
}
160151
for _, label := range labels {
161-
_, ok := normalizeLabel(label)
162-
if !ok {
163-
return fmt.Errorf("critical header label: require int / tstr type, got '%T': %v", label, label)
152+
if !canInt(label) && !canTstr(label) {
153+
return fmt.Errorf("require int / tstr type, got '%T': %v", label, label)
164154
}
165-
if _, ok := h[label]; !ok {
155+
if _, ok := headers[label]; !ok {
166156
return fmt.Errorf("missing critical header: %v", label)
167157
}
168158
}
@@ -179,13 +169,7 @@ func (h UnprotectedHeader) MarshalCBOR() ([]byte, error) {
179169
if len(h) == 0 {
180170
return []byte{0xa0}, nil
181171
}
182-
if err := validateHeaderLabel(h); err != nil {
183-
return nil, err
184-
}
185-
if err := ensureNoCritical(h); err != nil {
186-
return nil, fmt.Errorf("unprotected header: %w", err)
187-
}
188-
if err := ensureHeaderIV(h); err != nil {
172+
if err := validateHeaderParameters(h, false); err != nil {
189173
return nil, fmt.Errorf("unprotected header: %w", err)
190174
}
191175
return encMode.Marshal(map[interface{}]interface{}(h))
@@ -214,10 +198,7 @@ func (h *UnprotectedHeader) UnmarshalCBOR(data []byte) error {
214198
if err := decMode.Unmarshal(data, &header); err != nil {
215199
return err
216200
}
217-
if err := ensureNoCritical(header); err != nil {
218-
return fmt.Errorf("unprotected header: %w", err)
219-
}
220-
if err := ensureHeaderIV(header); err != nil {
201+
if err := validateHeaderParameters(header, false); err != nil {
221202
return fmt.Errorf("unprotected header: %w", err)
222203
}
223204
*h = header
@@ -397,48 +378,108 @@ func hasLabel(h map[interface{}]interface{}, label interface{}) bool {
397378
return ok
398379
}
399380

400-
// ensureHeaderIV ensures IV and Partial IV are not both present in the header.
401-
//
402-
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
403-
func ensureHeaderIV(h map[interface{}]interface{}) error {
404-
if hasLabel(h, HeaderLabelIV) && hasLabel(h, HeaderLabelPartialIV) {
405-
return errors.New("IV and PartialIV parameters must not both be present")
406-
}
407-
return nil
408-
}
409-
410-
// ensureNoCritical ensures crit parameter is not present in the header.
411-
//
412-
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
413-
func ensureNoCritical(h map[interface{}]interface{}) error {
414-
if hasLabel(h, HeaderLabelCritical) {
415-
return errors.New("unexpected crit parameter found")
416-
}
417-
return nil
418-
}
419-
420-
// validateHeaderLabel validates if all header labels are integers or strings.
421-
//
422-
// label = int / tstr
423-
//
424-
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4
425-
func validateHeaderLabel(h map[interface{}]interface{}) error {
426-
existing := make(map[interface{}]struct{})
427-
for label := range h {
428-
var ok bool
429-
label, ok = normalizeLabel(label)
381+
// validateHeaderParameters validates all headers conform to the spec.
382+
func validateHeaderParameters(h map[interface{}]interface{}, protected bool) error {
383+
existing := make(map[interface{}]struct{}, len(h))
384+
for label, value := range h {
385+
// Validate that all header labels are integers or strings.
386+
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4
387+
label, ok := normalizeLabel(label)
430388
if !ok {
431-
return errors.New("cbor: header label: require int / tstr type")
389+
return errors.New("header label: require int / tstr type")
432390
}
391+
392+
// Validate that there are no duplicated labels.
393+
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3
433394
if _, ok := existing[label]; ok {
434-
return fmt.Errorf("cbor: header label: duplicated label: %v", label)
395+
return fmt.Errorf("header label: duplicated label: %v", label)
435396
} else {
436397
existing[label] = struct{}{}
437398
}
399+
400+
// Validate the generic parameters.
401+
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
402+
switch label {
403+
case HeaderLabelAlgorithm:
404+
_, is_alg := value.(Algorithm)
405+
if !is_alg && !canInt(value) && !canTstr(value) {
406+
return errors.New("header parameter: alg: require int / tstr type")
407+
}
408+
case HeaderLabelCritical:
409+
if !protected {
410+
return errors.New("header parameter: crit: not allowed")
411+
}
412+
if err := ensureCritical(value, h); err != nil {
413+
return fmt.Errorf("header parameter: crit: %w", err)
414+
}
415+
case HeaderLabelContentType:
416+
if !canTstr(value) && !canUint(value) {
417+
return errors.New("header parameter: content type: require tstr / uint type")
418+
}
419+
case HeaderLabelKeyID:
420+
if !canBstr(value) {
421+
return errors.New("header parameter: kid: require bstr type")
422+
}
423+
case HeaderLabelIV:
424+
if !canBstr(value) {
425+
return errors.New("header parameter: IV: require bstr type")
426+
}
427+
if hasLabel(h, HeaderLabelPartialIV) {
428+
return errors.New("header parameter: IV and PartialIV: parameters must not both be present")
429+
}
430+
case HeaderLabelPartialIV:
431+
if !canBstr(value) {
432+
return errors.New("header parameter: Partial IV: require bstr type")
433+
}
434+
if hasLabel(h, HeaderLabelIV) {
435+
return errors.New("header parameter: IV and PartialIV: parameters must not both be present")
436+
}
437+
}
438438
}
439439
return nil
440440
}
441441

442+
// canUint reports whether v can be used as a CBOR uint type.
443+
func canUint(v interface{}) bool {
444+
switch v := v.(type) {
445+
case uint, uint8, uint16, uint32, uint64:
446+
return true
447+
case int:
448+
return v >= 0
449+
case int8:
450+
return v >= 0
451+
case int16:
452+
return v >= 0
453+
case int32:
454+
return v >= 0
455+
case int64:
456+
return v >= 0
457+
}
458+
return false
459+
}
460+
461+
// canInt reports whether v can be used as a CBOR int type.
462+
func canInt(v interface{}) bool {
463+
switch v.(type) {
464+
case int, int8, int16, int32, int64,
465+
uint, uint8, uint16, uint32, uint64:
466+
return true
467+
}
468+
return false
469+
}
470+
471+
// canTstr reports whether v can be used as a CBOR tstr type.
472+
func canTstr(v interface{}) bool {
473+
_, ok := v.(string)
474+
return ok
475+
}
476+
477+
// canBstr reports whether v can be used as a CBOR bstr type.
478+
func canBstr(v interface{}) bool {
479+
_, ok := v.([]byte)
480+
return ok
481+
}
482+
442483
// normalizeLabel tries to cast label into a int64 or a string.
443484
// Returns (nil, false) if the label type is not valid.
444485
func normalizeLabel(label interface{}) (interface{}, bool) {

0 commit comments

Comments
 (0)