1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker //===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===// 10*523fa7a6SAndroid Build Coastguard Worker // 11*523fa7a6SAndroid Build Coastguard Worker // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 12*523fa7a6SAndroid Build Coastguard Worker // See https://llvm.org/LICENSE.txt for license information. 13*523fa7a6SAndroid Build Coastguard Worker // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 14*523fa7a6SAndroid Build Coastguard Worker // 15*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===// 16*523fa7a6SAndroid Build Coastguard Worker // 17*523fa7a6SAndroid Build Coastguard Worker // This file contains some extension to <functional>. 18*523fa7a6SAndroid Build Coastguard Worker // 19*523fa7a6SAndroid Build Coastguard Worker // No library is required when using these functions. 20*523fa7a6SAndroid Build Coastguard Worker // 21*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===// 22*523fa7a6SAndroid Build Coastguard Worker // Extra additions to <functional> 23*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===// 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker /// An efficient, type-erasing, non-owning reference to a callable. This is 26*523fa7a6SAndroid Build Coastguard Worker /// intended for use as the type of a function parameter that is not used 27*523fa7a6SAndroid Build Coastguard Worker /// after the function in question returns. 28*523fa7a6SAndroid Build Coastguard Worker /// 29*523fa7a6SAndroid Build Coastguard Worker /// This class does not own the callable, so it is not in general safe to store 30*523fa7a6SAndroid Build Coastguard Worker /// a FunctionRef. 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Worker // torch::executor: modified from llvm::function_ref 33*523fa7a6SAndroid Build Coastguard Worker // see https://www.foonathan.net/2017/01/function-ref-implementation/ 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker #pragma once 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker #include <cstdint> 38*523fa7a6SAndroid Build Coastguard Worker #include <type_traits> 39*523fa7a6SAndroid Build Coastguard Worker #include <utility> 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker namespace executorch { 42*523fa7a6SAndroid Build Coastguard Worker namespace extension { 43*523fa7a6SAndroid Build Coastguard Worker namespace pytree { 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===// 46*523fa7a6SAndroid Build Coastguard Worker // Features from C++20 47*523fa7a6SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===// 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker namespace internal { 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker template <typename T> 52*523fa7a6SAndroid Build Coastguard Worker struct remove_cvref { 53*523fa7a6SAndroid Build Coastguard Worker using type = 54*523fa7a6SAndroid Build Coastguard Worker typename std::remove_cv<typename std::remove_reference<T>::type>::type; 55*523fa7a6SAndroid Build Coastguard Worker }; 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker template <typename T> 58*523fa7a6SAndroid Build Coastguard Worker using remove_cvref_t = typename remove_cvref<T>::type; 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker } // namespace internal 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker template <typename Fn> 63*523fa7a6SAndroid Build Coastguard Worker class FunctionRef; 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker template <typename Ret, typename... Params> 66*523fa7a6SAndroid Build Coastguard Worker class FunctionRef<Ret(Params...)> { 67*523fa7a6SAndroid Build Coastguard Worker Ret (*callback_)(const void* memory, Params... params) = nullptr; 68*523fa7a6SAndroid Build Coastguard Worker union Storage { 69*523fa7a6SAndroid Build Coastguard Worker void* callable; 70*523fa7a6SAndroid Build Coastguard Worker Ret (*function)(Params...); 71*523fa7a6SAndroid Build Coastguard Worker } storage_; 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker public: 74*523fa7a6SAndroid Build Coastguard Worker FunctionRef() = default; FunctionRef(std::nullptr_t)75*523fa7a6SAndroid Build Coastguard Worker explicit FunctionRef(std::nullptr_t) {} 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker /** 78*523fa7a6SAndroid Build Coastguard Worker * Case 1: A callable object passed by lvalue reference. 79*523fa7a6SAndroid Build Coastguard Worker * Taking rvalue reference is error prone because the object will be always 80*523fa7a6SAndroid Build Coastguard Worker * be destroyed immediately. 81*523fa7a6SAndroid Build Coastguard Worker */ 82*523fa7a6SAndroid Build Coastguard Worker template < 83*523fa7a6SAndroid Build Coastguard Worker typename Callable, 84*523fa7a6SAndroid Build Coastguard Worker // This is not the copy-constructor. 85*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if< 86*523fa7a6SAndroid Build Coastguard Worker !std::is_same<internal::remove_cvref_t<Callable>, FunctionRef>::value, 87*523fa7a6SAndroid Build Coastguard Worker int32_t>::type = 0, 88*523fa7a6SAndroid Build Coastguard Worker // Avoid lvalue reference to non-capturing lambda. 89*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if< 90*523fa7a6SAndroid Build Coastguard Worker !std::is_convertible<Callable, Ret (*)(Params...)>::value, 91*523fa7a6SAndroid Build Coastguard Worker int32_t>::type = 0, 92*523fa7a6SAndroid Build Coastguard Worker // Functor must be callable and return a suitable type. 93*523fa7a6SAndroid Build Coastguard Worker // To make this container type safe, we need to ensure either: 94*523fa7a6SAndroid Build Coastguard Worker // 1. The return type is void. 95*523fa7a6SAndroid Build Coastguard Worker // 2. Or the resulting type from calling the callable is convertible to 96*523fa7a6SAndroid Build Coastguard Worker // the declared return type. 97*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if< 98*523fa7a6SAndroid Build Coastguard Worker std::is_void<Ret>::value || 99*523fa7a6SAndroid Build Coastguard Worker std::is_convertible< 100*523fa7a6SAndroid Build Coastguard Worker decltype(std::declval<Callable>()(std::declval<Params>()...)), 101*523fa7a6SAndroid Build Coastguard Worker Ret>::value, 102*523fa7a6SAndroid Build Coastguard Worker int32_t>::type = 0> FunctionRef(Callable & callable)103*523fa7a6SAndroid Build Coastguard Worker explicit FunctionRef(Callable& callable) 104*523fa7a6SAndroid Build Coastguard Worker : callback_([](const void* memory, Params... params) { 105*523fa7a6SAndroid Build Coastguard Worker auto& storage = *static_cast<const Storage*>(memory); 106*523fa7a6SAndroid Build Coastguard Worker auto& callable = *static_cast<Callable*>(storage.callable); 107*523fa7a6SAndroid Build Coastguard Worker return static_cast<Ret>(callable(std::forward<Params>(params)...)); 108*523fa7a6SAndroid Build Coastguard Worker }) { 109*523fa7a6SAndroid Build Coastguard Worker storage_.callable = &callable; 110*523fa7a6SAndroid Build Coastguard Worker } 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker /** 113*523fa7a6SAndroid Build Coastguard Worker * Case 2: A plain function pointer. 114*523fa7a6SAndroid Build Coastguard Worker * Instead of storing an opaque pointer to underlying callable object, 115*523fa7a6SAndroid Build Coastguard Worker * store a function pointer directly. 116*523fa7a6SAndroid Build Coastguard Worker * Note that in the future a variant which coerces compatible function 117*523fa7a6SAndroid Build Coastguard Worker * pointers could be implemented by erasing the storage type. 118*523fa7a6SAndroid Build Coastguard Worker */ FunctionRef(Ret (* ptr)(Params...))119*523fa7a6SAndroid Build Coastguard Worker /* implicit */ FunctionRef(Ret (*ptr)(Params...)) 120*523fa7a6SAndroid Build Coastguard Worker : callback_([](const void* memory, Params... params) { 121*523fa7a6SAndroid Build Coastguard Worker auto& storage = *static_cast<const Storage*>(memory); 122*523fa7a6SAndroid Build Coastguard Worker return storage.function(std::forward<Params>(params)...); 123*523fa7a6SAndroid Build Coastguard Worker }) { 124*523fa7a6SAndroid Build Coastguard Worker storage_.function = ptr; 125*523fa7a6SAndroid Build Coastguard Worker } 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Worker /** 128*523fa7a6SAndroid Build Coastguard Worker * Case 3: Implicit conversion from lambda to FunctionRef. 129*523fa7a6SAndroid Build Coastguard Worker * A common use pattern is like: 130*523fa7a6SAndroid Build Coastguard Worker * void foo(FunctionRef<...>) {...} 131*523fa7a6SAndroid Build Coastguard Worker * foo([](...){...}) 132*523fa7a6SAndroid Build Coastguard Worker * Here constructors for non const lvalue reference or function pointer 133*523fa7a6SAndroid Build Coastguard Worker * would not work because they do not cover implicit conversion from rvalue 134*523fa7a6SAndroid Build Coastguard Worker * lambda. 135*523fa7a6SAndroid Build Coastguard Worker * We need to define a constructor for capturing temporary callables and 136*523fa7a6SAndroid Build Coastguard Worker * always try to convert the lambda to a function pointer behind the scene. 137*523fa7a6SAndroid Build Coastguard Worker */ 138*523fa7a6SAndroid Build Coastguard Worker template < 139*523fa7a6SAndroid Build Coastguard Worker typename Function, 140*523fa7a6SAndroid Build Coastguard Worker // This is not the copy-constructor. 141*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if< 142*523fa7a6SAndroid Build Coastguard Worker !std::is_same<Function, FunctionRef>::value, 143*523fa7a6SAndroid Build Coastguard Worker int32_t>::type = 0, 144*523fa7a6SAndroid Build Coastguard Worker // Function is convertible to pointer of (Params...) -> Ret. 145*523fa7a6SAndroid Build Coastguard Worker typename std::enable_if< 146*523fa7a6SAndroid Build Coastguard Worker std::is_convertible<Function, Ret (*)(Params...)>::value, 147*523fa7a6SAndroid Build Coastguard Worker int32_t>::type = 0> FunctionRef(const Function & function)148*523fa7a6SAndroid Build Coastguard Worker /* implicit */ FunctionRef(const Function& function) 149*523fa7a6SAndroid Build Coastguard Worker : FunctionRef(static_cast<Ret (*)(Params...)>(function)) {} 150*523fa7a6SAndroid Build Coastguard Worker operator()151*523fa7a6SAndroid Build Coastguard Worker Ret operator()(Params... params) const { 152*523fa7a6SAndroid Build Coastguard Worker return callback_(&storage_, std::forward<Params>(params)...); 153*523fa7a6SAndroid Build Coastguard Worker } 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker explicit operator bool() const { 156*523fa7a6SAndroid Build Coastguard Worker return callback_; 157*523fa7a6SAndroid Build Coastguard Worker } 158*523fa7a6SAndroid Build Coastguard Worker }; 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard Worker } // namespace pytree 161*523fa7a6SAndroid Build Coastguard Worker } // namespace extension 162*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch 163*523fa7a6SAndroid Build Coastguard Worker 164*523fa7a6SAndroid Build Coastguard Worker namespace torch { 165*523fa7a6SAndroid Build Coastguard Worker namespace executor { 166*523fa7a6SAndroid Build Coastguard Worker namespace pytree { 167*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved 168*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces. 169*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::pytree::FunctionRef; 170*523fa7a6SAndroid Build Coastguard Worker } // namespace pytree 171*523fa7a6SAndroid Build Coastguard Worker } // namespace executor 172*523fa7a6SAndroid Build Coastguard Worker } // namespace torch 173