xref: /aosp_15_r20/external/executorch/extension/pytree/function_ref.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker //===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
10*523fa7a6SAndroid Build Coastguard Worker //
11*523fa7a6SAndroid Build Coastguard Worker // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
12*523fa7a6SAndroid Build Coastguard Worker // See https://llvm.org/LICENSE.txt for license information.
13*523fa7a6SAndroid Build Coastguard Worker // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14*523fa7a6SAndroid Build Coastguard Worker //
15*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
16*523fa7a6SAndroid Build Coastguard Worker //
17*523fa7a6SAndroid Build Coastguard Worker // This file contains some extension to <functional>.
18*523fa7a6SAndroid Build Coastguard Worker //
19*523fa7a6SAndroid Build Coastguard Worker // No library is required when using these functions.
20*523fa7a6SAndroid Build Coastguard Worker //
21*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
22*523fa7a6SAndroid Build Coastguard Worker //     Extra additions to <functional>
23*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
24*523fa7a6SAndroid Build Coastguard Worker 
25*523fa7a6SAndroid Build Coastguard Worker /// An efficient, type-erasing, non-owning reference to a callable. This is
26*523fa7a6SAndroid Build Coastguard Worker /// intended for use as the type of a function parameter that is not used
27*523fa7a6SAndroid Build Coastguard Worker /// after the function in question returns.
28*523fa7a6SAndroid Build Coastguard Worker ///
29*523fa7a6SAndroid Build Coastguard Worker /// This class does not own the callable, so it is not in general safe to store
30*523fa7a6SAndroid Build Coastguard Worker /// a FunctionRef.
31*523fa7a6SAndroid Build Coastguard Worker 
32*523fa7a6SAndroid Build Coastguard Worker // torch::executor: modified from llvm::function_ref
33*523fa7a6SAndroid Build Coastguard Worker // see https://www.foonathan.net/2017/01/function-ref-implementation/
34*523fa7a6SAndroid Build Coastguard Worker 
35*523fa7a6SAndroid Build Coastguard Worker #pragma once
36*523fa7a6SAndroid Build Coastguard Worker 
37*523fa7a6SAndroid Build Coastguard Worker #include <cstdint>
38*523fa7a6SAndroid Build Coastguard Worker #include <type_traits>
39*523fa7a6SAndroid Build Coastguard Worker #include <utility>
40*523fa7a6SAndroid Build Coastguard Worker 
41*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
42*523fa7a6SAndroid Build Coastguard Worker namespace extension {
43*523fa7a6SAndroid Build Coastguard Worker namespace pytree {
44*523fa7a6SAndroid Build Coastguard Worker 
45*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
46*523fa7a6SAndroid Build Coastguard Worker //     Features from C++20
47*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
48*523fa7a6SAndroid Build Coastguard Worker 
49*523fa7a6SAndroid Build Coastguard Worker namespace internal {
50*523fa7a6SAndroid Build Coastguard Worker 
51*523fa7a6SAndroid Build Coastguard Worker template <typename T>
52*523fa7a6SAndroid Build Coastguard Worker struct remove_cvref {
53*523fa7a6SAndroid Build Coastguard Worker   using type =
54*523fa7a6SAndroid Build Coastguard Worker       typename std::remove_cv<typename std::remove_reference<T>::type>::type;
55*523fa7a6SAndroid Build Coastguard Worker };
56*523fa7a6SAndroid Build Coastguard Worker 
57*523fa7a6SAndroid Build Coastguard Worker template <typename T>
58*523fa7a6SAndroid Build Coastguard Worker using remove_cvref_t = typename remove_cvref<T>::type;
59*523fa7a6SAndroid Build Coastguard Worker 
60*523fa7a6SAndroid Build Coastguard Worker } // namespace internal
61*523fa7a6SAndroid Build Coastguard Worker 
62*523fa7a6SAndroid Build Coastguard Worker template <typename Fn>
63*523fa7a6SAndroid Build Coastguard Worker class FunctionRef;
64*523fa7a6SAndroid Build Coastguard Worker 
65*523fa7a6SAndroid Build Coastguard Worker template <typename Ret, typename... Params>
66*523fa7a6SAndroid Build Coastguard Worker class FunctionRef<Ret(Params...)> {
67*523fa7a6SAndroid Build Coastguard Worker   Ret (*callback_)(const void* memory, Params... params) = nullptr;
68*523fa7a6SAndroid Build Coastguard Worker   union Storage {
69*523fa7a6SAndroid Build Coastguard Worker     void* callable;
70*523fa7a6SAndroid Build Coastguard Worker     Ret (*function)(Params...);
71*523fa7a6SAndroid Build Coastguard Worker   } storage_;
72*523fa7a6SAndroid Build Coastguard Worker 
73*523fa7a6SAndroid Build Coastguard Worker  public:
74*523fa7a6SAndroid Build Coastguard Worker   FunctionRef() = default;
FunctionRef(std::nullptr_t)75*523fa7a6SAndroid Build Coastguard Worker   explicit FunctionRef(std::nullptr_t) {}
76*523fa7a6SAndroid Build Coastguard Worker 
77*523fa7a6SAndroid Build Coastguard Worker   /**
78*523fa7a6SAndroid Build Coastguard Worker    * Case 1: A callable object passed by lvalue reference.
79*523fa7a6SAndroid Build Coastguard Worker    * Taking rvalue reference is error prone because the object will be always
80*523fa7a6SAndroid Build Coastguard Worker    * be destroyed immediately.
81*523fa7a6SAndroid Build Coastguard Worker    */
82*523fa7a6SAndroid Build Coastguard Worker   template <
83*523fa7a6SAndroid Build Coastguard Worker       typename Callable,
84*523fa7a6SAndroid Build Coastguard Worker       // This is not the copy-constructor.
85*523fa7a6SAndroid Build Coastguard Worker       typename std::enable_if<
86*523fa7a6SAndroid Build Coastguard Worker           !std::is_same<internal::remove_cvref_t<Callable>, FunctionRef>::value,
87*523fa7a6SAndroid Build Coastguard Worker           int32_t>::type = 0,
88*523fa7a6SAndroid Build Coastguard Worker       // Avoid lvalue reference to non-capturing lambda.
89*523fa7a6SAndroid Build Coastguard Worker       typename std::enable_if<
90*523fa7a6SAndroid Build Coastguard Worker           !std::is_convertible<Callable, Ret (*)(Params...)>::value,
91*523fa7a6SAndroid Build Coastguard Worker           int32_t>::type = 0,
92*523fa7a6SAndroid Build Coastguard Worker       // Functor must be callable and return a suitable type.
93*523fa7a6SAndroid Build Coastguard Worker       // To make this container type safe, we need to ensure either:
94*523fa7a6SAndroid Build Coastguard Worker       // 1. The return type is void.
95*523fa7a6SAndroid Build Coastguard Worker       // 2. Or the resulting type from calling the callable is convertible to
96*523fa7a6SAndroid Build Coastguard Worker       // the declared return type.
97*523fa7a6SAndroid Build Coastguard Worker       typename std::enable_if<
98*523fa7a6SAndroid Build Coastguard Worker           std::is_void<Ret>::value ||
99*523fa7a6SAndroid Build Coastguard Worker               std::is_convertible<
100*523fa7a6SAndroid Build Coastguard Worker                   decltype(std::declval<Callable>()(std::declval<Params>()...)),
101*523fa7a6SAndroid Build Coastguard Worker                   Ret>::value,
102*523fa7a6SAndroid Build Coastguard Worker           int32_t>::type = 0>
FunctionRef(Callable & callable)103*523fa7a6SAndroid Build Coastguard Worker   explicit FunctionRef(Callable& callable)
104*523fa7a6SAndroid Build Coastguard Worker       : callback_([](const void* memory, Params... params) {
105*523fa7a6SAndroid Build Coastguard Worker           auto& storage = *static_cast<const Storage*>(memory);
106*523fa7a6SAndroid Build Coastguard Worker           auto& callable = *static_cast<Callable*>(storage.callable);
107*523fa7a6SAndroid Build Coastguard Worker           return static_cast<Ret>(callable(std::forward<Params>(params)...));
108*523fa7a6SAndroid Build Coastguard Worker         }) {
109*523fa7a6SAndroid Build Coastguard Worker     storage_.callable = &callable;
110*523fa7a6SAndroid Build Coastguard Worker   }
111*523fa7a6SAndroid Build Coastguard Worker 
112*523fa7a6SAndroid Build Coastguard Worker   /**
113*523fa7a6SAndroid Build Coastguard Worker    * Case 2: A plain function pointer.
114*523fa7a6SAndroid Build Coastguard Worker    * Instead of storing an opaque pointer to underlying callable object,
115*523fa7a6SAndroid Build Coastguard Worker    * store a function pointer directly.
116*523fa7a6SAndroid Build Coastguard Worker    * Note that in the future a variant which coerces compatible function
117*523fa7a6SAndroid Build Coastguard Worker    * pointers could be implemented by erasing the storage type.
118*523fa7a6SAndroid Build Coastguard Worker    */
FunctionRef(Ret (* ptr)(Params...))119*523fa7a6SAndroid Build Coastguard Worker   /* implicit */ FunctionRef(Ret (*ptr)(Params...))
120*523fa7a6SAndroid Build Coastguard Worker       : callback_([](const void* memory, Params... params) {
121*523fa7a6SAndroid Build Coastguard Worker           auto& storage = *static_cast<const Storage*>(memory);
122*523fa7a6SAndroid Build Coastguard Worker           return storage.function(std::forward<Params>(params)...);
123*523fa7a6SAndroid Build Coastguard Worker         }) {
124*523fa7a6SAndroid Build Coastguard Worker     storage_.function = ptr;
125*523fa7a6SAndroid Build Coastguard Worker   }
126*523fa7a6SAndroid Build Coastguard Worker 
127*523fa7a6SAndroid Build Coastguard Worker   /**
128*523fa7a6SAndroid Build Coastguard Worker    * Case 3: Implicit conversion from lambda to FunctionRef.
129*523fa7a6SAndroid Build Coastguard Worker    * A common use pattern is like:
130*523fa7a6SAndroid Build Coastguard Worker    * void foo(FunctionRef<...>) {...}
131*523fa7a6SAndroid Build Coastguard Worker    * foo([](...){...})
132*523fa7a6SAndroid Build Coastguard Worker    * Here constructors for non const lvalue reference or function pointer
133*523fa7a6SAndroid Build Coastguard Worker    * would not work because they do not cover implicit conversion from rvalue
134*523fa7a6SAndroid Build Coastguard Worker    * lambda.
135*523fa7a6SAndroid Build Coastguard Worker    * We need to define a constructor for capturing temporary callables and
136*523fa7a6SAndroid Build Coastguard Worker    * always try to convert the lambda to a function pointer behind the scene.
137*523fa7a6SAndroid Build Coastguard Worker    */
138*523fa7a6SAndroid Build Coastguard Worker   template <
139*523fa7a6SAndroid Build Coastguard Worker       typename Function,
140*523fa7a6SAndroid Build Coastguard Worker       // This is not the copy-constructor.
141*523fa7a6SAndroid Build Coastguard Worker       typename std::enable_if<
142*523fa7a6SAndroid Build Coastguard Worker           !std::is_same<Function, FunctionRef>::value,
143*523fa7a6SAndroid Build Coastguard Worker           int32_t>::type = 0,
144*523fa7a6SAndroid Build Coastguard Worker       // Function is convertible to pointer of (Params...) -> Ret.
145*523fa7a6SAndroid Build Coastguard Worker       typename std::enable_if<
146*523fa7a6SAndroid Build Coastguard Worker           std::is_convertible<Function, Ret (*)(Params...)>::value,
147*523fa7a6SAndroid Build Coastguard Worker           int32_t>::type = 0>
FunctionRef(const Function & function)148*523fa7a6SAndroid Build Coastguard Worker   /* implicit */ FunctionRef(const Function& function)
149*523fa7a6SAndroid Build Coastguard Worker       : FunctionRef(static_cast<Ret (*)(Params...)>(function)) {}
150*523fa7a6SAndroid Build Coastguard Worker 
operator()151*523fa7a6SAndroid Build Coastguard Worker   Ret operator()(Params... params) const {
152*523fa7a6SAndroid Build Coastguard Worker     return callback_(&storage_, std::forward<Params>(params)...);
153*523fa7a6SAndroid Build Coastguard Worker   }
154*523fa7a6SAndroid Build Coastguard Worker 
155*523fa7a6SAndroid Build Coastguard Worker   explicit operator bool() const {
156*523fa7a6SAndroid Build Coastguard Worker     return callback_;
157*523fa7a6SAndroid Build Coastguard Worker   }
158*523fa7a6SAndroid Build Coastguard Worker };
159*523fa7a6SAndroid Build Coastguard Worker 
160*523fa7a6SAndroid Build Coastguard Worker } // namespace pytree
161*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
162*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
163*523fa7a6SAndroid Build Coastguard Worker 
164*523fa7a6SAndroid Build Coastguard Worker namespace torch {
165*523fa7a6SAndroid Build Coastguard Worker namespace executor {
166*523fa7a6SAndroid Build Coastguard Worker namespace pytree {
167*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved
168*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces.
169*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::pytree::FunctionRef;
170*523fa7a6SAndroid Build Coastguard Worker } // namespace pytree
171*523fa7a6SAndroid Build Coastguard Worker } // namespace executor
172*523fa7a6SAndroid Build Coastguard Worker } // namespace torch
173