StrFormat: format scientific notation without heap allocation

PiperOrigin-RevId: 852845312
Change-Id: I26dcfc5784383d1cf86fab795ba931445f24c575
diff --git a/absl/strings/internal/str_format/convert_test.cc b/absl/strings/internal/str_format/convert_test.cc
index ea0329c..f340df4 100644
--- a/absl/strings/internal/str_format/convert_test.cc
+++ b/absl/strings/internal/str_format/convert_test.cc
@@ -284,6 +284,11 @@
   return native_traits;
 }
 
+bool IsNativeHexFloatConversion(char f) { return f == 'a' || f == 'A'; }
+bool IsNativeFloatConversion(char f) {
+  return f == 'f' || f == 'F' || f == 'e' || f == 'E' || f == 'a' || f == 'A';
+}
+
 class FormatConvertTest : public ::testing::Test { };
 
 template <typename T>
@@ -799,20 +804,19 @@
                    'e', 'E'}) {
       std::string fmt_str = std::string(fmt) + f;
 
-      if (fmt == absl::string_view("%.5000") && f != 'f' && f != 'F' &&
-          f != 'a' && f != 'A') {
+      if (fmt == absl::string_view("%.5000") && !IsNativeFloatConversion(f)) {
         // This particular test takes way too long with snprintf.
         // Disable for the case we are not implementing natively.
         continue;
       }
 
-      if ((f == 'a' || f == 'A') &&
+      if (IsNativeHexFloatConversion(f) &&
           !native_traits.hex_float_has_glibc_rounding) {
         continue;
       }
 
       if (!native_traits.hex_float_prefers_denormal_repr &&
-          (f == 'a' || f == 'A') &&
+          IsNativeHexFloatConversion(f) &&
           std::fpclassify(tested_float) == FP_SUBNORMAL) {
         continue;
       }
@@ -1328,14 +1332,13 @@
                    'e', 'E'}) {
       std::string fmt_str = std::string(fmt) + 'L' + f;
 
-      if (fmt == absl::string_view("%.5000") && f != 'f' && f != 'F' &&
-          f != 'a' && f != 'A') {
+      if (fmt == absl::string_view("%.5000") && !IsNativeFloatConversion(f)) {
         // This particular test takes way too long with snprintf.
         // Disable for the case we are not implementing natively.
         continue;
       }
 
-      if (f == 'a' || f == 'A') {
+      if (IsNativeHexFloatConversion(f)) {
         if (!native_traits.hex_float_has_glibc_rounding ||
             !native_traits.hex_float_optimizes_leading_digit_bit_count) {
           continue;
diff --git a/absl/strings/internal/str_format/float_conversion.cc b/absl/strings/internal/str_format/float_conversion.cc
index aa31998..0168662 100644
--- a/absl/strings/internal/str_format/float_conversion.cc
+++ b/absl/strings/internal/str_format/float_conversion.cc
@@ -20,7 +20,10 @@
 #include <array>
 #include <cassert>
 #include <cmath>
+#include <cstdint>
+#include <cstring>
 #include <limits>
+#include <optional>
 #include <string>
 
 #include "absl/base/attributes.h"
@@ -31,7 +34,9 @@
 #include "absl/numeric/bits.h"
 #include "absl/numeric/int128.h"
 #include "absl/numeric/internal/representation.h"
+#include "absl/strings/internal/str_format/extension.h"
 #include "absl/strings/numbers.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
 #include "absl/types/span.h"
 
@@ -451,6 +456,71 @@
   return p;
 }
 
+struct FractionalDigitPrinterResult {
+  char* end;
+  size_t skipped_zeros;
+  bool nonzero_remainder;
+};
+
+FractionalDigitPrinterResult PrintFractionalDigitsScientific(
+    uint64_t v, char* start, int exp, size_t precision, bool skip_zeros) {
+  char* p = start;
+  v <<= (64 - exp);
+
+  size_t skipped_zeros = 0;
+  while (v != 0 && precision > 0) {
+    char carry = MultiplyBy10WithCarry(&v, 0);
+    if (skip_zeros) {
+      if (carry == 0) {
+        ++skipped_zeros;
+        continue;
+      }
+      skip_zeros = false;
+    }
+    *p++ = carry + '0';
+    --precision;
+  }
+  return {p, skipped_zeros, v != 0};
+}
+
+FractionalDigitPrinterResult PrintFractionalDigitsScientific(
+    uint128 v, char* start, int exp, size_t precision, bool skip_zeros) {
+  char* p = start;
+  v <<= (128 - exp);
+  auto high = static_cast<uint64_t>(v >> 64);
+  auto low = static_cast<uint64_t>(v);
+
+  size_t skipped_zeros = 0;
+  while (precision > 0 && low != 0) {
+    char carry = MultiplyBy10WithCarry(&low, 0);
+    carry = MultiplyBy10WithCarry(&high, carry);
+    if (skip_zeros) {
+      if (carry == 0) {
+        ++skipped_zeros;
+        continue;
+      }
+      skip_zeros = false;
+    }
+    *p++ = carry + '0';
+    --precision;
+  }
+
+  while (precision > 0 && high != 0) {
+    char carry = MultiplyBy10WithCarry(&high, 0);
+    if (skip_zeros) {
+      if (carry == 0) {
+        ++skipped_zeros;
+        continue;
+      }
+      skip_zeros = false;
+    }
+    *p++ = carry + '0';
+    --precision;
+  }
+
+  return {p, skipped_zeros, high != 0 || low != 0};
+}
+
 struct FormatState {
   char sign_char;
   size_t precision;
@@ -1333,6 +1403,427 @@
   sink->Append(right_spaces, ' ');
 }
 
+template <typename Int>
+void FormatE(Int mantissa, int exp, bool uppercase, const FormatState& state) {
+  if (exp > 0) {
+    const int total_bits =
+        static_cast<int>(sizeof(Int) * 8) - LeadingZeros(mantissa) + exp;
+    if (total_bits > 128) {
+      FormatEPositiveExpSlow(mantissa, exp, uppercase, state);
+      return;
+    }
+  } else {
+    if (ABSL_PREDICT_FALSE(exp < -128)) {
+      FormatENegativeExpSlow(mantissa, exp, uppercase, state);
+      return;
+    }
+  }
+  FormatEFast(mantissa, exp, uppercase, state);
+}
+
+// Guaranteed to fit into 128 bits at this point
+template <typename Int>
+void FormatEFast(Int v, int exp, bool uppercase, const FormatState& state) {
+  if (!v) {
+    absl::string_view mantissa_str = state.ShouldPrintDot() ? "0." : "0";
+    FinalPrint(state, mantissa_str, 0, state.precision,
+               uppercase ? "E+00" : "e+00");
+    return;
+  }
+  constexpr int kInputBits = sizeof(Int) * 8;
+  constexpr int kMaxFractionalDigits = 128;
+  constexpr int kBufferSize = 2 +                    // '.' + rounding
+                              kMaxFixedPrecision +   // Integral
+                              kMaxFractionalDigits;  // Fractional
+  const int total_bits = kInputBits - LeadingZeros(v) + exp;
+  char buffer[kBufferSize];
+  char* integral_start = buffer + 2;
+  char* integral_end = buffer + 2 + kMaxFixedPrecision;
+  char* final_start;
+  char* final_end;
+  bool zero_integral = false;
+  int scientific_exp = 0;
+  size_t digits_printed = 0;
+  size_t trailing_zeros = 0;
+  bool has_more_non_zero = false;
+
+  auto check_integral_zeros =
+      [](char* const begin, char* const end,
+         const size_t precision, size_t digits_processed) -> bool {
+    // When considering rounding to even, we care about the digits after the
+    // round digit which means the total digits to move from the start is
+    // precision + 2 since the first digit we print before the decimal point
+    // is not a part of precision.
+    size_t digit_upper_bound = precision + 2;
+    if (digits_processed > digit_upper_bound) {
+      return std::any_of(begin + digit_upper_bound, end,
+                         [](char c) { return c != '0'; });
+    }
+    return false;
+  };
+
+  if (exp >= 0) {
+    integral_end = total_bits <= 64 ? numbers_internal::FastIntToBuffer(
+                               static_cast<uint64_t>(v) << exp, integral_start)
+                         : numbers_internal::FastIntToBuffer(
+                               static_cast<uint128>(v) << exp, integral_start);
+    *integral_end = '0';
+    final_start = integral_start;
+    // Integral is guaranteed to be non-zero at this point.
+    scientific_exp = static_cast<int>(integral_end - integral_start) - 1;
+    digits_printed = static_cast<size_t>(integral_end - integral_start);
+    final_end = integral_end;
+    has_more_non_zero = check_integral_zeros(integral_start, integral_end,
+                                             state.precision, digits_printed);
+  } else {
+    exp = -exp;
+    if (exp < kInputBits) {
+      integral_end =
+          numbers_internal::FastIntToBuffer(v >> exp, integral_start);
+    }
+    *integral_end = '0';
+    // We didn't move integral_start and it gets set to 0 in
+    zero_integral = exp >= kInputBits || v >> exp == 0;
+    if (!zero_integral) {
+      digits_printed = static_cast<size_t>(integral_end - integral_start);
+      has_more_non_zero = check_integral_zeros(integral_start, integral_end,
+                                               state.precision, digits_printed);
+      final_end = integral_end;
+    }
+    // Print fractional digits
+    char* fractional_start = integral_end;
+
+    size_t digits_to_print = (state.precision + 1) >= digits_printed
+                                 ? state.precision + 1 - digits_printed
+                                 : 0;
+    bool print_extra = digits_printed <= state.precision + 1;
+    auto [fractional_end, skipped_zeros, has_nonzero_rem] =
+        exp <= 64 ? PrintFractionalDigitsScientific(
+                        v, fractional_start, exp, digits_to_print + print_extra,
+                        zero_integral)
+                  : PrintFractionalDigitsScientific(
+                        static_cast<uint128>(v), fractional_start, exp,
+                        digits_to_print + print_extra, zero_integral);
+    final_end = fractional_end;
+    *fractional_end = '0';
+    has_more_non_zero |= has_nonzero_rem;
+    digits_printed += static_cast<size_t>(fractional_end - fractional_start);
+    if (zero_integral) {
+      scientific_exp = -1 * static_cast<int>(skipped_zeros + 1);
+    } else {
+      scientific_exp = static_cast<int>(integral_end - integral_start) - 1;
+    }
+    // Don't do any rounding here, we will do it ourselves.
+    final_start = zero_integral ? fractional_start : integral_start;
+  }
+
+  // For rounding
+  if (digits_printed >= state.precision + 1) {
+    final_start[-1] = '0';
+    char* round_digit_ptr = final_start + 1 + state.precision;
+    if (*round_digit_ptr > '5') {
+      RoundUp(round_digit_ptr - 1);
+    } else if (*round_digit_ptr == '5') {
+      if (has_more_non_zero) {
+        RoundUp(round_digit_ptr - 1);
+      } else {
+        RoundToEven(round_digit_ptr - 1);
+      }
+    }
+    final_end = round_digit_ptr;
+    if (final_start[-1] == '1') {
+      --final_start;
+      ++scientific_exp;
+      --final_end;
+    }
+  } else {
+    // Need to pad with zeros.
+    trailing_zeros = state.precision - (digits_printed - 1);
+  }
+
+  if (state.precision > 0 || state.ShouldPrintDot()) {
+    final_start[-1] = *final_start;
+    *final_start = '.';
+    --final_start;
+  }
+
+  // We need to add 2 to the buffer size for the +/- sign and the e
+  constexpr size_t kExpBufferSize = numbers_internal::kFastToBufferSize + 2;
+  char exp_buffer[kExpBufferSize];
+  char* exp_ptr_start = exp_buffer;
+  char* exp_ptr = exp_ptr_start;
+  *exp_ptr++ = uppercase ? 'E' : 'e';
+  if (scientific_exp >= 0) {
+    *exp_ptr++ = '+';
+  } else {
+    *exp_ptr++ = '-';
+    scientific_exp = -scientific_exp;
+  }
+
+  if (scientific_exp < 10) {
+    *exp_ptr++ = '0';
+  }
+  exp_ptr = numbers_internal::FastIntToBuffer(scientific_exp, exp_ptr);
+  FinalPrint(state,
+             absl::string_view(final_start,
+                               static_cast<size_t>(final_end - final_start)),
+             0, trailing_zeros,
+             absl::string_view(exp_ptr_start,
+                               static_cast<size_t>(exp_ptr - exp_ptr_start)));
+}
+
+void FormatENegativeExpSlow(uint128 mantissa, int exp, bool uppercase,
+                            const FormatState& state) {
+  assert(exp < 0);
+
+  FractionalDigitGenerator::RunConversion(
+      mantissa, -exp,
+      [&](FractionalDigitGenerator digit_gen) {
+        int first_digit = 0;
+        size_t nines = 0;
+        int num_leading_zeros = 0;
+        while (digit_gen.HasMoreDigits()) {
+          auto digits = digit_gen.GetDigits();
+          if (digits.digit_before_nine != 0) {
+            first_digit = digits.digit_before_nine;
+            nines = digits.num_nines;
+            break;
+          } else if (digits.num_nines > 0) {
+            // This also means the first digit is 0
+            first_digit = 9;
+            nines = digits.num_nines - 1;
+            num_leading_zeros++;
+            break;
+          }
+          num_leading_zeros++;
+        }
+
+        bool change_to_zeros = false;
+        if (nines >= state.precision || state.precision == 0) {
+          bool round_up = false;
+          if (nines == state.precision) {
+            round_up = digit_gen.IsGreaterThanHalf();
+          } else {
+            round_up = nines > 0 || digit_gen.IsGreaterThanHalf();
+          }
+          if (round_up) {
+            first_digit = (first_digit == 9 ? 1 : first_digit + 1);
+            num_leading_zeros -= (first_digit == 1);
+            change_to_zeros = true;
+          }
+        }
+        int scientific_exp = -(num_leading_zeros + 1);
+        assert(scientific_exp < 0);
+        char exp_buffer[numbers_internal::kFastToBufferSize];
+        char* exp_start = exp_buffer;
+        *exp_start++ = '-';
+        if (scientific_exp > -10) {
+          *exp_start++ = '0';
+        }
+        scientific_exp *= -1;
+        char* exp_end =
+            numbers_internal::FastIntToBuffer(scientific_exp, exp_start);
+        const size_t total_digits =
+            1                                   // First digit
+            + (state.ShouldPrintDot() ? 1 : 0)  // Decimal point
+            + state.precision                   // Digits after decimal
+            + 1                                 // 'e' or 'E'
+            + static_cast<size_t>(exp_end - exp_buffer);  // Exponent digits
+
+        const auto padding = ExtraWidthToPadding(
+            total_digits + (state.sign_char != '\0' ? 1 : 0), state);
+        state.sink->Append(padding.left_spaces, ' ');
+
+        if (state.sign_char != '\0') {
+          state.sink->Append(1, state.sign_char);
+        }
+
+        state.sink->Append(1, static_cast<char>(first_digit + '0'));
+        if (state.ShouldPrintDot()) {
+          state.sink->Append(1, '.');
+        }
+        size_t digits_to_go = state.precision;
+        size_t nines_to_print = std::min(nines, digits_to_go);
+        state.sink->Append(nines_to_print, change_to_zeros ? '0' : '9');
+        digits_to_go -= nines_to_print;
+        while (digits_to_go > 0 && digit_gen.HasMoreDigits()) {
+          auto digits = digit_gen.GetDigits();
+
+          if (digits.num_nines + 1 < digits_to_go) {
+            state.sink->Append(1, digits.digit_before_nine + '0');
+            state.sink->Append(digits.num_nines, '9');
+            digits_to_go -= digits.num_nines + 1;
+          } else {
+            bool round_up = false;
+            if (digits.num_nines + 1 > digits_to_go) {
+              round_up = true;
+            } else if (digit_gen.IsGreaterThanHalf()) {
+              round_up = true;
+            } else if (digit_gen.IsExactlyHalf()) {
+              round_up =
+                  digits.num_nines != 0 || digits.digit_before_nine % 2 == 1;
+            }
+            if (round_up) {
+              state.sink->Append(1, digits.digit_before_nine + '1');
+              --digits_to_go;
+            } else {
+              state.sink->Append(1, digits.digit_before_nine + '0');
+              state.sink->Append(digits_to_go - 1, '9');
+              digits_to_go = 0;
+            }
+            break;
+          }
+        }
+        state.sink->Append(digits_to_go, '0');
+        state.sink->Append(1, uppercase ? 'E' : 'e');
+        state.sink->Append(absl::string_view(
+            exp_buffer, static_cast<size_t>(exp_end - exp_buffer)));
+        state.sink->Append(padding.right_spaces, ' ');
+      });
+}
+
+std::optional<int> GetOneDigit(BinaryToDecimal& btd,
+                               absl::string_view& digits_view) {
+  if (digits_view.empty()) {
+    if (!btd.AdvanceDigits()) return std::nullopt;
+    digits_view = btd.CurrentDigits();
+  }
+  char d = digits_view.front();
+  digits_view.remove_prefix(1);
+  return d - '0';
+}
+
+struct DigitRun {
+  std::optional<int> digit;
+  size_t nines;
+};
+
+DigitRun GetDigits(BinaryToDecimal& btd, absl::string_view& digits_view) {
+  auto peek_digit = [&]() -> std::optional<int> {
+    if (digits_view.empty()) {
+      if (!btd.AdvanceDigits()) return std::nullopt;
+      digits_view = btd.CurrentDigits();
+    }
+    return digits_view.front() - '0';
+  };
+
+  auto digit_before_nines = GetOneDigit(btd, digits_view);
+  if (!digit_before_nines.has_value()) return {std::nullopt, 0};
+
+  auto next_digit = peek_digit();
+  size_t num_nines = 0;
+  while (next_digit == 9) {
+    // consume the 9
+    GetOneDigit(btd, digits_view);
+    ++num_nines;
+    next_digit = peek_digit();
+  }
+  return digit_before_nines == 9
+             ? DigitRun{std::nullopt, num_nines + 1}
+             : DigitRun{digit_before_nines, num_nines};
+}
+
+void FormatEPositiveExpSlow(uint128 mantissa, int exp, bool uppercase,
+                            const FormatState& state) {
+  BinaryToDecimal::RunConversion(
+      mantissa, exp, [&](BinaryToDecimal btd) {
+        int scientific_exp = static_cast<int>(btd.TotalDigits() - 1);
+        absl::string_view digits_view = btd.CurrentDigits();
+
+        size_t digits_to_go = state.precision + 1;
+        auto [first_digit_opt, nines] = GetDigits(btd, digits_view);
+        if (!first_digit_opt.has_value() && nines == 0) {
+          return;
+        }
+
+        int first_digit = first_digit_opt.value_or(9);
+        if (!first_digit_opt) {
+          --nines;
+        }
+
+        // At this point we are guaranteed to have some sort of first digit
+        bool change_to_zeros = false;
+        if (nines + 1 >= digits_to_go) {
+          // Everything we need to print is in the first DigitRun
+          auto [next_digit_opt, next_nines] = GetDigits(btd, digits_view);
+          if (nines == state.precision) {
+            change_to_zeros = next_digit_opt.value_or(0) > 4;
+          } else {
+            change_to_zeros = true;
+          }
+          if (change_to_zeros) {
+            if (first_digit != 9) {
+              first_digit = first_digit + 1;
+            } else {
+              first_digit = 1;
+              ++scientific_exp;
+            }
+          }
+        }
+
+        char exp_buffer[numbers_internal::kFastToBufferSize];
+        char* exp_buffer_end =
+            numbers_internal::FastIntToBuffer(scientific_exp, exp_buffer);
+        const size_t total_digits_out =
+            1 + state.ShouldPrintDot() + state.precision + 2 +
+            (static_cast<size_t>(exp_buffer_end - exp_buffer));
+
+        const auto padding = ExtraWidthToPadding(
+            total_digits_out + (state.sign_char != '\0' ? 1 : 0), state);
+
+        state.sink->Append(padding.left_spaces, ' ');
+        if (state.sign_char != '\0') {
+          state.sink->Append(1, state.sign_char);
+        }
+        state.sink->Append(1, static_cast<char>(first_digit + '0'));
+        --digits_to_go;
+        if (state.precision > 0 || state.ShouldPrintDot()) {
+          state.sink->Append(1, '.');
+        }
+        state.sink->Append(std::min(digits_to_go, nines),
+                           change_to_zeros ? '0' : '9');
+        digits_to_go -= std::min(digits_to_go, nines);
+        while (digits_to_go > 0) {
+          auto [digit_opt, curr_nines] = GetDigits(btd, digits_view);
+          if (!digit_opt.has_value()) break;
+          int digit = *digit_opt;
+          if (curr_nines + 1 < digits_to_go) {
+            state.sink->Append(1, static_cast<char>(digit + '0'));
+            state.sink->Append(curr_nines, '9');
+            digits_to_go -= curr_nines + 1;
+          } else {
+            bool need_round_up = false;
+            auto [next_digit_opt, next_nines] = GetDigits(btd, digits_view);
+            if (digits_to_go == 1) {
+              need_round_up = curr_nines > 0 || next_digit_opt > 4;
+            } else if (digits_to_go == curr_nines + 1) {
+              // Only round if next digit is > 4
+              need_round_up = next_digit_opt.value_or(0) > 4;
+            } else {
+              // we know we need to round since nine is after precision ends
+              need_round_up = true;
+            }
+            state.sink->Append(1,
+                               static_cast<char>(digit + need_round_up + '0'));
+            state.sink->Append(digits_to_go - 1, need_round_up ? '0' : '9');
+            digits_to_go = 0;
+          }
+        }
+
+        if (digits_to_go > 0) {
+          state.sink->Append(digits_to_go, '0');
+        }
+        state.sink->Append(1, uppercase ? 'E' : 'e');
+        state.sink->Append(1, scientific_exp >= 0 ? '+' : '-');
+        if (scientific_exp < 10) {
+          state.sink->Append(1, '0');
+        }
+        state.sink->Append(absl::string_view(
+            exp_buffer, static_cast<size_t>(exp_buffer_end - exp_buffer)));
+        state.sink->Append(padding.right_spaces, ' ');
+      });
+}
+
 template <typename Float>
 bool FloatToSink(const Float v, const FormatConversionSpecImpl &conv,
                  FormatSinkImpl *sink) {
@@ -1371,14 +1862,10 @@
     return true;
   } else if (c == FormatConversionCharInternal::e ||
              c == FormatConversionCharInternal::E) {
-    if (!FloatToBuffer<FormatStyle::Precision>(decomposed, precision, &buffer,
-                                               &exp)) {
-      return FallbackToSnprintf(v, conv, sink);
-    }
-    if (!conv.has_alt_flag() && buffer.back() == '.') buffer.pop_back();
-    PrintExponent(
-        exp, FormatConversionCharIsUpper(conv.conversion_char()) ? 'E' : 'e',
-        &buffer);
+    FormatE(decomposed.mantissa, decomposed.exponent,
+            FormatConversionCharIsUpper(conv.conversion_char()),
+            {sign_char, precision, conv, sink});
+    return true;
   } else if (c == FormatConversionCharInternal::g ||
              c == FormatConversionCharInternal::G) {
     precision = std::max(precision, size_t{1}) - 1;