xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesReduceOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/Operators.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 
12 #include <utility>
13 
14 namespace at::functorch {
15 
is_allowed_dim_on_scalar_tensor(int64_t dim)16 static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
17   return dim == 0 || dim == -1;
18 }
19 
sum_decomp(const Tensor & self,std::optional<ScalarType> dtype)20 static Tensor sum_decomp(
21     const Tensor& self, std::optional<ScalarType> dtype) {
22   return at::sum(self, range(0, self.dim()), false, dtype);
23 }
24 
_is_all_true_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim)25 static std::tuple<Tensor, std::optional<int64_t>> _is_all_true_batch_rule(
26     const Tensor& self, std::optional<int64_t> self_bdim) {
27   return std::make_tuple(at::_is_all_true(self), std::nullopt);
28 }
29 
_is_any_true_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim)30 static std::tuple<Tensor, std::optional<int64_t>> _is_any_true_batch_rule(
31      const Tensor& self, std::optional<int64_t> self_bdim) {
32    return std::make_tuple(at::_is_any_true(self), std::nullopt);
33  }
34 
mean_decomp(const Tensor & self,std::optional<ScalarType> dtype)35 static Tensor mean_decomp(
36     const Tensor& self, std::optional<ScalarType> dtype) {
37   return at::mean(self, range(0, self.dim()), false, dtype);
38 }
39 
prod_decomp(const Tensor & self,std::optional<ScalarType> dtype)40 static Tensor prod_decomp(
41     const Tensor& self, std::optional<ScalarType> dtype) {
42   return at::prod(self.flatten(), 0, false, dtype);
43 }
44 
max_decomp(const Tensor & self)45 static Tensor max_decomp(
46     const Tensor& self) {
47   return std::get<0>(at::max(self.flatten(), 0, false));
48 }
49 
min_decomp(const Tensor & self)50 static Tensor min_decomp(
51     const Tensor& self) {
52   return std::get<0>(at::min(self.flatten(), 0, false));
53 }
54 
norm_scalar_decomp(const Tensor & self,const Scalar & p)55 static Tensor norm_scalar_decomp(
56     const Tensor& self, const Scalar& p) {
57   return at::norm(self, p, range(0, self.dim()), false);
58 }
59 
nanmedian_decomp(const Tensor & self)60 static Tensor nanmedian_decomp(
61     const Tensor& self) {
62   return std::get<0>(at::nanmedian(self.flatten(), 0, false));
63 }
64 
median_decomp(const Tensor & self)65 static Tensor median_decomp(
66     const Tensor& self) {
67   return std::get<0>(at::median(self.flatten(), 0, false));
68 }
69 
all_decomp(const Tensor & self)70 static Tensor all_decomp(const Tensor& self) {
71   return at::all(self.flatten(), 0, false);
72 }
73 
any_decomp(const Tensor & self)74 static Tensor any_decomp(const Tensor& self) {
75   return at::any(self.flatten(), 0, false);
76 }
77 
78 enum class ReductionCase:uint8_t { DimArray, Dim };
79 
80 // Macros and templates have a difficult time dealing with enums,
81 // so we didn't turn this into an enum.
82 // See NOTE: [keepdim cases] for explanation of what these are.
83 static constexpr int KEEPDIM_CASE_FALSE = 0;
84 static constexpr int KEEPDIM_CASE_TRUE = 1;
85 static constexpr int KEEPDIM_CASE_VARIABLE = 2;
86 
87 // dim_arg_pos allows us to specify the location of the dim/dim array argument.
88 // For most PyTorch ops, this is equal to 1.
89 //
90 // NOTE: [keepdim cases]
91 // The operator in question either:
92 // - has a keepdim argument (KeepdimCase.Variable)
93 //   In this case, `maybe_keepdim_arg_pos` says where the index of the keepdim arg is.
94 //   example: sum(tensor, dim, keepdim)
95 // - always does a reduction with no keepdim (KeepdimCase.False)
96 //   that is, the rank of the output tensor is less than the rank of the input tensor.
97 // - always does a reduction with keepdim=True semantics (KeepdimCase.True)
98 //   That is, the rank of the output tensor is always the same as that of the input.
99 //   examples: log_softmax(tensor, dim), cumsum(tensor, dim)
100 template<
101   int dim_arg_pos,
102   int keepdim_case,
103   // optional cannot be used in a template, otherwise we would use it here.
104   int maybe_keepdim_arg_pos
105 >
boxed_reduction_batch_rule(const c10::OperatorHandle & op,torch::jit::Stack * stack)106 void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
107   const auto& schema = op.schema();
108   const auto num_returns = schema.returns().size();
109   const auto num_arguments = schema.arguments().size();
110 
111   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
112   auto maybe_layer = maybeCurrentDynamicLayer();
113   vmap_check_escaped(maybe_layer, "boxed_reduction_batch_rule");
114   int64_t cur_level = maybe_layer->layerId();
115 
116   auto orig_arguments = torch::jit::last(*stack, num_arguments);
117   if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
118     c10::impl::ExcludeDispatchKeyGuard guard_2(DispatchKey::FuncTorchBatched);
119     op.callBoxed(stack);
120     return;
121   }
122 
123   auto arguments = torch::jit::pop(*stack, num_arguments);
124 
125   TORCH_INTERNAL_ASSERT(arguments[0].isTensor());
126   auto [self, self_bdim] = unwrapTensorAtLevel(arguments[0].toTensor(), cur_level);
127 
128   self = moveBatchDimToFront(self, self_bdim);
129 
130   auto logical_dim = rankWithoutBatchDim(self, self_bdim);
131   std::vector<int64_t> dims;
132   ReductionCase reduction_case{};
133   if (arguments[dim_arg_pos].isIntList()) {
134     reduction_case = ReductionCase::DimArray;
135     dims = arguments[dim_arg_pos].toIntList().vec();
136     if (dims.empty()) {
137       auto all_dims = range(0, std::max((int64_t)1, logical_dim));
138       dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
139     }
140   } else if (arguments[dim_arg_pos].isInt()) {
141     reduction_case = ReductionCase::Dim;
142     dims = {arguments[dim_arg_pos].toInt()};
143   } else if (arguments[dim_arg_pos].isNone())  {
144     auto param_type = schema.arguments()[dim_arg_pos].type()->expect<OptionalType>()->getElementType();
145     if (param_type->kind() == IntType::Kind) {
146       reduction_case = ReductionCase::Dim;
147       if (self.dim() > 1) {
148         self = self.flatten(1);
149       }
150       dims = {0};
151     } else if (param_type->kind() == ListType::Kind) {
152       reduction_case = ReductionCase::DimArray;
153       if (logical_dim == 0) {
154         dims = {0};
155       } else {
156         auto all_dims = range(0, self.dim() - 1);
157         dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
158       }
159     } else {
160       TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims");
161     }
162   } else{
163     TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims");
164   }
165 
166   VmapDimVector new_dims;
167   new_dims.reserve(dims.size());
168   for (auto dim: dims) {
169     new_dims.push_back(getPhysicalDim(self, self_bdim.has_value(), dim));
170   }
171   bool is_scalar_case = logical_dim == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0]);
172   std::optional<bool> maybe_keepdim;
173   if (is_scalar_case) {
174     // NOTE: [boxed_reduction_batch_rule scalar tensor handling]
175     // Reduction operations in PyTorch have an edge case where they allow
176     // dim=0 and dim=-1 if the tensor has shape [].
177     //
178     // This can come up if we do something like
179     // vmap(lambda x: x.sum(0))(torch.tensor([10.])),
180     //
181     // In order to handle this edge case, we unsqueeze a dimension on the Tensor,
182     // run the operation (with dim=1 instead), and then process the output tensor.
183     // There are two cases:
184     // - keepdim = True
185     //     unsqueeze   op      squeeze
186     //   [B] -> [B, 1] -> [B, 1] -> [B]
187     // - keepdim = False
188     //     unsqueeze   op     no need to squeeze
189     //   [B] -> [B, 1] -> [B]
190     // if keepdim is True, then we need to squeeze the dimension of size 1.
191 
192     // Determine the value of keepdim
193     switch (keepdim_case) {
194       case KEEPDIM_CASE_FALSE:
195         maybe_keepdim = false;
196         break;
197       case KEEPDIM_CASE_TRUE:
198         maybe_keepdim = true;
199         break;
200       case KEEPDIM_CASE_VARIABLE:
201         TORCH_INTERNAL_ASSERT(maybe_keepdim_arg_pos >= 0);
202         maybe_keepdim = arguments[maybe_keepdim_arg_pos].toBool();
203         break;
204     }
205     self = self.unsqueeze(-1);
206     new_dims = {1};
207   }
208   arguments[0] = std::move(self);
209   if (reduction_case == ReductionCase::DimArray) {
210     arguments[dim_arg_pos] = std::vector<int64_t>(new_dims.begin(), new_dims.end());
211   } else if (reduction_case == ReductionCase::Dim) {
212     arguments[dim_arg_pos] = new_dims[0];
213   }
214   for (const auto arg_idx : c10::irange(0, num_arguments)) {
215     torch::jit::push(stack, arguments[arg_idx]);
216   }
217   op.callBoxed(stack);
218 
219   const auto returns = torch::jit::pop(*stack, num_returns);
220   for (const auto& ret : returns) {
221     if (ret.isTensor()) {
222       auto res = ret.toTensor();
223       // see NOTE: [boxed_reduction_batch_rule scalar tensor handling]
224       if (is_scalar_case && maybe_keepdim.value()) {
225         // squeeze(-1) is a no-op if the shape of the dim is not 1.
226         // To make it safer, we internal assert here.
227         TORCH_INTERNAL_ASSERT(res.size(-1) == 1);
228         res = res.squeeze(-1);
229       }
230       torch::jit::push(stack, makeBatched(res, 0, cur_level));
231     } else {
232       TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
233     }
234   }
235 }
236 
237 // Skipping all/any since they don't have opinfo tests right now :P
238 
dist_decomp(const Tensor & self,const Tensor & other,const Scalar & p)239 static Tensor dist_decomp(const Tensor& self, const Tensor& other, const Scalar& p) {
240   return at::norm((self - other), p);
241 }
242 
expand_bdims(const Tensor & a,bool a_has_bdim,const Tensor & b,bool b_has_bdim)243 static std::tuple<Tensor, Tensor> expand_bdims(
244     const Tensor& a, bool a_has_bdim,
245     const Tensor& b, bool b_has_bdim) {
246   Tensor flagpole;
247   if (a_has_bdim) {
248     flagpole = a;
249   } else if (b_has_bdim) {
250     flagpole = b;
251   } else {
252     TORCH_INTERNAL_ASSERT(false);
253   }
254   return std::make_tuple(
255       a_has_bdim ? a : a.expand_as(flagpole),
256       b_has_bdim ? b : b.expand_as(flagpole));
257 }
258 
_softmax_backward_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & output,std::optional<int64_t> output_bdim,int64_t dim,ScalarType input_dtype)259 static std::tuple<Tensor, std::optional<int64_t>> _softmax_backward_batch_rule(
260     const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
261     const Tensor& output, std::optional<int64_t> output_bdim,
262     int64_t dim,
263     ScalarType input_dtype) {
264   // softmax_backward's decomposition is y * gy - y * (y * gy).sum(dim, keepdim=True)
265   // NB: the CUDA kernel handles strides so we can just expand
266   // all of the tensors and call it a day. The CPU kernel is not as good but
267   // idk if the perf on that really matters
268   auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
269   auto output_ = moveBatchDimToFront(output, output_bdim);
270 
271   // Expand out that extra dimension for everyone
272   std::tie(grad_output_, output_) = expand_bdims(
273       grad_output_, grad_output_bdim.has_value(),
274       output_, output_bdim.has_value());
275 
276   // Scalar tensor case. softmax turns into the identity when this happens.
277   // I don't know why the output is zeros, though, but that's what softmax tells me...
278   if (output_.dim() == 1 && (dim == 0 || dim == -1)) {
279     return std::make_tuple(at::zeros_like(grad_output_), 0);
280   }
281 
282   dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim);
283 
284   // Not sure why output_ needs to be marked as .contiguous(). Someting must
285   // have changed in PyTorch (and output of softmax is probably always contiguous)
286   return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0);
287 }
288 
_log_softmax_backward_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & output,std::optional<int64_t> output_bdim,int64_t dim,c10::ScalarType input_dtype)289 static std::tuple<Tensor, std::optional<int64_t>> _log_softmax_backward_batch_rule(
290     const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
291     const Tensor& output, std::optional<int64_t> output_bdim,
292     int64_t dim,
293     c10::ScalarType input_dtype) {
294   // NB: It turns out that expanding + calling log_softmax_backward is generally
295   // faster than the decomposition.
296   // Benchmark here: https://gist.github.com/zou3519/ae3b33b5730a84aae8a80a05c89e078a
297   // Decomposition is (grad_output - grad_output.sum(dim, keepdim=True) * result.exp())
298   // We can squeeze out a last mile of performance by writing custom kernels.
299   auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
300   auto output_ = moveBatchDimToFront(output, output_bdim);
301 
302   // Expand out that extra dimension for everyone
303   std::tie(grad_output_, output_) = expand_bdims(
304       grad_output_, grad_output_bdim.has_value(),
305       output_, output_bdim.has_value());
306 
307   // Scalar tensor case. log_softmax returns zeros when this happens
308   if (output_.dim() == 1 && (dim == 0 || dim == -1)) {
309     return std::make_tuple(at::zeros_like(grad_output_), 0);
310   }
311 
312   dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim);
313 
314   return std::make_tuple(at::_log_softmax_backward_data(grad_output_, output_, dim, input_dtype), 0);
315 }
316 
searchsorted_batch_rule(const Tensor & sorted_sequence,std::optional<int64_t> sorted_sequence_bdim,const Tensor & self,std::optional<int64_t> self_bdim,bool out_int32,bool right,std::optional<c10::string_view> side,const std::optional<Tensor> & sorter,std::optional<int64_t> sorter_bdim)317 static std::tuple<Tensor, std::optional<int64_t>> searchsorted_batch_rule(
318     const Tensor& sorted_sequence,
319     std::optional<int64_t> sorted_sequence_bdim,
320     const Tensor& self,
321     std::optional<int64_t> self_bdim,
322     bool out_int32,
323     bool right,
324     std::optional<c10::string_view> side,
325     const std::optional<Tensor>& sorter,
326     std::optional<int64_t> sorter_bdim) {
327   auto buckets_logical_rank = rankWithoutBatchDim(sorted_sequence, sorted_sequence_bdim);
328   auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
329 
330   // Preprocess sorter and sorted_sequence.
331   // If they both exist, and only one has a bdim, then we need to make sure both do.
332   // After this step, we can forget about sorter for a bit.
333   auto buckets = moveBatchDimToFront(sorted_sequence, sorted_sequence_bdim);
334   std::optional<int64_t> buckets_bdim;
335   if (sorted_sequence_bdim.has_value()) {
336     buckets_bdim = 0;
337   }
338 
339   std::optional<Tensor> sorter_;
340   if (sorter.has_value() && sorter->defined()) {
341     auto sorter__ = moveBatchDimToFront(*sorter, sorter_bdim);
342     if (sorted_sequence_bdim.has_value() != sorter_bdim.has_value()) {
343       auto bdim_size = get_bdim_size2(
344           sorted_sequence, sorted_sequence_bdim,
345           sorter.value(), sorter_bdim);
346       sorter__ = ensure_has_bdim(sorter__, sorter_bdim.has_value(), bdim_size);
347       buckets = ensure_has_bdim(buckets, sorted_sequence_bdim.has_value(), bdim_size);
348       buckets_bdim = 0;
349     }
350     sorter_ = sorter__;
351   }
352 
353   // Two cases: buckets_logical_rank is 1, or it is greater than 1.
354   // searchsorted is basically two operators with different semantics jammed
355   // into one
356   if (buckets_logical_rank > 1) {
357     // B<...>D, B<...>V -> no change
358     if (buckets_bdim.has_value() && self_bdim.has_value()) {
359       auto self_ = moveBatchDimToFront(self, self_bdim);
360       auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_);
361       return std::make_tuple(std::move(result), 0);
362     }
363     // B<...>D, <...>V -> B<...>D, B<...>V
364     if (buckets_bdim.has_value() && !self_bdim.has_value()) {
365       auto self_ = moveBatchDimToFront(self, self_bdim);
366       self_ = ensure_has_bdim(self_, self_bdim.has_value(), buckets.size(0));
367       auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_);
368       return std::make_tuple(std::move(result), 0);
369     }
370     // <...>D, B<...>V -> <...>D, <...>(BV)
371     if (!buckets_bdim.has_value() && self_bdim.has_value()) {
372       auto bdim_size = self.size(*self_bdim);
373       auto self_ = reshape_dim_into(*self_bdim, -1, self);
374       auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_);
375       result = reshape_dim_outof(-1, bdim_size, result);
376       return std::make_tuple(result, result.dim() - 2);
377     }
378     TORCH_INTERNAL_ASSERT(false);
379   }
380   // buckets_logical_rank == 1 case.
381   // BD, B* -> BD, B flat(*)
382   if (buckets_bdim.has_value() && self_bdim.has_value()) {
383     auto self_ = moveBatchDimToFront(self, self_bdim);
384     auto self_view_ = self_logical_rank == 0 ? self_.unsqueeze(-1) : self_.flatten(1);
385     auto result = at::searchsorted(buckets, self_view_, out_int32, right, std::move(side), sorter_);
386     result = self_logical_rank == 0 ? result.squeeze(-1) : result.view(self_.sizes());
387     return std::make_tuple(std::move(result), 0);
388   }
389   // BD, * -> BD, flat(*) -> BD, B flat(*)
390   if (buckets_bdim.has_value() && !self_bdim.has_value()) {
391     auto bdim_size = buckets.size(*buckets_bdim);
392     auto self_ = ensure_has_bdim(self, false, bdim_size);
393     auto self_view_ = self_logical_rank == 0 ? self_.unsqueeze(-1) : self_.flatten(1);
394     auto result = at::searchsorted(buckets, self_view_, out_int32, right, std::move(side), sorter_);
395     result = self_logical_rank == 0 ? result.squeeze(-1) : result.view(self_.sizes());
396     return std::make_tuple(std::move(result), 0);
397   }
398   // D, B* -> no change
399   if (!buckets_bdim.has_value() && self_bdim.has_value()) {
400     auto result = at::searchsorted(buckets, self, out_int32, right, std::move(side), sorter_);
401     return std::make_tuple(std::move(result), self_bdim);
402   }
403   TORCH_INTERNAL_ASSERT(false);
404 }
405 
bucketize_decomp_Tensor(const Tensor & self,const Tensor & boundaries,bool out_int32,bool right)406 static Tensor bucketize_decomp_Tensor(
407     const Tensor& self,
408     const Tensor& boundaries,
409     bool out_int32,
410     bool right) {
411   // checking logical rank
412   TORCH_CHECK(boundaries.dim() == 1, "bucketize: boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
413   return at::searchsorted(boundaries, self, out_int32, right, std::nullopt, std::nullopt);
414 }
415 
bucketize_decomp_Scalar(const Scalar & self,const Tensor & boundaries,bool out_int32,bool right)416 static Tensor bucketize_decomp_Scalar(
417     const Scalar& self,
418     const Tensor& boundaries,
419     bool out_int32,
420     bool right) {
421   // checking logical rank
422   TORCH_CHECK(boundaries.dim() == 1, "bucketize: boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
423   return at::searchsorted(boundaries, self, out_int32, right, std::nullopt, std::nullopt);
424 }
425 
426 // Use when the other macros don't work out.
427 // - dim_pos: index of the dim argument
428 // - keepdim_case: either True, False, or Variable.
429 //   See NOTE: [keepdim cases] for more details.
430 // - maybe_keepdim_pos. The index of the keepdim argument,
431 //   if exists. Otherwise, the value is ignored.
432 #define REDUCTION_BOXED_ARGS(op, dim_pos, keepdim_case, maybe_keepdim_pos) \
433   m.impl(#op, torch::CppFunction::makeFromBoxedFunction< \
434       SINGLE_ARG(boxed_reduction_batch_rule<dim_pos, keepdim_case, maybe_keepdim_pos>)>());
435 
436 // Provided for your convenience; most operators that have a keepdim arg
437 // will work with this macro.
438 // Assumes the dim arg is at position 1 and the keepdim arg is at pos 2.
439 #define REDUCTION_WITH_KEEPDIM_ARG(op) \
440   REDUCTION_BOXED_ARGS(op, 1, KEEPDIM_CASE_VARIABLE, 2)
441 
442 // Provided for your convenience; most operators that do not have a keepdim
443 // arg will work with this macro.
444 // Assumes the dim arg is at position 1 and the operation always returns
445 // a tensor of the same rank (instead of a smaller rank).
446 #define REDUCTION_NO_KEEPDIM_ARG(op) \
447   REDUCTION_BOXED_ARGS(op, 1, KEEPDIM_CASE_TRUE, -1)
448 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)449 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
450   VMAP_SUPPORT2(searchsorted, Tensor, searchsorted_batch_rule);
451   REDUCTION_NO_KEEPDIM_ARG(_fft_r2c);
452   REDUCTION_NO_KEEPDIM_ARG(_fft_c2r);
453   REDUCTION_NO_KEEPDIM_ARG(_fft_c2c);
454   REDUCTION_WITH_KEEPDIM_ARG(amax);
455   REDUCTION_WITH_KEEPDIM_ARG(amin);
456   REDUCTION_WITH_KEEPDIM_ARG(aminmax);
457   m.impl("all", all_decomp);
458   REDUCTION_WITH_KEEPDIM_ARG(all.dim);
459   REDUCTION_WITH_KEEPDIM_ARG(all.dims);
460   m.impl("any", any_decomp);
461   REDUCTION_WITH_KEEPDIM_ARG(any.dim);
462   REDUCTION_WITH_KEEPDIM_ARG(any.dims);
463   REDUCTION_WITH_KEEPDIM_ARG(argmax);
464   REDUCTION_WITH_KEEPDIM_ARG(argmin);
465   m.impl("bucketize.Tensor", bucketize_decomp_Tensor);
466   m.impl("bucketize.Scalar", bucketize_decomp_Scalar);
467   REDUCTION_BOXED_ARGS(count_nonzero.dim_IntList, 1, KEEPDIM_CASE_FALSE, -1);
468   REDUCTION_NO_KEEPDIM_ARG(cummax);
469   REDUCTION_NO_KEEPDIM_ARG(cummin);
470   REDUCTION_NO_KEEPDIM_ARG(cumprod);
471   REDUCTION_NO_KEEPDIM_ARG(cumsum);
472   m.impl("dist", dist_decomp);
473   REDUCTION_BOXED_ARGS(kthvalue, 2, KEEPDIM_CASE_VARIABLE, 3);
474   REDUCTION_BOXED_ARGS(linalg_vector_norm, 2, KEEPDIM_CASE_VARIABLE, 3);
475   REDUCTION_NO_KEEPDIM_ARG(logcumsumexp);
476   REDUCTION_WITH_KEEPDIM_ARG(logsumexp);
477   m.impl("max", max_decomp);
478   REDUCTION_WITH_KEEPDIM_ARG(max.dim);
479   m.impl("mean", mean_decomp);
480   REDUCTION_WITH_KEEPDIM_ARG(mean.dim);
481   m.impl("median", median_decomp);
482   REDUCTION_WITH_KEEPDIM_ARG(median.dim);
483   m.impl("min", min_decomp);
484   REDUCTION_WITH_KEEPDIM_ARG(min.dim);
485   REDUCTION_WITH_KEEPDIM_ARG(mode);
486   m.impl("nanmedian", nanmedian_decomp);
487   REDUCTION_WITH_KEEPDIM_ARG(nanmedian.dim);
488   REDUCTION_WITH_KEEPDIM_ARG(nansum);
489   m.impl("norm.Scalar", norm_scalar_decomp);
490   REDUCTION_BOXED_ARGS(norm.ScalarOpt_dim, 2, KEEPDIM_CASE_VARIABLE, 3);
491   m.impl("prod", prod_decomp);
492   REDUCTION_WITH_KEEPDIM_ARG(prod.dim_int);
493   REDUCTION_BOXED_ARGS(std.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
494   REDUCTION_NO_KEEPDIM_ARG(_softmax);
495   REDUCTION_NO_KEEPDIM_ARG(_safe_softmax);
496   REDUCTION_NO_KEEPDIM_ARG(sort);
497   REDUCTION_BOXED_ARGS(sort.stable, 2, KEEPDIM_CASE_TRUE, -1);
498   REDUCTION_BOXED_ARGS(std_mean.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
499   m.impl("sum", sum_decomp);
500   REDUCTION_WITH_KEEPDIM_ARG(sum.dim_IntList);
501   REDUCTION_BOXED_ARGS(topk, 2, KEEPDIM_CASE_TRUE, -1);
502   REDUCTION_BOXED_ARGS(var.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
503   REDUCTION_BOXED_ARGS(var_mean.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
504   REDUCTION_NO_KEEPDIM_ARG(_log_softmax);
505   REDUCTION_BOXED_ARGS(rot90, 2, KEEPDIM_CASE_TRUE, -1);
506   VMAP_SUPPORT(_log_softmax_backward_data, _log_softmax_backward_batch_rule);
507   VMAP_SUPPORT(_softmax_backward_data, _softmax_backward_batch_rule);
508   VMAP_SUPPORT(_is_all_true, _is_all_true_batch_rule);
509   VMAP_SUPPORT(_is_any_true, _is_any_true_batch_rule);
510 }
511 
512 } // namespace at::functorch
513