From 63b768bb849879886df6c71f438b9b7bff04686a Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Tue, 19 May 2026 15:44:11 -0600 Subject: [PATCH] Use rust_decimal's builtin PG encode/decode We had enabled the feature on the rust_decimal crate to compile these functions, but were never actually using them. The postgres-types crate already existed as a transitive dependency of rust_decicmal, it's only added to our Cargo.toml so we can use the trait impls that decimal has. Our tests for all the various malformed inputs were helpful here, the only one I think is excessive is rejecting NaN with digits. All the other cases caught places where the decimal crate hit unchecked overflow. It wouldn't have resulted in UB but the fact that it returns nonsense instead of an error is something that's good for us to catch. This does slightly change our behavior, as their encoding implementation fails to remove leading zeroes from the output when the result is an exact power of 10000. This shouldn't matter, as the two values are completely equivalent as far as PG is concerned (and I've verified as such manually). I've updated the test case to reflect this behavior, though I think if we accept this change we should probably just delete it since we're testing a crate's behavior and not our own. --- Cargo.lock | 1 + pgdog-postgres-types/Cargo.toml | 1 + pgdog-postgres-types/src/numeric.rs | 319 +++------------------------- 3 files changed, 37 insertions(+), 284 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 11ba1835c..2206c016d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3286,6 +3286,7 @@ dependencies = [ "bytes", "chrono", "pgdog-vector", + "postgres-types", "rust_decimal", "schemars", "serde", diff --git a/pgdog-postgres-types/Cargo.toml b/pgdog-postgres-types/Cargo.toml index 22381c9eb..8c5042bad 100644 --- a/pgdog-postgres-types/Cargo.toml +++ b/pgdog-postgres-types/Cargo.toml @@ -12,3 +12,4 @@ serde = { version = "*", features = ["derive"]} schemars.workspace = true pgdog-vector = { path = "../pgdog-vector" } rust_decimal = { version = "1.36", features = ["db-postgres"] } +postgres-types = "0.2.13" diff --git a/pgdog-postgres-types/src/numeric.rs b/pgdog-postgres-types/src/numeric.rs index b40884fa8..1ebe49f65 100644 --- a/pgdog-postgres-types/src/numeric.rs +++ b/pgdog-postgres-types/src/numeric.rs @@ -1,6 +1,7 @@ use std::{cmp::Ordering, fmt::Display, hash::Hash, ops::Add, str::FromStr}; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use postgres_types::{FromSql, ToSql, Type}; use rust_decimal::Decimal; use serde::Deserialize; use serde::{ @@ -12,6 +13,9 @@ use crate::Data; use super::*; +/// Internal sign representation of NaN +const POSTGRES_NAN: u16 = 0xC000; + /// Enum to represent different numeric values including NaN. #[derive(Copy, Clone, Debug)] enum NumericValue { @@ -85,17 +89,13 @@ impl Ord for Numeric { } impl Add for Numeric { - type Output = Numeric; + type Output = Self; fn add(self, rhs: Self) -> Self::Output { match (self.value, rhs.value) { - (NumericValue::Number(a), NumericValue::Number(b)) => Numeric { - value: NumericValue::Number(a + b), - }, + (NumericValue::Number(a), NumericValue::Number(b)) => Self::new(a + b), // Any operation with NaN yields NaN - _ => Numeric { - value: NumericValue::NaN, - }, + _ => Self::nan(), } } } @@ -141,169 +141,24 @@ impl FromDataType for Numeric { if dscale < 0 { return Err(Error::UnexpectedPayload); } - // Downstream arithmetic computes (-weight - 1) * 4 as i16. - // That overflows when -weight - 1 > 8191, i.e. weight < -8192. - if weight < -8192 { - return Err(Error::UnexpectedPayload); - } - - // Handle special sign values using pattern matching - let is_negative = match sign { - 0x0000 => false, // Positive - 0x4000 => true, // Negative - 0xC000 => { - // NaN value - ndigits should be 0 for NaN - if ndigits != 0 { - return Err(Error::UnexpectedPayload); - } - return Ok(Self { - value: NumericValue::NaN, - }); - } - _ => { - // Invalid sign value - return Err(Error::UnexpectedPayload); - } - }; - - if ndigits == 0 { - return Ok(Self { - value: NumericValue::Number(Decimal::ZERO), - }); - } - - if buf.len() < (ndigits as usize) * 2 { - return Err(Error::WrongSizeBinary(bytes.len())); - } - - // Read digits (base 10000) - let mut digits = Vec::with_capacity(ndigits as usize); - for _ in 0..ndigits { - digits.push(buf.get_i16()); - } - - // Reconstruct the decimal number from base-10000 digits - let mut result = String::new(); - - if is_negative { - result.push('-'); - } - - // PostgreSQL format with dscale: - // - Integer digits represent the integer part - // - Fractional digits are stored after the integer digits - // - dscale tells us how many decimal places to extract from fractional digits - - // Build the integer part - let mut integer_str = String::new(); - let mut fractional_str = String::new(); - - // Determine how many digits are for the integer part - // Weight tells us the position of the first digit - let integer_digit_count = if weight >= 0 { - // Check for overflow before adding - if weight == i16::MAX { - return Err(Error::UnexpectedPayload); - } - (weight + 1) as usize - } else { - 0 - }; - - // Process integer digits - for (i, digit) in digits - .iter() - .enumerate() - .take(integer_digit_count.min(digits.len())) - { - if i == 0 && *digit < 1000 && weight >= 0 { - // First digit, no leading zeros - integer_str.push_str(&digit.to_string()); - } else { - // Subsequent digits or first digit >= 1000 - if i == 0 && weight >= 0 { - integer_str.push_str(&digit.to_string()); - } else { - integer_str.push_str(&format!("{:04}", digit)); - } - } - } - // Add trailing zeros for missing integer digits - if (0..i16::MAX).contains(&weight) { - let expected_integer_digits = (weight + 1) as usize; - for _ in digits.len()..expected_integer_digits { - integer_str.push_str("0000"); - } - } - - // Process fractional digits - for digit in digits.iter().skip(integer_digit_count) { - fractional_str.push_str(&format!("{:04}", digit)); + // Verify that weight is within a reasonable range to prevent + // decoding from hitting an overflow + if weight < -8192 || weight == i16::MAX { + return Err(Error::UnexpectedPayload); } - // Build final result based on dscale - if dscale == 0 { - // Pure integer - if !integer_str.is_empty() { - result.push_str(&integer_str); + if sign == POSTGRES_NAN { + if ndigits == 0 { + Ok(Self::nan()) } else { - result.push('0'); - } - } else if weight < 0 { - // Pure fractional (weight < 0) - result.push_str("0."); - - // For negative weight, add leading zeros - // Each negative weight unit represents 4 decimal places - let leading_zeros = ((-weight - 1) * 4) as usize; - for _ in 0..leading_zeros { - result.push('0'); - } - - // We've added `leading_zeros` decimal places so far - // We need `dscale` total decimal places - // Calculate how many more we need from fractional_str - let remaining_needed = (dscale as usize).saturating_sub(leading_zeros); - - if remaining_needed > 0 { - // Add digits from fractional_str, up to remaining_needed - let to_take = remaining_needed.min(fractional_str.len()); - result.push_str(&fractional_str[..to_take]); - - // Pad with zeros if we don't have enough digits - for _ in to_take..remaining_needed { - result.push('0'); - } + Err(Error::UnexpectedPayload) } - // If remaining_needed is 0, we've already added enough leading zeros } else { - // Mixed integer and fractional - if !integer_str.is_empty() { - result.push_str(&integer_str); - } else { - result.push('0'); - } - - if dscale > 0 { - result.push('.'); - // Take exactly dscale digits from fractional part - if fractional_str.len() >= dscale as usize { - result.push_str(&fractional_str[..dscale as usize]); - } else { - result.push_str(&fractional_str); - // Pad with zeros if needed - for _ in fractional_str.len()..(dscale as usize) { - result.push('0'); - } - } - } + Decimal::from_sql(&Type::NUMERIC, bytes) + .map(Self::new) + .map_err(|_| Error::UnexpectedPayload) } - - let decimal = Decimal::from_str(&result).map_err(|_| Error::UnexpectedPayload)?; - Ok(Self { - value: NumericValue::Number(decimal), - }) } } } @@ -320,125 +175,16 @@ impl FromDataType for Numeric { let mut buf = BytesMut::new(); buf.put_i16(0); // ndigits buf.put_i16(0); // weight - buf.put_u16(0xC000); // NaN sign + buf.put_u16(POSTGRES_NAN); // NaN sign buf.put_i16(0); // dscale Ok(buf.freeze()) } NumericValue::Number(decimal) => { - // Handle zero case - if decimal.is_zero() { - let mut buf = BytesMut::new(); - buf.put_i16(0); // ndigits - buf.put_i16(0); // weight - buf.put_u16(0); // sign (positive) - buf.put_i16(0); // dscale - return Ok(buf.freeze()); - } - - // Handle all numbers (integers and decimals, positive and negative) - let is_negative = decimal.is_sign_negative(); - let abs_decimal = decimal.abs(); - let decimal_str = abs_decimal.to_string(); - - // Split into integer and fractional parts - let parts: Vec<&str> = decimal_str.split('.').collect(); - let integer_part = parts[0]; - let fractional_part = parts.get(1).unwrap_or(&""); - let dscale = fractional_part.len() as i16; - - // PostgreSQL keeps integer and fractional parts separate - // Process them independently to match PostgreSQL's format - - // Process integer part (right to left, in groups of 4) - let mut integer_digits = Vec::new(); - - if integer_part != "0" { - let int_chars: Vec = integer_part.chars().collect(); - let mut pos = int_chars.len(); - - while pos > 0 { - let start = pos.saturating_sub(4); - let chunk: String = int_chars[start..pos].iter().collect(); - let digit_value: i16 = - chunk.parse().map_err(|_| Error::UnexpectedPayload)?; - integer_digits.insert(0, digit_value); - pos = start; - } - } - - // Process fractional part (left to right, in groups of 4) - let mut fractional_digits = Vec::new(); - if !fractional_part.is_empty() { - let frac_chars: Vec = fractional_part.chars().collect(); - let mut pos = 0; - - while pos < frac_chars.len() { - let end = std::cmp::min(pos + 4, frac_chars.len()); - let mut chunk: String = frac_chars[pos..end].iter().collect(); - - // Pad the last chunk with zeros if needed - while chunk.len() < 4 { - chunk.push('0'); - } - - let digit_value: i16 = - chunk.parse().map_err(|_| Error::UnexpectedPayload)?; - fractional_digits.push(digit_value); - pos = end; - } - } - - // Calculate initial weight before optimization - let initial_weight = if integer_part == "0" || integer_part.is_empty() { - // Pure fractional number - weight is negative - -1 - } else { - // Based on number of integer digits - integer_digits.len() as i16 - 1 - }; - - // Combine integer and fractional parts - let mut digits = integer_digits; - digits.extend(fractional_digits.clone()); - - // PostgreSQL optimization: if we have no fractional part and integer part - // has trailing zeros, we can remove them and adjust the weight - let weight = if fractional_digits.is_empty() - && !digits.is_empty() - && initial_weight >= 0 - { - // Count and remove trailing zero i16 values - let original_len = digits.len(); - while digits.len() > 1 && digits.last() == Some(&0) { - digits.pop(); - } - let _removed_count = (original_len - digits.len()) as i16; - // Weight stays the same even after removing trailing zeros - // because weight represents the position of the first digit - initial_weight - } else { - initial_weight - }; - - if digits.is_empty() { - digits.push(0); - } - let mut buf = BytesMut::new(); - let ndigits = digits.len() as i16; - let sign = if is_negative { 0x4000_u16 } else { 0_u16 }; - - buf.put_i16(ndigits); - buf.put_i16(weight); - buf.put_u16(sign); - buf.put_i16(dscale); - - // Write all digits - for digit in digits { - buf.put_i16(digit); - } - - Ok(buf.freeze()) + decimal + .to_sql(&Type::NUMERIC, &mut buf) + .map(|_| buf.freeze()) + .map_err(|_| Error::UnexpectedPayload) } }, } @@ -503,14 +249,19 @@ impl From for Numeric { impl From for Numeric { fn from(value: Decimal) -> Self { - Self { - value: NumericValue::Number(value), - } + Self::new(value) } } // Helper methods for Numeric impl Numeric { + /// Create a new Numeric value + pub fn new(value: Decimal) -> Self { + Self { + value: NumericValue::Number(value), + } + } + /// Create a NaN Numeric value pub fn nan() -> Self { Self { @@ -704,11 +455,11 @@ mod tests { }, TestCase { value: "10000", - expected_ndigits: 1, // PostgreSQL uses 1 digit with weight=1 + expected_ndigits: 2, // Decimal crate adds an extra zero expected_weight: 1, expected_sign: 0x0000, expected_dscale: 0, - expected_digits: vec![1], // PostgreSQL format: [1] + expected_digits: vec![1, 0], // [1] and [1, 0] are equivalent }, TestCase { value: "0.0001", @@ -728,11 +479,11 @@ mod tests { }, TestCase { value: "100000000000000000000", // 10^20 - expected_ndigits: 1, + expected_ndigits: 6, expected_weight: 5, expected_sign: 0x0000, expected_dscale: 0, - expected_digits: vec![1], // Just [1] with weight=5 + expected_digits: vec![1, 0, 0, 0, 0, 0], // Decimal crate adds a lot of extra zeroes but it's fine }, ];