diff --git a/der/src/header.rs b/der/src/header.rs index 58cbe40a7..22a6df053 100644 --- a/der/src/header.rs +++ b/der/src/header.rs @@ -1,6 +1,8 @@ //! ASN.1 DER headers. -use crate::{Decode, DerOrd, Encode, Error, ErrorKind, Length, Reader, Result, Tag, Writer}; +use crate::{ + Decode, DerOrd, Encode, EncodingRules, Error, ErrorKind, Length, Reader, Result, Tag, Writer, +}; use core::cmp::Ordering; /// ASN.1 DER headers: tag + length component of TLV-encoded values @@ -34,6 +36,7 @@ impl<'a> Decode<'a> for Header { type Error = Error; fn decode>(reader: &mut R) -> Result
{ + let is_constructed = Tag::peek_is_constructed(reader)?; let tag = Tag::decode(reader)?; let length = Length::decode(reader).map_err(|e| { @@ -44,6 +47,11 @@ impl<'a> Decode<'a> for Header { } })?; + if length.is_indefinite() && !is_constructed { + debug_assert_eq!(reader.encoding_rules(), EncodingRules::Ber); + return Err(reader.error(ErrorKind::IndefiniteLength)); + } + Ok(Self { tag, length }) } } diff --git a/der/src/length.rs b/der/src/length.rs index a08693b42..a6627e33f 100644 --- a/der/src/length.rs +++ b/der/src/length.rs @@ -18,18 +18,27 @@ use core::{ const INDEFINITE_LENGTH_OCTET: u8 = 0b10000000; // 0x80 /// ASN.1-encoded length. -#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)] -pub struct Length(u32); +#[derive(Copy, Clone, Default, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub struct Length { + /// Inner length as a `u32`. Note that the decoder and encoder also support a maximum length + /// of 32-bits. + inner: u32, + + /// Flag bit which specifies whether the length was indeterminate when decoding ASN.1 BER. + /// + /// This should always be false when working with DER. + indefinite: bool, +} impl Length { /// Length of `0` - pub const ZERO: Self = Self(0); + pub const ZERO: Self = Self::new(0); /// Length of `1` - pub const ONE: Self = Self(1); + pub const ONE: Self = Self::new(1); /// Maximum length (`u32::MAX`). - pub const MAX: Self = Self(u32::MAX); + pub const MAX: Self = Self::new(u32::MAX); /// Maximum number of octets in a DER encoding of a [`Length`] using the /// rules implemented by this crate. @@ -38,8 +47,11 @@ impl Length { /// Create a new [`Length`] for any value which fits inside of a [`u16`]. /// /// This function is const-safe and therefore useful for [`Length`] constants. - pub const fn new(value: u16) -> Self { - Self(value as u32) + pub const fn new(value: u32) -> Self { + Self { + inner: value, + indefinite: false, + } } /// Create a new [`Length`] for any value which fits inside the length type. @@ -50,14 +62,18 @@ impl Length { if len > (u32::MAX as usize) { Err(Error::from_kind(ErrorKind::Overflow)) } else { - Ok(Length(len as u32)) + Ok(Self::new(len as u32)) } } /// Is this length equal to zero? pub const fn is_zero(self) -> bool { - let value = self.0; - value == 0 + self.inner == 0 + } + + /// Was this length decoded from an indefinite length when decoding BER? + pub(crate) const fn is_indefinite(self) -> bool { + self.indefinite } /// Get the length of DER Tag-Length-Value (TLV) encoded data if `self` @@ -68,12 +84,12 @@ impl Length { /// Perform saturating addition of two lengths. pub fn saturating_add(self, rhs: Self) -> Self { - Self(self.0.saturating_add(rhs.0)) + Self::new(self.inner.saturating_add(rhs.inner)) } /// Perform saturating subtraction of two lengths. pub fn saturating_sub(self, rhs: Self) -> Self { - Self(self.0.saturating_sub(rhs.0)) + Self::new(self.inner.saturating_sub(rhs.inner)) } /// Get initial octet of the encoded length (if one is required). @@ -89,7 +105,7 @@ impl Length { /// > most significant bit; /// > c) the value 11111111₂ shall not be used. fn initial_octet(self) -> Option { - match self.0 { + match self.inner { 0x80..=0xFF => Some(0x81), 0x100..=0xFFFF => Some(0x82), 0x10000..=0xFFFFFF => Some(0x83), @@ -103,10 +119,10 @@ impl Add for Length { type Output = Result; fn add(self, other: Self) -> Result { - self.0 - .checked_add(other.0) + self.inner + .checked_add(other.inner) .ok_or_else(|| ErrorKind::Overflow.into()) - .map(Self) + .map(Self::new) } } @@ -154,10 +170,10 @@ impl Sub for Length { type Output = Result; fn sub(self, other: Length) -> Result { - self.0 - .checked_sub(other.0) + self.inner + .checked_sub(other.inner) .ok_or_else(|| ErrorKind::Overflow.into()) - .map(Self) + .map(Self::new) } } @@ -171,25 +187,25 @@ impl Sub for Result { impl From for Length { fn from(len: u8) -> Length { - Length(len.into()) + Length::new(len.into()) } } impl From for Length { fn from(len: u16) -> Length { - Length(len.into()) + Length::new(len.into()) } } impl From for Length { fn from(len: u32) -> Length { - Length(len) + Length::new(len) } } impl From for u32 { fn from(length: Length) -> u32 { - length.0 + length.inner } } @@ -205,7 +221,7 @@ impl TryFrom for usize { type Error = Error; fn try_from(len: Length) -> Result { - len.0.try_into().map_err(|_| ErrorKind::Overflow.into()) + len.inner.try_into().map_err(|_| ErrorKind::Overflow.into()) } } @@ -259,12 +275,12 @@ impl<'a> Decode<'a> for Length { impl Encode for Length { fn encoded_len(&self) -> Result { - match self.0 { - 0..=0x7F => Ok(Length(1)), - 0x80..=0xFF => Ok(Length(2)), - 0x100..=0xFFFF => Ok(Length(3)), - 0x10000..=0xFFFFFF => Ok(Length(4)), - 0x1000000..=0xFFFFFFFF => Ok(Length(5)), + match self.inner { + 0..=0x7F => Ok(Length::new(1)), + 0x80..=0xFF => Ok(Length::new(2)), + 0x100..=0xFFFF => Ok(Length::new(3)), + 0x10000..=0xFFFFFF => Ok(Length::new(4)), + 0x1000000..=0xFFFFFFFF => Ok(Length::new(5)), } } @@ -274,7 +290,7 @@ impl Encode for Length { writer.write_byte(tag_byte)?; // Strip leading zeroes - match self.0.to_be_bytes() { + match self.inner.to_be_bytes() { [0, 0, 0, byte] => writer.write_byte(byte), [0, 0, bytes @ ..] => writer.write(&bytes), [0, bytes @ ..] => writer.write(&bytes), @@ -282,7 +298,7 @@ impl Encode for Length { } } #[allow(clippy::cast_possible_truncation)] - None => writer.write_byte(self.0 as u8), + None => writer.write_byte(self.inner as u8), } } } @@ -302,9 +318,19 @@ impl DerOrd for Length { } } +impl fmt::Debug for Length { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.indefinite { + write!(f, "Length({self} [indefinite])") + } else { + f.debug_tuple("Length").field(&self.inner).finish() + } + } +} + impl fmt::Display for Length { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) + self.inner.fmt(f) } } @@ -313,7 +339,7 @@ impl fmt::Display for Length { #[cfg(feature = "arbitrary")] impl<'a> arbitrary::Arbitrary<'a> for Length { fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - Ok(Self(u.arbitrary()?)) + Ok(Self::new(u.arbitrary()?)) } fn size_hint(depth: usize) -> (usize, Option) { @@ -362,11 +388,15 @@ fn decode_indefinite_length<'a, R: Reader<'a>>(reader: &mut R) -> Result // Read the length byte and ensure it's zero (i.e. the full EOC is `00 00`) let length_byte = reader.read_byte()?; - if length_byte == 0 { - return current_pos - start_pos; - } else { + + if length_byte != 0 { return Err(reader.error(ErrorKind::IndefiniteLength)); } + + // Compute how much we read and flag the decoded length as indefinite + let mut ret = (current_pos - start_pos)?; + ret.indefinite = true; + return Ok(ret); } let header = Header::decode(reader)?; @@ -492,6 +522,9 @@ mod tests { 27 F0 F0 00 00 00 00" ); + // Ensure the indefinite bit isn't set when decoding DER + assert!(!Length::from_der(&[0x00]).unwrap().indefinite); + let mut reader = SliceReader::new_with_encoding_rules(&EXAMPLE_BER, EncodingRules::Ber).unwrap(); @@ -501,6 +534,7 @@ mod tests { // Decode indefinite length let length = Length::decode(&mut reader).unwrap(); + assert!(length.indefinite); // Decoding the length should leave the position at the end of the indefinite length octet let pos = usize::try_from(reader.position()).unwrap(); @@ -530,6 +564,7 @@ mod tests { // Parse the inner indefinite length let length = Length::decode(&mut reader).unwrap(); + assert!(length.indefinite); assert_eq!(usize::try_from(length).unwrap(), 18); } } diff --git a/der/src/tag.rs b/der/src/tag.rs index b9bb60a33..351105296 100644 --- a/der/src/tag.rs +++ b/der/src/tag.rs @@ -171,6 +171,12 @@ impl Tag { Self::decode(&mut reader.clone()) } + /// Peek at whether the next byte in the reader has the constructed bit set. + pub(crate) fn peek_is_constructed<'a>(reader: &impl Reader<'a>) -> Result { + let octet = reader.clone().read_byte()?; + Ok(octet & CONSTRUCTED_FLAG != 0) + } + /// Returns true if given context-specific (or any given class) tag number matches the peeked tag. pub(crate) fn peek_matches<'a, R: Reader<'a>>( reader: &mut R, @@ -400,7 +406,7 @@ impl Encode for Tag { let length = if number <= 30 { Length::ONE } else { - Length::new(number.ilog2() as u16 / 7 + 2) + Length::new(number.ilog2() / 7 + 2) }; Ok(length) diff --git a/der_derive/src/bitstring.rs b/der_derive/src/bitstring.rs index f48835cca..f7c97d884 100644 --- a/der_derive/src/bitstring.rs +++ b/der_derive/src/bitstring.rs @@ -165,7 +165,7 @@ impl DeriveBitString { impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause { fn value_len(&self) -> der::Result { - Ok(der::Length::new(#max_expected_bytes + 1)) + Ok(der::Length::new(#max_expected_bytes as u32 + 1)) } fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> { diff --git a/pkcs5/src/pbes2/kdf/salt.rs b/pkcs5/src/pbes2/kdf/salt.rs index d29204e3e..94e975dff 100644 --- a/pkcs5/src/pbes2/kdf/salt.rs +++ b/pkcs5/src/pbes2/kdf/salt.rs @@ -33,7 +33,7 @@ impl Salt { Ok(Self { inner, - length: Length::new(slice.len() as u16), + length: Length::new(slice.len() as u32), }) }