xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesHelper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 #pragma once
7 
8 #include <c10/util/TypeList.h>
9 
10 #include <ATen/ATen.h>
11 #include <ATen/Operators.h>
12 
13 #include <ATen/functorch/DynamicLayer.h>
14 #include <ATen/functorch/TensorWrapper.h>
15 #include <ATen/functorch/BatchingMetaprogramming.h>
16 #include <ATen/functorch/LegacyVmapTransforms.h>
17 #include <ATen/functorch/BatchedFallback.h>
18 #include <ATen/functorch/PlumbingHelper.h>
19 #include <ATen/core/dispatch/Dispatcher.h>
20 #include <ATen/VmapGeneratedPlumbing.h>
21 
22 #include <utility>
23 
24 // This file contains helper functions for batching rules.
25 
26 namespace at::functorch {
27 
28 TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
29 TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
30 
31 TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x);
32 
33 Tensor moveBatchDimToFront(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
34 int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
35 int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
36 std::optional<int64_t> valIfNonempty(std::optional<int64_t> maybe_empty, int64_t new_val);
37 int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
38 VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
39 
40 void vmapIncompatibleInplaceError(const char* schema_name);
41 
42 Tensor maybePadToLogicalRank(const Tensor& tensor, std::optional<int64_t> has_bdim, int64_t logical_rank);
43 
44 void check_randomness(RandomnessType randomness);
45 void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
46 
ensure_has_bdim(const Tensor & tensor,bool has_bdim,c10::SymInt batch_size)47 inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) {
48   if (has_bdim) {
49     return tensor;
50   }
51   const auto sizes = tensor.sym_sizes();
52   SymDimVector expanded_shape;
53   expanded_shape.reserve(sizes.size());
54   expanded_shape.emplace_back(std::move(batch_size));
55   expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
56   return tensor.expand_symint(expanded_shape);
57 }
58 
59 #define VMAP_SUPPORT(op, batch_rule) \
60   m.impl(#op, op ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
61 
62 #define VMAP_SUPPORT2(op, overload, batch_rule) \
63   m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
64 
65 #define OP_DECOMPOSE(op)  m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
66 #define OP_DECOMPOSE2(op, overload)  m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
67 
68 // DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain
69 template <typename A, A a, typename C>
70 struct BasicUnaryBatchRuleHelper;
71 
72 template <typename F, F Func, typename A, typename... T>
73 struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
74   static std::tuple<Tensor, std::optional<int64_t>> apply(
75       const Tensor& tensor,
76       std::optional<int64_t> batch_dim,
77       T... extra_args) {
78     return std::make_tuple(Func(tensor, std::forward<T>(extra_args)...), batch_dim);
79   }
80 };
81 
82 // USAGE: BASIC_UNARY_BATCH_RULE(at::sin)
83 // INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin)
84 // It is important that this macro is not passed a function pointer!!
85 #define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\
86     BasicUnaryBatchRuleHelper<\
87       decltype(&fn),\
88       &fn,\
89       c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
90 
91 #define UNARY_POINTWISE(op) \
92   VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
93 
94 template <typename A, A a, typename C>
95 struct VariadicBdimsBatchRuleHelper;
96 
97 template <typename F, F Func, typename A, typename... T>
98 struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
99   static std::tuple<Tensor, std::optional<int64_t>> apply(
100       const Tensor& tensor,
101       std::optional<int64_t> batch_dim,
102       T... extra_args) {
103     auto tensor_ = moveBatchDimToFront(tensor, batch_dim);
104     return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0);
105   }
106 };
107 
108 // USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse)
109 // INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse)
110 // It is important that this macro is not passed a function pointer!!
111 #define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\
112     VariadicBdimsBatchRuleHelper<\
113       decltype(&fn),\
114       &fn,\
115       c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
116 
117 #define VARIADIC_BDIMS(op) \
118   VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op)));
119 
120 #define VARIADIC_BDIMS2(op, overload) \
121   VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload)));
122 
123 template<class F, F Func>
124 void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
125   const auto& schema = op.schema();
126   const auto num_returns = schema.returns().size();
127   const auto num_arguments = schema.arguments().size();
128 
129   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
130   auto maybe_layer = maybeCurrentDynamicLayer();
131   vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
132 
133   int64_t cur_level = maybe_layer->layerId();
134 
135   auto orig_arguments = torch::jit::last(*stack, num_arguments);
136   if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
137     op.callBoxed(stack);
138     return;
139   }
140 
141   auto arguments = torch::jit::pop(*stack, num_arguments);
142   std::vector<std::pair<Tensor, std::optional<int64_t>>> tensor_inputs;
143   std::vector<int64_t> tensor_pos;
144   for (const auto idx : c10::irange(0, num_arguments)) {
145     const auto& ivalue = arguments[idx];
146     if (ivalue.isTensor()) {
147       auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
148       tensor_inputs.emplace_back(tensor_value, tensor_bdim);
149       tensor_pos.push_back(static_cast<int64_t>(idx));
150     }
151   }
152   Func(tensor_inputs);
153 
154   size_t tensor_idx = 0;
155   TORCH_INTERNAL_ASSERT(!tensor_pos.empty());
156   for (const auto arg_idx : c10::irange(0, num_arguments)) {
157     if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
158       torch::jit::push(stack, arguments[arg_idx]);
159     } else {
160       TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
161       torch::jit::push(stack, tensor_inputs[tensor_idx].first);
162       tensor_idx++;
163     }
164   }
165 
166   op.callBoxed(stack);
167   const auto returns = torch::jit::pop(*stack, num_returns);
168   for (const auto& ret : returns) {
169     if (ret.isTensor()) {
170       torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
171     } else {
172       TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
173     }
174   }
175 }
176 
177 inline void handle_pointwise_ops(std::vector<std::pair<Tensor, std::optional<int64_t>>> &tensor_inputs) {
178   int64_t out_logical_rank = 0;
179   for (auto& tensor_input : tensor_inputs) {
180     int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second);
181     out_logical_rank = std::max(out_logical_rank, cur_logical_rank);
182   }
183   for (auto& tensor_input: tensor_inputs) {
184     tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
185     tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank);
186   }
187 }
188 
189 #define POINTWISE_BOXED(op) \
190   m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
191 
192 #define POINTWISE_BOXED2(op, overload) \
193   m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
194 
195 inline void handle_variadic_bdims(std::vector<std::pair<Tensor, std::optional<int64_t>>> &tensor_inputs) {
196   for (auto & tensor_input : tensor_inputs) {
197     tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
198   }
199 }
200 
201 #define VARIADIC_BDIMS_BOXED(op) \
202   m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
203 
204 using UnpackedBatchedTensor = std::tuple<Tensor, std::optional<int64_t>>;
205 
206 inline void find_and_unpack_tensors(
207     const torch::jit::Stack* stack,
208     int64_t num_args,
209     int64_t cur_level,
210     SmallVector<UnpackedBatchedTensor, 5>* tensors,
211     SmallVector<int64_t, 5>* tensors_pos,
212     int64_t* batch_size) {
213 
214   int64_t computed_batch_size = -1;
215   int64_t args_begin = static_cast<int64_t>(stack->size()) - num_args;
216 
217   for (const auto idx : c10::irange(0, num_args)) {
218     const auto& ivalue = (*stack)[args_begin + idx];
219     if (!ivalue.isTensor()) {
220       continue;
221     }
222     auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
223     const auto& tensor_value = std::get<0>(unpacked);
224     const auto tensor_bdim = std::get<1>(unpacked);
225     if (tensor_bdim.has_value()) {
226       auto candidate_batch_size = tensor_value.size(*tensor_bdim);
227       if (computed_batch_size == -1) {
228         computed_batch_size = candidate_batch_size;
229       }
230       TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
231     }
232 
233     tensors->push_back(std::move(unpacked));
234     tensors_pos->push_back(idx);
235   }
236   TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
237   *batch_size = computed_batch_size;
238 }
239 
240 inline void boxed_existing_bdim_all_batch_rule(
241     const c10::OperatorHandle& op, torch::jit::Stack* stack) {
242   const auto& schema = op.schema();
243   const auto num_returns = schema.returns().size();
244   const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
245 
246   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
247   auto maybe_layer = maybeCurrentDynamicLayer();
248   vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule");
249   int64_t cur_level = maybe_layer->layerId();
250 
251   const auto arguments = torch::jit::last(stack, num_arguments);
252   if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
253     op.callBoxed(stack);
254     return;
255   }
256 
257   int64_t args_begin = static_cast<int64_t>(stack->size()) - num_arguments;
258   SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
259   SmallVector<int64_t, 5> tensor_pos;
260   int64_t batch_size = 0;
261 
262   find_and_unpack_tensors(
263       stack, num_arguments, cur_level,
264       &tensor_inputs, &tensor_pos, &batch_size);
265 
266   // for each tensor, ensure it has a bdim and reshape it.
267   for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
268     const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
269     auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
270     auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
271     if (!bdim.has_value()) {
272       bdim = 0;
273     }
274     (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
275   }
276 
277   op.callBoxed(stack);
278 
279   for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
280     const auto& ret = (*stack)[idx];
281     TORCH_INTERNAL_ASSERT(ret.isTensor(),
282         "This boxed batching rule does not currently support ops that return non-tensor values");
283     (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
284   }
285 }
286 
287 // Use when all tensors arguments accept one (normal) batch dim.
288 // This batching rule expands the batch dim on all Tensors, reshapes it into
289 // dim 0, calls the op, and then reshapes the batch dim out of dim 0.
290 // This is not the most efficient thing; if there are alternatives, plese try
291 // to use them. Use this only as a last resort.
292 #define EXISTING_BDIM_ALL_BOXED(op) \
293   m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
294 
295 template <int64_t feature_rank, int64_t contig_tensor_index=-1>
296 inline void boxed_all_tensors_have_optional_bdim(
297     const c10::OperatorHandle& op, torch::jit::Stack* stack) {
298   const auto& schema = op.schema();
299   const auto num_returns = schema.returns().size();
300   const auto num_arguments = schema.arguments().size();
301 
302   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
303   auto maybe_layer = maybeCurrentDynamicLayer();
304   vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
305   int64_t cur_level = maybe_layer->layerId();
306 
307   const auto arguments = torch::jit::last(stack, num_arguments);
308   if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
309     op.callBoxed(stack);
310     return;
311   }
312 
313   int64_t args_begin = static_cast<int64_t>(stack->size() - num_arguments);
314   SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
315   SmallVector<int64_t, 5> tensor_pos;
316   int64_t batch_size = 0;
317 
318   find_and_unpack_tensors(
319       stack, static_cast<int64_t>(num_arguments), cur_level,
320       &tensor_inputs, &tensor_pos, &batch_size);
321 
322   std::optional<bool> is_no_batch_dim_case;
323 
324   for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
325     const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
326     auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
327     const auto logical_rank = rankWithoutBatchDim(value, bdim);
328 
329     if (!is_no_batch_dim_case.has_value()) {
330       is_no_batch_dim_case = (logical_rank == feature_rank);
331     }
332     auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
333     if (!bdim.has_value()) {
334       bdim = 0;
335     }
336     if (*is_no_batch_dim_case) {
337       TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
338       value_ = moveBatchDimToFront(value_, bdim);
339       if (tensor_idx == contig_tensor_index) {
340         value_ = value_.contiguous();
341       }
342       (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
343       continue;
344     }
345     TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
346     value_ = reshape_dim_into(*bdim, 0, value_);
347     if (tensor_idx == contig_tensor_index) {
348       value_ = value_.contiguous();
349     }
350     (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
351   }
352 
353   op.callBoxed(stack);
354 
355   for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
356     const auto& ret = (*stack)[idx];
357     TORCH_INTERNAL_ASSERT(ret.isTensor(),
358         "This boxed batching rule does not currently support ops that return non-tensor values");
359     if (*is_no_batch_dim_case) {
360       (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
361     } else {
362       (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
363     }
364   }
365 }
366 
367 // Useful for many NN operators.
368 // The operator must satisfy the following:
369 // - All arguments must accept an optional batch dim.
370 // - All arguments must be the same rank
371 #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
372   m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());
373 
374 #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
375   m.impl(#op, \
376          torch::CppFunction::makeFromBoxedFunction<\
377              boxed_all_tensors_have_optional_bdim<\
378                  feature_rank, \
379                  contig_tensor_index>\
380              >());
381 
382 template <typename A, A a, typename C>
383 struct ExistingBdimBatchRuleHelper;
384 
385 template <typename F, F Func, typename A, typename... T>
386 struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
387   static std::tuple<Tensor, std::optional<int64_t>> apply(
388       const Tensor& self,
389       std::optional<int64_t> self_bdim,
390       T... extra_args) {
391     auto self_ = reshape_dim_into(*self_bdim, 0, self);
392     auto out = Func(self_, std::forward<T>(extra_args)...);
393     return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0);
394   }
395 };
396 
397 // USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse)
398 // INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse)
399 // It is important that this macro is not passed a function pointer!!
400 #define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\
401     ExistingBdimBatchRuleHelper<\
402       decltype(&fn),\
403       &fn,\
404       c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
405 
406 
407 #define EXISTING_BDIM(op) \
408   VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op)));
409 
410 #define EXISTING_BDIM2(op, overload) \
411   VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload)));
412 
413 #define INVOKE(object,ptrToMember)  ((object).*(ptrToMember))
414 
415 
416 template <typename F, F Method, typename... ExtraArgs>
417 Tensor& unary_inplace_batch_rule(Tensor& self, std::optional<int64_t>, ExtraArgs... extra_args) {
418   INVOKE(self, Method)(std::forward<ExtraArgs>(extra_args)...);
419   return self;
420 }
421 
422 inline int64_t get_bdim_size4(
423     const Tensor& a_value, std::optional<int64_t> a_bdim,
424     const Tensor& b_value, std::optional<int64_t> b_bdim,
425     const Tensor& c_value, std::optional<int64_t> c_bdim,
426     const Tensor& d_value, std::optional<int64_t> d_bdim) {
427   if (a_bdim)
428     return a_value.size(*a_bdim);
429   if (b_bdim)
430     return b_value.size(*b_bdim);
431   if (c_bdim)
432     return c_value.size(*c_bdim);
433   if (d_bdim)
434     return d_value.size(*d_bdim);
435   TORCH_INTERNAL_ASSERT(false);
436 }
437 
438 inline int64_t get_bdim_size3(
439     const Tensor& a_value, std::optional<int64_t> a_bdim,
440     const Tensor& b_value, std::optional<int64_t> b_bdim,
441     const Tensor& c_value, std::optional<int64_t> c_bdim) {
442   if (a_bdim)
443     return a_value.size(*a_bdim);
444   if (b_bdim)
445     return b_value.size(*b_bdim);
446   if (c_bdim)
447     return c_value.size(*c_bdim);
448   TORCH_INTERNAL_ASSERT(false);
449 }
450 
451 inline int64_t get_bdim_size2(
452     const Tensor& a_value, std::optional<int64_t> a_bdim,
453     const Tensor& b_value, std::optional<int64_t> b_bdim) {
454   if (a_bdim)
455     return a_value.size(*a_bdim);
456   if (b_bdim)
457     return b_value.size(*b_bdim);
458   TORCH_INTERNAL_ASSERT(false);
459 }
460 
461 // [start, start + 1, ..., stop - 1]
462 inline VmapDimVector range(int64_t start, int64_t stop) {
463   TORCH_INTERNAL_ASSERT(stop >= start);
464   VmapDimVector dims;
465   dims.reserve(stop - start);
466   for (int64_t i = start; i < stop; i++) {
467     dims.emplace_back(i);
468   }
469   return dims;
470 }
471 std::tuple<Tensor, Tensor> _binary_pointwise_helper(
472     const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, const Tensor& other, std::optional<int64_t> other_batch_dim,
473     bool do_type_promotion=true);
474 
475 } // namespace at::functorch
476