diff --git a/.gitignore b/.gitignore index f2c1f91..303cd33 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.dylib *.so *.whl +*.xlsx coverage.html/* _cache/* .coverage diff --git a/_unittests/ut_validation/test_f8.py b/_unittests/ut_validation/test_f8.py index 85f27aa..80611b5 100644 --- a/_unittests/ut_validation/test_f8.py +++ b/_unittests/ut_validation/test_f8.py @@ -1246,6 +1246,21 @@ def test_nan(self): f8 = float32_to_fe4m3(x) self.assertEqual(e, f8) + def test_negative_zero_uz(self): + self.assertEqual(numpy.float32(-0.0), numpy.float32(0.0)) + self.assertEqual(float32_to_fe4m3(-0.00000001, fn=True, uz=False), 128) + self.assertEqual(float32_to_fe4m3(0.00000001, fn=True, uz=True), 0) + self.assertEqual(float32_to_fe4m3(-0.00000001, fn=True, uz=True), 0) + self.assertEqual(float32_to_fe5m2(-0.00000001, fn=False, uz=False), 128) + self.assertEqual(float32_to_fe5m2(0.00000001, fn=True, uz=True), 0) + self.assertEqual(float32_to_fe5m2(-0.00000001, fn=True, uz=True), 0) + self.assertEqual(float32_to_fe4m3(-0.0001, fn=True, uz=False), 128) + self.assertEqual(float32_to_fe4m3(-0.0001, fn=True, uz=True), 0) + self.assertEqual(search_float32_into_fe4m3(-0.0001, fn=True, uz=False), 128) + self.assertEqual(search_float32_into_fe4m3(-0.0001, fn=True, uz=True), 0) + self.assertEqual(search_float32_into_fe5m2(-0.000001, fn=False, uz=False), 128) + self.assertEqual(search_float32_into_fe5m2(-0.000001, fn=True, uz=True), 0) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/validation/f8.py b/onnx_array_api/validation/f8.py index c630807..ecd68f8 100644 --- a/onnx_array_api/validation/f8.py +++ b/onnx_array_api/validation/f8.py @@ -445,6 +445,11 @@ def search_float32_into_fe4m3( return (max_value[1] | ret) if saturate else 0x7F | ret f = numpy.float32(value) i = CastFloat8.find_closest_value(f, set_values) + if uz: + ic = i & 0x7F + if ic == 0: + return 0 + return ic | ret return (i & 0x7F) | ret @@ -488,6 +493,11 @@ def search_float32_into_fe5m2( f = numpy.float32(value) i = CastFloat8.find_closest_value(f, set_values) + if uz: + ic = i & 0x7F + if ic == 0: + return 0 + return ic | ret return (i & 0x7F) | ret @@ -518,47 +528,45 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa - if e != 0: - if e < 116: - pass - elif e < 120: - # denormalized number - ex = e - 119 - if ex >= -2: - ret |= 1 << (2 + ex) - ret |= m >> (21 - ex) - elif m > 0: - ret |= 1 - mask = 1 << (20 - ex) - if m & mask and ( - ret & 1 - or m & (mask - 1) > 0 - or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) - ): + if e < 116: + ret = 0 + elif e < 120: + # denormalized number + ex = e - 119 + if ex >= -2: + ret |= 1 << (2 + ex) + ret |= m >> (21 - ex) + elif m > 0: + ret |= 1 + else: + ret = 0 + mask = 1 << (20 - ex) + if m & mask and ( + ret & 1 + or m & (mask - 1) > 0 + or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) + ): + # rounding + ret += 1 + elif e < 135: + # normalized number + ex = e - 119 # 127 - 8 + if ex == 0: + ret |= 0x4 + ret |= m >> 21 + else: + ret |= ex << 3 + ret |= m >> 20 + if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)): + if (ret & 0x7F) < 0x7F: # rounding ret += 1 - elif e < 135: - # normalized number - ex = e - 119 # 127 - 8 - if ex == 0: - ret |= 0x4 - ret |= m >> 21 - else: - ret |= ex << 3 - ret |= m >> 20 - if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)): - if (ret & 0x7F) < 0x7F: - # rounding - ret += 1 - elif not saturate: - return 0x80 - elif saturate: - ret |= 0x7F # 01111110 - else: - ret = 0x80 - elif m == 0: - # -0 - ret = 0 + elif not saturate: + return 0x80 + elif saturate: + ret |= 0x7F # 01111110 + else: + ret = 0x80 return int(ret) else: if (b & 0x7FFFFFFF) == 0x7F800000: @@ -640,45 +648,43 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa - if e != 0: - if e < 109: - pass - elif e < 112: - # denormalized number - ex = e - 111 - if ex >= -1: - ret |= 1 << (1 + ex) - ret |= m >> (22 - ex) - elif m > 0: - ret |= 1 - mask = 1 << (21 - ex) - if m & mask and ( - ret & 1 - or m & (mask - 1) > 0 - or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) - ): + if e < 109: + ret = 0 + elif e < 112: + # denormalized number + ex = e - 111 + if ex >= -1: + ret |= 1 << (1 + ex) + ret |= m >> (22 - ex) + elif m > 0: + ret |= 1 + else: + ret = 0 + mask = 1 << (21 - ex) + if m & mask and ( + ret & 1 + or m & (mask - 1) > 0 + or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) + ): + # rounding + ret += 1 + elif e < 143: + # normalized number + ex = e - 111 + ret |= ex << 2 + ret |= m >> 21 + if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): + if (ret & 0x7F) < 0x7F: # rounding ret += 1 - elif e < 143: - # normalized number - ex = e - 111 - ret |= ex << 2 - ret |= m >> 21 - if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): - if (ret & 0x7F) < 0x7F: - # rounding - ret += 1 - elif not saturate: - ret = 0x80 - elif e == 255 and m == 0: # inf - ret = 0x80 - elif saturate: - ret |= 0x7F # last possible number - else: - ret = 0x80 - elif m == 0: - # -0 - ret = 0 + elif not saturate: + ret = 0x80 + elif e == 255 and m == 0: # inf + ret = 0x80 + elif saturate: + ret |= 0x7F # last possible number + else: + ret = 0x80 return int(ret) elif not fn and not uz: if (b & 0x7FFFFFFF) == 0x7F800000: