From f80288b08cf4189bf272f606e70fe53846b7e7db Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Sun, 2 Jul 2023 20:28:51 +0530 Subject: [PATCH 01/14] support no_std for tls_codec_derive --- tls_codec/Cargo.toml | 2 +- tls_codec/derive/Cargo.toml | 4 ++++ tls_codec/derive/src/lib.rs | 8 +++++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tls_codec/Cargo.toml b/tls_codec/Cargo.toml index 8f66ca057..bce671748 100644 --- a/tls_codec/Cargo.toml +++ b/tls_codec/Cargo.toml @@ -31,7 +31,7 @@ arbitrary = [ "std", "dep:arbitrary" ] derive = [ "std", "tls_codec_derive" ] serde = [ "std", "dep:serde" ] mls = [] # In MLS variable length vectors are limited compared to QUIC. -std = [] +std = [ "tls_codec_derive?/std" ] [[bench]] name = "tls_vec" diff --git a/tls_codec/derive/Cargo.toml b/tls_codec/derive/Cargo.toml index 42771f553..de63660ed 100644 --- a/tls_codec/derive/Cargo.toml +++ b/tls_codec/derive/Cargo.toml @@ -21,3 +21,7 @@ proc-macro2 = "1.0" [dev-dependencies] tls_codec = { path = "../" } trybuild = "1" + +[features] +default = [ "std" ] +std = [] \ No newline at end of file diff --git a/tls_codec/derive/src/lib.rs b/tls_codec/derive/src/lib.rs index 4dfffa4e0..dbb48b617 100644 --- a/tls_codec/derive/src/lib.rs +++ b/tls_codec/derive/src/lib.rs @@ -698,7 +698,7 @@ fn impl_tls_size(parsed_ast: TlsStruct) -> TokenStream2 { let field_len = match self { #(#field_arms)* }; - std::mem::size_of::<#repr>() + field_len + core::mem::size_of::<#repr>() + field_len } } @@ -740,6 +740,7 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr SerializeVariant::Write => { quote! { impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_serialize(&self, writer: &mut W) -> core::result::Result { let mut written = 0usize; #( @@ -760,6 +761,7 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr } impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_serialize(&self, writer: &mut W) -> core::result::Result { tls_codec::Serialize::tls_serialize(*self, writer) } @@ -850,6 +852,7 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr SerializeVariant::Write => { quote! { impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_serialize(&self, writer: &mut W) -> core::result::Result { #discriminant_constants match self { @@ -859,6 +862,7 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr } impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_serialize(&self, writer: &mut W) -> core::result::Result { tls_codec::Serialize::tls_serialize(*self, writer) } @@ -909,6 +913,7 @@ fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 { let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); quote! { impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause { + #[cfg(feature = "std")] fn tls_deserialize(bytes: &mut R) -> core::result::Result { Ok(Self { #(#members: #prefixes::tls_deserialize(bytes)?,)* @@ -948,6 +953,7 @@ fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 { quote! { impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause { #[allow(non_upper_case_globals)] + #[cfg(feature = "std")] fn tls_deserialize(bytes: &mut R) -> core::result::Result { #discriminant_constants let discriminant = <#repr as tls_codec::Deserialize>::tls_deserialize(bytes)?; From 1c70c7c018ba6562cd6fc602fe63012168f6c19b Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Sun, 2 Jul 2023 21:11:50 +0530 Subject: [PATCH 02/14] remove std feature from derive feature --- tls_codec/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tls_codec/Cargo.toml b/tls_codec/Cargo.toml index bce671748..6e30fe3bb 100644 --- a/tls_codec/Cargo.toml +++ b/tls_codec/Cargo.toml @@ -28,7 +28,7 @@ regex = "1.8" [features] default = [ "std" ] arbitrary = [ "std", "dep:arbitrary" ] -derive = [ "std", "tls_codec_derive" ] +derive = [ "tls_codec_derive" ] serde = [ "std", "dep:serde" ] mls = [] # In MLS variable length vectors are limited compared to QUIC. std = [ "tls_codec_derive?/std" ] From 640fa2ae8037d577f295904ee6a60514a329a54a Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Tue, 4 Jul 2023 17:08:34 +0530 Subject: [PATCH 03/14] fix tests --- tls_codec/derive/tests/decode.rs | 143 ++++++++++++++++++++++--------- tls_codec/derive/tests/encode.rs | 71 ++++++++++++--- 2 files changed, 158 insertions(+), 56 deletions(-) diff --git a/tls_codec/derive/tests/decode.rs b/tls_codec/derive/tests/decode.rs index 7717b8c12..f26310bc0 100644 --- a/tls_codec/derive/tests/decode.rs +++ b/tls_codec/derive/tests/decode.rs @@ -1,11 +1,14 @@ -use tls_codec::{ - Deserialize, Error, Serialize, Size, TlsSliceU16, TlsVecU16, TlsVecU32, TlsVecU8, VLBytes, -}; -use tls_codec_derive::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize}; - -#[derive( - TlsDeserialize, TlsDeserializeBytes, Debug, PartialEq, Clone, Copy, TlsSize, TlsSerialize, -)] +use tls_codec::{TlsVecU16, TlsVecU32, TlsVecU8, VLBytes}; +use tls_codec_derive::{TlsDeserializeBytes, TlsSize}; + +#[cfg(feature = "std")] +use tls_codec::{Deserialize, Error, Serialize, Size, TlsSliceU16}; + +#[cfg(feature = "std")] +use tls_codec_derive::{TlsDeserialize, TlsSerialize}; + +#[derive(TlsDeserializeBytes, Debug, PartialEq, Clone, Copy, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] #[repr(u16)] pub enum ExtensionType { Reserved = 0, @@ -23,31 +26,35 @@ impl Default for ExtensionType { } } -#[derive( - TlsDeserialize, TlsDeserializeBytes, Debug, PartialEq, TlsSerialize, TlsSize, Clone, Default, -)] +#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize, Clone, Default)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] pub struct ExtensionStruct { extension_type: ExtensionType, extension_data: TlsVecU32, } -#[derive(TlsDeserialize, TlsDeserializeBytes, Debug, PartialEq, TlsSize, TlsSerialize)] +#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] pub struct ExtensionTypeVec { data: TlsVecU8, } -#[derive(TlsDeserialize, TlsDeserializeBytes, Debug, PartialEq, TlsSize, TlsSerialize)] +#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] pub struct ArrayWrap { data: [u8; 8], } -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] pub struct TupleStruct1(ExtensionStruct); -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] pub struct TupleStruct(ExtensionStruct, u8); #[test] +#[cfg(feature = "std")] fn tuple_struct() { let ext = ExtensionStruct { extension_type: ExtensionType::KeyId, @@ -89,6 +96,7 @@ fn tuple_struct() { } #[test] +#[cfg(feature = "std")] fn simple_enum() { let b = &[0u8, 5] as &[u8]; let mut b_reader = b; @@ -116,6 +124,7 @@ fn simple_enum() { } #[test] +#[cfg(feature = "std")] fn deserialize_tls_vec() { let long_vector = vec![ExtensionStruct::default(); 3000]; let serialized_long_vec = TlsSliceU16(&long_vector).tls_serialize_detached().unwrap(); @@ -138,6 +147,7 @@ fn deserialize_tls_vec() { } #[test] +#[cfg(feature = "std")] fn byte_arrays() { let x = [0u8, 1, 2, 3]; let serialized = x.tls_serialize_detached().unwrap(); @@ -156,6 +166,7 @@ fn byte_arrays() { } #[test] +#[cfg(feature = "std")] fn simple_struct() { let mut b = &[0u8, 3, 0, 0, 0, 5, 1, 2, 3, 4, 5] as &[u8]; let extension = ExtensionStruct { @@ -178,50 +189,60 @@ fn simple_struct() { assert_eq!(extension, deserialized); } -#[derive(TlsDeserialize, TlsDeserializeBytes, Clone, TlsSize, PartialEq)] +#[derive(TlsDeserializeBytes, Clone, TlsSize, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsDeserialize))] struct DeserializeOnlyStruct(u16); // KAT from MLS -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] #[repr(u8)] enum ProtocolVersion { Reserved = 0, Mls10 = 1, } -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct CipherSuite(u16); -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct HPKEPublicKey(TlsVecU16); -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct CredentialType(u16); -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct SignatureScheme(u16); -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct BasicCredential { identity: TlsVecU16, signature_scheme: SignatureScheme, signature_key: TlsVecU16, } -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct Credential { credential_type: CredentialType, credential: BasicCredential, } -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct Extension { extension_type: ExtensionType, extension_data: TlsVecU32, } -#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct KeyPackage { version: ProtocolVersion, cipher_suite: CipherSuite, @@ -232,6 +253,7 @@ struct KeyPackage { } #[test] +#[cfg(feature = "std")] fn kat_mls_key_package() { let key_package_bytes = &[ 0x01u8, 0x00, 0x01, 0x00, 0x20, 0xF2, 0xBC, 0xD8, 0x95, 0x19, 0xDD, 0x1D, 0x06, 0x9F, 0x8B, @@ -260,7 +282,8 @@ fn kat_mls_key_package() { ); } -#[derive(Debug, PartialEq, TlsDeserialize, TlsSerialize, TlsSize)] +#[derive(Debug, PartialEq, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct Custom { #[tls_codec(with = "custom")] values: Vec, @@ -268,23 +291,30 @@ struct Custom { } mod custom { + use tls_codec::{Size, TlsByteSliceU32}; + + #[cfg(feature = "std")] use std::io::{Read, Write}; - use tls_codec::{Deserialize, Serialize, Size, TlsByteSliceU32, TlsByteVecU32}; + #[cfg(feature = "std")] + use tls_codec::{Deserialize, Serialize, TlsByteVecU32}; pub fn tls_serialized_len(v: &[u8]) -> usize { TlsByteSliceU32(v).tls_serialized_len() } + #[cfg(feature = "std")] pub fn tls_serialize(v: &[u8], writer: &mut W) -> Result { TlsByteSliceU32(v).tls_serialize(writer) } + #[cfg(feature = "std")] pub fn tls_deserialize(bytes: &mut R) -> Result, tls_codec::Error> { Ok(TlsByteVecU32::tls_deserialize(bytes)?.into_vec()) } } -#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSerialize, TlsSize)] +#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] struct CustomBytes { #[tls_codec(with = "custom_bytes")] values: Vec, @@ -292,13 +322,18 @@ struct CustomBytes { } mod custom_bytes { + use tls_codec::{DeserializeBytes, Size, TlsByteSliceU32, TlsByteVecU32}; + + #[cfg(feature = "std")] use std::io::Write; - use tls_codec::{DeserializeBytes, Serialize, Size, TlsByteSliceU32, TlsByteVecU32}; + #[cfg(feature = "std")] + use tls_codec::Serialize; pub fn tls_serialized_len(v: &[u8]) -> usize { TlsByteSliceU32(v).tls_serialized_len() } + #[cfg(feature = "std")] pub fn tls_serialize(v: &[u8], writer: &mut W) -> Result { TlsByteSliceU32(v).tls_serialize(writer) } @@ -307,9 +342,11 @@ mod custom_bytes { let (vec, remainder) = TlsByteVecU32::tls_deserialize(bytes)?; Ok((vec.into_vec(), remainder)) } + // pub fn tls_deserialize(bytes: &mut R) -> Result<(Vec, &R), tls_codec::Error> {} } #[test] +#[cfg(feature = "std")] fn custom() { let x = Custom { values: vec![0, 1, 2], @@ -320,13 +357,15 @@ fn custom() { assert_eq!(x, deserialized); } -#[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)] +#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] #[repr(u8)] enum EnumWithTupleVariant { A(u8, u32), } #[test] +#[cfg(feature = "std")] fn enum_with_tuple_variant() { let x = EnumWithTupleVariant::A(3, 4); let serialized = x.tls_serialize_detached().unwrap(); @@ -334,13 +373,15 @@ fn enum_with_tuple_variant() { assert_eq!(deserialized, x); } -#[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)] +#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] #[repr(u8)] enum EnumWithStructVariant { A { foo: u8, bar: u32 }, } #[test] +#[cfg(feature = "std")] fn enum_with_struct_variant() { let x = EnumWithStructVariant::A { foo: 3, bar: 4 }; let serialized = x.tls_serialize_detached().unwrap(); @@ -348,7 +389,8 @@ fn enum_with_struct_variant() { assert_eq!(deserialized, x); } -#[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)] +#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] #[repr(u16)] enum EnumWithDataAndDiscriminant { #[tls_codec(discriminant = 3)] @@ -357,6 +399,7 @@ enum EnumWithDataAndDiscriminant { } #[test] +#[cfg(feature = "std")] fn enum_with_data_and_discriminant() { for x in [ EnumWithDataAndDiscriminant::A(4), @@ -382,7 +425,8 @@ mod discriminant { } } -#[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)] +#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] #[repr(u16)] enum EnumWithDataAndConstDiscriminant { #[tls_codec(discriminant = "discriminant::test::constant::TEST_CONST")] @@ -394,6 +438,7 @@ enum EnumWithDataAndConstDiscriminant { } #[test] +#[cfg(feature = "std")] fn enum_with_data_and_const_discriminant() { for x in [ EnumWithDataAndConstDiscriminant::A(4), @@ -407,13 +452,15 @@ fn enum_with_data_and_const_discriminant() { } } -#[derive(Debug, PartialEq, TlsDeserialize, TlsSerialize, TlsSize)] +#[cfg(feature = "std")] +#[derive(Debug, PartialEq, TlsSize, TlsSerialize, TlsDeserialize)] #[repr(u8)] enum EnumWithCustomSerializedField { A(#[tls_codec(with = "custom")] Vec), } #[test] +#[cfg(feature = "std")] fn enum_with_custom_serialized_field() { let x = EnumWithCustomSerializedField::A(vec![1, 2, 3]); let serialized = x.tls_serialize_detached().unwrap(); @@ -421,24 +468,28 @@ fn enum_with_custom_serialized_field() { assert_eq!(deserialized, x); } -#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSerialize, TlsSize)] +#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] #[repr(u8)] enum EnumWithCustomSerializedFieldBytes { A(#[tls_codec(with = "custom_bytes")] Vec), } // Variable length vectors -#[derive(Debug, PartialEq, TlsDeserialize, TlsSerialize, TlsSize)] +#[derive(Debug, PartialEq, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct MyContainer { value: Vec, } -#[derive(Debug, PartialEq, TlsDeserialize, TlsSerialize, TlsSize)] +#[derive(Debug, PartialEq, TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct MyByteContainer { value: VLBytes, } #[test] +#[cfg(feature = "std")] fn simple_variable_length_struct() { let val = MyContainer { value: vec![1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -456,6 +507,7 @@ fn simple_variable_length_struct() { } #[test] +#[cfg(feature = "std")] fn that_skip_attribute_on_struct_works() { fn test(test: &[u8], expected: T) where @@ -469,7 +521,8 @@ fn that_skip_attribute_on_struct_works() { assert_eq!(expected, got); } - #[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSize)] + #[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] + #[cfg_attr(feature = "std", derive(TlsDeserialize))] struct StructWithSkip1 { #[tls_codec(skip)] a: u8, @@ -477,7 +530,8 @@ fn that_skip_attribute_on_struct_works() { c: u8, } - #[derive(Debug, PartialEq, TlsDeserialize, TlsSize)] + #[derive(Debug, PartialEq, TlsSize)] + #[cfg_attr(feature = "std", derive(TlsDeserialize))] struct StructWithSkip2 { a: u8, #[tls_codec(skip)] @@ -485,7 +539,8 @@ fn that_skip_attribute_on_struct_works() { c: u8, } - #[derive(Debug, PartialEq, TlsDeserialize, TlsSize)] + #[derive(Debug, PartialEq, TlsSize)] + #[cfg_attr(feature = "std", derive(TlsDeserialize))] struct StructWithSkip3 { a: u8, b: u8, @@ -499,8 +554,10 @@ fn that_skip_attribute_on_struct_works() { } #[test] +#[cfg(feature = "std")] fn generic_struct() { - #[derive(PartialEq, Eq, Debug, TlsSize, TlsSerialize, TlsDeserialize)] + #[derive(PartialEq, Eq, Debug, TlsSize)] + #[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] struct GenericStruct where T: Size + Serialize + Deserialize, @@ -518,7 +575,8 @@ fn generic_struct() { assert_eq!(deserialized, insta); } -#[derive(TlsDeserialize, TlsSerialize, TlsSize)] +#[cfg(feature = "std")] +#[derive(TlsSerialize, TlsDeserialize, TlsSize)] #[repr(u16)] enum TypeWithUnknowns { First = 1, @@ -526,6 +584,7 @@ enum TypeWithUnknowns { } #[test] +#[cfg(feature = "std")] fn type_with_unknowns() { let incoming = [0x00u8, 0x03]; // This must be parsed into TypeWithUnknowns into an unknown let deserialized = TypeWithUnknowns::tls_deserialize_exact(incoming); diff --git a/tls_codec/derive/tests/encode.rs b/tls_codec/derive/tests/encode.rs index e4047678e..926ce0220 100644 --- a/tls_codec/derive/tests/encode.rs +++ b/tls_codec/derive/tests/encode.rs @@ -1,7 +1,14 @@ -use tls_codec::{SecretTlsVecU16, Serialize, Size, TlsSliceU16, TlsVecU16, TlsVecU32}; -use tls_codec_derive::{TlsSerialize, TlsSize}; +use tls_codec::{SecretTlsVecU16, Serialize, TlsSliceU16, TlsVecU16, TlsVecU32}; +use tls_codec_derive::TlsSize; -#[derive(TlsSerialize, TlsSize, Debug)] +#[cfg(feature = "std")] +use tls_codec::Size; + +#[cfg(feature = "std")] +use tls_codec_derive::TlsSerialize; + +#[derive(TlsSize, Debug)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] #[repr(u16)] pub enum ExtensionType { Reserved = 0, @@ -13,32 +20,38 @@ pub enum ExtensionType { SomethingElse = 500, } -#[derive(TlsSerialize, TlsSize, Debug)] +#[derive(TlsSize, Debug)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] pub struct ExtensionStruct { extension_type: ExtensionType, extension_data: TlsVecU32, additional_data: Option>, } -#[derive(TlsSerialize, TlsSize, Debug)] +#[derive(TlsSize, Debug)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] pub struct TupleStruct(ExtensionStruct, u8); -#[derive(TlsSerialize, TlsSize, Debug)] +#[derive(TlsSize, Debug)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] pub struct StructWithLifetime<'a> { value: &'a TlsVecU16, } -#[derive(TlsSerialize, TlsSize, Debug, Clone)] +#[derive(TlsSize, Debug, Clone)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] struct SomeValue { val: TlsVecU16, } -#[derive(TlsSerialize, TlsSize)] +#[derive(TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] pub struct StructWithDoubleLifetime<'a, 'b> { value: &'a TlsSliceU16<'a, &'b SomeValue>, } #[test] +#[cfg(feature = "std")] fn lifetime_struct() { let value: TlsVecU16 = vec![7u8; 33].into(); let s = StructWithLifetime { value: &value }; @@ -60,6 +73,7 @@ fn lifetime_struct() { } #[test] +#[cfg(feature = "std")] fn simple_enum() { let serialized = ExtensionType::KeyId.tls_serialize_detached().unwrap(); assert_eq!(vec![0, 3], serialized); @@ -70,6 +84,7 @@ fn simple_enum() { } #[test] +#[cfg(feature = "std")] fn simple_struct() { let extension = ExtensionStruct { extension_type: ExtensionType::KeyId, @@ -81,6 +96,7 @@ fn simple_struct() { } #[test] +#[cfg(feature = "std")] fn tuple_struct() { let ext = ExtensionStruct { extension_type: ExtensionType::KeyId, @@ -100,6 +116,7 @@ fn byte_arrays() { } #[test] +#[cfg(feature = "std")] fn lifetimes() { let x = vec![1, 2, 3, 4].into(); let s = StructWithLifetime { value: &x }; @@ -113,7 +130,8 @@ fn lifetimes() { assert_eq!(vec![0, 4, 1, 2, 3, 4], serialized); } -#[derive(TlsSerialize, TlsSize)] +#[derive(TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] struct Custom { #[tls_codec(with = "custom")] values: Vec, @@ -121,19 +139,26 @@ struct Custom { } mod custom { + use tls_codec::{Size, TlsByteSliceU32}; + + #[cfg(feature = "std")] use std::io::Write; - use tls_codec::{Serialize, Size, TlsByteSliceU32}; + + #[cfg(feature = "std")] + use tls_codec::Serialize; pub fn tls_serialized_len(v: &[u8]) -> usize { TlsByteSliceU32(v).tls_serialized_len() } + #[cfg(feature = "std")] pub fn tls_serialize(v: &[u8], writer: &mut W) -> Result { TlsByteSliceU32(v).tls_serialize(writer) } } #[test] +#[cfg(feature = "std")] fn custom() { let x = Custom { values: vec![0, 1, 2], @@ -143,7 +168,8 @@ fn custom() { assert_eq!(vec![0, 0, 0, 3, 0, 1, 2, 3], serialized); } -#[derive(TlsSerialize, TlsSize)] +#[derive(TlsSize)] +#[cfg_attr(feature = "std", derive(TlsSerialize))] struct OptionalMemberRef<'a> { optional_member: Option<&'a u32>, ref_optional_member: &'a Option<&'a u32>, @@ -151,6 +177,7 @@ struct OptionalMemberRef<'a> { } #[test] +#[cfg(feature = "std")] fn optional_member() { let m = 6; let v = vec![1, 2, 3]; @@ -163,6 +190,7 @@ fn optional_member() { assert_eq!(vec![1, 0, 0, 0, 6, 0, 0, 6, 0, 1, 0, 2, 0, 3], serialized); } +#[cfg(feature = "std")] #[derive(TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithTupleVariant { @@ -170,12 +198,14 @@ enum EnumWithTupleVariant { } #[test] +#[cfg(feature = "std")] fn enum_with_tuple_variant() { let x = EnumWithTupleVariant::A(3, 4); let serialized = x.tls_serialize_detached().unwrap(); assert_eq!(vec![0, 3, 0, 0, 0, 4], serialized); } +#[cfg(feature = "std")] #[derive(TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithStructVariant { @@ -183,12 +213,14 @@ enum EnumWithStructVariant { } #[test] +#[cfg(feature = "std")] fn enum_with_struct_variant() { let x = EnumWithStructVariant::A { foo: 3, bar: 4 }; let serialized = x.tls_serialize_detached().unwrap(); assert_eq!(vec![0, 3, 0, 0, 0, 4], serialized); } +#[cfg(feature = "std")] #[derive(TlsSerialize, TlsSize)] #[repr(u16)] enum EnumWithDataAndDiscriminant { @@ -198,6 +230,7 @@ enum EnumWithDataAndDiscriminant { } #[test] +#[cfg(feature = "std")] fn enum_with_data_and_discriminant() { let x = EnumWithDataAndDiscriminant::A(4); let serialized = x.tls_serialize_detached().unwrap(); @@ -205,12 +238,14 @@ fn enum_with_data_and_discriminant() { } #[test] +#[cfg(feature = "std")] fn discriminant_is_incremented_implicitly() { let x = EnumWithDataAndDiscriminant::B; let serialized = x.tls_serialize_detached().unwrap(); assert_eq!(vec![0, 4], serialized); } +#[cfg(feature = "std")] mod discriminant { pub mod test { pub mod constant { @@ -224,6 +259,7 @@ mod discriminant { } } +#[cfg(feature = "std")] #[derive(Debug, PartialEq, TlsSerialize, TlsSize)] #[repr(u16)] enum EnumWithDataAndConstDiscriminant { @@ -236,6 +272,7 @@ enum EnumWithDataAndConstDiscriminant { } #[test] +#[cfg(feature = "std")] fn enum_with_data_and_const_discriminant() { let serialized = EnumWithDataAndConstDiscriminant::A(4) .tls_serialize_detached() @@ -251,6 +288,7 @@ fn enum_with_data_and_const_discriminant() { assert_eq!(vec![0, 12], serialized); } +#[cfg(feature = "std")] #[derive(TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithCustomSerializedField { @@ -258,6 +296,7 @@ enum EnumWithCustomSerializedField { } #[test] +#[cfg(feature = "std")] fn enum_with_custom_serialized_field() { let x = EnumWithCustomSerializedField::A(vec![1, 2, 3]); let serialized = x.tls_serialize_detached().unwrap(); @@ -265,6 +304,7 @@ fn enum_with_custom_serialized_field() { } #[test] +#[cfg(feature = "std")] fn that_skip_attribute_on_struct_works() { fn test(test: T, expected: &[u8]) where @@ -277,7 +317,8 @@ fn that_skip_attribute_on_struct_works() { assert_eq!(test.tls_serialize_detached().unwrap(), expected); } - #[derive(Debug, PartialEq, TlsSerialize, TlsSize)] + #[derive(Debug, PartialEq, TlsSize)] + #[cfg_attr(feature = "std", derive(TlsSerialize))] struct StructWithSkip1 { #[tls_codec(skip)] a: u8, @@ -285,7 +326,8 @@ fn that_skip_attribute_on_struct_works() { c: u8, } - #[derive(Debug, PartialEq, TlsSerialize, TlsSize)] + #[derive(Debug, PartialEq, TlsSize)] + #[cfg_attr(feature = "std", derive(TlsSerialize))] struct StructWithSkip2 { a: u8, #[tls_codec(skip)] @@ -293,7 +335,8 @@ fn that_skip_attribute_on_struct_works() { c: u8, } - #[derive(Debug, PartialEq, TlsSerialize, TlsSize)] + #[derive(Debug, PartialEq, TlsSize)] + #[cfg_attr(feature = "std", derive(TlsSerialize))] struct StructWithSkip3 { a: u8, b: u8, From d602da3bb59c19703faf6294260abb3d294d0662 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Tue, 4 Jul 2023 17:15:50 +0530 Subject: [PATCH 04/14] test all features for tls_codec_derive --- .github/workflows/tls_codec.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tls_codec.yml b/.github/workflows/tls_codec.yml index f66ea38b2..87e59bb17 100644 --- a/.github/workflows/tls_codec.yml +++ b/.github/workflows/tls_codec.yml @@ -70,4 +70,4 @@ jobs: - run: ${{ matrix.deps }} - uses: RustCrypto/actions/cargo-hack-install@master - run: cargo hack test --feature-powerset - - run: cargo hack test -p tls_codec_derive --test encode\* --test decode\* + - run: cargo hack test -p tls_codec_derive --feature-powerset --test encode\* --test decode\* From c5b2e6112bb3acfdae12e9abbc725e23d4e0975b Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Tue, 4 Jul 2023 18:32:04 +0530 Subject: [PATCH 05/14] Revert "fix tests" This reverts commit 640fa2ae8037d577f295904ee6a60514a329a54a. --- tls_codec/derive/tests/decode.rs | 143 +++++++++---------------------- tls_codec/derive/tests/encode.rs | 71 +++------------ 2 files changed, 56 insertions(+), 158 deletions(-) diff --git a/tls_codec/derive/tests/decode.rs b/tls_codec/derive/tests/decode.rs index f26310bc0..7717b8c12 100644 --- a/tls_codec/derive/tests/decode.rs +++ b/tls_codec/derive/tests/decode.rs @@ -1,14 +1,11 @@ -use tls_codec::{TlsVecU16, TlsVecU32, TlsVecU8, VLBytes}; -use tls_codec_derive::{TlsDeserializeBytes, TlsSize}; - -#[cfg(feature = "std")] -use tls_codec::{Deserialize, Error, Serialize, Size, TlsSliceU16}; - -#[cfg(feature = "std")] -use tls_codec_derive::{TlsDeserialize, TlsSerialize}; - -#[derive(TlsDeserializeBytes, Debug, PartialEq, Clone, Copy, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +use tls_codec::{ + Deserialize, Error, Serialize, Size, TlsSliceU16, TlsVecU16, TlsVecU32, TlsVecU8, VLBytes, +}; +use tls_codec_derive::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize}; + +#[derive( + TlsDeserialize, TlsDeserializeBytes, Debug, PartialEq, Clone, Copy, TlsSize, TlsSerialize, +)] #[repr(u16)] pub enum ExtensionType { Reserved = 0, @@ -26,35 +23,31 @@ impl Default for ExtensionType { } } -#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize, Clone, Default)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive( + TlsDeserialize, TlsDeserializeBytes, Debug, PartialEq, TlsSerialize, TlsSize, Clone, Default, +)] pub struct ExtensionStruct { extension_type: ExtensionType, extension_data: TlsVecU32, } -#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsDeserialize, TlsDeserializeBytes, Debug, PartialEq, TlsSize, TlsSerialize)] pub struct ExtensionTypeVec { data: TlsVecU8, } -#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsDeserialize, TlsDeserializeBytes, Debug, PartialEq, TlsSize, TlsSerialize)] pub struct ArrayWrap { data: [u8; 8], } -#[derive(TlsDeserializeBytes, TlsSize, Debug, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] pub struct TupleStruct1(ExtensionStruct); -#[derive(TlsDeserializeBytes, TlsSize, Debug, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] pub struct TupleStruct(ExtensionStruct, u8); #[test] -#[cfg(feature = "std")] fn tuple_struct() { let ext = ExtensionStruct { extension_type: ExtensionType::KeyId, @@ -96,7 +89,6 @@ fn tuple_struct() { } #[test] -#[cfg(feature = "std")] fn simple_enum() { let b = &[0u8, 5] as &[u8]; let mut b_reader = b; @@ -124,7 +116,6 @@ fn simple_enum() { } #[test] -#[cfg(feature = "std")] fn deserialize_tls_vec() { let long_vector = vec![ExtensionStruct::default(); 3000]; let serialized_long_vec = TlsSliceU16(&long_vector).tls_serialize_detached().unwrap(); @@ -147,7 +138,6 @@ fn deserialize_tls_vec() { } #[test] -#[cfg(feature = "std")] fn byte_arrays() { let x = [0u8, 1, 2, 3]; let serialized = x.tls_serialize_detached().unwrap(); @@ -166,7 +156,6 @@ fn byte_arrays() { } #[test] -#[cfg(feature = "std")] fn simple_struct() { let mut b = &[0u8, 3, 0, 0, 0, 5, 1, 2, 3, 4, 5] as &[u8]; let extension = ExtensionStruct { @@ -189,60 +178,50 @@ fn simple_struct() { assert_eq!(extension, deserialized); } -#[derive(TlsDeserializeBytes, Clone, TlsSize, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsDeserialize))] +#[derive(TlsDeserialize, TlsDeserializeBytes, Clone, TlsSize, PartialEq)] struct DeserializeOnlyStruct(u16); // KAT from MLS -#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] #[repr(u8)] enum ProtocolVersion { Reserved = 0, Mls10 = 1, } -#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] struct CipherSuite(u16); -#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] struct HPKEPublicKey(TlsVecU16); -#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] struct CredentialType(u16); -#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] struct SignatureScheme(u16); -#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] struct BasicCredential { identity: TlsVecU16, signature_scheme: SignatureScheme, signature_key: TlsVecU16, } -#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] struct Credential { credential_type: CredentialType, credential: BasicCredential, } -#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] struct Extension { extension_type: ExtensionType, extension_data: TlsVecU32, } -#[derive(TlsDeserializeBytes, TlsSize, Clone, PartialEq)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, Clone, PartialEq)] struct KeyPackage { version: ProtocolVersion, cipher_suite: CipherSuite, @@ -253,7 +232,6 @@ struct KeyPackage { } #[test] -#[cfg(feature = "std")] fn kat_mls_key_package() { let key_package_bytes = &[ 0x01u8, 0x00, 0x01, 0x00, 0x20, 0xF2, 0xBC, 0xD8, 0x95, 0x19, 0xDD, 0x1D, 0x06, 0x9F, 0x8B, @@ -282,8 +260,7 @@ fn kat_mls_key_package() { ); } -#[derive(Debug, PartialEq, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(Debug, PartialEq, TlsDeserialize, TlsSerialize, TlsSize)] struct Custom { #[tls_codec(with = "custom")] values: Vec, @@ -291,30 +268,23 @@ struct Custom { } mod custom { - use tls_codec::{Size, TlsByteSliceU32}; - - #[cfg(feature = "std")] use std::io::{Read, Write}; - #[cfg(feature = "std")] - use tls_codec::{Deserialize, Serialize, TlsByteVecU32}; + use tls_codec::{Deserialize, Serialize, Size, TlsByteSliceU32, TlsByteVecU32}; pub fn tls_serialized_len(v: &[u8]) -> usize { TlsByteSliceU32(v).tls_serialized_len() } - #[cfg(feature = "std")] pub fn tls_serialize(v: &[u8], writer: &mut W) -> Result { TlsByteSliceU32(v).tls_serialize(writer) } - #[cfg(feature = "std")] pub fn tls_deserialize(bytes: &mut R) -> Result, tls_codec::Error> { Ok(TlsByteVecU32::tls_deserialize(bytes)?.into_vec()) } } -#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSerialize, TlsSize)] struct CustomBytes { #[tls_codec(with = "custom_bytes")] values: Vec, @@ -322,18 +292,13 @@ struct CustomBytes { } mod custom_bytes { - use tls_codec::{DeserializeBytes, Size, TlsByteSliceU32, TlsByteVecU32}; - - #[cfg(feature = "std")] use std::io::Write; - #[cfg(feature = "std")] - use tls_codec::Serialize; + use tls_codec::{DeserializeBytes, Serialize, Size, TlsByteSliceU32, TlsByteVecU32}; pub fn tls_serialized_len(v: &[u8]) -> usize { TlsByteSliceU32(v).tls_serialized_len() } - #[cfg(feature = "std")] pub fn tls_serialize(v: &[u8], writer: &mut W) -> Result { TlsByteSliceU32(v).tls_serialize(writer) } @@ -342,11 +307,9 @@ mod custom_bytes { let (vec, remainder) = TlsByteVecU32::tls_deserialize(bytes)?; Ok((vec.into_vec(), remainder)) } - // pub fn tls_deserialize(bytes: &mut R) -> Result<(Vec, &R), tls_codec::Error> {} } #[test] -#[cfg(feature = "std")] fn custom() { let x = Custom { values: vec![0, 1, 2], @@ -357,15 +320,13 @@ fn custom() { assert_eq!(x, deserialized); } -#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithTupleVariant { A(u8, u32), } #[test] -#[cfg(feature = "std")] fn enum_with_tuple_variant() { let x = EnumWithTupleVariant::A(3, 4); let serialized = x.tls_serialize_detached().unwrap(); @@ -373,15 +334,13 @@ fn enum_with_tuple_variant() { assert_eq!(deserialized, x); } -#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithStructVariant { A { foo: u8, bar: u32 }, } #[test] -#[cfg(feature = "std")] fn enum_with_struct_variant() { let x = EnumWithStructVariant::A { foo: 3, bar: 4 }; let serialized = x.tls_serialize_detached().unwrap(); @@ -389,8 +348,7 @@ fn enum_with_struct_variant() { assert_eq!(deserialized, x); } -#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)] #[repr(u16)] enum EnumWithDataAndDiscriminant { #[tls_codec(discriminant = 3)] @@ -399,7 +357,6 @@ enum EnumWithDataAndDiscriminant { } #[test] -#[cfg(feature = "std")] fn enum_with_data_and_discriminant() { for x in [ EnumWithDataAndDiscriminant::A(4), @@ -425,8 +382,7 @@ mod discriminant { } } -#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize)] #[repr(u16)] enum EnumWithDataAndConstDiscriminant { #[tls_codec(discriminant = "discriminant::test::constant::TEST_CONST")] @@ -438,7 +394,6 @@ enum EnumWithDataAndConstDiscriminant { } #[test] -#[cfg(feature = "std")] fn enum_with_data_and_const_discriminant() { for x in [ EnumWithDataAndConstDiscriminant::A(4), @@ -452,15 +407,13 @@ fn enum_with_data_and_const_discriminant() { } } -#[cfg(feature = "std")] -#[derive(Debug, PartialEq, TlsSize, TlsSerialize, TlsDeserialize)] +#[derive(Debug, PartialEq, TlsDeserialize, TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithCustomSerializedField { A(#[tls_codec(with = "custom")] Vec), } #[test] -#[cfg(feature = "std")] fn enum_with_custom_serialized_field() { let x = EnumWithCustomSerializedField::A(vec![1, 2, 3]); let serialized = x.tls_serialize_detached().unwrap(); @@ -468,28 +421,24 @@ fn enum_with_custom_serialized_field() { assert_eq!(deserialized, x); } -#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithCustomSerializedFieldBytes { A(#[tls_codec(with = "custom_bytes")] Vec), } // Variable length vectors -#[derive(Debug, PartialEq, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(Debug, PartialEq, TlsDeserialize, TlsSerialize, TlsSize)] struct MyContainer { value: Vec, } -#[derive(Debug, PartialEq, TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] +#[derive(Debug, PartialEq, TlsDeserialize, TlsSerialize, TlsSize)] struct MyByteContainer { value: VLBytes, } #[test] -#[cfg(feature = "std")] fn simple_variable_length_struct() { let val = MyContainer { value: vec![1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -507,7 +456,6 @@ fn simple_variable_length_struct() { } #[test] -#[cfg(feature = "std")] fn that_skip_attribute_on_struct_works() { fn test(test: &[u8], expected: T) where @@ -521,8 +469,7 @@ fn that_skip_attribute_on_struct_works() { assert_eq!(expected, got); } - #[derive(Debug, PartialEq, TlsDeserializeBytes, TlsSize)] - #[cfg_attr(feature = "std", derive(TlsDeserialize))] + #[derive(Debug, PartialEq, TlsDeserialize, TlsDeserializeBytes, TlsSize)] struct StructWithSkip1 { #[tls_codec(skip)] a: u8, @@ -530,8 +477,7 @@ fn that_skip_attribute_on_struct_works() { c: u8, } - #[derive(Debug, PartialEq, TlsSize)] - #[cfg_attr(feature = "std", derive(TlsDeserialize))] + #[derive(Debug, PartialEq, TlsDeserialize, TlsSize)] struct StructWithSkip2 { a: u8, #[tls_codec(skip)] @@ -539,8 +485,7 @@ fn that_skip_attribute_on_struct_works() { c: u8, } - #[derive(Debug, PartialEq, TlsSize)] - #[cfg_attr(feature = "std", derive(TlsDeserialize))] + #[derive(Debug, PartialEq, TlsDeserialize, TlsSize)] struct StructWithSkip3 { a: u8, b: u8, @@ -554,10 +499,8 @@ fn that_skip_attribute_on_struct_works() { } #[test] -#[cfg(feature = "std")] fn generic_struct() { - #[derive(PartialEq, Eq, Debug, TlsSize)] - #[cfg_attr(feature = "std", derive(TlsSerialize, TlsDeserialize))] + #[derive(PartialEq, Eq, Debug, TlsSize, TlsSerialize, TlsDeserialize)] struct GenericStruct where T: Size + Serialize + Deserialize, @@ -575,8 +518,7 @@ fn generic_struct() { assert_eq!(deserialized, insta); } -#[cfg(feature = "std")] -#[derive(TlsSerialize, TlsDeserialize, TlsSize)] +#[derive(TlsDeserialize, TlsSerialize, TlsSize)] #[repr(u16)] enum TypeWithUnknowns { First = 1, @@ -584,7 +526,6 @@ enum TypeWithUnknowns { } #[test] -#[cfg(feature = "std")] fn type_with_unknowns() { let incoming = [0x00u8, 0x03]; // This must be parsed into TypeWithUnknowns into an unknown let deserialized = TypeWithUnknowns::tls_deserialize_exact(incoming); diff --git a/tls_codec/derive/tests/encode.rs b/tls_codec/derive/tests/encode.rs index 926ce0220..e4047678e 100644 --- a/tls_codec/derive/tests/encode.rs +++ b/tls_codec/derive/tests/encode.rs @@ -1,14 +1,7 @@ -use tls_codec::{SecretTlsVecU16, Serialize, TlsSliceU16, TlsVecU16, TlsVecU32}; -use tls_codec_derive::TlsSize; +use tls_codec::{SecretTlsVecU16, Serialize, Size, TlsSliceU16, TlsVecU16, TlsVecU32}; +use tls_codec_derive::{TlsSerialize, TlsSize}; -#[cfg(feature = "std")] -use tls_codec::Size; - -#[cfg(feature = "std")] -use tls_codec_derive::TlsSerialize; - -#[derive(TlsSize, Debug)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(TlsSerialize, TlsSize, Debug)] #[repr(u16)] pub enum ExtensionType { Reserved = 0, @@ -20,38 +13,32 @@ pub enum ExtensionType { SomethingElse = 500, } -#[derive(TlsSize, Debug)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(TlsSerialize, TlsSize, Debug)] pub struct ExtensionStruct { extension_type: ExtensionType, extension_data: TlsVecU32, additional_data: Option>, } -#[derive(TlsSize, Debug)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(TlsSerialize, TlsSize, Debug)] pub struct TupleStruct(ExtensionStruct, u8); -#[derive(TlsSize, Debug)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(TlsSerialize, TlsSize, Debug)] pub struct StructWithLifetime<'a> { value: &'a TlsVecU16, } -#[derive(TlsSize, Debug, Clone)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(TlsSerialize, TlsSize, Debug, Clone)] struct SomeValue { val: TlsVecU16, } -#[derive(TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(TlsSerialize, TlsSize)] pub struct StructWithDoubleLifetime<'a, 'b> { value: &'a TlsSliceU16<'a, &'b SomeValue>, } #[test] -#[cfg(feature = "std")] fn lifetime_struct() { let value: TlsVecU16 = vec![7u8; 33].into(); let s = StructWithLifetime { value: &value }; @@ -73,7 +60,6 @@ fn lifetime_struct() { } #[test] -#[cfg(feature = "std")] fn simple_enum() { let serialized = ExtensionType::KeyId.tls_serialize_detached().unwrap(); assert_eq!(vec![0, 3], serialized); @@ -84,7 +70,6 @@ fn simple_enum() { } #[test] -#[cfg(feature = "std")] fn simple_struct() { let extension = ExtensionStruct { extension_type: ExtensionType::KeyId, @@ -96,7 +81,6 @@ fn simple_struct() { } #[test] -#[cfg(feature = "std")] fn tuple_struct() { let ext = ExtensionStruct { extension_type: ExtensionType::KeyId, @@ -116,7 +100,6 @@ fn byte_arrays() { } #[test] -#[cfg(feature = "std")] fn lifetimes() { let x = vec![1, 2, 3, 4].into(); let s = StructWithLifetime { value: &x }; @@ -130,8 +113,7 @@ fn lifetimes() { assert_eq!(vec![0, 4, 1, 2, 3, 4], serialized); } -#[derive(TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(TlsSerialize, TlsSize)] struct Custom { #[tls_codec(with = "custom")] values: Vec, @@ -139,26 +121,19 @@ struct Custom { } mod custom { - use tls_codec::{Size, TlsByteSliceU32}; - - #[cfg(feature = "std")] use std::io::Write; - - #[cfg(feature = "std")] - use tls_codec::Serialize; + use tls_codec::{Serialize, Size, TlsByteSliceU32}; pub fn tls_serialized_len(v: &[u8]) -> usize { TlsByteSliceU32(v).tls_serialized_len() } - #[cfg(feature = "std")] pub fn tls_serialize(v: &[u8], writer: &mut W) -> Result { TlsByteSliceU32(v).tls_serialize(writer) } } #[test] -#[cfg(feature = "std")] fn custom() { let x = Custom { values: vec![0, 1, 2], @@ -168,8 +143,7 @@ fn custom() { assert_eq!(vec![0, 0, 0, 3, 0, 1, 2, 3], serialized); } -#[derive(TlsSize)] -#[cfg_attr(feature = "std", derive(TlsSerialize))] +#[derive(TlsSerialize, TlsSize)] struct OptionalMemberRef<'a> { optional_member: Option<&'a u32>, ref_optional_member: &'a Option<&'a u32>, @@ -177,7 +151,6 @@ struct OptionalMemberRef<'a> { } #[test] -#[cfg(feature = "std")] fn optional_member() { let m = 6; let v = vec![1, 2, 3]; @@ -190,7 +163,6 @@ fn optional_member() { assert_eq!(vec![1, 0, 0, 0, 6, 0, 0, 6, 0, 1, 0, 2, 0, 3], serialized); } -#[cfg(feature = "std")] #[derive(TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithTupleVariant { @@ -198,14 +170,12 @@ enum EnumWithTupleVariant { } #[test] -#[cfg(feature = "std")] fn enum_with_tuple_variant() { let x = EnumWithTupleVariant::A(3, 4); let serialized = x.tls_serialize_detached().unwrap(); assert_eq!(vec![0, 3, 0, 0, 0, 4], serialized); } -#[cfg(feature = "std")] #[derive(TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithStructVariant { @@ -213,14 +183,12 @@ enum EnumWithStructVariant { } #[test] -#[cfg(feature = "std")] fn enum_with_struct_variant() { let x = EnumWithStructVariant::A { foo: 3, bar: 4 }; let serialized = x.tls_serialize_detached().unwrap(); assert_eq!(vec![0, 3, 0, 0, 0, 4], serialized); } -#[cfg(feature = "std")] #[derive(TlsSerialize, TlsSize)] #[repr(u16)] enum EnumWithDataAndDiscriminant { @@ -230,7 +198,6 @@ enum EnumWithDataAndDiscriminant { } #[test] -#[cfg(feature = "std")] fn enum_with_data_and_discriminant() { let x = EnumWithDataAndDiscriminant::A(4); let serialized = x.tls_serialize_detached().unwrap(); @@ -238,14 +205,12 @@ fn enum_with_data_and_discriminant() { } #[test] -#[cfg(feature = "std")] fn discriminant_is_incremented_implicitly() { let x = EnumWithDataAndDiscriminant::B; let serialized = x.tls_serialize_detached().unwrap(); assert_eq!(vec![0, 4], serialized); } -#[cfg(feature = "std")] mod discriminant { pub mod test { pub mod constant { @@ -259,7 +224,6 @@ mod discriminant { } } -#[cfg(feature = "std")] #[derive(Debug, PartialEq, TlsSerialize, TlsSize)] #[repr(u16)] enum EnumWithDataAndConstDiscriminant { @@ -272,7 +236,6 @@ enum EnumWithDataAndConstDiscriminant { } #[test] -#[cfg(feature = "std")] fn enum_with_data_and_const_discriminant() { let serialized = EnumWithDataAndConstDiscriminant::A(4) .tls_serialize_detached() @@ -288,7 +251,6 @@ fn enum_with_data_and_const_discriminant() { assert_eq!(vec![0, 12], serialized); } -#[cfg(feature = "std")] #[derive(TlsSerialize, TlsSize)] #[repr(u8)] enum EnumWithCustomSerializedField { @@ -296,7 +258,6 @@ enum EnumWithCustomSerializedField { } #[test] -#[cfg(feature = "std")] fn enum_with_custom_serialized_field() { let x = EnumWithCustomSerializedField::A(vec![1, 2, 3]); let serialized = x.tls_serialize_detached().unwrap(); @@ -304,7 +265,6 @@ fn enum_with_custom_serialized_field() { } #[test] -#[cfg(feature = "std")] fn that_skip_attribute_on_struct_works() { fn test(test: T, expected: &[u8]) where @@ -317,8 +277,7 @@ fn that_skip_attribute_on_struct_works() { assert_eq!(test.tls_serialize_detached().unwrap(), expected); } - #[derive(Debug, PartialEq, TlsSize)] - #[cfg_attr(feature = "std", derive(TlsSerialize))] + #[derive(Debug, PartialEq, TlsSerialize, TlsSize)] struct StructWithSkip1 { #[tls_codec(skip)] a: u8, @@ -326,8 +285,7 @@ fn that_skip_attribute_on_struct_works() { c: u8, } - #[derive(Debug, PartialEq, TlsSize)] - #[cfg_attr(feature = "std", derive(TlsSerialize))] + #[derive(Debug, PartialEq, TlsSerialize, TlsSize)] struct StructWithSkip2 { a: u8, #[tls_codec(skip)] @@ -335,8 +293,7 @@ fn that_skip_attribute_on_struct_works() { c: u8, } - #[derive(Debug, PartialEq, TlsSize)] - #[cfg_attr(feature = "std", derive(TlsSerialize))] + #[derive(Debug, PartialEq, TlsSerialize, TlsSize)] struct StructWithSkip3 { a: u8, b: u8, From da63de873fd55688b22b51419e16b0127af51bd1 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Tue, 4 Jul 2023 18:36:39 +0530 Subject: [PATCH 06/14] run encode/decode tests only when std feature is enabled --- tls_codec/derive/tests/decode.rs | 1 + tls_codec/derive/tests/encode.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/tls_codec/derive/tests/decode.rs b/tls_codec/derive/tests/decode.rs index 7717b8c12..107624e23 100644 --- a/tls_codec/derive/tests/decode.rs +++ b/tls_codec/derive/tests/decode.rs @@ -1,3 +1,4 @@ +#![cfg(feature = "std")] use tls_codec::{ Deserialize, Error, Serialize, Size, TlsSliceU16, TlsVecU16, TlsVecU32, TlsVecU8, VLBytes, }; diff --git a/tls_codec/derive/tests/encode.rs b/tls_codec/derive/tests/encode.rs index e4047678e..84b760c91 100644 --- a/tls_codec/derive/tests/encode.rs +++ b/tls_codec/derive/tests/encode.rs @@ -1,3 +1,4 @@ +#![cfg(feature = "std")] use tls_codec::{SecretTlsVecU16, Serialize, Size, TlsSliceU16, TlsVecU16, TlsVecU32}; use tls_codec_derive::{TlsSerialize, TlsSize}; From a0e0de3cec3432d355f62b58e039e999d17f3a91 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Tue, 4 Jul 2023 19:00:55 +0530 Subject: [PATCH 07/14] add DeserializeBytes tests --- tls_codec/derive/tests/decode_bytes.rs | 94 ++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tls_codec/derive/tests/decode_bytes.rs diff --git a/tls_codec/derive/tests/decode_bytes.rs b/tls_codec/derive/tests/decode_bytes.rs new file mode 100644 index 000000000..0b76e8252 --- /dev/null +++ b/tls_codec/derive/tests/decode_bytes.rs @@ -0,0 +1,94 @@ +use tls_codec::{TlsVecU16, TlsVecU32, TlsVecU8}; +use tls_codec_derive::{TlsDeserializeBytes, TlsSize}; + +#[derive(TlsDeserializeBytes, Debug, PartialEq, Clone, Copy, TlsSize)] +#[repr(u16)] +pub enum ExtensionType { + Reserved = 0, + Capabilities = 1, + Lifetime = 2, + KeyId = 3, + ParentHash = 4, + RatchetTree = 5, + SomethingElse = 500, +} + +impl Default for ExtensionType { + fn default() -> Self { + Self::Reserved + } +} + +#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize, Clone, Default)] +pub struct ExtensionStruct { + extension_type: ExtensionType, + extension_data: TlsVecU32, +} + +#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize)] +pub struct ExtensionTypeVec { + data: TlsVecU8, +} + +#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize)] +pub struct ArrayWrap { + data: [u8; 8], +} + +#[derive(TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +pub struct TupleStruct1(ExtensionStruct); + +#[derive(TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +pub struct TupleStruct(ExtensionStruct, u8); + +#[test] +fn tuple_struct() { + let ext = ExtensionStruct { + extension_type: ExtensionType::KeyId, + extension_data: TlsVecU32::from_slice(&[1, 2, 3, 4, 5]), + }; + let t1 = TupleStruct1(ext.clone()); + let serialized_t1 = vec![0, 3, 0, 0, 0, 5, 1, 2, 3, 4, 5]; //t1.tls_serialize_detached().unwrap(); + println!("{:?}", serialized_t1); + let (deserialized_bytes_t1, _remainder) = + ::tls_deserialize(serialized_t1.as_slice()) + .unwrap(); + assert_eq!(t1, deserialized_bytes_t1); +} + +#[test] +fn simple_enum() { + let b = &[0u8, 5] as &[u8]; + let (deserialized_bytes, _remainder) = + ::tls_deserialize(b).unwrap(); + assert_eq!(ExtensionType::RatchetTree, deserialized_bytes); + + let mut b = &[0u8, 5, 1, 244, 0, 1] as &[u8]; + let variants = [ + ExtensionType::RatchetTree, + ExtensionType::SomethingElse, + ExtensionType::Capabilities, + ]; + for variant in variants.iter() { + let (deserialized_bytes, remainder) = + ::tls_deserialize(b).unwrap(); + b = remainder; + assert_eq!(variant, &deserialized_bytes); + } +} + +#[test] +fn deserialize_tls_vec() { + let long_vector = vec![ExtensionStruct::default(); 4]; + let serialized_long_vec = [ + 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]; + let (deserialized_long_vec_bytes, _remainder): (Vec, &[u8]) = + as tls_codec::DeserializeBytes>::tls_deserialize( + serialized_long_vec.as_slice(), + ) + .map(|(v, r)| (v.into(), r)) + .unwrap(); + assert_eq!(long_vector.len(), deserialized_long_vec_bytes.len()); + assert_eq!(long_vector, deserialized_long_vec_bytes); +} From 710378c4685914e8c8adf72351e69142033677b1 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Tue, 4 Jul 2023 19:10:45 +0530 Subject: [PATCH 08/14] add one more tuple_struct test --- tls_codec/derive/tests/decode_bytes.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tls_codec/derive/tests/decode_bytes.rs b/tls_codec/derive/tests/decode_bytes.rs index 0b76e8252..c9ee9316a 100644 --- a/tls_codec/derive/tests/decode_bytes.rs +++ b/tls_codec/derive/tests/decode_bytes.rs @@ -48,12 +48,19 @@ fn tuple_struct() { extension_data: TlsVecU32::from_slice(&[1, 2, 3, 4, 5]), }; let t1 = TupleStruct1(ext.clone()); - let serialized_t1 = vec![0, 3, 0, 0, 0, 5, 1, 2, 3, 4, 5]; //t1.tls_serialize_detached().unwrap(); + let serialized_t1 = vec![0, 3, 0, 0, 0, 5, 1, 2, 3, 4, 5]; println!("{:?}", serialized_t1); let (deserialized_bytes_t1, _remainder) = ::tls_deserialize(serialized_t1.as_slice()) .unwrap(); assert_eq!(t1, deserialized_bytes_t1); + + let t2 = TupleStruct(ext, 5); + let serialized_t2 = vec![0, 3, 0, 0, 0, 5, 1, 2, 3, 4, 5, 5]; + let (deserialized_bytes_t2, _remainder) = + ::tls_deserialize(serialized_t2.as_slice()) + .unwrap(); + assert_eq!(t2, deserialized_bytes_t2); } #[test] From 5217fdc54eb9eca84c6461fa50fe3ecaf3bc0821 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Tue, 4 Jul 2023 19:33:37 +0530 Subject: [PATCH 09/14] add another test --- tls_codec/derive/tests/decode_bytes.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tls_codec/derive/tests/decode_bytes.rs b/tls_codec/derive/tests/decode_bytes.rs index c9ee9316a..852f4d38d 100644 --- a/tls_codec/derive/tests/decode_bytes.rs +++ b/tls_codec/derive/tests/decode_bytes.rs @@ -99,3 +99,27 @@ fn deserialize_tls_vec() { assert_eq!(long_vector.len(), deserialized_long_vec_bytes.len()); assert_eq!(long_vector, deserialized_long_vec_bytes); } + +#[test] +fn byte_arrays() { + let x = [0u8, 1, 2, 3]; + let serialized = [0, 1, 2, 3]; + assert_eq!(x.to_vec(), serialized); + + let (deserialized, rest) = + <[u8; 4] as tls_codec::DeserializeBytes>::tls_deserialize(&mut serialized.as_slice()) + .unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); + + let x = [0u8, 1, 2, 3, 7, 6, 5, 4]; + let w = ArrayWrap { data: x }; + let serialized = [0, 1, 2, 3, 7, 6, 5, 4]; + assert_eq!(x.to_vec(), serialized); + + let (deserialized, rest) = + ::tls_deserialize(&mut serialized.as_slice()) + .unwrap(); + assert_eq!(deserialized, w); + assert_eq!(rest, []); +} From 2f7b38bcbb47b1b1becfa318793b55b043104246 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Tue, 4 Jul 2023 20:00:41 +0530 Subject: [PATCH 10/14] add proper DeserializeBytes tests --- tls_codec/derive/tests/decode_bytes.rs | 353 +++++++++++++++++++------ 1 file changed, 275 insertions(+), 78 deletions(-) diff --git a/tls_codec/derive/tests/decode_bytes.rs b/tls_codec/derive/tests/decode_bytes.rs index 852f4d38d..4a3267f01 100644 --- a/tls_codec/derive/tests/decode_bytes.rs +++ b/tls_codec/derive/tests/decode_bytes.rs @@ -1,7 +1,7 @@ -use tls_codec::{TlsVecU16, TlsVecU32, TlsVecU8}; -use tls_codec_derive::{TlsDeserializeBytes, TlsSize}; +use tls_codec::{DeserializeBytes, SerializeBytes, Size}; +use tls_codec_derive::{TlsDeserializeBytes, TlsSerializeBytes, TlsSize}; -#[derive(TlsDeserializeBytes, Debug, PartialEq, Clone, Copy, TlsSize)] +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, PartialEq, Debug)] #[repr(u16)] pub enum ExtensionType { Reserved = 0, @@ -13,113 +13,310 @@ pub enum ExtensionType { SomethingElse = 500, } -impl Default for ExtensionType { - fn default() -> Self { - Self::Reserved - } -} - -#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize, Clone, Default)] +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] pub struct ExtensionStruct { extension_type: ExtensionType, - extension_data: TlsVecU32, + extension_data: Vec, + additional_data: Option>, } -#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize)] -pub struct ExtensionTypeVec { - data: TlsVecU8, -} +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +pub struct TupleStruct(ExtensionStruct, u8); -#[derive(TlsDeserializeBytes, Debug, PartialEq, TlsSize)] -pub struct ArrayWrap { - data: [u8; 8], +#[derive(TlsSerializeBytes, TlsSize, Debug, Clone)] +struct SomeValue { + val: Vec, } -#[derive(TlsDeserializeBytes, TlsSize, Debug, PartialEq)] -pub struct TupleStruct1(ExtensionStruct); +#[test] +fn simple_enum() { + let serialized = ExtensionType::KeyId.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, ExtensionType::KeyId); + assert_eq!(rest, []); + let serialized = ExtensionType::SomethingElse.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, ExtensionType::SomethingElse); + assert_eq!(rest, []); +} -#[derive(TlsDeserializeBytes, TlsSize, Debug, PartialEq)] -pub struct TupleStruct(ExtensionStruct, u8); +#[test] +fn simple_struct() { + let extension = ExtensionStruct { + extension_type: ExtensionType::KeyId, + extension_data: vec![1, 2, 3, 4, 5], + additional_data: None, + }; + let serialized = extension.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, extension); + assert_eq!(rest, []); +} #[test] fn tuple_struct() { let ext = ExtensionStruct { extension_type: ExtensionType::KeyId, - extension_data: TlsVecU32::from_slice(&[1, 2, 3, 4, 5]), + extension_data: vec![1, 2, 3, 4, 5], + additional_data: None, }; - let t1 = TupleStruct1(ext.clone()); - let serialized_t1 = vec![0, 3, 0, 0, 0, 5, 1, 2, 3, 4, 5]; - println!("{:?}", serialized_t1); - let (deserialized_bytes_t1, _remainder) = - ::tls_deserialize(serialized_t1.as_slice()) - .unwrap(); - assert_eq!(t1, deserialized_bytes_t1); - - let t2 = TupleStruct(ext, 5); - let serialized_t2 = vec![0, 3, 0, 0, 0, 5, 1, 2, 3, 4, 5, 5]; - let (deserialized_bytes_t2, _remainder) = - ::tls_deserialize(serialized_t2.as_slice()) - .unwrap(); - assert_eq!(t2, deserialized_bytes_t2); + let x = TupleStruct(ext, 6); + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); } #[test] -fn simple_enum() { - let b = &[0u8, 5] as &[u8]; - let (deserialized_bytes, _remainder) = - ::tls_deserialize(b).unwrap(); - assert_eq!(ExtensionType::RatchetTree, deserialized_bytes); - - let mut b = &[0u8, 5, 1, 244, 0, 1] as &[u8]; - let variants = [ - ExtensionType::RatchetTree, - ExtensionType::SomethingElse, - ExtensionType::Capabilities, - ]; - for variant in variants.iter() { - let (deserialized_bytes, remainder) = - ::tls_deserialize(b).unwrap(); - b = remainder; - assert_eq!(variant, &deserialized_bytes); +fn byte_arrays() { + let x = [0u8, 1, 2, 3]; + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = <[u8; 4] as DeserializeBytes>::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +struct Custom { + #[tls_codec(with = "custom")] + values: Vec, + a: u8, +} + +mod custom { + use tls_codec::{DeserializeBytes, SerializeBytes, Size}; + + pub fn tls_serialized_len(v: &[u8]) -> usize { + v.tls_serialized_len() + } + + pub fn tls_serialize(v: &[u8]) -> Result, tls_codec::Error> { + v.tls_serialize() + } + + pub fn tls_deserialize( + bytes: &[u8], + ) -> Result<(T, &[u8]), tls_codec::Error> { + ::tls_deserialize(bytes) } } #[test] -fn deserialize_tls_vec() { - let long_vector = vec![ExtensionStruct::default(); 4]; - let serialized_long_vec = [ - 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]; - let (deserialized_long_vec_bytes, _remainder): (Vec, &[u8]) = - as tls_codec::DeserializeBytes>::tls_deserialize( - serialized_long_vec.as_slice(), - ) - .map(|(v, r)| (v.into(), r)) - .unwrap(); - assert_eq!(long_vector.len(), deserialized_long_vec_bytes.len()); - assert_eq!(long_vector, deserialized_long_vec_bytes); +fn custom() { + let x = Custom { + values: vec![0, 1, 2], + a: 3, + }; + let serialized = x.tls_serialize().unwrap(); + assert_eq!(vec![3, 0, 1, 2, 3], serialized); + let (deserialized, rest) = ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[repr(u8)] +enum EnumWithTupleVariant { + A(u8, u32), } #[test] -fn byte_arrays() { - let x = [0u8, 1, 2, 3]; - let serialized = [0, 1, 2, 3]; - assert_eq!(x.to_vec(), serialized); +fn enum_with_tuple_variant() { + let x = EnumWithTupleVariant::A(3, 4); + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[repr(u8)] +enum EnumWithStructVariant { + A { foo: u8, bar: u32 }, +} + +#[test] +fn enum_with_struct_variant() { + let x = EnumWithStructVariant::A { foo: 3, bar: 4 }; + let serialized = x.tls_serialize().unwrap(); let (deserialized, rest) = - <[u8; 4] as tls_codec::DeserializeBytes>::tls_deserialize(&mut serialized.as_slice()) + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[repr(u16)] +enum EnumWithDataAndDiscriminant { + #[tls_codec(discriminant = 3)] + A(u8), + B, +} + +#[test] +fn enum_with_data_and_discriminant() { + let x = EnumWithDataAndDiscriminant::A(4); + let serialized = x.tls_serialize().unwrap(); + + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[test] +fn discriminant_is_incremented_implicitly() { + let x = EnumWithDataAndDiscriminant::B; + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +mod discriminant { + pub mod test { + pub mod constant { + pub const TEST_CONST: u8 = 3; + } + pub mod enum_val { + pub enum Test { + Potato = 0x0004, + } + } + } +} + +#[derive(Debug, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] +#[repr(u16)] +enum EnumWithDataAndConstDiscriminant { + #[tls_codec(discriminant = "discriminant::test::constant::TEST_CONST")] + A(u8), + #[tls_codec(discriminant = "discriminant::test::enum_val::Test::Potato")] + B, + #[tls_codec(discriminant = 12)] + C, +} + +#[test] +fn enum_with_data_and_const_discriminant() { + let x = EnumWithDataAndConstDiscriminant::A(4); + let serialized = x.tls_serialize().unwrap(); + assert_eq!(vec![0, 3, 4], serialized); + let (deserialized, rest) = + ::tls_deserialize(&serialized) .unwrap(); assert_eq!(deserialized, x); assert_eq!(rest, []); - let x = [0u8, 1, 2, 3, 7, 6, 5, 4]; - let w = ArrayWrap { data: x }; - let serialized = [0, 1, 2, 3, 7, 6, 5, 4]; - assert_eq!(x.to_vec(), serialized); + let x = EnumWithDataAndConstDiscriminant::B; + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized) + .unwrap(); + assert_eq!(deserialized, x); + assert_eq!(rest, []); + let x = EnumWithDataAndConstDiscriminant::C; + let serialized = x.tls_serialize().unwrap(); let (deserialized, rest) = - ::tls_deserialize(&mut serialized.as_slice()) + ::tls_deserialize(&serialized) .unwrap(); - assert_eq!(deserialized, w); + assert_eq!(deserialized, x); + assert_eq!(rest, []); +} + +#[derive(TlsSerializeBytes, TlsDeserializeBytes, TlsSize, Debug, PartialEq)] +#[repr(u8)] +enum EnumWithCustomSerializedField { + A(#[tls_codec(with = "custom")] Vec), +} + +#[test] +fn enum_with_custom_serialized_field() { + let x = EnumWithCustomSerializedField::A(vec![1, 2, 3]); + let serialized = x.tls_serialize().unwrap(); + let (deserialized, rest) = + ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, x); assert_eq!(rest, []); } + +#[test] +fn that_skip_attribute_on_struct_works() { + fn test(test: T, expected: T) + where + T: std::fmt::Debug + PartialEq + SerializeBytes + Size, + { + let serialized = test.tls_serialize().unwrap(); + let (deserialized, rest) = ::tls_deserialize(&serialized).unwrap(); + assert_eq!(deserialized, expected); + assert_eq!(rest, []); + } + + #[derive(Debug, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] + struct StructWithSkip1 { + #[tls_codec(skip)] + a: u8, + b: u8, + c: u8, + } + + #[derive(Debug, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] + struct StructWithSkip2 { + a: u8, + #[tls_codec(skip)] + b: u8, + c: u8, + } + + #[derive(Debug, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] + struct StructWithSkip3 { + a: u8, + b: u8, + #[tls_codec(skip)] + c: u8, + } + + test( + StructWithSkip1 { + a: 123, + b: 13, + c: 42, + }, + StructWithSkip1 { + a: Default::default(), + b: 13, + c: 42, + }, + ); + test( + StructWithSkip2 { + a: 123, + b: 13, + c: 42, + }, + StructWithSkip2 { + a: 123, + b: Default::default(), + c: 42, + }, + ); + test( + StructWithSkip3 { + a: 123, + b: 13, + c: 42, + }, + StructWithSkip3 { + a: 123, + b: 13, + c: Default::default(), + }, + ); +} From 86607eb3e5e64b3f1d990f7d0902d27f9032f7f7 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Fri, 21 Jul 2023 12:13:36 +0530 Subject: [PATCH 11/14] Fix formatting Co-authored-by: Franziskus Kiefer --- tls_codec/derive/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tls_codec/derive/Cargo.toml b/tls_codec/derive/Cargo.toml index de63660ed..6e3c10396 100644 --- a/tls_codec/derive/Cargo.toml +++ b/tls_codec/derive/Cargo.toml @@ -24,4 +24,4 @@ trybuild = "1" [features] default = [ "std" ] -std = [] \ No newline at end of file +std = [] From 7c9d3d77c9bf814d909027bf580371a349d99a23 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Fri, 21 Jul 2023 12:34:58 +0530 Subject: [PATCH 12/14] include `derive` feature when building for no_std targets --- .github/workflows/tls_codec.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tls_codec.yml b/.github/workflows/tls_codec.yml index 87e59bb17..c2130551e 100644 --- a/.github/workflows/tls_codec.yml +++ b/.github/workflows/tls_codec.yml @@ -36,7 +36,7 @@ jobs: toolchain: ${{ matrix.rust }} targets: ${{ matrix.target }} - uses: RustCrypto/actions/cargo-hack-install@master - - run: cargo hack build --target ${{ matrix.target }} --feature-powerset --exclude-features std,default,derive,serde,arbitrary + - run: cargo hack build --target ${{ matrix.target }} --feature-powerset --exclude-features std,default,serde,arbitrary minimal-versions: uses: RustCrypto/actions/.github/workflows/minimal-versions.yml@master From ac5c718f38e1f58b7f3671125ea581ebd3d66552 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Fri, 21 Jul 2023 12:49:24 +0530 Subject: [PATCH 13/14] enable serde feature in no_std environments --- .github/workflows/tls_codec.yml | 2 +- tls_codec/Cargo.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tls_codec.yml b/.github/workflows/tls_codec.yml index c2130551e..d2190ac7f 100644 --- a/.github/workflows/tls_codec.yml +++ b/.github/workflows/tls_codec.yml @@ -36,7 +36,7 @@ jobs: toolchain: ${{ matrix.rust }} targets: ${{ matrix.target }} - uses: RustCrypto/actions/cargo-hack-install@master - - run: cargo hack build --target ${{ matrix.target }} --feature-powerset --exclude-features std,default,serde,arbitrary + - run: cargo hack build --target ${{ matrix.target }} --feature-powerset --exclude-features std,default,arbitrary minimal-versions: uses: RustCrypto/actions/.github/workflows/minimal-versions.yml@master diff --git a/tls_codec/Cargo.toml b/tls_codec/Cargo.toml index 6e30fe3bb..977a0c831 100644 --- a/tls_codec/Cargo.toml +++ b/tls_codec/Cargo.toml @@ -29,9 +29,9 @@ regex = "1.8" default = [ "std" ] arbitrary = [ "std", "dep:arbitrary" ] derive = [ "tls_codec_derive" ] -serde = [ "std", "dep:serde" ] +serde = [] mls = [] # In MLS variable length vectors are limited compared to QUIC. -std = [ "tls_codec_derive?/std" ] +std = [ "tls_codec_derive?/std", "serde?/std" ] [[bench]] name = "tls_vec" From aa9acf7b76bb5c739d3ae6963e9bf723ff915f2d Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Fri, 21 Jul 2023 13:01:19 +0530 Subject: [PATCH 14/14] Revert "enable serde feature in no_std environments" This reverts commit ac5c718f38e1f58b7f3671125ea581ebd3d66552. --- .github/workflows/tls_codec.yml | 2 +- tls_codec/Cargo.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tls_codec.yml b/.github/workflows/tls_codec.yml index d2190ac7f..c2130551e 100644 --- a/.github/workflows/tls_codec.yml +++ b/.github/workflows/tls_codec.yml @@ -36,7 +36,7 @@ jobs: toolchain: ${{ matrix.rust }} targets: ${{ matrix.target }} - uses: RustCrypto/actions/cargo-hack-install@master - - run: cargo hack build --target ${{ matrix.target }} --feature-powerset --exclude-features std,default,arbitrary + - run: cargo hack build --target ${{ matrix.target }} --feature-powerset --exclude-features std,default,serde,arbitrary minimal-versions: uses: RustCrypto/actions/.github/workflows/minimal-versions.yml@master diff --git a/tls_codec/Cargo.toml b/tls_codec/Cargo.toml index 977a0c831..6e30fe3bb 100644 --- a/tls_codec/Cargo.toml +++ b/tls_codec/Cargo.toml @@ -29,9 +29,9 @@ regex = "1.8" default = [ "std" ] arbitrary = [ "std", "dep:arbitrary" ] derive = [ "tls_codec_derive" ] -serde = [] +serde = [ "std", "dep:serde" ] mls = [] # In MLS variable length vectors are limited compared to QUIC. -std = [ "tls_codec_derive?/std", "serde?/std" ] +std = [ "tls_codec_derive?/std" ] [[bench]] name = "tls_vec"