1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 //===----------------------------------------------------------------------===// 10 /// \file runtime/kernel/make_aten_functor_from_et_functor.h 11 /// Defines a template that can be used to create a ATen version of an unboxed 12 /// ExecuTorch kernel. 13 //===----------------------------------------------------------------------===// 14 15 #pragma once 16 #include <type_traits> 17 #include <vector> 18 #if __cplusplus < 201703L 19 #error "This header requires C++17" 20 #endif 21 #include <ATen/native/Resize.h> 22 #include <executorch/extension/kernel_util/type_list.h> 23 #include <executorch/extension/tensor/tensor.h> 24 #include <executorch/runtime/core/evalue.h> 25 #include <torch/torch.h> 26 27 namespace executorch { 28 namespace extension { 29 namespace internal { 30 31 // Map types from ETen to ATen. 32 // This is used to convert ETen arguments into ATen. 33 template <typename T> 34 struct type_map final { 35 using type = T; 36 }; 37 38 // Const. 39 template <typename T> 40 struct type_map<const T> final { 41 using type = const typename type_map<T>::type; 42 }; 43 44 // Ref. 45 template <typename T> 46 struct type_map<T&> final { 47 using type = typename type_map<T>::type&; 48 }; 49 50 // Const ref. 51 template <typename T> 52 struct type_map<const T&> final { 53 using type = const typename type_map<T>::type&; 54 }; 55 56 // Tensor. 57 template <> 58 struct type_map<torch::executor::Tensor> final { 59 using type = at::Tensor; 60 }; 61 62 // Optional. 63 template <class T> 64 struct type_map<torch::executor::optional<T>> final { 65 using type = std::optional<typename type_map<T>::type>; 66 }; 67 68 template <class T> 69 struct type_map<torch::executor::optional<T>&> final { 70 using type = std::optional<typename type_map<T>::type>&; 71 }; 72 73 // ArrayRef. 74 template <class T> 75 struct type_map<torch::executor::ArrayRef<T>> final { 76 using type = at::ArrayRef<typename type_map<T>::type>; 77 }; 78 79 template <typename T> 80 struct remove_const_ref final { 81 using type = std::remove_const_t<std::remove_reference_t<T>>; 82 }; 83 84 // Convert ATen->ETen: input args. 85 // Convert ETen->ATen: output args. 86 // Default argument conversions between ATen and ETen (scalars). 87 template <typename F, typename T, typename Enable = void> 88 struct type_convert final { 89 public: 90 F val; 91 explicit type_convert(F value) : val(value) {} 92 T call() { 93 return static_cast<T>(val); 94 } 95 }; 96 97 // Tensors: ATen to ETen. 98 template <class ATensor, class ETensor> 99 struct type_convert< 100 ATensor, 101 ETensor, 102 std::enable_if_t< 103 std::is_same_v<typename remove_const_ref<ATensor>::type, at::Tensor> && 104 std::is_same_v< 105 typename remove_const_ref<ETensor>::type, 106 torch::executor::Tensor>>> 107 final { 108 explicit type_convert(ATensor value) 109 : value_(value), 110 converted_(from_blob( 111 value_.mutable_data_ptr(), 112 {value_.sizes().begin(), value_.sizes().end()}, 113 ::torch::executor::ScalarType(value_.scalar_type()))) {} 114 115 ETensor call() { 116 return *converted_; 117 } 118 119 private: 120 ATensor value_; 121 TensorPtr converted_; 122 }; 123 124 // Tensors: ETen to ATen. 125 template <class ETensor, class ATensor> 126 struct type_convert< 127 ETensor, 128 ATensor, 129 std::enable_if_t< 130 std::is_same_v<typename remove_const_ref<ATensor>::type, at::Tensor> && 131 std::is_same_v< 132 typename remove_const_ref<ETensor>::type, 133 ::torch::executor::Tensor>>> 134 final { 135 explicit type_convert(ETensor value) 136 : value_(value), 137 converted_(at::from_blob( 138 value_.mutable_data_ptr(), 139 std::vector<int64_t>{value_.sizes().begin(), value_.sizes().end()}, 140 c10::ScalarType(value_.scalar_type()))) {} 141 142 ATensor call() { 143 return converted_; 144 } 145 146 private: 147 ETensor value_; 148 at::Tensor converted_; 149 }; 150 151 // Optionals: ATen to ETen. 152 template <class F, class T> 153 struct type_convert<std::optional<F>, torch::executor::optional<T>> final { 154 public: 155 std::optional<F> val; 156 std::unique_ptr<struct type_convert<F, T>> convert_struct; 157 explicit type_convert(std::optional<F> value) : val(value) {} 158 torch::executor::optional<T> call() { 159 if (val.has_value()) { 160 convert_struct = std::make_unique<struct type_convert<F, T>>( 161 type_convert<F, T>(val.value())); 162 return torch::executor::optional<T>(convert_struct->call()); 163 } else { 164 return torch::executor::optional<T>(); 165 } 166 } 167 }; 168 169 // Optionals: ETen to ATen. 170 template <class F, class T> 171 struct type_convert<torch::executor::optional<F>, std::optional<T>> final { 172 public: 173 torch::executor::optional<F> val; 174 std::unique_ptr<struct type_convert<F, T>> convert_struct; 175 explicit type_convert(torch::executor::optional<F> value) : val(value) {} 176 std::optional<T> call() { 177 if (val.has_value()) { 178 convert_struct = std::make_unique<struct type_convert<F, T>>( 179 type_convert<F, T>(val.value())); 180 return std::optional<T>(convert_struct->call()); 181 } else { 182 return std::optional<T>(); 183 } 184 } 185 }; 186 187 // ArrayRefs: ATen to ETen. 188 template <class F, class T> 189 struct type_convert<c10::ArrayRef<F>, torch::executor::ArrayRef<T>> final { 190 public: 191 c10::ArrayRef<F> val; 192 std::vector<T> converted; 193 std::vector<type_convert<F, T>> converters; 194 explicit type_convert(c10::ArrayRef<F> value) : val(value) { 195 for (int i = 0; i < val.size(); i++) { 196 converters.push_back(type_convert<F, T>(val[i])); 197 } 198 } 199 torch::executor::ArrayRef<T> call() { 200 for (int i = 0; i < val.size(); i++) { 201 converted.push_back(converters[i].call()); 202 } 203 return torch::executor::ArrayRef<T>(converted.data(), converted.size()); 204 } 205 }; 206 207 // ArrayRefs: ETen to ATen. 208 template <class F, class T> 209 struct type_convert<torch::executor::ArrayRef<F>, c10::ArrayRef<T>> final { 210 public: 211 torch::executor::ArrayRef<F> val; 212 std::vector<T> converted; 213 std::vector<type_convert<F, T>> converters; 214 explicit type_convert(torch::executor::ArrayRef<F> value) : val(value) { 215 for (int i = 0; i < val.size(); i++) { 216 converters.push_back(type_convert<F, T>(val[i])); 217 } 218 } 219 c10::ArrayRef<T> call() { 220 for (int i = 0; i < val.size(); i++) { 221 converted.push_back(converters[i].call()); 222 } 223 return c10::ArrayRef<T>(converted); 224 } 225 }; 226 227 template <class F, F f, typename N = int, N index = N(-1)> 228 struct wrapper_impl; 229 230 template <class R, class... Args, R (*f)(Args...), int N> 231 struct wrapper_impl<R (*)(Args...), f, int, N> { 232 static_assert( 233 !(std::is_same<R, at::Tensor&>::value && N == -1), 234 "Can't wrap a kernel with 'Tensor &' return type without specifying an index to the out tensor"); 235 using ReturnType = typename type_map<R>::type; 236 using TupleConvertsType = 237 std::tuple<type_convert<typename type_map<Args>::type, Args>...>; 238 using TupleArgsType = std::tuple<typename type_map<Args>::type...>; 239 static constexpr size_t num_args = sizeof...(Args); 240 static_assert( 241 (N < num_args && 242 std::is_same_v< 243 executorch::extension::kernel_util_internal::element_t< 244 N, 245 executorch::extension::kernel_util_internal::typelist<Args...>>, 246 R>) || 247 N == -1, 248 "The index of the out tensor can't be greater or equal to num_args and " 249 "the Nth argument type has to be the same as the return type."); 250 251 static ReturnType wrap(typename type_map<Args>::type... args) { 252 // The wrapped function that takes ATen argument types, convert them into 253 // ExecuTorch equivalent, call `f` then return the result converted back to 254 // ATen. 255 TupleArgsType args_tuple = std::forward_as_tuple(args...); 256 TupleConvertsType converts = std::forward_as_tuple( 257 type_convert<typename type_map<Args>::type, Args>(args)...); 258 R result = 259 call_functor_with_args(converts, std::make_index_sequence<num_args>()); 260 typename std::remove_reference<ReturnType>::type converted_result = 261 type_convert<R, ReturnType>(result).call(); 262 if constexpr (N == -1) { 263 return converted_result; 264 } else { 265 static_assert( 266 std::is_same_v< 267 typename std::remove_reference<ReturnType>::type, 268 at::Tensor>, 269 "Only support at::Tensor-like return"); 270 ReturnType out = std::get<N>(args_tuple); 271 at::native::resize_output(out, converted_result.sizes()); 272 out.copy_(converted_result); 273 return out; 274 } 275 } 276 277 private: 278 template <size_t... indices> 279 static R call_functor_with_args( 280 TupleConvertsType& converts, 281 std::index_sequence<indices...>) { 282 return f(std::get<indices>(converts).call()...); 283 } 284 }; 285 286 } // namespace internal 287 } // namespace extension 288 } // namespace executorch 289 290 // Wrapper macro for out variant function. N is the index of the out tensor. 291 // We need N to know how to preserve the semantics of modifying out tensor and 292 // return the reference without allocating a new memory buffer for out tensor. 293 #define _WRAP_2(func, N) \ 294 ::executorch::extension::internal:: \ 295 wrapper_impl<decltype(&func), func, decltype(N), N>::wrap 296 #define _WRAP_1(func) \ 297 ::executorch::extension::internal::wrapper_impl<decltype(&func), func>::wrap 298 299 #define _GET_MACRO(_1, _2, NAME, ...) NAME 300 #define WRAP_TO_ATEN(...) _GET_MACRO(__VA_ARGS__, _WRAP_2, _WRAP_1)(__VA_ARGS__) 301