xref: /aosp_15_r20/external/executorch/extension/pytree/function_ref.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 //===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
10 //
11 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
12 // See https://llvm.org/LICENSE.txt for license information.
13 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14 //
15 //===----------------------------------------------------------------------===//
16 //
17 // This file contains some extension to <functional>.
18 //
19 // No library is required when using these functions.
20 //
21 //===----------------------------------------------------------------------===//
22 //     Extra additions to <functional>
23 //===----------------------------------------------------------------------===//
24 
25 /// An efficient, type-erasing, non-owning reference to a callable. This is
26 /// intended for use as the type of a function parameter that is not used
27 /// after the function in question returns.
28 ///
29 /// This class does not own the callable, so it is not in general safe to store
30 /// a FunctionRef.
31 
32 // torch::executor: modified from llvm::function_ref
33 // see https://www.foonathan.net/2017/01/function-ref-implementation/
34 
35 #pragma once
36 
37 #include <cstdint>
38 #include <type_traits>
39 #include <utility>
40 
41 namespace executorch {
42 namespace extension {
43 namespace pytree {
44 
45 //===----------------------------------------------------------------------===//
46 //     Features from C++20
47 //===----------------------------------------------------------------------===//
48 
49 namespace internal {
50 
51 template <typename T>
52 struct remove_cvref {
53   using type =
54       typename std::remove_cv<typename std::remove_reference<T>::type>::type;
55 };
56 
57 template <typename T>
58 using remove_cvref_t = typename remove_cvref<T>::type;
59 
60 } // namespace internal
61 
62 template <typename Fn>
63 class FunctionRef;
64 
65 template <typename Ret, typename... Params>
66 class FunctionRef<Ret(Params...)> {
67   Ret (*callback_)(const void* memory, Params... params) = nullptr;
68   union Storage {
69     void* callable;
70     Ret (*function)(Params...);
71   } storage_;
72 
73  public:
74   FunctionRef() = default;
FunctionRef(std::nullptr_t)75   explicit FunctionRef(std::nullptr_t) {}
76 
77   /**
78    * Case 1: A callable object passed by lvalue reference.
79    * Taking rvalue reference is error prone because the object will be always
80    * be destroyed immediately.
81    */
82   template <
83       typename Callable,
84       // This is not the copy-constructor.
85       typename std::enable_if<
86           !std::is_same<internal::remove_cvref_t<Callable>, FunctionRef>::value,
87           int32_t>::type = 0,
88       // Avoid lvalue reference to non-capturing lambda.
89       typename std::enable_if<
90           !std::is_convertible<Callable, Ret (*)(Params...)>::value,
91           int32_t>::type = 0,
92       // Functor must be callable and return a suitable type.
93       // To make this container type safe, we need to ensure either:
94       // 1. The return type is void.
95       // 2. Or the resulting type from calling the callable is convertible to
96       // the declared return type.
97       typename std::enable_if<
98           std::is_void<Ret>::value ||
99               std::is_convertible<
100                   decltype(std::declval<Callable>()(std::declval<Params>()...)),
101                   Ret>::value,
102           int32_t>::type = 0>
FunctionRef(Callable & callable)103   explicit FunctionRef(Callable& callable)
104       : callback_([](const void* memory, Params... params) {
105           auto& storage = *static_cast<const Storage*>(memory);
106           auto& callable = *static_cast<Callable*>(storage.callable);
107           return static_cast<Ret>(callable(std::forward<Params>(params)...));
108         }) {
109     storage_.callable = &callable;
110   }
111 
112   /**
113    * Case 2: A plain function pointer.
114    * Instead of storing an opaque pointer to underlying callable object,
115    * store a function pointer directly.
116    * Note that in the future a variant which coerces compatible function
117    * pointers could be implemented by erasing the storage type.
118    */
FunctionRef(Ret (* ptr)(Params...))119   /* implicit */ FunctionRef(Ret (*ptr)(Params...))
120       : callback_([](const void* memory, Params... params) {
121           auto& storage = *static_cast<const Storage*>(memory);
122           return storage.function(std::forward<Params>(params)...);
123         }) {
124     storage_.function = ptr;
125   }
126 
127   /**
128    * Case 3: Implicit conversion from lambda to FunctionRef.
129    * A common use pattern is like:
130    * void foo(FunctionRef<...>) {...}
131    * foo([](...){...})
132    * Here constructors for non const lvalue reference or function pointer
133    * would not work because they do not cover implicit conversion from rvalue
134    * lambda.
135    * We need to define a constructor for capturing temporary callables and
136    * always try to convert the lambda to a function pointer behind the scene.
137    */
138   template <
139       typename Function,
140       // This is not the copy-constructor.
141       typename std::enable_if<
142           !std::is_same<Function, FunctionRef>::value,
143           int32_t>::type = 0,
144       // Function is convertible to pointer of (Params...) -> Ret.
145       typename std::enable_if<
146           std::is_convertible<Function, Ret (*)(Params...)>::value,
147           int32_t>::type = 0>
FunctionRef(const Function & function)148   /* implicit */ FunctionRef(const Function& function)
149       : FunctionRef(static_cast<Ret (*)(Params...)>(function)) {}
150 
operator()151   Ret operator()(Params... params) const {
152     return callback_(&storage_, std::forward<Params>(params)...);
153   }
154 
155   explicit operator bool() const {
156     return callback_;
157   }
158 };
159 
160 } // namespace pytree
161 } // namespace extension
162 } // namespace executorch
163 
164 namespace torch {
165 namespace executor {
166 namespace pytree {
167 // TODO(T197294990): Remove these deprecated aliases once all users have moved
168 // to the new `::executorch` namespaces.
169 using ::executorch::extension::pytree::FunctionRef;
170 } // namespace pytree
171 } // namespace executor
172 } // namespace torch
173