diff --git a/.github/workflows/tls_codec.yml b/.github/workflows/tls_codec.yml index f66ea38b2..c2130551e 100644 --- a/.github/workflows/tls_codec.yml +++ b/.github/workflows/tls_codec.yml @@ -36,7 +36,7 @@ jobs: toolchain: ${{ matrix.rust }} targets: ${{ matrix.target }} - uses: RustCrypto/actions/cargo-hack-install@master - - run: cargo hack build --target ${{ matrix.target }} --feature-powerset --exclude-features std,default,derive,serde,arbitrary + - run: cargo hack build --target ${{ matrix.target }} --feature-powerset --exclude-features std,default,serde,arbitrary minimal-versions: uses: RustCrypto/actions/.github/workflows/minimal-versions.yml@master @@ -70,4 +70,4 @@ jobs: - run: ${{ matrix.deps }} - uses: RustCrypto/actions/cargo-hack-install@master - run: cargo hack test --feature-powerset - - run: cargo hack test -p tls_codec_derive --test encode\* --test decode\* + - run: cargo hack test -p tls_codec_derive --feature-powerset --test encode\* --test decode\* diff --git a/tls_codec/Cargo.toml b/tls_codec/Cargo.toml index 8f66ca057..6e30fe3bb 100644 --- a/tls_codec/Cargo.toml +++ b/tls_codec/Cargo.toml @@ -28,10 +28,10 @@ regex = "1.8" [features] default = [ "std" ] arbitrary = [ "std", "dep:arbitrary" ] -derive = [ "std", "tls_codec_derive" ] +derive = [ "tls_codec_derive" ] serde = [ "std", "dep:serde" ] mls = [] # In MLS variable length vectors are limited compared to QUIC. -std = [] +std = [ "tls_codec_derive?/std" ] [[bench]] name = "tls_vec" diff --git a/tls_codec/derive/Cargo.toml b/tls_codec/derive/Cargo.toml index 42771f553..6e3c10396 100644 --- a/tls_codec/derive/Cargo.toml +++ b/tls_codec/derive/Cargo.toml @@ -21,3 +21,7 @@ proc-macro2 = "1.0" [dev-dependencies] tls_codec = { path = "../" } trybuild = "1" + +[features] +default = [ "std" ] +std = [] diff --git a/tls_codec/derive/src/lib.rs b/tls_codec/derive/src/lib.rs index 4dfffa4e0..dbb48b617 100644 --- a/tls_codec/derive/src/lib.rs +++ b/tls_codec/derive/src/lib.rs @@ -698,7 +698,7 @@ fn impl_tls_size(parsed_ast: TlsStruct) -> TokenStream2 { let field_len = match self { #(#field_arms)* }; - std::mem::size_of::<#repr>() + field_len + core::mem::size_of::<#repr>() + field_len } } @@ -740,6 +740,7 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr SerializeVariant::Write => { quote! { impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_serialize(&self, writer: &mut W) -> core::result::Result { let mut written = 0usize; #( @@ -760,6 +761,7 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr } impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_serialize(&self, writer: &mut W) -> core::result::Result { tls_codec::Serialize::tls_serialize(*self, writer) } @@ -850,6 +852,7 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr SerializeVariant::Write => { quote! { impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_serialize(&self, writer: &mut W) -> core::result::Result { #discriminant_constants match self { @@ -859,6 +862,7 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr } impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_serialize(&self, writer: &mut W) -> core::result::Result { tls_codec::Serialize::tls_serialize(*self, writer) } @@ -909,6 +913,7 @@ fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 { let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); quote! { impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_deserialize(bytes: &mut R) -> core::result::Result { Ok(Self { #(#members: #prefixes::tls_deserialize(bytes)?,)* @@ -948,6 +953,7 @@ fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 { quote! { impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause { #[allow(non_upper_case_globals)] + #[cfg(feature = "std")] fn tls_deserialize(bytes: &mut R) -> core::result::Result { #discriminant_constants let discriminant = <#repr as tls_codec::Deserialize>::tls_deserialize(bytes)?; diff --git a/tls_codec/derive/tests/decode.rs b/tls_codec/derive/tests/decode.rs index 7717b8c12..107624e23 100644 --- a/tls_codec/derive/tests/decode.rs +++ b/tls_codec/derive/tests/decode.rs @@ -1,3 +1,4 @@ +#![cfg(feature = "std")] use tls_codec::{ Deserialize, Error, Serialize, Size, TlsSliceU16, TlsVecU16, TlsVecU32, TlsVecU8, VLBytes, }; diff --git a/tls_codec/derive/tests/decode_bytes.rs b/tls_codec/derive/tests/decode_bytes.rs new file mode 100644 index 000000000..4a3267f01 --- /dev/null +++ b/tls_codec/derive/tests/decode_bytes.rs @@ -0,0 +1,322 @@ +use tls_codec::{DeserializeBytes, SerializeBytes, Size}; +use tls_codec_derive::{TlsDeserializeBytes, TlsSerializeBytes, TlsSize}; + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, PartialEq, Debug)] +#[repr(u16)] +pub enum ExtensionType { + Reserved = 0, + Capabilities = 1, + Lifetime = 2, + KeyId = 3, + ParentHash = 4, + RatchetTree = 5, + SomethingElse = 500, +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +pub struct ExtensionStruct { + extension_type: ExtensionType, + extension_data: Vec, + additional_data: Option>, +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +pub struct TupleStruct(ExtensionStruct, u8); + +#[derive(TlsSerializeBytes, TlsSize, Debug, Clone)] +struct SomeValue { + val: Vec, +} + +#[test] +fn simple_enum() { + let serialized = ExtensionType::KeyId.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, ExtensionType::KeyId); + assert_eq!(rest, []); + let serialized = ExtensionType::SomethingElse.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, ExtensionType::SomethingElse); + assert_eq!(rest, []); +} + +#[test] +fn simple_struct() { + let extension = ExtensionStruct { + extension_type: ExtensionType::KeyId, + extension_data: vec![1, 2, 3, 4, 5], + additional_data: None, + }; + let serialized = extension.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, extension); + assert_eq!(rest, []); +} + +#[test] +fn tuple_struct() { + let ext = ExtensionStruct { + extension_type: ExtensionType::KeyId, + extension_data: vec![1, 2, 3, 4, 5], + additional_data: None, + }; + let x = TupleStruct(ext, 6); + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[test] +fn byte_arrays() { + let x = [0u8, 1, 2, 3]; + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = <[u8; 4] as DeserializeBytes>::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +struct Custom { + #[tls_codec(with = "custom")] + values: Vec, + a: u8, +} + +mod custom { + use tls_codec::{DeserializeBytes, SerializeBytes, Size}; + + pub fn tls_serialized_len(v: &[u8]) -> usize { + v.tls_serialized_len() + } + + pub fn tls_serialize(v: &[u8]) -> Result, tls_codec::Error> { + v.tls_serialize() + } + + pub fn tls_deserialize( + bytes: &[u8], + ) -> Result<(T, &[u8]), tls_codec::Error> { + ::tls_deserialize(bytes) + } +} + +#[test] +fn custom() { + let x = Custom { + values: vec![0, 1, 2], + a: 3, + }; + let serialized = x.tls_serialize().unwrap(); + assert_eq!(vec![3, 0, 1, 2, 3], serialized); + let (deserialized, rest) = ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[repr(u8)] +enum EnumWithTupleVariant { + A(u8, u32), +} + +#[test] +fn enum_with_tuple_variant() { + let x = EnumWithTupleVariant::A(3, 4); + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[repr(u8)] +enum EnumWithStructVariant { + A { foo: u8, bar: u32 }, +} + +#[test] +fn enum_with_struct_variant() { + let x = EnumWithStructVariant::A { foo: 3, bar: 4 }; + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[repr(u16)] +enum EnumWithDataAndDiscriminant { + #[tls_codec(discriminant = 3)] + A(u8), + B, +} + +#[test] +fn enum_with_data_and_discriminant() { + let x = EnumWithDataAndDiscriminant::A(4); + let serialized = x.tls_serialize().unwrap(); + + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[test] +fn discriminant_is_incremented_implicitly() { + let x = EnumWithDataAndDiscriminant::B; + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +mod discriminant { + pub mod test { + pub mod constant { + pub const TEST_CONST: u8 = 3; + } + pub mod enum_val { + pub enum Test { + Potato = 0x0004, + } + } + } +} + +#[derive(Debug, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] +#[repr(u16)] +enum EnumWithDataAndConstDiscriminant { + #[tls_codec(discriminant = "discriminant::test::constant::TEST_CONST")] + A(u8), + #[tls_codec(discriminant = "discriminant::test::enum_val::Test::Potato")] + B, + #[tls_codec(discriminant = 12)] + C, +} + +#[test] +fn enum_with_data_and_const_discriminant() { + let x = EnumWithDataAndConstDiscriminant::A(4); + let serialized = x.tls_serialize().unwrap(); + assert_eq!(vec![0, 3, 4], serialized); + let (deserialized, rest) = + ::tls_deserialize(&serialized) + .unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); + + let x = EnumWithDataAndConstDiscriminant::B; + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized) + .unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); + + let x = EnumWithDataAndConstDiscriminant::C; + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized) + .unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[repr(u8)] +enum EnumWithCustomSerializedField { + A(#[tls_codec(with = "custom")] Vec), +} + +#[test] +fn enum_with_custom_serialized_field() { + let x = EnumWithCustomSerializedField::A(vec![1, 2, 3]); + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[test] +fn that_skip_attribute_on_struct_works() { + fn test(test: T, expected: T) + where + T: std::fmt::Debug + PartialEq + SerializeBytes + Size, + { + let serialized = test.tls_serialize().unwrap(); + let (deserialized, rest) = ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, expected); + assert_eq!(rest, []); + } + + #[derive(Debug, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] + struct StructWithSkip1 { + #[tls_codec(skip)] + a: u8, + b: u8, + c: u8, + } + + #[derive(Debug, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] + struct StructWithSkip2 { + a: u8, + #[tls_codec(skip)] + b: u8, + c: u8, + } + + #[derive(Debug, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] + struct StructWithSkip3 { + a: u8, + b: u8, + #[tls_codec(skip)] + c: u8, + } + + test( + StructWithSkip1 { + a: 123, + b: 13, + c: 42, + }, + StructWithSkip1 { + a: Default::default(), + b: 13, + c: 42, + }, + ); + test( + StructWithSkip2 { + a: 123, + b: 13, + c: 42, + }, + StructWithSkip2 { + a: 123, + b: Default::default(), + c: 42, + }, + ); + test( + StructWithSkip3 { + a: 123, + b: 13, + c: 42, + }, + StructWithSkip3 { + a: 123, + b: 13, + c: Default::default(), + }, + ); +} diff --git a/tls_codec/derive/tests/encode.rs b/tls_codec/derive/tests/encode.rs index e4047678e..84b760c91 100644 --- a/tls_codec/derive/tests/encode.rs +++ b/tls_codec/derive/tests/encode.rs @@ -1,3 +1,4 @@ +#![cfg(feature = "std")] use tls_codec::{SecretTlsVecU16, Serialize, Size, TlsSliceU16, TlsVecU16, TlsVecU32}; use tls_codec_derive::{TlsSerialize, TlsSize};