diff --git a/Cargo.lock b/Cargo.lock index 11ba1835c..2206c016d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3286,6 +3286,7 @@ dependencies = [ "bytes", "chrono", "pgdog-vector", + "postgres-types", "rust_decimal", "schemars", "serde", diff --git a/pgdog-postgres-types/Cargo.toml b/pgdog-postgres-types/Cargo.toml index 22381c9eb..8c5042bad 100644 --- a/pgdog-postgres-types/Cargo.toml +++ b/pgdog-postgres-types/Cargo.toml @@ -12,3 +12,4 @@ serde = { version = "*", features = ["derive"]} schemars.workspace = true pgdog-vector = { path = "../pgdog-vector" } rust_decimal = { version = "1.36", features = ["db-postgres"] } +postgres-types = "0.2.13" diff --git a/pgdog-postgres-types/src/numeric.rs b/pgdog-postgres-types/src/numeric.rs index b40884fa8..1ebe49f65 100644 --- a/pgdog-postgres-types/src/numeric.rs +++ b/pgdog-postgres-types/src/numeric.rs @@ -1,6 +1,7 @@ use std::{cmp::Ordering, fmt::Display, hash::Hash, ops::Add, str::FromStr}; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use postgres_types::{FromSql, ToSql, Type}; use rust_decimal::Decimal; use serde::Deserialize; use serde::{ @@ -12,6 +13,9 @@ use crate::Data; use super::*; +/// Internal sign representation of NaN +const POSTGRES_NAN: u16 = 0xC000; + /// Enum to represent different numeric values including NaN. #[derive(Copy, Clone, Debug)] enum NumericValue { @@ -85,17 +89,13 @@ impl Ord for Numeric { } impl Add for Numeric { - type Output = Numeric; + type Output = Self; fn add(self, rhs: Self) -> Self::Output { match (self.value, rhs.value) { - (NumericValue::Number(a), NumericValue::Number(b)) => Numeric { - value: NumericValue::Number(a + b), - }, + (NumericValue::Number(a), NumericValue::Number(b)) => Self::new(a + b), // Any operation with NaN yields NaN - _ => Numeric { - value: NumericValue::NaN, - }, + _ => Self::nan(), } } } @@ -141,169 +141,24 @@ impl FromDataType for Numeric { if dscale < 0 { return Err(Error::UnexpectedPayload); } - // Downstream arithmetic computes (-weight - 1) * 4 as i16. - // That overflows when -weight - 1 > 8191, i.e. weight < -8192. - if weight < -8192 { - return Err(Error::UnexpectedPayload); - } - - // Handle special sign values using pattern matching - let is_negative = match sign { - 0x0000 => false, // Positive - 0x4000 => true, // Negative - 0xC000 => { - // NaN value - ndigits should be 0 for NaN - if ndigits != 0 { - return Err(Error::UnexpectedPayload); - } - return Ok(Self { - value: NumericValue::NaN, - }); - } - _ => { - // Invalid sign value - return Err(Error::UnexpectedPayload); - } - }; - - if ndigits == 0 { - return Ok(Self { - value: NumericValue::Number(Decimal::ZERO), - }); - } - - if buf.len() < (ndigits as usize) * 2 { - return Err(Error::WrongSizeBinary(bytes.len())); - } - - // Read digits (base 10000) - let mut digits = Vec::with_capacity(ndigits as usize); - for _ in 0..ndigits { - digits.push(buf.get_i16()); - } - - // Reconstruct the decimal number from base-10000 digits - let mut result = String::new(); - - if is_negative { - result.push('-'); - } - - // PostgreSQL format with dscale: - // - Integer digits represent the integer part - // - Fractional digits are stored after the integer digits - // - dscale tells us how many decimal places to extract from fractional digits - - // Build the integer part - let mut integer_str = String::new(); - let mut fractional_str = String::new(); - - // Determine how many digits are for the integer part - // Weight tells us the position of the first digit - let integer_digit_count = if weight >= 0 { - // Check for overflow before adding - if weight == i16::MAX { - return Err(Error::UnexpectedPayload); - } - (weight + 1) as usize - } else { - 0 - }; - - // Process integer digits - for (i, digit) in digits - .iter() - .enumerate() - .take(integer_digit_count.min(digits.len())) - { - if i == 0 && *digit < 1000 && weight >= 0 { - // First digit, no leading zeros - integer_str.push_str(&digit.to_string()); - } else { - // Subsequent digits or first digit >= 1000 - if i == 0 && weight >= 0 { - integer_str.push_str(&digit.to_string()); - } else { - integer_str.push_str(&format!("{:04}", digit)); - } - } - } - // Add trailing zeros for missing integer digits - if (0..i16::MAX).contains(&weight) { - let expected_integer_digits = (weight + 1) as usize; - for _ in digits.len()..expected_integer_digits { - integer_str.push_str("0000"); - } - } - - // Process fractional digits - for digit in digits.iter().skip(integer_digit_count) { - fractional_str.push_str(&format!("{:04}", digit)); + // Verify that weight is within a reasonable range to prevent + // decoding from hitting an overflow + if weight < -8192 || weight == i16::MAX { + return Err(Error::UnexpectedPayload); } - // Build final result based on dscale - if dscale == 0 { - // Pure integer - if !integer_str.is_empty() { - result.push_str(&integer_str); + if sign == POSTGRES_NAN { + if ndigits == 0 { + Ok(Self::nan()) } else { - result.push('0'); - } - } else if weight < 0 { - // Pure fractional (weight < 0) - result.push_str("0."); - - // For negative weight, add leading zeros - // Each negative weight unit represents 4 decimal places - let leading_zeros = ((-weight - 1) * 4) as usize; - for _ in 0..leading_zeros { - result.push('0'); - } - - // We've added `leading_zeros` decimal places so far - // We need `dscale` total decimal places - // Calculate how many more we need from fractional_str - let remaining_needed = (dscale as usize).saturating_sub(leading_zeros); - - if remaining_needed > 0 { - // Add digits from fractional_str, up to remaining_needed - let to_take = remaining_needed.min(fractional_str.len()); - result.push_str(&fractional_str[..to_take]); - - // Pad with zeros if we don't have enough digits - for _ in to_take..remaining_needed { - result.push('0'); - } + Err(Error::UnexpectedPayload) } - // If remaining_needed is 0, we've already added enough leading zeros } else { - // Mixed integer and fractional - if !integer_str.is_empty() { - result.push_str(&integer_str); - } else { - result.push('0'); - } - - if dscale > 0 { - result.push('.'); - // Take exactly dscale digits from fractional part - if fractional_str.len() >= dscale as usize { - result.push_str(&fractional_str[..dscale as usize]); - } else { - result.push_str(&fractional_str); - // Pad with zeros if needed - for _ in fractional_str.len()..(dscale as usize) { - result.push('0'); - } - } - } + Decimal::from_sql(&Type::NUMERIC, bytes) + .map(Self::new) + .map_err(|_| Error::UnexpectedPayload) } - - let decimal = Decimal::from_str(&result).map_err(|_| Error::UnexpectedPayload)?; - Ok(Self { - value: NumericValue::Number(decimal), - }) } } } @@ -320,125 +175,16 @@ impl FromDataType for Numeric { let mut buf = BytesMut::new(); buf.put_i16(0); // ndigits buf.put_i16(0); // weight - buf.put_u16(0xC000); // NaN sign + buf.put_u16(POSTGRES_NAN); // NaN sign buf.put_i16(0); // dscale Ok(buf.freeze()) } NumericValue::Number(decimal) => { - // Handle zero case - if decimal.is_zero() { - let mut buf = BytesMut::new(); - buf.put_i16(0); // ndigits - buf.put_i16(0); // weight - buf.put_u16(0); // sign (positive) - buf.put_i16(0); // dscale - return Ok(buf.freeze()); - } - - // Handle all numbers (integers and decimals, positive and negative) - let is_negative = decimal.is_sign_negative(); - let abs_decimal = decimal.abs(); - let decimal_str = abs_decimal.to_string(); - - // Split into integer and fractional parts - let parts: Vec<&str> = decimal_str.split('.').collect(); - let integer_part = parts[0]; - let fractional_part = parts.get(1).unwrap_or(&""); - let dscale = fractional_part.len() as i16; - - // PostgreSQL keeps integer and fractional parts separate - // Process them independently to match PostgreSQL's format - - // Process integer part (right to left, in groups of 4) - let mut integer_digits = Vec::new(); - - if integer_part != "0" { - let int_chars: Vec = integer_part.chars().collect(); - let mut pos = int_chars.len(); - - while pos > 0 { - let start = pos.saturating_sub(4); - let chunk: String = int_chars[start..pos].iter().collect(); - let digit_value: i16 = - chunk.parse().map_err(|_| Error::UnexpectedPayload)?; - integer_digits.insert(0, digit_value); - pos = start; - } - } - - // Process fractional part (left to right, in groups of 4) - let mut fractional_digits = Vec::new(); - if !fractional_part.is_empty() { - let frac_chars: Vec = fractional_part.chars().collect(); - let mut pos = 0; - - while pos < frac_chars.len() { - let end = std::cmp::min(pos + 4, frac_chars.len()); - let mut chunk: String = frac_chars[pos..end].iter().collect(); - - // Pad the last chunk with zeros if needed - while chunk.len() < 4 { - chunk.push('0'); - } - - let digit_value: i16 = - chunk.parse().map_err(|_| Error::UnexpectedPayload)?; - fractional_digits.push(digit_value); - pos = end; - } - } - - // Calculate initial weight before optimization - let initial_weight = if integer_part == "0" || integer_part.is_empty() { - // Pure fractional number - weight is negative - -1 - } else { - // Based on number of integer digits - integer_digits.len() as i16 - 1 - }; - - // Combine integer and fractional parts - let mut digits = integer_digits; - digits.extend(fractional_digits.clone()); - - // PostgreSQL optimization: if we have no fractional part and integer part - // has trailing zeros, we can remove them and adjust the weight - let weight = if fractional_digits.is_empty() - && !digits.is_empty() - && initial_weight >= 0 - { - // Count and remove trailing zero i16 values - let original_len = digits.len(); - while digits.len() > 1 && digits.last() == Some(&0) { - digits.pop(); - } - let _removed_count = (original_len - digits.len()) as i16; - // Weight stays the same even after removing trailing zeros - // because weight represents the position of the first digit - initial_weight - } else { - initial_weight - }; - - if digits.is_empty() { - digits.push(0); - } - let mut buf = BytesMut::new(); - let ndigits = digits.len() as i16; - let sign = if is_negative { 0x4000_u16 } else { 0_u16 }; - - buf.put_i16(ndigits); - buf.put_i16(weight); - buf.put_u16(sign); - buf.put_i16(dscale); - - // Write all digits - for digit in digits { - buf.put_i16(digit); - } - - Ok(buf.freeze()) + decimal + .to_sql(&Type::NUMERIC, &mut buf) + .map(|_| buf.freeze()) + .map_err(|_| Error::UnexpectedPayload) } }, } @@ -503,14 +249,19 @@ impl From for Numeric { impl From for Numeric { fn from(value: Decimal) -> Self { - Self { - value: NumericValue::Number(value), - } + Self::new(value) } } // Helper methods for Numeric impl Numeric { + /// Create a new Numeric value + pub fn new(value: Decimal) -> Self { + Self { + value: NumericValue::Number(value), + } + } + /// Create a NaN Numeric value pub fn nan() -> Self { Self { @@ -704,11 +455,11 @@ mod tests { }, TestCase { value: "10000", - expected_ndigits: 1, // PostgreSQL uses 1 digit with weight=1 + expected_ndigits: 2, // Decimal crate adds an extra zero expected_weight: 1, expected_sign: 0x0000, expected_dscale: 0, - expected_digits: vec![1], // PostgreSQL format: [1] + expected_digits: vec![1, 0], // [1] and [1, 0] are equivalent }, TestCase { value: "0.0001", @@ -728,11 +479,11 @@ mod tests { }, TestCase { value: "100000000000000000000", // 10^20 - expected_ndigits: 1, + expected_ndigits: 6, expected_weight: 5, expected_sign: 0x0000, expected_dscale: 0, - expected_digits: vec![1], // Just [1] with weight=5 + expected_digits: vec![1, 0, 0, 0, 0, 0], // Decimal crate adds a lot of extra zeroes but it's fine }, ];