Make absl::FunctionRef support non-const callables, aligning it with std::function_ref from C++26 PiperOrigin-RevId: 807349765 Change-Id: I2a59d749818c5df669f6332f88bc1c9d59a2174d
diff --git a/absl/functional/BUILD.bazel b/absl/functional/BUILD.bazel index f9c58d4..b7aa31f 100644 --- a/absl/functional/BUILD.bazel +++ b/absl/functional/BUILD.bazel
@@ -104,8 +104,10 @@ linkopts = ABSL_DEFAULT_LINKOPTS, deps = [ ":any_invocable", + "//absl/base:config", "//absl/base:core_headers", "//absl/meta:type_traits", + "//absl/utility", ], ) @@ -117,8 +119,10 @@ deps = [ ":any_invocable", ":function_ref", + "//absl/base:config", "//absl/container:test_instance_tracker", "//absl/memory", + "//absl/utility", "@googletest//:gtest", "@googletest//:gtest_main", ],
diff --git a/absl/functional/CMakeLists.txt b/absl/functional/CMakeLists.txt index 34d285d..07f3dc0 100644 --- a/absl/functional/CMakeLists.txt +++ b/absl/functional/CMakeLists.txt
@@ -87,9 +87,11 @@ COPTS ${ABSL_DEFAULT_COPTS} DEPS + absl::config absl::core_headers absl::any_invocable absl::meta + absl::utility PUBLIC ) @@ -101,9 +103,11 @@ COPTS ${ABSL_TEST_COPTS} DEPS + absl::config absl::function_ref absl::memory absl::test_instance_tracker + absl::utility GTest::gmock_main )
diff --git a/absl/functional/function_ref.h b/absl/functional/function_ref.h index f1d087a..edf61de 100644 --- a/absl/functional/function_ref.h +++ b/absl/functional/function_ref.h
@@ -47,12 +47,13 @@ #define ABSL_FUNCTIONAL_FUNCTION_REF_H_ #include <cassert> -#include <functional> #include <type_traits> #include "absl/base/attributes.h" +#include "absl/base/config.h" #include "absl/functional/internal/function_ref.h" #include "absl/meta/type_traits.h" +#include "absl/utility/utility.h" namespace absl { ABSL_NAMESPACE_BEGIN @@ -89,15 +90,17 @@ // signature of this FunctionRef. template <typename F, typename FR = std::invoke_result_t<F, Args&&...>> using EnableIfCompatible = - typename std::enable_if<std::is_void<R>::value || - std::is_convertible<FR, R>::value>::type; + std::enable_if_t<std::conditional_t<std::is_void_v<R>, std::true_type, + std::is_invocable_r<R, FR()>>::value>; public: // Constructs a FunctionRef from any invocable type. - template <typename F, typename = EnableIfCompatible<const F&>> - // NOLINTNEXTLINE(runtime/explicit) - FunctionRef(const F& f ABSL_ATTRIBUTE_LIFETIME_BOUND) - : invoker_(&absl::functional_internal::InvokeObject<F, R, Args...>) { + template <typename F, + typename = EnableIfCompatible<std::enable_if_t< + !std::is_same_v<FunctionRef, absl::remove_cvref_t<F>>, F&>>> + // NOLINTNEXTLINE(google-explicit-constructor) + FunctionRef(F&& f ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept + : invoker_(&absl::functional_internal::InvokeObject<F&, R, Args...>) { absl::functional_internal::AssertNonNull(f); ptr_.obj = &f; } @@ -111,14 +114,39 @@ template < typename F, typename = EnableIfCompatible<F*>, absl::functional_internal::EnableIf<absl::is_function<F>::value> = 0> - FunctionRef(F* f) // NOLINT(runtime/explicit) + // NOLINTNEXTLINE(google-explicit-constructor) + FunctionRef(F* f ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept : invoker_(&absl::functional_internal::InvokeFunction<F*, R, Args...>) { assert(f != nullptr); ptr_.fun = reinterpret_cast<decltype(ptr_.fun)>(f); } - FunctionRef& operator=(const FunctionRef& rhs) = default; - FunctionRef(const FunctionRef& rhs) = default; +#if ABSL_INTERNAL_CPLUSPLUS_LANG >= 202002L + // Similar to the other overloads, but passes the address of a known callable + // `F` at compile time. This allows calling arbitrary functions while avoiding + // an indirection. + // Needs C++20 as `nontype_t` needs C++20 for `auto` template parameters. + template <auto F> + FunctionRef(nontype_t<F>) noexcept // NOLINT(google-explicit-constructor) + : invoker_(&absl::functional_internal::InvokeFunction<decltype(F), F, R, + Args...>) {} + + template <auto F, typename Obj> + // NOLINTNEXTLINE(google-explicit-constructor) + FunctionRef(nontype_t<F>, Obj&& obj) noexcept + : invoker_(&absl::functional_internal::InvokeObject<Obj&, decltype(F), F, + R, Args...>) { + ptr_.obj = std::addressof(obj); + } + + template <auto F, typename Obj> + // NOLINTNEXTLINE(google-explicit-constructor) + FunctionRef(nontype_t<F>, Obj* obj) noexcept + : invoker_(&absl::functional_internal::InvokePtr<Obj, decltype(F), F, R, + Args...>) { + ptr_.obj = obj; + } +#endif // Call the underlying object. R operator()(Args... args) const { @@ -134,8 +162,39 @@ // constness anyway we can just make this a no-op. template <typename R, typename... Args> class FunctionRef<R(Args...) const> : public FunctionRef<R(Args...)> { + using Base = FunctionRef<R(Args...)>; + + template <typename F, typename T = void> + using EnableIfCallable = + std::enable_if_t<!std::is_same_v<FunctionRef, absl::remove_cvref_t<F>> && + std::is_invocable_r_v<R, F, Args...> && + std::is_constructible_v<Base, F>, + T>; + public: - using FunctionRef<R(Args...)>::FunctionRef; + template <typename F, typename = EnableIfCallable<const F&>> + // NOLINTNEXTLINE(google-explicit-constructor) + FunctionRef(const F& f ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept : Base(f) {} + + template <typename F, + typename = std::enable_if_t<std::is_constructible_v<Base, F*>>> + // NOLINTNEXTLINE(google-explicit-constructor) + FunctionRef(F* f ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept : Base(f) {} + +#if ABSL_INTERNAL_CPLUSPLUS_LANG >= 202002L + template <auto F, typename = EnableIfCallable<decltype(F)>> + // NOLINTNEXTLINE(google-explicit-constructor) + FunctionRef(nontype_t<F> arg) noexcept : Base(arg) {} + + template <auto F, typename Obj, typename = EnableIfCallable<decltype(F)>> + // NOLINTNEXTLINE(google-explicit-constructor) + FunctionRef(nontype_t<F> arg, Obj&& obj) noexcept + : Base(arg, std::forward<Obj>(obj)) {} + + template <auto F, typename Obj, typename = EnableIfCallable<decltype(F)>> + // NOLINTNEXTLINE(google-explicit-constructor) + FunctionRef(nontype_t<F> arg, Obj* obj) noexcept : Base(arg, obj) {} +#endif }; ABSL_NAMESPACE_END
diff --git a/absl/functional/function_ref_test.cc b/absl/functional/function_ref_test.cc index 98d11f7..c8ff080 100644 --- a/absl/functional/function_ref_test.cc +++ b/absl/functional/function_ref_test.cc
@@ -16,26 +16,31 @@ #include <functional> #include <memory> +#include <type_traits> +#include <utility> #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/container/internal/test_instance_tracker.h" #include "absl/functional/any_invocable.h" #include "absl/memory/memory.h" +#include "absl/utility/utility.h" namespace absl { ABSL_NAMESPACE_BEGIN namespace { -void RunFun(FunctionRef<void()> f) { f(); } +int Function() { return 1337; } -TEST(FunctionRefTest, Lambda) { - bool ran = false; - RunFun([&] { ran = true; }); - EXPECT_TRUE(ran); +template <typename T> +T Dereference(const T* v) { + return *v; } -int Function() { return 1337; } +template <typename T> +T Copy(const T& v) { + return v; +} TEST(FunctionRefTest, Function1) { FunctionRef<int()> ref(&Function); @@ -251,11 +256,11 @@ std::is_same<Invoker<void, Trivial>, void (*)(VoidPtr, Trivial)>::value, "Small trivial types should be passed by value"); static_assert(std::is_same<Invoker<void, LargeTrivial>, - void (*)(VoidPtr, LargeTrivial &&)>::value, + void (*)(VoidPtr, LargeTrivial&&)>::value, "Large trivial types should be passed by rvalue reference"); static_assert( std::is_same<Invoker<void, CopyableMovableInstance>, - void (*)(VoidPtr, CopyableMovableInstance &&)>::value, + void (*)(VoidPtr, CopyableMovableInstance&&)>::value, "Types with copy/move ctor should be passed by rvalue reference"); // References are passed as references. @@ -268,7 +273,7 @@ "Reference types should be preserved"); static_assert( std::is_same<Invoker<void, CopyableMovableInstance&&>, - void (*)(VoidPtr, CopyableMovableInstance &&)>::value, + void (*)(VoidPtr, CopyableMovableInstance&&)>::value, "Reference types should be preserved"); // Make sure the address of an object received by reference is the same as the @@ -298,6 +303,61 @@ ref(obj); } +TEST(FunctionRefTest, CorrectConstQualifiers) { + struct S { + int operator()() { return 42; } + int operator()() const { return 1337; } + }; + S s; + EXPECT_EQ(42, FunctionRef<int()>(s)()); + EXPECT_EQ(1337, FunctionRef<int() const>(s)()); + EXPECT_EQ(1337, FunctionRef<int()>(std::as_const(s))()); +} + +TEST(FunctionRefTest, Lambdas) { + // Stateless lambdas implicitly convert to function pointers, so their + // mutability is irrelevant. + EXPECT_TRUE(FunctionRef<bool()>([]() /*const*/ { return true; })()); + EXPECT_TRUE(FunctionRef<bool()>([]() mutable { return true; })()); + EXPECT_TRUE(FunctionRef<bool() const>([]() /*const*/ { return true; })()); +#if defined(__clang__) || (ABSL_INTERNAL_CPLUSPLUS_LANG >= 202002L && \ + defined(_MSC_VER) && !defined(__EDG__)) + // MSVC has problems compiling the following code pre-C++20: + // const auto f = []() mutable {}; + // f(); + // EDG's MSVC-compatible mode (which Visual C++ uses for Intellisense) + // exhibits the bug in C++20 as well. So we don't support them. + EXPECT_TRUE(FunctionRef<bool() const>([]() mutable { return true; })()); +#endif + + // Stateful lambdas are not implicitly convertible to function pointers, so + // a const stateful lambda is not mutably callable. + EXPECT_TRUE(FunctionRef<bool()>([v = true]() /*const*/ { return v; })()); + EXPECT_TRUE(FunctionRef<bool()>([v = true]() mutable { return v; })()); + EXPECT_TRUE( + FunctionRef<bool() const>([v = true]() /*const*/ { return v; })()); + const auto func = [v = true]() mutable { return v; }; + static_assert( + !std::is_convertible_v<decltype(func), FunctionRef<bool() const>>); +} + +#if ABSL_INTERNAL_CPLUSPLUS_LANG >= 202002L +TEST(FunctionRefTest, NonTypeParameter) { + EXPECT_EQ(1337, FunctionRef<int()>(nontype<&Function>)()); + EXPECT_EQ(42, FunctionRef<int()>(nontype<&Copy<int>>, 42)()); + EXPECT_EQ(42, FunctionRef<int()>(nontype<&Dereference<int>>, + &std::integral_constant<int, 42>::value)()); +} +#endif + +TEST(FunctionRefTest, OptionalArguments) { + struct S { + int operator()(int = 0) const { return 1337; } + }; + S s; + EXPECT_EQ(1337, FunctionRef<int()>(s)()); +} + } // namespace ABSL_NAMESPACE_END } // namespace absl
diff --git a/absl/functional/internal/function_ref.h b/absl/functional/internal/function_ref.h index 27d45b8..0796364 100644 --- a/absl/functional/internal/function_ref.h +++ b/absl/functional/internal/function_ref.h
@@ -72,8 +72,25 @@ // static_cast<R> handles the case the return type is void. template <typename Obj, typename R, typename... Args> R InvokeObject(VoidPtr ptr, typename ForwardT<Args>::type... args) { - auto o = static_cast<const Obj*>(ptr.obj); - return static_cast<R>(std::invoke(*o, std::forward<Args>(args)...)); + using T = std::remove_reference_t<Obj>; + return static_cast<R>(std::invoke( + std::forward<Obj>(*const_cast<T*>(static_cast<const T*>(ptr.obj))), + std::forward<typename ForwardT<Args>::type>(args)...)); +} + +template <typename Obj, typename Fun, Fun F, typename R, typename... Args> +R InvokeObject(VoidPtr ptr, typename ForwardT<Args>::type... args) { + using T = std::remove_reference_t<Obj>; + return static_cast<R>( + F(std::forward<Obj>(*const_cast<T*>(static_cast<const T*>(ptr.obj))), + std::forward<typename ForwardT<Args>::type>(args)...)); +} + +template <typename T, typename Fun, Fun F, typename R, typename... Args> +R InvokePtr(VoidPtr ptr, typename ForwardT<Args>::type... args) { + return static_cast<R>( + F(const_cast<T*>(static_cast<const T*>(ptr.obj)), + std::forward<typename ForwardT<Args>::type>(args)...)); } template <typename Fun, typename R, typename... Args> @@ -82,6 +99,12 @@ return static_cast<R>(std::invoke(f, std::forward<Args>(args)...)); } +template <typename Fun, Fun F, typename R, typename... Args> +R InvokeFunction(VoidPtr, typename ForwardT<Args>::type... args) { + return static_cast<R>( + F(std::forward<typename ForwardT<Args>::type>(args)...)); +} + template <typename Sig> void AssertNonNull(const std::function<Sig>& f) { assert(f != nullptr); @@ -98,7 +121,7 @@ void AssertNonNull(const F&) {} template <typename F, typename C> -void AssertNonNull(F C::*f) { +void AssertNonNull(F C::* f) { assert(f != nullptr); (void)f; }
diff --git a/absl/utility/utility.h b/absl/utility/utility.h index 4637b03..4d72c31 100644 --- a/absl/utility/utility.h +++ b/absl/utility/utility.h
@@ -49,6 +49,19 @@ using std::make_integer_sequence; using std::move; +#if ABSL_INTERNAL_CPLUSPLUS_LANG >= 202002L +// Backfill for std::nontype_t. An instance of this class can be provided as a +// disambiguation tag to `absl::function_ref` to pass the address of a known +// callable at compile time. +// Requires C++20 due to `auto` template parameter. +template <auto> +struct nontype_t { + explicit nontype_t() = default; +}; +template <auto V> +constexpr nontype_t<V> nontype{}; +#endif + ABSL_NAMESPACE_END } // namespace absl