diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 6bde6e1b472..c103ae26221 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -110,7 +110,9 @@ impl DType { .all(|(l, r)| l.eq_ignore_nullability(&r))) } (Struct(..), _) => false, - (Extension(lhs_extdtype), Extension(rhs_extdtype)) => lhs_extdtype == rhs_extdtype, + (Extension(lhs_extdtype), Extension(rhs_extdtype)) => { + lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype) + } (Extension(_), _) => false, } } diff --git a/vortex-dtype/src/extension.rs b/vortex-dtype/src/extension.rs index 305b3dff4df..a02475683f4 100644 --- a/vortex-dtype/src/extension.rs +++ b/vortex-dtype/src/extension.rs @@ -58,7 +58,7 @@ impl From<&[u8]> for ExtMetadata { } /// A type descriptor for an extension type -#[derive(Debug, Clone, PartialOrd, Eq)] +#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct ExtDType { id: ExtID, @@ -66,18 +66,6 @@ pub struct ExtDType { metadata: Option, } -impl PartialEq for ExtDType { - fn eq(&self, other: &Self) -> bool { - self.id == other.id - } -} - -impl std::hash::Hash for ExtDType { - fn hash(&self, state: &mut H) { - self.id.hash(state); - } -} - impl ExtDType { /// Creates a new `ExtDType`. /// @@ -148,4 +136,13 @@ impl ExtDType { pub fn metadata(&self) -> Option<&ExtMetadata> { self.metadata.as_ref() } + + /// Check if `self` and `other` are equal, ignoring the storage nullability + pub fn eq_ignore_nullability(&self, other: &Self) -> bool { + self.id() == other.id() + && self.metadata() == other.metadata() + && self + .storage_dtype() + .eq_ignore_nullability(other.storage_dtype()) + } }