StrFormat: format scientific notation without heap allocation PiperOrigin-RevId: 852845312 Change-Id: I26dcfc5784383d1cf86fab795ba931445f24c575
diff --git a/absl/strings/internal/str_format/convert_test.cc b/absl/strings/internal/str_format/convert_test.cc index ea0329c..f340df4 100644 --- a/absl/strings/internal/str_format/convert_test.cc +++ b/absl/strings/internal/str_format/convert_test.cc
@@ -284,6 +284,11 @@ return native_traits; } +bool IsNativeHexFloatConversion(char f) { return f == 'a' || f == 'A'; } +bool IsNativeFloatConversion(char f) { + return f == 'f' || f == 'F' || f == 'e' || f == 'E' || f == 'a' || f == 'A'; +} + class FormatConvertTest : public ::testing::Test { }; template <typename T> @@ -799,20 +804,19 @@ 'e', 'E'}) { std::string fmt_str = std::string(fmt) + f; - if (fmt == absl::string_view("%.5000") && f != 'f' && f != 'F' && - f != 'a' && f != 'A') { + if (fmt == absl::string_view("%.5000") && !IsNativeFloatConversion(f)) { // This particular test takes way too long with snprintf. // Disable for the case we are not implementing natively. continue; } - if ((f == 'a' || f == 'A') && + if (IsNativeHexFloatConversion(f) && !native_traits.hex_float_has_glibc_rounding) { continue; } if (!native_traits.hex_float_prefers_denormal_repr && - (f == 'a' || f == 'A') && + IsNativeHexFloatConversion(f) && std::fpclassify(tested_float) == FP_SUBNORMAL) { continue; } @@ -1328,14 +1332,13 @@ 'e', 'E'}) { std::string fmt_str = std::string(fmt) + 'L' + f; - if (fmt == absl::string_view("%.5000") && f != 'f' && f != 'F' && - f != 'a' && f != 'A') { + if (fmt == absl::string_view("%.5000") && !IsNativeFloatConversion(f)) { // This particular test takes way too long with snprintf. // Disable for the case we are not implementing natively. continue; } - if (f == 'a' || f == 'A') { + if (IsNativeHexFloatConversion(f)) { if (!native_traits.hex_float_has_glibc_rounding || !native_traits.hex_float_optimizes_leading_digit_bit_count) { continue;
diff --git a/absl/strings/internal/str_format/float_conversion.cc b/absl/strings/internal/str_format/float_conversion.cc index aa31998..0168662 100644 --- a/absl/strings/internal/str_format/float_conversion.cc +++ b/absl/strings/internal/str_format/float_conversion.cc
@@ -20,7 +20,10 @@ #include <array> #include <cassert> #include <cmath> +#include <cstdint> +#include <cstring> #include <limits> +#include <optional> #include <string> #include "absl/base/attributes.h" @@ -31,7 +34,9 @@ #include "absl/numeric/bits.h" #include "absl/numeric/int128.h" #include "absl/numeric/internal/representation.h" +#include "absl/strings/internal/str_format/extension.h" #include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -451,6 +456,71 @@ return p; } +struct FractionalDigitPrinterResult { + char* end; + size_t skipped_zeros; + bool nonzero_remainder; +}; + +FractionalDigitPrinterResult PrintFractionalDigitsScientific( + uint64_t v, char* start, int exp, size_t precision, bool skip_zeros) { + char* p = start; + v <<= (64 - exp); + + size_t skipped_zeros = 0; + while (v != 0 && precision > 0) { + char carry = MultiplyBy10WithCarry(&v, 0); + if (skip_zeros) { + if (carry == 0) { + ++skipped_zeros; + continue; + } + skip_zeros = false; + } + *p++ = carry + '0'; + --precision; + } + return {p, skipped_zeros, v != 0}; +} + +FractionalDigitPrinterResult PrintFractionalDigitsScientific( + uint128 v, char* start, int exp, size_t precision, bool skip_zeros) { + char* p = start; + v <<= (128 - exp); + auto high = static_cast<uint64_t>(v >> 64); + auto low = static_cast<uint64_t>(v); + + size_t skipped_zeros = 0; + while (precision > 0 && low != 0) { + char carry = MultiplyBy10WithCarry(&low, 0); + carry = MultiplyBy10WithCarry(&high, carry); + if (skip_zeros) { + if (carry == 0) { + ++skipped_zeros; + continue; + } + skip_zeros = false; + } + *p++ = carry + '0'; + --precision; + } + + while (precision > 0 && high != 0) { + char carry = MultiplyBy10WithCarry(&high, 0); + if (skip_zeros) { + if (carry == 0) { + ++skipped_zeros; + continue; + } + skip_zeros = false; + } + *p++ = carry + '0'; + --precision; + } + + return {p, skipped_zeros, high != 0 || low != 0}; +} + struct FormatState { char sign_char; size_t precision; @@ -1333,6 +1403,427 @@ sink->Append(right_spaces, ' '); } +template <typename Int> +void FormatE(Int mantissa, int exp, bool uppercase, const FormatState& state) { + if (exp > 0) { + const int total_bits = + static_cast<int>(sizeof(Int) * 8) - LeadingZeros(mantissa) + exp; + if (total_bits > 128) { + FormatEPositiveExpSlow(mantissa, exp, uppercase, state); + return; + } + } else { + if (ABSL_PREDICT_FALSE(exp < -128)) { + FormatENegativeExpSlow(mantissa, exp, uppercase, state); + return; + } + } + FormatEFast(mantissa, exp, uppercase, state); +} + +// Guaranteed to fit into 128 bits at this point +template <typename Int> +void FormatEFast(Int v, int exp, bool uppercase, const FormatState& state) { + if (!v) { + absl::string_view mantissa_str = state.ShouldPrintDot() ? "0." : "0"; + FinalPrint(state, mantissa_str, 0, state.precision, + uppercase ? "E+00" : "e+00"); + return; + } + constexpr int kInputBits = sizeof(Int) * 8; + constexpr int kMaxFractionalDigits = 128; + constexpr int kBufferSize = 2 + // '.' + rounding + kMaxFixedPrecision + // Integral + kMaxFractionalDigits; // Fractional + const int total_bits = kInputBits - LeadingZeros(v) + exp; + char buffer[kBufferSize]; + char* integral_start = buffer + 2; + char* integral_end = buffer + 2 + kMaxFixedPrecision; + char* final_start; + char* final_end; + bool zero_integral = false; + int scientific_exp = 0; + size_t digits_printed = 0; + size_t trailing_zeros = 0; + bool has_more_non_zero = false; + + auto check_integral_zeros = + [](char* const begin, char* const end, + const size_t precision, size_t digits_processed) -> bool { + // When considering rounding to even, we care about the digits after the + // round digit which means the total digits to move from the start is + // precision + 2 since the first digit we print before the decimal point + // is not a part of precision. + size_t digit_upper_bound = precision + 2; + if (digits_processed > digit_upper_bound) { + return std::any_of(begin + digit_upper_bound, end, + [](char c) { return c != '0'; }); + } + return false; + }; + + if (exp >= 0) { + integral_end = total_bits <= 64 ? numbers_internal::FastIntToBuffer( + static_cast<uint64_t>(v) << exp, integral_start) + : numbers_internal::FastIntToBuffer( + static_cast<uint128>(v) << exp, integral_start); + *integral_end = '0'; + final_start = integral_start; + // Integral is guaranteed to be non-zero at this point. + scientific_exp = static_cast<int>(integral_end - integral_start) - 1; + digits_printed = static_cast<size_t>(integral_end - integral_start); + final_end = integral_end; + has_more_non_zero = check_integral_zeros(integral_start, integral_end, + state.precision, digits_printed); + } else { + exp = -exp; + if (exp < kInputBits) { + integral_end = + numbers_internal::FastIntToBuffer(v >> exp, integral_start); + } + *integral_end = '0'; + // We didn't move integral_start and it gets set to 0 in + zero_integral = exp >= kInputBits || v >> exp == 0; + if (!zero_integral) { + digits_printed = static_cast<size_t>(integral_end - integral_start); + has_more_non_zero = check_integral_zeros(integral_start, integral_end, + state.precision, digits_printed); + final_end = integral_end; + } + // Print fractional digits + char* fractional_start = integral_end; + + size_t digits_to_print = (state.precision + 1) >= digits_printed + ? state.precision + 1 - digits_printed + : 0; + bool print_extra = digits_printed <= state.precision + 1; + auto [fractional_end, skipped_zeros, has_nonzero_rem] = + exp <= 64 ? PrintFractionalDigitsScientific( + v, fractional_start, exp, digits_to_print + print_extra, + zero_integral) + : PrintFractionalDigitsScientific( + static_cast<uint128>(v), fractional_start, exp, + digits_to_print + print_extra, zero_integral); + final_end = fractional_end; + *fractional_end = '0'; + has_more_non_zero |= has_nonzero_rem; + digits_printed += static_cast<size_t>(fractional_end - fractional_start); + if (zero_integral) { + scientific_exp = -1 * static_cast<int>(skipped_zeros + 1); + } else { + scientific_exp = static_cast<int>(integral_end - integral_start) - 1; + } + // Don't do any rounding here, we will do it ourselves. + final_start = zero_integral ? fractional_start : integral_start; + } + + // For rounding + if (digits_printed >= state.precision + 1) { + final_start[-1] = '0'; + char* round_digit_ptr = final_start + 1 + state.precision; + if (*round_digit_ptr > '5') { + RoundUp(round_digit_ptr - 1); + } else if (*round_digit_ptr == '5') { + if (has_more_non_zero) { + RoundUp(round_digit_ptr - 1); + } else { + RoundToEven(round_digit_ptr - 1); + } + } + final_end = round_digit_ptr; + if (final_start[-1] == '1') { + --final_start; + ++scientific_exp; + --final_end; + } + } else { + // Need to pad with zeros. + trailing_zeros = state.precision - (digits_printed - 1); + } + + if (state.precision > 0 || state.ShouldPrintDot()) { + final_start[-1] = *final_start; + *final_start = '.'; + --final_start; + } + + // We need to add 2 to the buffer size for the +/- sign and the e + constexpr size_t kExpBufferSize = numbers_internal::kFastToBufferSize + 2; + char exp_buffer[kExpBufferSize]; + char* exp_ptr_start = exp_buffer; + char* exp_ptr = exp_ptr_start; + *exp_ptr++ = uppercase ? 'E' : 'e'; + if (scientific_exp >= 0) { + *exp_ptr++ = '+'; + } else { + *exp_ptr++ = '-'; + scientific_exp = -scientific_exp; + } + + if (scientific_exp < 10) { + *exp_ptr++ = '0'; + } + exp_ptr = numbers_internal::FastIntToBuffer(scientific_exp, exp_ptr); + FinalPrint(state, + absl::string_view(final_start, + static_cast<size_t>(final_end - final_start)), + 0, trailing_zeros, + absl::string_view(exp_ptr_start, + static_cast<size_t>(exp_ptr - exp_ptr_start))); +} + +void FormatENegativeExpSlow(uint128 mantissa, int exp, bool uppercase, + const FormatState& state) { + assert(exp < 0); + + FractionalDigitGenerator::RunConversion( + mantissa, -exp, + [&](FractionalDigitGenerator digit_gen) { + int first_digit = 0; + size_t nines = 0; + int num_leading_zeros = 0; + while (digit_gen.HasMoreDigits()) { + auto digits = digit_gen.GetDigits(); + if (digits.digit_before_nine != 0) { + first_digit = digits.digit_before_nine; + nines = digits.num_nines; + break; + } else if (digits.num_nines > 0) { + // This also means the first digit is 0 + first_digit = 9; + nines = digits.num_nines - 1; + num_leading_zeros++; + break; + } + num_leading_zeros++; + } + + bool change_to_zeros = false; + if (nines >= state.precision || state.precision == 0) { + bool round_up = false; + if (nines == state.precision) { + round_up = digit_gen.IsGreaterThanHalf(); + } else { + round_up = nines > 0 || digit_gen.IsGreaterThanHalf(); + } + if (round_up) { + first_digit = (first_digit == 9 ? 1 : first_digit + 1); + num_leading_zeros -= (first_digit == 1); + change_to_zeros = true; + } + } + int scientific_exp = -(num_leading_zeros + 1); + assert(scientific_exp < 0); + char exp_buffer[numbers_internal::kFastToBufferSize]; + char* exp_start = exp_buffer; + *exp_start++ = '-'; + if (scientific_exp > -10) { + *exp_start++ = '0'; + } + scientific_exp *= -1; + char* exp_end = + numbers_internal::FastIntToBuffer(scientific_exp, exp_start); + const size_t total_digits = + 1 // First digit + + (state.ShouldPrintDot() ? 1 : 0) // Decimal point + + state.precision // Digits after decimal + + 1 // 'e' or 'E' + + static_cast<size_t>(exp_end - exp_buffer); // Exponent digits + + const auto padding = ExtraWidthToPadding( + total_digits + (state.sign_char != '\0' ? 1 : 0), state); + state.sink->Append(padding.left_spaces, ' '); + + if (state.sign_char != '\0') { + state.sink->Append(1, state.sign_char); + } + + state.sink->Append(1, static_cast<char>(first_digit + '0')); + if (state.ShouldPrintDot()) { + state.sink->Append(1, '.'); + } + size_t digits_to_go = state.precision; + size_t nines_to_print = std::min(nines, digits_to_go); + state.sink->Append(nines_to_print, change_to_zeros ? '0' : '9'); + digits_to_go -= nines_to_print; + while (digits_to_go > 0 && digit_gen.HasMoreDigits()) { + auto digits = digit_gen.GetDigits(); + + if (digits.num_nines + 1 < digits_to_go) { + state.sink->Append(1, digits.digit_before_nine + '0'); + state.sink->Append(digits.num_nines, '9'); + digits_to_go -= digits.num_nines + 1; + } else { + bool round_up = false; + if (digits.num_nines + 1 > digits_to_go) { + round_up = true; + } else if (digit_gen.IsGreaterThanHalf()) { + round_up = true; + } else if (digit_gen.IsExactlyHalf()) { + round_up = + digits.num_nines != 0 || digits.digit_before_nine % 2 == 1; + } + if (round_up) { + state.sink->Append(1, digits.digit_before_nine + '1'); + --digits_to_go; + } else { + state.sink->Append(1, digits.digit_before_nine + '0'); + state.sink->Append(digits_to_go - 1, '9'); + digits_to_go = 0; + } + break; + } + } + state.sink->Append(digits_to_go, '0'); + state.sink->Append(1, uppercase ? 'E' : 'e'); + state.sink->Append(absl::string_view( + exp_buffer, static_cast<size_t>(exp_end - exp_buffer))); + state.sink->Append(padding.right_spaces, ' '); + }); +} + +std::optional<int> GetOneDigit(BinaryToDecimal& btd, + absl::string_view& digits_view) { + if (digits_view.empty()) { + if (!btd.AdvanceDigits()) return std::nullopt; + digits_view = btd.CurrentDigits(); + } + char d = digits_view.front(); + digits_view.remove_prefix(1); + return d - '0'; +} + +struct DigitRun { + std::optional<int> digit; + size_t nines; +}; + +DigitRun GetDigits(BinaryToDecimal& btd, absl::string_view& digits_view) { + auto peek_digit = [&]() -> std::optional<int> { + if (digits_view.empty()) { + if (!btd.AdvanceDigits()) return std::nullopt; + digits_view = btd.CurrentDigits(); + } + return digits_view.front() - '0'; + }; + + auto digit_before_nines = GetOneDigit(btd, digits_view); + if (!digit_before_nines.has_value()) return {std::nullopt, 0}; + + auto next_digit = peek_digit(); + size_t num_nines = 0; + while (next_digit == 9) { + // consume the 9 + GetOneDigit(btd, digits_view); + ++num_nines; + next_digit = peek_digit(); + } + return digit_before_nines == 9 + ? DigitRun{std::nullopt, num_nines + 1} + : DigitRun{digit_before_nines, num_nines}; +} + +void FormatEPositiveExpSlow(uint128 mantissa, int exp, bool uppercase, + const FormatState& state) { + BinaryToDecimal::RunConversion( + mantissa, exp, [&](BinaryToDecimal btd) { + int scientific_exp = static_cast<int>(btd.TotalDigits() - 1); + absl::string_view digits_view = btd.CurrentDigits(); + + size_t digits_to_go = state.precision + 1; + auto [first_digit_opt, nines] = GetDigits(btd, digits_view); + if (!first_digit_opt.has_value() && nines == 0) { + return; + } + + int first_digit = first_digit_opt.value_or(9); + if (!first_digit_opt) { + --nines; + } + + // At this point we are guaranteed to have some sort of first digit + bool change_to_zeros = false; + if (nines + 1 >= digits_to_go) { + // Everything we need to print is in the first DigitRun + auto [next_digit_opt, next_nines] = GetDigits(btd, digits_view); + if (nines == state.precision) { + change_to_zeros = next_digit_opt.value_or(0) > 4; + } else { + change_to_zeros = true; + } + if (change_to_zeros) { + if (first_digit != 9) { + first_digit = first_digit + 1; + } else { + first_digit = 1; + ++scientific_exp; + } + } + } + + char exp_buffer[numbers_internal::kFastToBufferSize]; + char* exp_buffer_end = + numbers_internal::FastIntToBuffer(scientific_exp, exp_buffer); + const size_t total_digits_out = + 1 + state.ShouldPrintDot() + state.precision + 2 + + (static_cast<size_t>(exp_buffer_end - exp_buffer)); + + const auto padding = ExtraWidthToPadding( + total_digits_out + (state.sign_char != '\0' ? 1 : 0), state); + + state.sink->Append(padding.left_spaces, ' '); + if (state.sign_char != '\0') { + state.sink->Append(1, state.sign_char); + } + state.sink->Append(1, static_cast<char>(first_digit + '0')); + --digits_to_go; + if (state.precision > 0 || state.ShouldPrintDot()) { + state.sink->Append(1, '.'); + } + state.sink->Append(std::min(digits_to_go, nines), + change_to_zeros ? '0' : '9'); + digits_to_go -= std::min(digits_to_go, nines); + while (digits_to_go > 0) { + auto [digit_opt, curr_nines] = GetDigits(btd, digits_view); + if (!digit_opt.has_value()) break; + int digit = *digit_opt; + if (curr_nines + 1 < digits_to_go) { + state.sink->Append(1, static_cast<char>(digit + '0')); + state.sink->Append(curr_nines, '9'); + digits_to_go -= curr_nines + 1; + } else { + bool need_round_up = false; + auto [next_digit_opt, next_nines] = GetDigits(btd, digits_view); + if (digits_to_go == 1) { + need_round_up = curr_nines > 0 || next_digit_opt > 4; + } else if (digits_to_go == curr_nines + 1) { + // Only round if next digit is > 4 + need_round_up = next_digit_opt.value_or(0) > 4; + } else { + // we know we need to round since nine is after precision ends + need_round_up = true; + } + state.sink->Append(1, + static_cast<char>(digit + need_round_up + '0')); + state.sink->Append(digits_to_go - 1, need_round_up ? '0' : '9'); + digits_to_go = 0; + } + } + + if (digits_to_go > 0) { + state.sink->Append(digits_to_go, '0'); + } + state.sink->Append(1, uppercase ? 'E' : 'e'); + state.sink->Append(1, scientific_exp >= 0 ? '+' : '-'); + if (scientific_exp < 10) { + state.sink->Append(1, '0'); + } + state.sink->Append(absl::string_view( + exp_buffer, static_cast<size_t>(exp_buffer_end - exp_buffer))); + state.sink->Append(padding.right_spaces, ' '); + }); +} + template <typename Float> bool FloatToSink(const Float v, const FormatConversionSpecImpl &conv, FormatSinkImpl *sink) { @@ -1371,14 +1862,10 @@ return true; } else if (c == FormatConversionCharInternal::e || c == FormatConversionCharInternal::E) { - if (!FloatToBuffer<FormatStyle::Precision>(decomposed, precision, &buffer, - &exp)) { - return FallbackToSnprintf(v, conv, sink); - } - if (!conv.has_alt_flag() && buffer.back() == '.') buffer.pop_back(); - PrintExponent( - exp, FormatConversionCharIsUpper(conv.conversion_char()) ? 'E' : 'e', - &buffer); + FormatE(decomposed.mantissa, decomposed.exponent, + FormatConversionCharIsUpper(conv.conversion_char()), + {sign_char, precision, conv, sink}); + return true; } else if (c == FormatConversionCharInternal::g || c == FormatConversionCharInternal::G) { precision = std::max(precision, size_t{1}) - 1;