@@ -38,14 +38,8 @@ func (h ProtectedHeader) MarshalCBOR() ([]byte, error) {
3838 if len (h ) == 0 {
3939 encoded = []byte {}
4040 } else {
41- err := validateHeaderLabel ( h )
41+ err := validateHeaderParameters ( h , true )
4242 if err != nil {
43- return nil , err
44- }
45- if err = h .ensureCritical (); err != nil {
46- return nil , err
47- }
48- if err = ensureHeaderIV (h ); err != nil {
4943 return nil , fmt .Errorf ("protected header: %w" , err )
5044 }
5145 encoded , err = encMode .Marshal (map [interface {}]interface {}(h ))
@@ -85,10 +79,7 @@ func (h *ProtectedHeader) UnmarshalCBOR(data []byte) error {
8579 return err
8680 }
8781 candidate := ProtectedHeader (header )
88- if err := candidate .ensureCritical (); err != nil {
89- return err
90- }
91- if err := ensureHeaderIV (candidate ); err != nil {
82+ if err := validateHeaderParameters (candidate , true ); err != nil {
9283 return fmt .Errorf ("protected header: %w" , err )
9384 }
9485
@@ -140,29 +131,28 @@ func (h ProtectedHeader) Critical() ([]interface{}, error) {
140131 if ! ok {
141132 return nil , nil
142133 }
143- criticalLabels , ok := value .([]interface {})
144- if ! ok {
145- return nil , errors .New ("invalid crit header" )
146- }
147- // if present, the array MUST have at least one value in it.
148- if len (criticalLabels ) == 0 {
149- return nil , errors .New ("empty crit header" )
134+ err := ensureCritical (value , h )
135+ if err != nil {
136+ return nil , err
150137 }
151- return criticalLabels , nil
138+ return value .([] interface {}) , nil
152139}
153140
154141// ensureCritical ensures all critical headers are present in the protected bucket.
155- func (h ProtectedHeader ) ensureCritical () error {
156- labels , err := h .Critical ()
157- if err != nil {
158- return err
142+ func ensureCritical (value interface {}, headers map [interface {}]interface {}) error {
143+ labels , ok := value .([]interface {})
144+ if ! ok {
145+ return errors .New ("invalid crit header" )
146+ }
147+ // if present, the array MUST have at least one value in it.
148+ if len (labels ) == 0 {
149+ return errors .New ("empty crit header" )
159150 }
160151 for _ , label := range labels {
161- _ , ok := normalizeLabel (label )
162- if ! ok {
163- return fmt .Errorf ("critical header label: require int / tstr type, got '%T': %v" , label , label )
152+ if ! canInt (label ) && ! canTstr (label ) {
153+ return fmt .Errorf ("require int / tstr type, got '%T': %v" , label , label )
164154 }
165- if _ , ok := h [label ]; ! ok {
155+ if _ , ok := headers [label ]; ! ok {
166156 return fmt .Errorf ("missing critical header: %v" , label )
167157 }
168158 }
@@ -179,13 +169,7 @@ func (h UnprotectedHeader) MarshalCBOR() ([]byte, error) {
179169 if len (h ) == 0 {
180170 return []byte {0xa0 }, nil
181171 }
182- if err := validateHeaderLabel (h ); err != nil {
183- return nil , err
184- }
185- if err := ensureNoCritical (h ); err != nil {
186- return nil , fmt .Errorf ("unprotected header: %w" , err )
187- }
188- if err := ensureHeaderIV (h ); err != nil {
172+ if err := validateHeaderParameters (h , false ); err != nil {
189173 return nil , fmt .Errorf ("unprotected header: %w" , err )
190174 }
191175 return encMode .Marshal (map [interface {}]interface {}(h ))
@@ -214,10 +198,7 @@ func (h *UnprotectedHeader) UnmarshalCBOR(data []byte) error {
214198 if err := decMode .Unmarshal (data , & header ); err != nil {
215199 return err
216200 }
217- if err := ensureNoCritical (header ); err != nil {
218- return fmt .Errorf ("unprotected header: %w" , err )
219- }
220- if err := ensureHeaderIV (header ); err != nil {
201+ if err := validateHeaderParameters (header , false ); err != nil {
221202 return fmt .Errorf ("unprotected header: %w" , err )
222203 }
223204 * h = header
@@ -397,48 +378,108 @@ func hasLabel(h map[interface{}]interface{}, label interface{}) bool {
397378 return ok
398379}
399380
400- // ensureHeaderIV ensures IV and Partial IV are not both present in the header.
401- //
402- // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
403- func ensureHeaderIV (h map [interface {}]interface {}) error {
404- if hasLabel (h , HeaderLabelIV ) && hasLabel (h , HeaderLabelPartialIV ) {
405- return errors .New ("IV and PartialIV parameters must not both be present" )
406- }
407- return nil
408- }
409-
410- // ensureNoCritical ensures crit parameter is not present in the header.
411- //
412- // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
413- func ensureNoCritical (h map [interface {}]interface {}) error {
414- if hasLabel (h , HeaderLabelCritical ) {
415- return errors .New ("unexpected crit parameter found" )
416- }
417- return nil
418- }
419-
420- // validateHeaderLabel validates if all header labels are integers or strings.
421- //
422- // label = int / tstr
423- //
424- // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4
425- func validateHeaderLabel (h map [interface {}]interface {}) error {
426- existing := make (map [interface {}]struct {})
427- for label := range h {
428- var ok bool
429- label , ok = normalizeLabel (label )
381+ // validateHeaderParameters validates all headers conform to the spec.
382+ func validateHeaderParameters (h map [interface {}]interface {}, protected bool ) error {
383+ existing := make (map [interface {}]struct {}, len (h ))
384+ for label , value := range h {
385+ // Validate that all header labels are integers or strings.
386+ // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4
387+ label , ok := normalizeLabel (label )
430388 if ! ok {
431- return errors .New ("cbor: header label: require int / tstr type" )
389+ return errors .New ("header label: require int / tstr type" )
432390 }
391+
392+ // Validate that there are no duplicated labels.
393+ // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3
433394 if _ , ok := existing [label ]; ok {
434- return fmt .Errorf ("cbor: header label: duplicated label: %v" , label )
395+ return fmt .Errorf ("header label: duplicated label: %v" , label )
435396 } else {
436397 existing [label ] = struct {}{}
437398 }
399+
400+ // Validate the generic parameters.
401+ // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
402+ switch label {
403+ case HeaderLabelAlgorithm :
404+ _ , is_alg := value .(Algorithm )
405+ if ! is_alg && ! canInt (value ) && ! canTstr (value ) {
406+ return errors .New ("header parameter: alg: require int / tstr type" )
407+ }
408+ case HeaderLabelCritical :
409+ if ! protected {
410+ return errors .New ("header parameter: crit: not allowed" )
411+ }
412+ if err := ensureCritical (value , h ); err != nil {
413+ return fmt .Errorf ("header parameter: crit: %w" , err )
414+ }
415+ case HeaderLabelContentType :
416+ if ! canTstr (value ) && ! canUint (value ) {
417+ return errors .New ("header parameter: content type: require tstr / uint type" )
418+ }
419+ case HeaderLabelKeyID :
420+ if ! canBstr (value ) {
421+ return errors .New ("header parameter: kid: require bstr type" )
422+ }
423+ case HeaderLabelIV :
424+ if ! canBstr (value ) {
425+ return errors .New ("header parameter: IV: require bstr type" )
426+ }
427+ if hasLabel (h , HeaderLabelPartialIV ) {
428+ return errors .New ("header parameter: IV and PartialIV: parameters must not both be present" )
429+ }
430+ case HeaderLabelPartialIV :
431+ if ! canBstr (value ) {
432+ return errors .New ("header parameter: Partial IV: require bstr type" )
433+ }
434+ if hasLabel (h , HeaderLabelIV ) {
435+ return errors .New ("header parameter: IV and PartialIV: parameters must not both be present" )
436+ }
437+ }
438438 }
439439 return nil
440440}
441441
442+ // canUint reports whether v can be used as a CBOR uint type.
443+ func canUint (v interface {}) bool {
444+ switch v := v .(type ) {
445+ case uint , uint8 , uint16 , uint32 , uint64 :
446+ return true
447+ case int :
448+ return v >= 0
449+ case int8 :
450+ return v >= 0
451+ case int16 :
452+ return v >= 0
453+ case int32 :
454+ return v >= 0
455+ case int64 :
456+ return v >= 0
457+ }
458+ return false
459+ }
460+
461+ // canInt reports whether v can be used as a CBOR int type.
462+ func canInt (v interface {}) bool {
463+ switch v .(type ) {
464+ case int , int8 , int16 , int32 , int64 ,
465+ uint , uint8 , uint16 , uint32 , uint64 :
466+ return true
467+ }
468+ return false
469+ }
470+
471+ // canTstr reports whether v can be used as a CBOR tstr type.
472+ func canTstr (v interface {}) bool {
473+ _ , ok := v .(string )
474+ return ok
475+ }
476+
477+ // canBstr reports whether v can be used as a CBOR bstr type.
478+ func canBstr (v interface {}) bool {
479+ _ , ok := v .([]byte )
480+ return ok
481+ }
482+
442483// normalizeLabel tries to cast label into a int64 or a string.
443484// Returns (nil, false) if the label type is not valid.
444485func normalizeLabel (label interface {}) (interface {}, bool ) {
0 commit comments