1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/Operators.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11
12 #include <utility>
13
14 namespace at::functorch {
15
is_allowed_dim_on_scalar_tensor(int64_t dim)16 static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
17 return dim == 0 || dim == -1;
18 }
19
sum_decomp(const Tensor & self,std::optional<ScalarType> dtype)20 static Tensor sum_decomp(
21 const Tensor& self, std::optional<ScalarType> dtype) {
22 return at::sum(self, range(0, self.dim()), false, dtype);
23 }
24
_is_all_true_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim)25 static std::tuple<Tensor, std::optional<int64_t>> _is_all_true_batch_rule(
26 const Tensor& self, std::optional<int64_t> self_bdim) {
27 return std::make_tuple(at::_is_all_true(self), std::nullopt);
28 }
29
_is_any_true_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim)30 static std::tuple<Tensor, std::optional<int64_t>> _is_any_true_batch_rule(
31 const Tensor& self, std::optional<int64_t> self_bdim) {
32 return std::make_tuple(at::_is_any_true(self), std::nullopt);
33 }
34
mean_decomp(const Tensor & self,std::optional<ScalarType> dtype)35 static Tensor mean_decomp(
36 const Tensor& self, std::optional<ScalarType> dtype) {
37 return at::mean(self, range(0, self.dim()), false, dtype);
38 }
39
prod_decomp(const Tensor & self,std::optional<ScalarType> dtype)40 static Tensor prod_decomp(
41 const Tensor& self, std::optional<ScalarType> dtype) {
42 return at::prod(self.flatten(), 0, false, dtype);
43 }
44
max_decomp(const Tensor & self)45 static Tensor max_decomp(
46 const Tensor& self) {
47 return std::get<0>(at::max(self.flatten(), 0, false));
48 }
49
min_decomp(const Tensor & self)50 static Tensor min_decomp(
51 const Tensor& self) {
52 return std::get<0>(at::min(self.flatten(), 0, false));
53 }
54
norm_scalar_decomp(const Tensor & self,const Scalar & p)55 static Tensor norm_scalar_decomp(
56 const Tensor& self, const Scalar& p) {
57 return at::norm(self, p, range(0, self.dim()), false);
58 }
59
nanmedian_decomp(const Tensor & self)60 static Tensor nanmedian_decomp(
61 const Tensor& self) {
62 return std::get<0>(at::nanmedian(self.flatten(), 0, false));
63 }
64
median_decomp(const Tensor & self)65 static Tensor median_decomp(
66 const Tensor& self) {
67 return std::get<0>(at::median(self.flatten(), 0, false));
68 }
69
all_decomp(const Tensor & self)70 static Tensor all_decomp(const Tensor& self) {
71 return at::all(self.flatten(), 0, false);
72 }
73
any_decomp(const Tensor & self)74 static Tensor any_decomp(const Tensor& self) {
75 return at::any(self.flatten(), 0, false);
76 }
77
78 enum class ReductionCase:uint8_t { DimArray, Dim };
79
80 // Macros and templates have a difficult time dealing with enums,
81 // so we didn't turn this into an enum.
82 // See NOTE: [keepdim cases] for explanation of what these are.
83 static constexpr int KEEPDIM_CASE_FALSE = 0;
84 static constexpr int KEEPDIM_CASE_TRUE = 1;
85 static constexpr int KEEPDIM_CASE_VARIABLE = 2;
86
87 // dim_arg_pos allows us to specify the location of the dim/dim array argument.
88 // For most PyTorch ops, this is equal to 1.
89 //
90 // NOTE: [keepdim cases]
91 // The operator in question either:
92 // - has a keepdim argument (KeepdimCase.Variable)
93 // In this case, `maybe_keepdim_arg_pos` says where the index of the keepdim arg is.
94 // example: sum(tensor, dim, keepdim)
95 // - always does a reduction with no keepdim (KeepdimCase.False)
96 // that is, the rank of the output tensor is less than the rank of the input tensor.
97 // - always does a reduction with keepdim=True semantics (KeepdimCase.True)
98 // That is, the rank of the output tensor is always the same as that of the input.
99 // examples: log_softmax(tensor, dim), cumsum(tensor, dim)
100 template<
101 int dim_arg_pos,
102 int keepdim_case,
103 // optional cannot be used in a template, otherwise we would use it here.
104 int maybe_keepdim_arg_pos
105 >
boxed_reduction_batch_rule(const c10::OperatorHandle & op,torch::jit::Stack * stack)106 void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
107 const auto& schema = op.schema();
108 const auto num_returns = schema.returns().size();
109 const auto num_arguments = schema.arguments().size();
110
111 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
112 auto maybe_layer = maybeCurrentDynamicLayer();
113 vmap_check_escaped(maybe_layer, "boxed_reduction_batch_rule");
114 int64_t cur_level = maybe_layer->layerId();
115
116 auto orig_arguments = torch::jit::last(*stack, num_arguments);
117 if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
118 c10::impl::ExcludeDispatchKeyGuard guard_2(DispatchKey::FuncTorchBatched);
119 op.callBoxed(stack);
120 return;
121 }
122
123 auto arguments = torch::jit::pop(*stack, num_arguments);
124
125 TORCH_INTERNAL_ASSERT(arguments[0].isTensor());
126 auto [self, self_bdim] = unwrapTensorAtLevel(arguments[0].toTensor(), cur_level);
127
128 self = moveBatchDimToFront(self, self_bdim);
129
130 auto logical_dim = rankWithoutBatchDim(self, self_bdim);
131 std::vector<int64_t> dims;
132 ReductionCase reduction_case{};
133 if (arguments[dim_arg_pos].isIntList()) {
134 reduction_case = ReductionCase::DimArray;
135 dims = arguments[dim_arg_pos].toIntList().vec();
136 if (dims.empty()) {
137 auto all_dims = range(0, std::max((int64_t)1, logical_dim));
138 dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
139 }
140 } else if (arguments[dim_arg_pos].isInt()) {
141 reduction_case = ReductionCase::Dim;
142 dims = {arguments[dim_arg_pos].toInt()};
143 } else if (arguments[dim_arg_pos].isNone()) {
144 auto param_type = schema.arguments()[dim_arg_pos].type()->expect<OptionalType>()->getElementType();
145 if (param_type->kind() == IntType::Kind) {
146 reduction_case = ReductionCase::Dim;
147 if (self.dim() > 1) {
148 self = self.flatten(1);
149 }
150 dims = {0};
151 } else if (param_type->kind() == ListType::Kind) {
152 reduction_case = ReductionCase::DimArray;
153 if (logical_dim == 0) {
154 dims = {0};
155 } else {
156 auto all_dims = range(0, self.dim() - 1);
157 dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
158 }
159 } else {
160 TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims");
161 }
162 } else{
163 TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims");
164 }
165
166 VmapDimVector new_dims;
167 new_dims.reserve(dims.size());
168 for (auto dim: dims) {
169 new_dims.push_back(getPhysicalDim(self, self_bdim.has_value(), dim));
170 }
171 bool is_scalar_case = logical_dim == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0]);
172 std::optional<bool> maybe_keepdim;
173 if (is_scalar_case) {
174 // NOTE: [boxed_reduction_batch_rule scalar tensor handling]
175 // Reduction operations in PyTorch have an edge case where they allow
176 // dim=0 and dim=-1 if the tensor has shape [].
177 //
178 // This can come up if we do something like
179 // vmap(lambda x: x.sum(0))(torch.tensor([10.])),
180 //
181 // In order to handle this edge case, we unsqueeze a dimension on the Tensor,
182 // run the operation (with dim=1 instead), and then process the output tensor.
183 // There are two cases:
184 // - keepdim = True
185 // unsqueeze op squeeze
186 // [B] -> [B, 1] -> [B, 1] -> [B]
187 // - keepdim = False
188 // unsqueeze op no need to squeeze
189 // [B] -> [B, 1] -> [B]
190 // if keepdim is True, then we need to squeeze the dimension of size 1.
191
192 // Determine the value of keepdim
193 switch (keepdim_case) {
194 case KEEPDIM_CASE_FALSE:
195 maybe_keepdim = false;
196 break;
197 case KEEPDIM_CASE_TRUE:
198 maybe_keepdim = true;
199 break;
200 case KEEPDIM_CASE_VARIABLE:
201 TORCH_INTERNAL_ASSERT(maybe_keepdim_arg_pos >= 0);
202 maybe_keepdim = arguments[maybe_keepdim_arg_pos].toBool();
203 break;
204 }
205 self = self.unsqueeze(-1);
206 new_dims = {1};
207 }
208 arguments[0] = std::move(self);
209 if (reduction_case == ReductionCase::DimArray) {
210 arguments[dim_arg_pos] = std::vector<int64_t>(new_dims.begin(), new_dims.end());
211 } else if (reduction_case == ReductionCase::Dim) {
212 arguments[dim_arg_pos] = new_dims[0];
213 }
214 for (const auto arg_idx : c10::irange(0, num_arguments)) {
215 torch::jit::push(stack, arguments[arg_idx]);
216 }
217 op.callBoxed(stack);
218
219 const auto returns = torch::jit::pop(*stack, num_returns);
220 for (const auto& ret : returns) {
221 if (ret.isTensor()) {
222 auto res = ret.toTensor();
223 // see NOTE: [boxed_reduction_batch_rule scalar tensor handling]
224 if (is_scalar_case && maybe_keepdim.value()) {
225 // squeeze(-1) is a no-op if the shape of the dim is not 1.
226 // To make it safer, we internal assert here.
227 TORCH_INTERNAL_ASSERT(res.size(-1) == 1);
228 res = res.squeeze(-1);
229 }
230 torch::jit::push(stack, makeBatched(res, 0, cur_level));
231 } else {
232 TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
233 }
234 }
235 }
236
237 // Skipping all/any since they don't have opinfo tests right now :P
238
dist_decomp(const Tensor & self,const Tensor & other,const Scalar & p)239 static Tensor dist_decomp(const Tensor& self, const Tensor& other, const Scalar& p) {
240 return at::norm((self - other), p);
241 }
242
expand_bdims(const Tensor & a,bool a_has_bdim,const Tensor & b,bool b_has_bdim)243 static std::tuple<Tensor, Tensor> expand_bdims(
244 const Tensor& a, bool a_has_bdim,
245 const Tensor& b, bool b_has_bdim) {
246 Tensor flagpole;
247 if (a_has_bdim) {
248 flagpole = a;
249 } else if (b_has_bdim) {
250 flagpole = b;
251 } else {
252 TORCH_INTERNAL_ASSERT(false);
253 }
254 return std::make_tuple(
255 a_has_bdim ? a : a.expand_as(flagpole),
256 b_has_bdim ? b : b.expand_as(flagpole));
257 }
258
_softmax_backward_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & output,std::optional<int64_t> output_bdim,int64_t dim,ScalarType input_dtype)259 static std::tuple<Tensor, std::optional<int64_t>> _softmax_backward_batch_rule(
260 const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
261 const Tensor& output, std::optional<int64_t> output_bdim,
262 int64_t dim,
263 ScalarType input_dtype) {
264 // softmax_backward's decomposition is y * gy - y * (y * gy).sum(dim, keepdim=True)
265 // NB: the CUDA kernel handles strides so we can just expand
266 // all of the tensors and call it a day. The CPU kernel is not as good but
267 // idk if the perf on that really matters
268 auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
269 auto output_ = moveBatchDimToFront(output, output_bdim);
270
271 // Expand out that extra dimension for everyone
272 std::tie(grad_output_, output_) = expand_bdims(
273 grad_output_, grad_output_bdim.has_value(),
274 output_, output_bdim.has_value());
275
276 // Scalar tensor case. softmax turns into the identity when this happens.
277 // I don't know why the output is zeros, though, but that's what softmax tells me...
278 if (output_.dim() == 1 && (dim == 0 || dim == -1)) {
279 return std::make_tuple(at::zeros_like(grad_output_), 0);
280 }
281
282 dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim);
283
284 // Not sure why output_ needs to be marked as .contiguous(). Someting must
285 // have changed in PyTorch (and output of softmax is probably always contiguous)
286 return std::make_tuple(at::_softmax_backward_data(grad_output_, output_.contiguous(), dim, input_dtype), 0);
287 }
288
_log_softmax_backward_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & output,std::optional<int64_t> output_bdim,int64_t dim,c10::ScalarType input_dtype)289 static std::tuple<Tensor, std::optional<int64_t>> _log_softmax_backward_batch_rule(
290 const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
291 const Tensor& output, std::optional<int64_t> output_bdim,
292 int64_t dim,
293 c10::ScalarType input_dtype) {
294 // NB: It turns out that expanding + calling log_softmax_backward is generally
295 // faster than the decomposition.
296 // Benchmark here: https://gist.github.com/zou3519/ae3b33b5730a84aae8a80a05c89e078a
297 // Decomposition is (grad_output - grad_output.sum(dim, keepdim=True) * result.exp())
298 // We can squeeze out a last mile of performance by writing custom kernels.
299 auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
300 auto output_ = moveBatchDimToFront(output, output_bdim);
301
302 // Expand out that extra dimension for everyone
303 std::tie(grad_output_, output_) = expand_bdims(
304 grad_output_, grad_output_bdim.has_value(),
305 output_, output_bdim.has_value());
306
307 // Scalar tensor case. log_softmax returns zeros when this happens
308 if (output_.dim() == 1 && (dim == 0 || dim == -1)) {
309 return std::make_tuple(at::zeros_like(grad_output_), 0);
310 }
311
312 dim = getPhysicalDim(output_, /*has_batch_dim*/true, dim);
313
314 return std::make_tuple(at::_log_softmax_backward_data(grad_output_, output_, dim, input_dtype), 0);
315 }
316
searchsorted_batch_rule(const Tensor & sorted_sequence,std::optional<int64_t> sorted_sequence_bdim,const Tensor & self,std::optional<int64_t> self_bdim,bool out_int32,bool right,std::optional<c10::string_view> side,const std::optional<Tensor> & sorter,std::optional<int64_t> sorter_bdim)317 static std::tuple<Tensor, std::optional<int64_t>> searchsorted_batch_rule(
318 const Tensor& sorted_sequence,
319 std::optional<int64_t> sorted_sequence_bdim,
320 const Tensor& self,
321 std::optional<int64_t> self_bdim,
322 bool out_int32,
323 bool right,
324 std::optional<c10::string_view> side,
325 const std::optional<Tensor>& sorter,
326 std::optional<int64_t> sorter_bdim) {
327 auto buckets_logical_rank = rankWithoutBatchDim(sorted_sequence, sorted_sequence_bdim);
328 auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
329
330 // Preprocess sorter and sorted_sequence.
331 // If they both exist, and only one has a bdim, then we need to make sure both do.
332 // After this step, we can forget about sorter for a bit.
333 auto buckets = moveBatchDimToFront(sorted_sequence, sorted_sequence_bdim);
334 std::optional<int64_t> buckets_bdim;
335 if (sorted_sequence_bdim.has_value()) {
336 buckets_bdim = 0;
337 }
338
339 std::optional<Tensor> sorter_;
340 if (sorter.has_value() && sorter->defined()) {
341 auto sorter__ = moveBatchDimToFront(*sorter, sorter_bdim);
342 if (sorted_sequence_bdim.has_value() != sorter_bdim.has_value()) {
343 auto bdim_size = get_bdim_size2(
344 sorted_sequence, sorted_sequence_bdim,
345 sorter.value(), sorter_bdim);
346 sorter__ = ensure_has_bdim(sorter__, sorter_bdim.has_value(), bdim_size);
347 buckets = ensure_has_bdim(buckets, sorted_sequence_bdim.has_value(), bdim_size);
348 buckets_bdim = 0;
349 }
350 sorter_ = sorter__;
351 }
352
353 // Two cases: buckets_logical_rank is 1, or it is greater than 1.
354 // searchsorted is basically two operators with different semantics jammed
355 // into one
356 if (buckets_logical_rank > 1) {
357 // B<...>D, B<...>V -> no change
358 if (buckets_bdim.has_value() && self_bdim.has_value()) {
359 auto self_ = moveBatchDimToFront(self, self_bdim);
360 auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_);
361 return std::make_tuple(std::move(result), 0);
362 }
363 // B<...>D, <...>V -> B<...>D, B<...>V
364 if (buckets_bdim.has_value() && !self_bdim.has_value()) {
365 auto self_ = moveBatchDimToFront(self, self_bdim);
366 self_ = ensure_has_bdim(self_, self_bdim.has_value(), buckets.size(0));
367 auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_);
368 return std::make_tuple(std::move(result), 0);
369 }
370 // <...>D, B<...>V -> <...>D, <...>(BV)
371 if (!buckets_bdim.has_value() && self_bdim.has_value()) {
372 auto bdim_size = self.size(*self_bdim);
373 auto self_ = reshape_dim_into(*self_bdim, -1, self);
374 auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_);
375 result = reshape_dim_outof(-1, bdim_size, result);
376 return std::make_tuple(result, result.dim() - 2);
377 }
378 TORCH_INTERNAL_ASSERT(false);
379 }
380 // buckets_logical_rank == 1 case.
381 // BD, B* -> BD, B flat(*)
382 if (buckets_bdim.has_value() && self_bdim.has_value()) {
383 auto self_ = moveBatchDimToFront(self, self_bdim);
384 auto self_view_ = self_logical_rank == 0 ? self_.unsqueeze(-1) : self_.flatten(1);
385 auto result = at::searchsorted(buckets, self_view_, out_int32, right, std::move(side), sorter_);
386 result = self_logical_rank == 0 ? result.squeeze(-1) : result.view(self_.sizes());
387 return std::make_tuple(std::move(result), 0);
388 }
389 // BD, * -> BD, flat(*) -> BD, B flat(*)
390 if (buckets_bdim.has_value() && !self_bdim.has_value()) {
391 auto bdim_size = buckets.size(*buckets_bdim);
392 auto self_ = ensure_has_bdim(self, false, bdim_size);
393 auto self_view_ = self_logical_rank == 0 ? self_.unsqueeze(-1) : self_.flatten(1);
394 auto result = at::searchsorted(buckets, self_view_, out_int32, right, std::move(side), sorter_);
395 result = self_logical_rank == 0 ? result.squeeze(-1) : result.view(self_.sizes());
396 return std::make_tuple(std::move(result), 0);
397 }
398 // D, B* -> no change
399 if (!buckets_bdim.has_value() && self_bdim.has_value()) {
400 auto result = at::searchsorted(buckets, self, out_int32, right, std::move(side), sorter_);
401 return std::make_tuple(std::move(result), self_bdim);
402 }
403 TORCH_INTERNAL_ASSERT(false);
404 }
405
bucketize_decomp_Tensor(const Tensor & self,const Tensor & boundaries,bool out_int32,bool right)406 static Tensor bucketize_decomp_Tensor(
407 const Tensor& self,
408 const Tensor& boundaries,
409 bool out_int32,
410 bool right) {
411 // checking logical rank
412 TORCH_CHECK(boundaries.dim() == 1, "bucketize: boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
413 return at::searchsorted(boundaries, self, out_int32, right, std::nullopt, std::nullopt);
414 }
415
bucketize_decomp_Scalar(const Scalar & self,const Tensor & boundaries,bool out_int32,bool right)416 static Tensor bucketize_decomp_Scalar(
417 const Scalar& self,
418 const Tensor& boundaries,
419 bool out_int32,
420 bool right) {
421 // checking logical rank
422 TORCH_CHECK(boundaries.dim() == 1, "bucketize: boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
423 return at::searchsorted(boundaries, self, out_int32, right, std::nullopt, std::nullopt);
424 }
425
426 // Use when the other macros don't work out.
427 // - dim_pos: index of the dim argument
428 // - keepdim_case: either True, False, or Variable.
429 // See NOTE: [keepdim cases] for more details.
430 // - maybe_keepdim_pos. The index of the keepdim argument,
431 // if exists. Otherwise, the value is ignored.
432 #define REDUCTION_BOXED_ARGS(op, dim_pos, keepdim_case, maybe_keepdim_pos) \
433 m.impl(#op, torch::CppFunction::makeFromBoxedFunction< \
434 SINGLE_ARG(boxed_reduction_batch_rule<dim_pos, keepdim_case, maybe_keepdim_pos>)>());
435
436 // Provided for your convenience; most operators that have a keepdim arg
437 // will work with this macro.
438 // Assumes the dim arg is at position 1 and the keepdim arg is at pos 2.
439 #define REDUCTION_WITH_KEEPDIM_ARG(op) \
440 REDUCTION_BOXED_ARGS(op, 1, KEEPDIM_CASE_VARIABLE, 2)
441
442 // Provided for your convenience; most operators that do not have a keepdim
443 // arg will work with this macro.
444 // Assumes the dim arg is at position 1 and the operation always returns
445 // a tensor of the same rank (instead of a smaller rank).
446 #define REDUCTION_NO_KEEPDIM_ARG(op) \
447 REDUCTION_BOXED_ARGS(op, 1, KEEPDIM_CASE_TRUE, -1)
448
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)449 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
450 VMAP_SUPPORT2(searchsorted, Tensor, searchsorted_batch_rule);
451 REDUCTION_NO_KEEPDIM_ARG(_fft_r2c);
452 REDUCTION_NO_KEEPDIM_ARG(_fft_c2r);
453 REDUCTION_NO_KEEPDIM_ARG(_fft_c2c);
454 REDUCTION_WITH_KEEPDIM_ARG(amax);
455 REDUCTION_WITH_KEEPDIM_ARG(amin);
456 REDUCTION_WITH_KEEPDIM_ARG(aminmax);
457 m.impl("all", all_decomp);
458 REDUCTION_WITH_KEEPDIM_ARG(all.dim);
459 REDUCTION_WITH_KEEPDIM_ARG(all.dims);
460 m.impl("any", any_decomp);
461 REDUCTION_WITH_KEEPDIM_ARG(any.dim);
462 REDUCTION_WITH_KEEPDIM_ARG(any.dims);
463 REDUCTION_WITH_KEEPDIM_ARG(argmax);
464 REDUCTION_WITH_KEEPDIM_ARG(argmin);
465 m.impl("bucketize.Tensor", bucketize_decomp_Tensor);
466 m.impl("bucketize.Scalar", bucketize_decomp_Scalar);
467 REDUCTION_BOXED_ARGS(count_nonzero.dim_IntList, 1, KEEPDIM_CASE_FALSE, -1);
468 REDUCTION_NO_KEEPDIM_ARG(cummax);
469 REDUCTION_NO_KEEPDIM_ARG(cummin);
470 REDUCTION_NO_KEEPDIM_ARG(cumprod);
471 REDUCTION_NO_KEEPDIM_ARG(cumsum);
472 m.impl("dist", dist_decomp);
473 REDUCTION_BOXED_ARGS(kthvalue, 2, KEEPDIM_CASE_VARIABLE, 3);
474 REDUCTION_BOXED_ARGS(linalg_vector_norm, 2, KEEPDIM_CASE_VARIABLE, 3);
475 REDUCTION_NO_KEEPDIM_ARG(logcumsumexp);
476 REDUCTION_WITH_KEEPDIM_ARG(logsumexp);
477 m.impl("max", max_decomp);
478 REDUCTION_WITH_KEEPDIM_ARG(max.dim);
479 m.impl("mean", mean_decomp);
480 REDUCTION_WITH_KEEPDIM_ARG(mean.dim);
481 m.impl("median", median_decomp);
482 REDUCTION_WITH_KEEPDIM_ARG(median.dim);
483 m.impl("min", min_decomp);
484 REDUCTION_WITH_KEEPDIM_ARG(min.dim);
485 REDUCTION_WITH_KEEPDIM_ARG(mode);
486 m.impl("nanmedian", nanmedian_decomp);
487 REDUCTION_WITH_KEEPDIM_ARG(nanmedian.dim);
488 REDUCTION_WITH_KEEPDIM_ARG(nansum);
489 m.impl("norm.Scalar", norm_scalar_decomp);
490 REDUCTION_BOXED_ARGS(norm.ScalarOpt_dim, 2, KEEPDIM_CASE_VARIABLE, 3);
491 m.impl("prod", prod_decomp);
492 REDUCTION_WITH_KEEPDIM_ARG(prod.dim_int);
493 REDUCTION_BOXED_ARGS(std.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
494 REDUCTION_NO_KEEPDIM_ARG(_softmax);
495 REDUCTION_NO_KEEPDIM_ARG(_safe_softmax);
496 REDUCTION_NO_KEEPDIM_ARG(sort);
497 REDUCTION_BOXED_ARGS(sort.stable, 2, KEEPDIM_CASE_TRUE, -1);
498 REDUCTION_BOXED_ARGS(std_mean.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
499 m.impl("sum", sum_decomp);
500 REDUCTION_WITH_KEEPDIM_ARG(sum.dim_IntList);
501 REDUCTION_BOXED_ARGS(topk, 2, KEEPDIM_CASE_TRUE, -1);
502 REDUCTION_BOXED_ARGS(var.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
503 REDUCTION_BOXED_ARGS(var_mean.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
504 REDUCTION_NO_KEEPDIM_ARG(_log_softmax);
505 REDUCTION_BOXED_ARGS(rot90, 2, KEEPDIM_CASE_TRUE, -1);
506 VMAP_SUPPORT(_log_softmax_backward_data, _log_softmax_backward_batch_rule);
507 VMAP_SUPPORT(_softmax_backward_data, _softmax_backward_batch_rule);
508 VMAP_SUPPORT(_is_all_true, _is_all_true_batch_rule);
509 VMAP_SUPPORT(_is_any_true, _is_any_true_batch_rule);
510 }
511
512 } // namespace at::functorch
513