1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 #pragma once
7
8 #include <c10/util/TypeList.h>
9
10 #include <ATen/ATen.h>
11 #include <ATen/Operators.h>
12
13 #include <ATen/functorch/DynamicLayer.h>
14 #include <ATen/functorch/TensorWrapper.h>
15 #include <ATen/functorch/BatchingMetaprogramming.h>
16 #include <ATen/functorch/LegacyVmapTransforms.h>
17 #include <ATen/functorch/BatchedFallback.h>
18 #include <ATen/functorch/PlumbingHelper.h>
19 #include <ATen/core/dispatch/Dispatcher.h>
20 #include <ATen/VmapGeneratedPlumbing.h>
21
22 #include <utility>
23
24 // This file contains helper functions for batching rules.
25
26 namespace at::functorch {
27
28 TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
29 TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
30
31 TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x);
32
33 Tensor moveBatchDimToFront(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
34 int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
35 int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim);
36 std::optional<int64_t> valIfNonempty(std::optional<int64_t> maybe_empty, int64_t new_val);
37 int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
38 VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
39
40 void vmapIncompatibleInplaceError(const char* schema_name);
41
42 Tensor maybePadToLogicalRank(const Tensor& tensor, std::optional<int64_t> has_bdim, int64_t logical_rank);
43
44 void check_randomness(RandomnessType randomness);
45 void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
46
ensure_has_bdim(const Tensor & tensor,bool has_bdim,c10::SymInt batch_size)47 inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) {
48 if (has_bdim) {
49 return tensor;
50 }
51 const auto sizes = tensor.sym_sizes();
52 SymDimVector expanded_shape;
53 expanded_shape.reserve(sizes.size());
54 expanded_shape.emplace_back(std::move(batch_size));
55 expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
56 return tensor.expand_symint(expanded_shape);
57 }
58
59 #define VMAP_SUPPORT(op, batch_rule) \
60 m.impl(#op, op ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
61
62 #define VMAP_SUPPORT2(op, overload, batch_rule) \
63 m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
64
65 #define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
66 #define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
67
68 // DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain
69 template <typename A, A a, typename C>
70 struct BasicUnaryBatchRuleHelper;
71
72 template <typename F, F Func, typename A, typename... T>
73 struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
74 static std::tuple<Tensor, std::optional<int64_t>> apply(
75 const Tensor& tensor,
76 std::optional<int64_t> batch_dim,
77 T... extra_args) {
78 return std::make_tuple(Func(tensor, std::forward<T>(extra_args)...), batch_dim);
79 }
80 };
81
82 // USAGE: BASIC_UNARY_BATCH_RULE(at::sin)
83 // INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin)
84 // It is important that this macro is not passed a function pointer!!
85 #define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\
86 BasicUnaryBatchRuleHelper<\
87 decltype(&fn),\
88 &fn,\
89 c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
90
91 #define UNARY_POINTWISE(op) \
92 VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
93
94 template <typename A, A a, typename C>
95 struct VariadicBdimsBatchRuleHelper;
96
97 template <typename F, F Func, typename A, typename... T>
98 struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
99 static std::tuple<Tensor, std::optional<int64_t>> apply(
100 const Tensor& tensor,
101 std::optional<int64_t> batch_dim,
102 T... extra_args) {
103 auto tensor_ = moveBatchDimToFront(tensor, batch_dim);
104 return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0);
105 }
106 };
107
108 // USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse)
109 // INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse)
110 // It is important that this macro is not passed a function pointer!!
111 #define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\
112 VariadicBdimsBatchRuleHelper<\
113 decltype(&fn),\
114 &fn,\
115 c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
116
117 #define VARIADIC_BDIMS(op) \
118 VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op)));
119
120 #define VARIADIC_BDIMS2(op, overload) \
121 VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload)));
122
123 template<class F, F Func>
124 void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
125 const auto& schema = op.schema();
126 const auto num_returns = schema.returns().size();
127 const auto num_arguments = schema.arguments().size();
128
129 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
130 auto maybe_layer = maybeCurrentDynamicLayer();
131 vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
132
133 int64_t cur_level = maybe_layer->layerId();
134
135 auto orig_arguments = torch::jit::last(*stack, num_arguments);
136 if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
137 op.callBoxed(stack);
138 return;
139 }
140
141 auto arguments = torch::jit::pop(*stack, num_arguments);
142 std::vector<std::pair<Tensor, std::optional<int64_t>>> tensor_inputs;
143 std::vector<int64_t> tensor_pos;
144 for (const auto idx : c10::irange(0, num_arguments)) {
145 const auto& ivalue = arguments[idx];
146 if (ivalue.isTensor()) {
147 auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
148 tensor_inputs.emplace_back(tensor_value, tensor_bdim);
149 tensor_pos.push_back(static_cast<int64_t>(idx));
150 }
151 }
152 Func(tensor_inputs);
153
154 size_t tensor_idx = 0;
155 TORCH_INTERNAL_ASSERT(!tensor_pos.empty());
156 for (const auto arg_idx : c10::irange(0, num_arguments)) {
157 if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
158 torch::jit::push(stack, arguments[arg_idx]);
159 } else {
160 TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
161 torch::jit::push(stack, tensor_inputs[tensor_idx].first);
162 tensor_idx++;
163 }
164 }
165
166 op.callBoxed(stack);
167 const auto returns = torch::jit::pop(*stack, num_returns);
168 for (const auto& ret : returns) {
169 if (ret.isTensor()) {
170 torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
171 } else {
172 TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
173 }
174 }
175 }
176
177 inline void handle_pointwise_ops(std::vector<std::pair<Tensor, std::optional<int64_t>>> &tensor_inputs) {
178 int64_t out_logical_rank = 0;
179 for (auto& tensor_input : tensor_inputs) {
180 int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second);
181 out_logical_rank = std::max(out_logical_rank, cur_logical_rank);
182 }
183 for (auto& tensor_input: tensor_inputs) {
184 tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
185 tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank);
186 }
187 }
188
189 #define POINTWISE_BOXED(op) \
190 m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
191
192 #define POINTWISE_BOXED2(op, overload) \
193 m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
194
195 inline void handle_variadic_bdims(std::vector<std::pair<Tensor, std::optional<int64_t>>> &tensor_inputs) {
196 for (auto & tensor_input : tensor_inputs) {
197 tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
198 }
199 }
200
201 #define VARIADIC_BDIMS_BOXED(op) \
202 m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
203
204 using UnpackedBatchedTensor = std::tuple<Tensor, std::optional<int64_t>>;
205
206 inline void find_and_unpack_tensors(
207 const torch::jit::Stack* stack,
208 int64_t num_args,
209 int64_t cur_level,
210 SmallVector<UnpackedBatchedTensor, 5>* tensors,
211 SmallVector<int64_t, 5>* tensors_pos,
212 int64_t* batch_size) {
213
214 int64_t computed_batch_size = -1;
215 int64_t args_begin = static_cast<int64_t>(stack->size()) - num_args;
216
217 for (const auto idx : c10::irange(0, num_args)) {
218 const auto& ivalue = (*stack)[args_begin + idx];
219 if (!ivalue.isTensor()) {
220 continue;
221 }
222 auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
223 const auto& tensor_value = std::get<0>(unpacked);
224 const auto tensor_bdim = std::get<1>(unpacked);
225 if (tensor_bdim.has_value()) {
226 auto candidate_batch_size = tensor_value.size(*tensor_bdim);
227 if (computed_batch_size == -1) {
228 computed_batch_size = candidate_batch_size;
229 }
230 TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
231 }
232
233 tensors->push_back(std::move(unpacked));
234 tensors_pos->push_back(idx);
235 }
236 TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
237 *batch_size = computed_batch_size;
238 }
239
240 inline void boxed_existing_bdim_all_batch_rule(
241 const c10::OperatorHandle& op, torch::jit::Stack* stack) {
242 const auto& schema = op.schema();
243 const auto num_returns = schema.returns().size();
244 const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
245
246 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
247 auto maybe_layer = maybeCurrentDynamicLayer();
248 vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule");
249 int64_t cur_level = maybe_layer->layerId();
250
251 const auto arguments = torch::jit::last(stack, num_arguments);
252 if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
253 op.callBoxed(stack);
254 return;
255 }
256
257 int64_t args_begin = static_cast<int64_t>(stack->size()) - num_arguments;
258 SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
259 SmallVector<int64_t, 5> tensor_pos;
260 int64_t batch_size = 0;
261
262 find_and_unpack_tensors(
263 stack, num_arguments, cur_level,
264 &tensor_inputs, &tensor_pos, &batch_size);
265
266 // for each tensor, ensure it has a bdim and reshape it.
267 for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
268 const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
269 auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
270 auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
271 if (!bdim.has_value()) {
272 bdim = 0;
273 }
274 (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
275 }
276
277 op.callBoxed(stack);
278
279 for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
280 const auto& ret = (*stack)[idx];
281 TORCH_INTERNAL_ASSERT(ret.isTensor(),
282 "This boxed batching rule does not currently support ops that return non-tensor values");
283 (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
284 }
285 }
286
287 // Use when all tensors arguments accept one (normal) batch dim.
288 // This batching rule expands the batch dim on all Tensors, reshapes it into
289 // dim 0, calls the op, and then reshapes the batch dim out of dim 0.
290 // This is not the most efficient thing; if there are alternatives, plese try
291 // to use them. Use this only as a last resort.
292 #define EXISTING_BDIM_ALL_BOXED(op) \
293 m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
294
295 template <int64_t feature_rank, int64_t contig_tensor_index=-1>
296 inline void boxed_all_tensors_have_optional_bdim(
297 const c10::OperatorHandle& op, torch::jit::Stack* stack) {
298 const auto& schema = op.schema();
299 const auto num_returns = schema.returns().size();
300 const auto num_arguments = schema.arguments().size();
301
302 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
303 auto maybe_layer = maybeCurrentDynamicLayer();
304 vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
305 int64_t cur_level = maybe_layer->layerId();
306
307 const auto arguments = torch::jit::last(stack, num_arguments);
308 if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
309 op.callBoxed(stack);
310 return;
311 }
312
313 int64_t args_begin = static_cast<int64_t>(stack->size() - num_arguments);
314 SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
315 SmallVector<int64_t, 5> tensor_pos;
316 int64_t batch_size = 0;
317
318 find_and_unpack_tensors(
319 stack, static_cast<int64_t>(num_arguments), cur_level,
320 &tensor_inputs, &tensor_pos, &batch_size);
321
322 std::optional<bool> is_no_batch_dim_case;
323
324 for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
325 const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
326 auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
327 const auto logical_rank = rankWithoutBatchDim(value, bdim);
328
329 if (!is_no_batch_dim_case.has_value()) {
330 is_no_batch_dim_case = (logical_rank == feature_rank);
331 }
332 auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
333 if (!bdim.has_value()) {
334 bdim = 0;
335 }
336 if (*is_no_batch_dim_case) {
337 TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
338 value_ = moveBatchDimToFront(value_, bdim);
339 if (tensor_idx == contig_tensor_index) {
340 value_ = value_.contiguous();
341 }
342 (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
343 continue;
344 }
345 TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
346 value_ = reshape_dim_into(*bdim, 0, value_);
347 if (tensor_idx == contig_tensor_index) {
348 value_ = value_.contiguous();
349 }
350 (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
351 }
352
353 op.callBoxed(stack);
354
355 for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
356 const auto& ret = (*stack)[idx];
357 TORCH_INTERNAL_ASSERT(ret.isTensor(),
358 "This boxed batching rule does not currently support ops that return non-tensor values");
359 if (*is_no_batch_dim_case) {
360 (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
361 } else {
362 (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
363 }
364 }
365 }
366
367 // Useful for many NN operators.
368 // The operator must satisfy the following:
369 // - All arguments must accept an optional batch dim.
370 // - All arguments must be the same rank
371 #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
372 m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());
373
374 #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
375 m.impl(#op, \
376 torch::CppFunction::makeFromBoxedFunction<\
377 boxed_all_tensors_have_optional_bdim<\
378 feature_rank, \
379 contig_tensor_index>\
380 >());
381
382 template <typename A, A a, typename C>
383 struct ExistingBdimBatchRuleHelper;
384
385 template <typename F, F Func, typename A, typename... T>
386 struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
387 static std::tuple<Tensor, std::optional<int64_t>> apply(
388 const Tensor& self,
389 std::optional<int64_t> self_bdim,
390 T... extra_args) {
391 auto self_ = reshape_dim_into(*self_bdim, 0, self);
392 auto out = Func(self_, std::forward<T>(extra_args)...);
393 return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0);
394 }
395 };
396
397 // USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse)
398 // INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse)
399 // It is important that this macro is not passed a function pointer!!
400 #define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\
401 ExistingBdimBatchRuleHelper<\
402 decltype(&fn),\
403 &fn,\
404 c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
405
406
407 #define EXISTING_BDIM(op) \
408 VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op)));
409
410 #define EXISTING_BDIM2(op, overload) \
411 VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload)));
412
413 #define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
414
415
416 template <typename F, F Method, typename... ExtraArgs>
417 Tensor& unary_inplace_batch_rule(Tensor& self, std::optional<int64_t>, ExtraArgs... extra_args) {
418 INVOKE(self, Method)(std::forward<ExtraArgs>(extra_args)...);
419 return self;
420 }
421
422 inline int64_t get_bdim_size4(
423 const Tensor& a_value, std::optional<int64_t> a_bdim,
424 const Tensor& b_value, std::optional<int64_t> b_bdim,
425 const Tensor& c_value, std::optional<int64_t> c_bdim,
426 const Tensor& d_value, std::optional<int64_t> d_bdim) {
427 if (a_bdim)
428 return a_value.size(*a_bdim);
429 if (b_bdim)
430 return b_value.size(*b_bdim);
431 if (c_bdim)
432 return c_value.size(*c_bdim);
433 if (d_bdim)
434 return d_value.size(*d_bdim);
435 TORCH_INTERNAL_ASSERT(false);
436 }
437
438 inline int64_t get_bdim_size3(
439 const Tensor& a_value, std::optional<int64_t> a_bdim,
440 const Tensor& b_value, std::optional<int64_t> b_bdim,
441 const Tensor& c_value, std::optional<int64_t> c_bdim) {
442 if (a_bdim)
443 return a_value.size(*a_bdim);
444 if (b_bdim)
445 return b_value.size(*b_bdim);
446 if (c_bdim)
447 return c_value.size(*c_bdim);
448 TORCH_INTERNAL_ASSERT(false);
449 }
450
451 inline int64_t get_bdim_size2(
452 const Tensor& a_value, std::optional<int64_t> a_bdim,
453 const Tensor& b_value, std::optional<int64_t> b_bdim) {
454 if (a_bdim)
455 return a_value.size(*a_bdim);
456 if (b_bdim)
457 return b_value.size(*b_bdim);
458 TORCH_INTERNAL_ASSERT(false);
459 }
460
461 // [start, start + 1, ..., stop - 1]
462 inline VmapDimVector range(int64_t start, int64_t stop) {
463 TORCH_INTERNAL_ASSERT(stop >= start);
464 VmapDimVector dims;
465 dims.reserve(stop - start);
466 for (int64_t i = start; i < stop; i++) {
467 dims.emplace_back(i);
468 }
469 return dims;
470 }
471 std::tuple<Tensor, Tensor> _binary_pointwise_helper(
472 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, const Tensor& other, std::optional<int64_t> other_batch_dim,
473 bool do_type_promotion=true);
474
475 } // namespace at::functorch
476