Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/byteorder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
//!
//! ```rust,edition2021
//! # #[cfg(feature = "derive")] { // This example uses derives, and won't compile without them
//! use zerocopy::{IntoBytes, ByteSlice, FromBytes, FromZeros, NoCell, Ref, Unaligned};
//! use zerocopy::{IntoBytes, ByteSlice, FromBytes, NoCell, Ref, Unaligned};
//! use zerocopy::byteorder::network_endian::U16;
//!
//! #[derive(FromZeros, FromBytes, IntoBytes, NoCell, Unaligned)]
//! #[derive(FromBytes, IntoBytes, NoCell, Unaligned)]
//! #[repr(C)]
//! struct UdpHeader {
//! src_port: U16,
Expand Down Expand Up @@ -357,7 +357,7 @@ example of how it can be used for parsing UDP packets.
[`IntoBytes`]: crate::IntoBytes
[`Unaligned`]: crate::Unaligned"),
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
#[cfg_attr(any(feature = "derive", test), derive(KnownLayout, NoCell, TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned))]
#[cfg_attr(any(feature = "derive", test), derive(KnownLayout, NoCell, FromBytes, IntoBytes, Unaligned))]
#[repr(transparent)]
pub struct $name<O>([u8; $bytes], PhantomData<O>);
}
Expand Down
12 changes: 5 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5766,9 +5766,7 @@ mod tests {
//
// This is used to test the custom derives of our traits. The `[u8]` type
// gets a hand-rolled impl, so it doesn't exercise our custom derives.
#[derive(
Debug, Eq, PartialEq, TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned, NoCell,
)]
#[derive(Debug, Eq, PartialEq, FromBytes, IntoBytes, Unaligned, NoCell)]
#[repr(transparent)]
struct Unsized([u8]);

Expand Down Expand Up @@ -7896,7 +7894,7 @@ mod tests {
assert_eq!(too_many_bytes[0], 123);
}

#[derive(Debug, Eq, PartialEq, TryFromBytes, FromZeros, FromBytes, IntoBytes, NoCell)]
#[derive(Debug, Eq, PartialEq, FromBytes, IntoBytes, NoCell)]
#[repr(C)]
struct Foo {
a: u32,
Expand Down Expand Up @@ -7925,7 +7923,7 @@ mod tests {

#[test]
fn test_array() {
#[derive(TryFromBytes, FromZeros, FromBytes, IntoBytes, NoCell)]
#[derive(FromBytes, IntoBytes, NoCell)]
#[repr(C)]
struct Foo {
a: [u16; 33],
Expand Down Expand Up @@ -7989,7 +7987,7 @@ mod tests {

#[test]
fn test_transparent_packed_generic_struct() {
#[derive(IntoBytes, TryFromBytes, FromZeros, FromBytes, Unaligned)]
#[derive(IntoBytes, FromBytes, Unaligned)]
#[repr(transparent)]
struct Foo<T> {
_t: T,
Expand All @@ -7999,7 +7997,7 @@ mod tests {
assert_impl_all!(Foo<u32>: FromZeros, FromBytes, IntoBytes);
assert_impl_all!(Foo<u8>: Unaligned);

#[derive(IntoBytes, TryFromBytes, FromZeros, FromBytes, Unaligned)]
#[derive(IntoBytes, FromBytes, Unaligned)]
#[repr(packed)]
struct Bar<T, U> {
_t: T,
Expand Down
6 changes: 1 addition & 5 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ pub(crate) mod testutil {
#[derive(
KnownLayout,
NoCell,
TryFromBytes,
FromZeros,
FromBytes,
IntoBytes,
Eq,
Expand Down Expand Up @@ -249,9 +247,7 @@ pub(crate) mod testutil {
}
}

#[derive(
NoCell, FromZeros, FromBytes, Eq, PartialEq, Ord, PartialOrd, Default, Debug, Copy, Clone,
)]
#[derive(NoCell, FromBytes, Eq, PartialEq, Ord, PartialOrd, Default, Debug, Copy, Clone)]
#[repr(C)]
pub(crate) struct Nested<T, U: ?Sized> {
_t: T,
Expand Down
2 changes: 1 addition & 1 deletion src/wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use super::*;
#[derive(Default, Copy)]
#[cfg_attr(
any(feature = "derive", test),
derive(NoCell, KnownLayout, TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned)
derive(NoCell, KnownLayout, FromBytes, IntoBytes, Unaligned)
)]
#[repr(C, packed)]
pub struct Unalign<T>(T);
Expand Down
97 changes: 68 additions & 29 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,16 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt

#[proc_macro_derive(FromZeros)]
pub fn derive_from_zeros(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let try_from_bytes = derive_try_from_bytes(ts.clone());

let ast = syn::parse_macro_input!(ts as DeriveInput);
match &ast.data {
let from_zeros = match &ast.data {
Data::Struct(strct) => derive_from_zeros_struct(&ast, strct),
Data::Enum(enm) => derive_from_zeros_enum(&ast, enm),
Data::Union(unn) => derive_from_zeros_union(&ast, unn),
}
.into()
.into();
IntoIterator::into_iter([try_from_bytes, from_zeros]).collect()
}

/// Deprecated: prefer [`FromZeros`] instead.
Expand All @@ -314,13 +317,17 @@ pub fn derive_from_zeroes(ts: proc_macro::TokenStream) -> proc_macro::TokenStrea

#[proc_macro_derive(FromBytes)]
pub fn derive_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let from_zeros = derive_from_zeros(ts.clone());

let ast = syn::parse_macro_input!(ts as DeriveInput);
match &ast.data {
let from_bytes = match &ast.data {
Data::Struct(strct) => derive_from_bytes_struct(&ast, strct),
Data::Enum(enm) => derive_from_bytes_enum(&ast, enm),
Data::Union(unn) => derive_from_bytes_union(&ast, unn),
}
.into()
.into();

IntoIterator::into_iter([from_zeros, from_bytes]).collect()
}

#[proc_macro_derive(IntoBytes)]
Expand Down Expand Up @@ -447,25 +454,33 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2:
.to_compile_error();
}

// We don't actually care what the repr is; we just care that it's one of
// the allowed ones.
try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));
let reprs = try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));

// Figure out whether the enum could in theory implement `FromBytes`.
let from_bytes = enum_size_from_repr(reprs.as_slice())
.map(|size| {
// As of this writing, `enm.is_fieldless()` is redundant since we've
// already checked for it and returned if the check failed. However, if
// we ever remove that check, then without a similar check here, this
// code would become unsound.
enm.is_fieldless() && enm.variants.len() == 1usize << size
})
.unwrap_or(false);

let variant_names = enm.variants.iter().map(|v| &v.ident);
let extras = Some(quote!(
// SAFETY: We use `is_bit_valid` to validate that the bit pattern
// corresponds to one of the field-less enum's variant discriminants.
// Thus, this is a sound implementation of `is_bit_valid`.
fn is_bit_valid(
candidate: ::zerocopy::Ptr<
'_,
Self,
(
::zerocopy::pointer::invariant::Shared,
::zerocopy::pointer::invariant::AnyAlignment,
::zerocopy::pointer::invariant::Initialized,
),
>,
) -> ::zerocopy::macro_util::core_reexport::primitive::bool {
let is_bit_valid_body = if from_bytes {
// If the enum could implement `FromBytes`, we can avoid emitting a
// match statement. This is faster to compile, and generates code which
// performs better.
quote!({
// Prevent an "unused" warning.
let _ = candidate;
// SAFETY: If the enum could implement `FromBytes`, then all bit
// patterns are valid. Thus, this is a sound implementation.
true
})
} else {
quote!(
use ::zerocopy::macro_util::core_reexport;
// SAFETY:
// - `cast` is implemented as required.
Expand Down Expand Up @@ -499,6 +514,25 @@ fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2:
// `candidate` refers to a bit-valid `Self`.
discriminant == d
})*
)
};

let extras = Some(quote!(
// SAFETY: We use `is_bit_valid` to validate that the bit pattern
// corresponds to one of the field-less enum's variant discriminants.
// Thus, this is a sound implementation of `is_bit_valid`.
fn is_bit_valid(
candidate: ::zerocopy::Ptr<
'_,
Self,
(
::zerocopy::pointer::invariant::Shared,
::zerocopy::pointer::invariant::AnyAlignment,
::zerocopy::pointer::invariant::Initialized,
),
>,
) -> ::zerocopy::macro_util::core_reexport::primitive::bool {
#is_bit_valid_body
}
));
impl_block(ast, enm, Trait::TryFromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, extras)
Expand Down Expand Up @@ -608,13 +642,9 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok

let reprs = try_or_print!(ENUM_FROM_BYTES_CFG.validate_reprs(ast));

let variants_required = match reprs.as_slice() {
[EnumRepr::U8] | [EnumRepr::I8] => 1usize << 8,
[EnumRepr::U16] | [EnumRepr::I16] => 1usize << 16,
// `validate_reprs` has already validated that it's one of the preceding
// patterns.
_ => unreachable!(),
};
let variants_required = 1usize
<< enum_size_from_repr(reprs.as_slice())
.expect("internal error: `validate_reprs` has already validated that the reprs guarantee the enum's size");
if enm.variants.len() != variants_required {
return Error::new_spanned(
ast,
Expand All @@ -629,6 +659,15 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok
impl_block(ast, enm, Trait::FromBytes, FieldBounds::ALL_SELF, SelfBounds::None, None, None)
}

// Returns `None` if the enum's size is not guaranteed by the repr.
fn enum_size_from_repr(reprs: &[EnumRepr]) -> Option<usize> {
match reprs {
[EnumRepr::U8] | [EnumRepr::I8] => Some(8),
[EnumRepr::U16] | [EnumRepr::I16] => Some(16),
_ => None,
}
}

#[rustfmt::skip]
const ENUM_FROM_BYTES_CFG: Config<EnumRepr> = {
use EnumRepr::*;
Expand Down
12 changes: 6 additions & 6 deletions zerocopy-derive/tests/enum_from_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ include!("include.rs");
// `Variant128` has a discriminant of -128) since Rust won't automatically wrap
// a signed discriminant around without you explicitly telling it to.

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(u8)]
enum FooU8 {
Variant0,
Expand Down Expand Up @@ -292,7 +292,7 @@ enum FooU8 {

util_assert_impl_all!(FooU8: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(i8)]
enum FooI8 {
Variant0,
Expand Down Expand Up @@ -555,7 +555,7 @@ enum FooI8 {

util_assert_impl_all!(FooI8: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(u8, align(2))]
enum FooU8Align {
Variant0,
Expand Down Expand Up @@ -818,7 +818,7 @@ enum FooU8Align {

util_assert_impl_all!(FooU8Align: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(i8, align(2))]
enum FooI8Align {
Variant0,
Expand Down Expand Up @@ -1081,7 +1081,7 @@ enum FooI8Align {

util_assert_impl_all!(FooI8Align: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(u16)]
enum FooU16 {
Variant0,
Expand Down Expand Up @@ -66624,7 +66624,7 @@ enum FooU16 {

util_assert_impl_all!(FooU16: imp::FromBytes);

#[derive(imp::FromZeros, imp::FromBytes)]
#[derive(imp::FromBytes)]
#[repr(i16)]
enum FooI16 {
Variant0,
Expand Down
9 changes: 1 addition & 8 deletions zerocopy-derive/tests/hygiene.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,7 @@ include!("include.rs");

extern crate zerocopy as _zerocopy;

// #[macro_use]
// mod util;

// use std::{marker::PhantomData, option::IntoIter};

#[derive(
_zerocopy::KnownLayout, _zerocopy::FromZeros, _zerocopy::FromBytes, _zerocopy::Unaligned,
)]
#[derive(_zerocopy::KnownLayout, _zerocopy::FromBytes, _zerocopy::Unaligned)]
#[repr(C)]
struct TypeParams<'a, T, I: imp::Iterator> {
a: T,
Expand Down
2 changes: 0 additions & 2 deletions zerocopy-derive/tests/include.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ pub mod util {
#[derive(
super::imp::KnownLayout,
super::imp::NoCell,
super::imp::TryFromBytes,
super::imp::FromZeros,
super::imp::FromBytes,
super::imp::IntoBytes,
Copy,
Expand Down
6 changes: 3 additions & 3 deletions zerocopy-derive/tests/paths_and_modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ include!("include.rs");
mod foo {
use super::*;

#[derive(imp::FromZeros, imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[derive(imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[repr(C)]
pub struct Foo {
foo: u8,
}

#[derive(imp::FromZeros, imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[derive(imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[repr(C)]
pub struct Bar {
bar: u8,
Expand All @@ -32,7 +32,7 @@ mod foo {

use foo::Foo;

#[derive(imp::FromZeros, imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[derive(imp::FromBytes, imp::IntoBytes, imp::Unaligned)]
#[repr(C)]
struct Baz {
foo: Foo,
Expand Down
Loading