xref: /aosp_15_r20/external/executorch/extension/aten_util/make_aten_functor_from_et_functor.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 //===----------------------------------------------------------------------===//
10 /// \file runtime/kernel/make_aten_functor_from_et_functor.h
11 /// Defines a template that can be used to create a ATen version of an unboxed
12 /// ExecuTorch kernel.
13 //===----------------------------------------------------------------------===//
14 
15 #pragma once
16 #include <type_traits>
17 #include <vector>
18 #if __cplusplus < 201703L
19 #error "This header requires C++17"
20 #endif
21 #include <ATen/native/Resize.h>
22 #include <executorch/extension/kernel_util/type_list.h>
23 #include <executorch/extension/tensor/tensor.h>
24 #include <executorch/runtime/core/evalue.h>
25 #include <torch/torch.h>
26 
27 namespace executorch {
28 namespace extension {
29 namespace internal {
30 
31 // Map types from ETen to ATen.
32 // This is used to convert ETen arguments into ATen.
33 template <typename T>
34 struct type_map final {
35   using type = T;
36 };
37 
38 // Const.
39 template <typename T>
40 struct type_map<const T> final {
41   using type = const typename type_map<T>::type;
42 };
43 
44 // Ref.
45 template <typename T>
46 struct type_map<T&> final {
47   using type = typename type_map<T>::type&;
48 };
49 
50 // Const ref.
51 template <typename T>
52 struct type_map<const T&> final {
53   using type = const typename type_map<T>::type&;
54 };
55 
56 // Tensor.
57 template <>
58 struct type_map<torch::executor::Tensor> final {
59   using type = at::Tensor;
60 };
61 
62 // Optional.
63 template <class T>
64 struct type_map<torch::executor::optional<T>> final {
65   using type = std::optional<typename type_map<T>::type>;
66 };
67 
68 template <class T>
69 struct type_map<torch::executor::optional<T>&> final {
70   using type = std::optional<typename type_map<T>::type>&;
71 };
72 
73 // ArrayRef.
74 template <class T>
75 struct type_map<torch::executor::ArrayRef<T>> final {
76   using type = at::ArrayRef<typename type_map<T>::type>;
77 };
78 
79 template <typename T>
80 struct remove_const_ref final {
81   using type = std::remove_const_t<std::remove_reference_t<T>>;
82 };
83 
84 // Convert ATen->ETen: input args.
85 // Convert ETen->ATen: output args.
86 // Default argument conversions between ATen and ETen (scalars).
87 template <typename F, typename T, typename Enable = void>
88 struct type_convert final {
89  public:
90   F val;
91   explicit type_convert(F value) : val(value) {}
92   T call() {
93     return static_cast<T>(val);
94   }
95 };
96 
97 // Tensors: ATen to ETen.
98 template <class ATensor, class ETensor>
99 struct type_convert<
100     ATensor,
101     ETensor,
102     std::enable_if_t<
103         std::is_same_v<typename remove_const_ref<ATensor>::type, at::Tensor> &&
104         std::is_same_v<
105             typename remove_const_ref<ETensor>::type,
106             torch::executor::Tensor>>>
107     final {
108   explicit type_convert(ATensor value)
109       : value_(value),
110         converted_(from_blob(
111             value_.mutable_data_ptr(),
112             {value_.sizes().begin(), value_.sizes().end()},
113             ::torch::executor::ScalarType(value_.scalar_type()))) {}
114 
115   ETensor call() {
116     return *converted_;
117   }
118 
119  private:
120   ATensor value_;
121   TensorPtr converted_;
122 };
123 
124 // Tensors: ETen to ATen.
125 template <class ETensor, class ATensor>
126 struct type_convert<
127     ETensor,
128     ATensor,
129     std::enable_if_t<
130         std::is_same_v<typename remove_const_ref<ATensor>::type, at::Tensor> &&
131         std::is_same_v<
132             typename remove_const_ref<ETensor>::type,
133             ::torch::executor::Tensor>>>
134     final {
135   explicit type_convert(ETensor value)
136       : value_(value),
137         converted_(at::from_blob(
138             value_.mutable_data_ptr(),
139             std::vector<int64_t>{value_.sizes().begin(), value_.sizes().end()},
140             c10::ScalarType(value_.scalar_type()))) {}
141 
142   ATensor call() {
143     return converted_;
144   }
145 
146  private:
147   ETensor value_;
148   at::Tensor converted_;
149 };
150 
151 // Optionals: ATen to ETen.
152 template <class F, class T>
153 struct type_convert<std::optional<F>, torch::executor::optional<T>> final {
154  public:
155   std::optional<F> val;
156   std::unique_ptr<struct type_convert<F, T>> convert_struct;
157   explicit type_convert(std::optional<F> value) : val(value) {}
158   torch::executor::optional<T> call() {
159     if (val.has_value()) {
160       convert_struct = std::make_unique<struct type_convert<F, T>>(
161           type_convert<F, T>(val.value()));
162       return torch::executor::optional<T>(convert_struct->call());
163     } else {
164       return torch::executor::optional<T>();
165     }
166   }
167 };
168 
169 // Optionals: ETen to ATen.
170 template <class F, class T>
171 struct type_convert<torch::executor::optional<F>, std::optional<T>> final {
172  public:
173   torch::executor::optional<F> val;
174   std::unique_ptr<struct type_convert<F, T>> convert_struct;
175   explicit type_convert(torch::executor::optional<F> value) : val(value) {}
176   std::optional<T> call() {
177     if (val.has_value()) {
178       convert_struct = std::make_unique<struct type_convert<F, T>>(
179           type_convert<F, T>(val.value()));
180       return std::optional<T>(convert_struct->call());
181     } else {
182       return std::optional<T>();
183     }
184   }
185 };
186 
187 // ArrayRefs: ATen to ETen.
188 template <class F, class T>
189 struct type_convert<c10::ArrayRef<F>, torch::executor::ArrayRef<T>> final {
190  public:
191   c10::ArrayRef<F> val;
192   std::vector<T> converted;
193   std::vector<type_convert<F, T>> converters;
194   explicit type_convert(c10::ArrayRef<F> value) : val(value) {
195     for (int i = 0; i < val.size(); i++) {
196       converters.push_back(type_convert<F, T>(val[i]));
197     }
198   }
199   torch::executor::ArrayRef<T> call() {
200     for (int i = 0; i < val.size(); i++) {
201       converted.push_back(converters[i].call());
202     }
203     return torch::executor::ArrayRef<T>(converted.data(), converted.size());
204   }
205 };
206 
207 // ArrayRefs: ETen to ATen.
208 template <class F, class T>
209 struct type_convert<torch::executor::ArrayRef<F>, c10::ArrayRef<T>> final {
210  public:
211   torch::executor::ArrayRef<F> val;
212   std::vector<T> converted;
213   std::vector<type_convert<F, T>> converters;
214   explicit type_convert(torch::executor::ArrayRef<F> value) : val(value) {
215     for (int i = 0; i < val.size(); i++) {
216       converters.push_back(type_convert<F, T>(val[i]));
217     }
218   }
219   c10::ArrayRef<T> call() {
220     for (int i = 0; i < val.size(); i++) {
221       converted.push_back(converters[i].call());
222     }
223     return c10::ArrayRef<T>(converted);
224   }
225 };
226 
227 template <class F, F f, typename N = int, N index = N(-1)>
228 struct wrapper_impl;
229 
230 template <class R, class... Args, R (*f)(Args...), int N>
231 struct wrapper_impl<R (*)(Args...), f, int, N> {
232   static_assert(
233       !(std::is_same<R, at::Tensor&>::value && N == -1),
234       "Can't wrap a kernel with 'Tensor &' return type without specifying an index to the out tensor");
235   using ReturnType = typename type_map<R>::type;
236   using TupleConvertsType =
237       std::tuple<type_convert<typename type_map<Args>::type, Args>...>;
238   using TupleArgsType = std::tuple<typename type_map<Args>::type...>;
239   static constexpr size_t num_args = sizeof...(Args);
240   static_assert(
241       (N < num_args &&
242        std::is_same_v<
243            executorch::extension::kernel_util_internal::element_t<
244                N,
245                executorch::extension::kernel_util_internal::typelist<Args...>>,
246            R>) ||
247           N == -1,
248       "The index of the out tensor can't be greater or equal to num_args and "
249       "the Nth argument type has to be the same as the return type.");
250 
251   static ReturnType wrap(typename type_map<Args>::type... args) {
252     // The wrapped function that takes ATen argument types, convert them into
253     // ExecuTorch equivalent, call `f` then return the result converted back to
254     // ATen.
255     TupleArgsType args_tuple = std::forward_as_tuple(args...);
256     TupleConvertsType converts = std::forward_as_tuple(
257         type_convert<typename type_map<Args>::type, Args>(args)...);
258     R result =
259         call_functor_with_args(converts, std::make_index_sequence<num_args>());
260     typename std::remove_reference<ReturnType>::type converted_result =
261         type_convert<R, ReturnType>(result).call();
262     if constexpr (N == -1) {
263       return converted_result;
264     } else {
265       static_assert(
266           std::is_same_v<
267               typename std::remove_reference<ReturnType>::type,
268               at::Tensor>,
269           "Only support at::Tensor-like return");
270       ReturnType out = std::get<N>(args_tuple);
271       at::native::resize_output(out, converted_result.sizes());
272       out.copy_(converted_result);
273       return out;
274     }
275   }
276 
277  private:
278   template <size_t... indices>
279   static R call_functor_with_args(
280       TupleConvertsType& converts,
281       std::index_sequence<indices...>) {
282     return f(std::get<indices>(converts).call()...);
283   }
284 };
285 
286 } // namespace internal
287 } // namespace extension
288 } // namespace executorch
289 
290 // Wrapper macro for out variant function. N is the index of the out tensor.
291 // We need N to know how to preserve the semantics of modifying out tensor and
292 // return the reference without allocating a new memory buffer for out tensor.
293 #define _WRAP_2(func, N)              \
294   ::executorch::extension::internal:: \
295       wrapper_impl<decltype(&func), func, decltype(N), N>::wrap
296 #define _WRAP_1(func) \
297   ::executorch::extension::internal::wrapper_impl<decltype(&func), func>::wrap
298 
299 #define _GET_MACRO(_1, _2, NAME, ...) NAME
300 #define WRAP_TO_ATEN(...) _GET_MACRO(__VA_ARGS__, _WRAP_2, _WRAP_1)(__VA_ARGS__)
301