spirv-as: Avoid overflow when parsing exponents on hex floats (#4874)

* spirv-as: Avoid overflow when parsing exponents on hex floats

When an exponent is so large that it would overflow the int
type in the parser, saturate the exponent.
This allows extremely large exponents, and saturates
to infinity when the exponent is positive, and zero when the exponent
is negative.

Fixes #4721.

* Avoid unexpected narrowing conversions from arithmetic operations

Co-authored-by: Alastair F. Donaldson <alastair.donaldson@imperial.ac.uk>
diff --git a/source/util/hex_float.h b/source/util/hex_float.h
index 903b628..06e3c57 100644
--- a/source/util/hex_float.h
+++ b/source/util/hex_float.h
@@ -209,9 +209,10 @@
 // be the default for any non-specialized type.
 template <typename T>
 struct HexFloatTraits {
-  // Integer type that can store this hex-float.
+  // Integer type that can store the bit representation of this hex-float.
   using uint_type = void;
-  // Signed integer type that can store this hex-float.
+  // Signed integer type that can store the bit representation of this
+  // hex-float.
   using int_type = void;
   // The numerical type that this HexFloat represents.
   using underlying_type = void;
@@ -958,9 +959,15 @@
   // This "looks" like a hex-float so treat it as one.
   bool seen_p = false;
   bool seen_dot = false;
+
+  // The mantissa bits, without the most significant 1 bit, and with the
+  // the most recently read bits in the least significant positions.
+  uint_type fraction = 0;
+  // The number of mantissa bits that have been read, including the leading 1
+  // bit that is not written into 'fraction'.
   uint_type fraction_index = 0;
 
-  uint_type fraction = 0;
+  // TODO(dneto): handle overflow and underflow
   int_type exponent = HF::exponent_bias;
 
   // Strip off leading zeros so we don't have to special-case them later.
@@ -968,11 +975,13 @@
     is.get();
   }
 
-  bool is_denorm =
-      true;  // Assume denorm "representation" until we hear otherwise.
-             // NB: This does not mean the value is actually denorm,
-             // it just means that it was written 0.
+  // Does the mantissa, as written, have non-zero digits to the left of
+  // the decimal point.  Assume no until proven otherwise.
+  bool has_integer_part = false;
   bool bits_written = false;  // Stays false until we write a bit.
+
+  // Scan the mantissa hex digits until we see a '.' or the 'p' that
+  // starts the exponent.
   while (!seen_p && !seen_dot) {
     // Handle characters that are left of the fractional part.
     if (next_char == '.') {
@@ -980,9 +989,8 @@
     } else if (next_char == 'p') {
       seen_p = true;
     } else if (::isxdigit(next_char)) {
-      // We know this is not denormalized since we have stripped all leading
-      // zeroes and we are not a ".".
-      is_denorm = false;
+      // We have stripped all leading zeroes and we have not yet seen a ".".
+      has_integer_part = true;
       int number = get_nibble_from_character(next_char);
       for (int i = 0; i < 4; ++i, number <<= 1) {
         uint_type write_bit = (number & 0x8) ? 0x1 : 0x0;
@@ -993,8 +1001,12 @@
               fraction |
               static_cast<uint_type>(
                   write_bit << (HF::top_bit_left_shift - fraction_index++)));
+          // TODO(dneto): Avoid overflow. Testing would require
+          // parameterization.
           exponent = static_cast<int_type>(exponent + 1);
         }
+        // Since this updated after setting fraction bits, this effectively
+        // drops the leading 1 bit.
         bits_written |= write_bit != 0;
       }
     } else {
@@ -1018,10 +1030,12 @@
       for (int i = 0; i < 4; ++i, number <<= 1) {
         uint_type write_bit = (number & 0x8) ? 0x01 : 0x00;
         bits_written |= write_bit != 0;
-        if (is_denorm && !bits_written) {
+        if ((!has_integer_part) && !bits_written) {
           // Handle modifying the exponent here this way we can handle
           // an arbitrary number of hex values without overflowing our
           // integer.
+          // TODO(dneto): Handle underflow. Testing would require extra
+          // parameterization.
           exponent = static_cast<int_type>(exponent - 1);
         } else {
           fraction = static_cast<uint_type>(
@@ -1043,25 +1057,40 @@
   // Finished reading the part preceding 'p'.
   // In hex floats syntax, the binary exponent is required.
 
-  bool seen_sign = false;
+  bool seen_exponent_sign = false;
   int8_t exponent_sign = 1;
   bool seen_written_exponent_digits = false;
+  // The magnitude of the exponent, as written, or the sentinel value to signal
+  // overflow.
   int_type written_exponent = 0;
+  // A sentinel value signalling overflow of the magnitude of the written
+  // exponent.  We'll assume that -written_exponent_overflow is valid for the
+  // type. Later we may add 1 or subtract 1 from the adjusted exponent, so leave
+  // room for an extra 1.
+  const int_type written_exponent_overflow =
+      std::numeric_limits<int_type>::max() - 1;
   while (true) {
     if (!seen_written_exponent_digits &&
         (next_char == '-' || next_char == '+')) {
-      if (seen_sign) {
+      if (seen_exponent_sign) {
         is.setstate(std::ios::failbit);
         return is;
       }
-      seen_sign = true;
+      seen_exponent_sign = true;
       exponent_sign = (next_char == '-') ? -1 : 1;
     } else if (::isdigit(next_char)) {
       seen_written_exponent_digits = true;
       // Hex-floats express their exponent as decimal.
-      written_exponent = static_cast<int_type>(written_exponent * 10);
-      written_exponent =
-          static_cast<int_type>(written_exponent + (next_char - '0'));
+      int_type digit =
+          static_cast<int_type>(static_cast<int_type>(next_char) - '0');
+      if (written_exponent >= (written_exponent_overflow - digit) / 10) {
+        // The exponent is very big. Saturate rather than overflow the exponent.
+        // signed integer, which would be undefined behaviour.
+        written_exponent = written_exponent_overflow;
+      } else {
+        written_exponent = static_cast<int_type>(
+            static_cast<int_type>(written_exponent * 10) + digit);
+      }
     } else {
       break;
     }
@@ -1075,10 +1104,29 @@
   }
 
   written_exponent = static_cast<int_type>(written_exponent * exponent_sign);
-  exponent = static_cast<int_type>(exponent + written_exponent);
+  // Now fold in the exponent bias into the written exponent, updating exponent.
+  // But avoid undefined behaviour that would result from overflowing int_type.
+  if (written_exponent >= 0 && exponent >= 0) {
+    // Saturate up to written_exponent_overflow.
+    if (written_exponent_overflow - exponent > written_exponent) {
+      exponent = static_cast<int_type>(written_exponent + exponent);
+    } else {
+      exponent = written_exponent_overflow;
+    }
+  } else if (written_exponent < 0 && exponent < 0) {
+    // Saturate down to -written_exponent_overflow.
+    if (written_exponent_overflow + exponent > -written_exponent) {
+      exponent = static_cast<int_type>(written_exponent + exponent);
+    } else {
+      exponent = static_cast<int_type>(-written_exponent_overflow);
+    }
+  } else {
+    // They're of opposing sign, so it's safe to add.
+    exponent = static_cast<int_type>(written_exponent + exponent);
+  }
 
-  bool is_zero = is_denorm && (fraction == 0);
-  if (is_denorm && !is_zero) {
+  bool is_zero = (!has_integer_part) && (fraction == 0);
+  if ((!has_integer_part) && !is_zero) {
     fraction = static_cast<uint_type>(fraction << 1);
     exponent = static_cast<int_type>(exponent - 1);
   } else if (is_zero) {
@@ -1095,7 +1143,7 @@
   const int_type max_exponent =
       SetBits<uint_type, 0, HF::num_exponent_bits>::get;
 
-  // Handle actual denorm numbers
+  // Handle denorm numbers
   while (exponent < 0 && !is_zero) {
     fraction = static_cast<uint_type>(fraction >> 1);
     exponent = static_cast<int_type>(exponent + 1);
diff --git a/test/hex_float_test.cpp b/test/hex_float_test.cpp
index 7edfd43..25d3c70 100644
--- a/test/hex_float_test.cpp
+++ b/test/hex_float_test.cpp
@@ -1395,6 +1395,47 @@
         {"0x1.0p1+", true, "+", 2.0f},
         {"0x1.0p1-", true, "-", 2.0f}}));
 
+INSTANTIATE_TEST_SUITE_P(
+    HexFloatPositiveExponentOverflow, FloatStreamParseTest,
+    ::testing::ValuesIn(std::vector<StreamParseCase<float>>{
+        // Positive exponents
+        {"0x1.0p1", true, "", 2.0f},       // fine, a normal number
+        {"0x1.0p15", true, "", 32768.0f},  // fine, a normal number
+        {"0x1.0p127", true, "", float(ldexp(1.0f, 127))},   // good large number
+        {"0x0.8p128", true, "", float(ldexp(1.0f, 127))},   // good large number
+        {"0x0.1p131", true, "", float(ldexp(1.0f, 127))},   // good large number
+        {"0x0.01p135", true, "", float(ldexp(1.0f, 127))},  // good large number
+        {"0x1.0p128", true, "", float(ldexp(1.0f, 128))},   // infinity
+        {"0x1.0p4294967295", true, "", float(ldexp(1.0f, 128))},  // infinity
+        {"0x1.0p5000000000", true, "", float(ldexp(1.0f, 128))},  // infinity
+        {"0x0.0p5000000000", true, "", 0.0f},  // zero mantissa, zero result
+    }));
+
+INSTANTIATE_TEST_SUITE_P(
+    HexFloatNegativeExponentOverflow, FloatStreamParseTest,
+    ::testing::ValuesIn(std::vector<StreamParseCase<float>>{
+        // Positive results, digits before '.'
+        {"0x1.0p-126", true, "",
+         float(ldexp(1.0f, -126))},  // fine, a small normal number
+        {"0x1.0p-127", true, "", float(ldexp(1.0f, -127))},  // denorm number
+        {"0x1.0p-149", true, "",
+         float(ldexp(1.0f, -149))},  // smallest positive denormal
+        {"0x0.8p-148", true, "",
+         float(ldexp(1.0f, -149))},  // smallest positive denormal
+        {"0x0.1p-145", true, "",
+         float(ldexp(1.0f, -149))},  // smallest positive denormal
+        {"0x0.01p-141", true, "",
+         float(ldexp(1.0f, -149))},  // smallest positive denormal
+
+        // underflow rounds down to zero
+        {"0x1.0p-150", true, "", 0.0f},
+        {"0x1.0p-4294967296", true, "",
+         0.0f},  // avoid exponent overflow in parser
+        {"0x1.0p-5000000000", true, "",
+         0.0f},  // avoid exponent overflow in parser
+        {"0x0.0p-5000000000", true, "", 0.0f},  // zero mantissa, zero result
+    }));
+
 // TODO(awoloszyn): Add fp16 tests and HexFloatTraits.
 }  // namespace
 }  // namespace utils