1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 //===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===// 10 // 11 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 12 // See https://llvm.org/LICENSE.txt for license information. 13 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 14 // 15 //===----------------------------------------------------------------------===// 16 // 17 // This file contains some extension to <functional>. 18 // 19 // No library is required when using these functions. 20 // 21 //===----------------------------------------------------------------------===// 22 // Extra additions to <functional> 23 //===----------------------------------------------------------------------===// 24 25 /// An efficient, type-erasing, non-owning reference to a callable. This is 26 /// intended for use as the type of a function parameter that is not used 27 /// after the function in question returns. 28 /// 29 /// This class does not own the callable, so it is not in general safe to store 30 /// a FunctionRef. 31 32 // torch::executor: modified from llvm::function_ref 33 // see https://www.foonathan.net/2017/01/function-ref-implementation/ 34 35 #pragma once 36 37 #include <cstdint> 38 #include <type_traits> 39 #include <utility> 40 41 namespace executorch { 42 namespace extension { 43 namespace pytree { 44 45 //===----------------------------------------------------------------------===// 46 // Features from C++20 47 //===----------------------------------------------------------------------===// 48 49 namespace internal { 50 51 template <typename T> 52 struct remove_cvref { 53 using type = 54 typename std::remove_cv<typename std::remove_reference<T>::type>::type; 55 }; 56 57 template <typename T> 58 using remove_cvref_t = typename remove_cvref<T>::type; 59 60 } // namespace internal 61 62 template <typename Fn> 63 class FunctionRef; 64 65 template <typename Ret, typename... Params> 66 class FunctionRef<Ret(Params...)> { 67 Ret (*callback_)(const void* memory, Params... params) = nullptr; 68 union Storage { 69 void* callable; 70 Ret (*function)(Params...); 71 } storage_; 72 73 public: 74 FunctionRef() = default; FunctionRef(std::nullptr_t)75 explicit FunctionRef(std::nullptr_t) {} 76 77 /** 78 * Case 1: A callable object passed by lvalue reference. 79 * Taking rvalue reference is error prone because the object will be always 80 * be destroyed immediately. 81 */ 82 template < 83 typename Callable, 84 // This is not the copy-constructor. 85 typename std::enable_if< 86 !std::is_same<internal::remove_cvref_t<Callable>, FunctionRef>::value, 87 int32_t>::type = 0, 88 // Avoid lvalue reference to non-capturing lambda. 89 typename std::enable_if< 90 !std::is_convertible<Callable, Ret (*)(Params...)>::value, 91 int32_t>::type = 0, 92 // Functor must be callable and return a suitable type. 93 // To make this container type safe, we need to ensure either: 94 // 1. The return type is void. 95 // 2. Or the resulting type from calling the callable is convertible to 96 // the declared return type. 97 typename std::enable_if< 98 std::is_void<Ret>::value || 99 std::is_convertible< 100 decltype(std::declval<Callable>()(std::declval<Params>()...)), 101 Ret>::value, 102 int32_t>::type = 0> FunctionRef(Callable & callable)103 explicit FunctionRef(Callable& callable) 104 : callback_([](const void* memory, Params... params) { 105 auto& storage = *static_cast<const Storage*>(memory); 106 auto& callable = *static_cast<Callable*>(storage.callable); 107 return static_cast<Ret>(callable(std::forward<Params>(params)...)); 108 }) { 109 storage_.callable = &callable; 110 } 111 112 /** 113 * Case 2: A plain function pointer. 114 * Instead of storing an opaque pointer to underlying callable object, 115 * store a function pointer directly. 116 * Note that in the future a variant which coerces compatible function 117 * pointers could be implemented by erasing the storage type. 118 */ FunctionRef(Ret (* ptr)(Params...))119 /* implicit */ FunctionRef(Ret (*ptr)(Params...)) 120 : callback_([](const void* memory, Params... params) { 121 auto& storage = *static_cast<const Storage*>(memory); 122 return storage.function(std::forward<Params>(params)...); 123 }) { 124 storage_.function = ptr; 125 } 126 127 /** 128 * Case 3: Implicit conversion from lambda to FunctionRef. 129 * A common use pattern is like: 130 * void foo(FunctionRef<...>) {...} 131 * foo([](...){...}) 132 * Here constructors for non const lvalue reference or function pointer 133 * would not work because they do not cover implicit conversion from rvalue 134 * lambda. 135 * We need to define a constructor for capturing temporary callables and 136 * always try to convert the lambda to a function pointer behind the scene. 137 */ 138 template < 139 typename Function, 140 // This is not the copy-constructor. 141 typename std::enable_if< 142 !std::is_same<Function, FunctionRef>::value, 143 int32_t>::type = 0, 144 // Function is convertible to pointer of (Params...) -> Ret. 145 typename std::enable_if< 146 std::is_convertible<Function, Ret (*)(Params...)>::value, 147 int32_t>::type = 0> FunctionRef(const Function & function)148 /* implicit */ FunctionRef(const Function& function) 149 : FunctionRef(static_cast<Ret (*)(Params...)>(function)) {} 150 operator()151 Ret operator()(Params... params) const { 152 return callback_(&storage_, std::forward<Params>(params)...); 153 } 154 155 explicit operator bool() const { 156 return callback_; 157 } 158 }; 159 160 } // namespace pytree 161 } // namespace extension 162 } // namespace executorch 163 164 namespace torch { 165 namespace executor { 166 namespace pytree { 167 // TODO(T197294990): Remove these deprecated aliases once all users have moved 168 // to the new `::executorch` namespaces. 169 using ::executorch::extension::pytree::FunctionRef; 170 } // namespace pytree 171 } // namespace executor 172 } // namespace torch 173