xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/Operators.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 
12 #include <utility>
13 
14 namespace at::functorch {
15 
16 template <typename F, F Func, typename... ExtraArgs>
_binary_pointwise_batch_rule(const Tensor & tensor,std::optional<int64_t> tensor_batch_dim,const Tensor & other,std::optional<int64_t> other_batch_dim,ExtraArgs...extra_args)17 std::tuple<Tensor, std::optional<int64_t>> _binary_pointwise_batch_rule(
18     const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
19     const Tensor& other, std::optional<int64_t> other_batch_dim,
20     ExtraArgs... extra_args) {
21 
22   auto tensor_other = _binary_pointwise_helper(
23       tensor, tensor_batch_dim, other, other_batch_dim);
24   auto tensor_ = std::get<0>(tensor_other);
25   auto other_ = std::get<1>(tensor_other);
26 
27   auto result = Func(tensor_, other_, std::forward<ExtraArgs>(extra_args)...);
28   return std::make_tuple(result, 0);
29 }
30 
31 template <typename A, A a, typename C>
32 struct BinaryPointwiseBatchRuleHelper;
33 
34 template <typename F, F Func, typename T1, typename T2, typename... T>
35 struct BinaryPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
applyat::functorch::BinaryPointwiseBatchRuleHelper36   static std::tuple<Tensor, std::optional<int64_t>> apply(
37       const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
38       const Tensor& other, std::optional<int64_t> other_batch_dim,
39       T... extra_args) {
40     return _binary_pointwise_batch_rule<F, Func, T...>(
41         tensor, tensor_batch_dim, other, other_batch_dim,
42         std::forward<T>(extra_args)...);
43   }
44 };
45 
46 #define BINARY_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\
47     BinaryPointwiseBatchRuleHelper<\
48       decltype(&fn),\
49       &fn,\
50       c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
51 
52 template <typename A, A a, typename C>
53 struct BinaryRandomPointwiseBatchRuleHelper;
54 
55 template <typename F, F Func, typename T1, typename T2, typename... T>
56 struct BinaryRandomPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
applyat::functorch::BinaryRandomPointwiseBatchRuleHelper57   static Tensor apply(const Tensor& tensor, const Tensor& other, T... extra_args) {
58     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
59     auto maybe_layer = maybeCurrentDynamicLayer();
60     auto cur_level = maybe_layer->layerId();
61     RandomnessType randomness = maybe_layer->randomness();
62 
63     auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(tensor, cur_level);
64 
65     auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level);
66 
67     check_randomness(randomness, (tensor_bdim || other_bdim));
68     if (randomness == RandomnessType::Different && !tensor_bdim && !other_bdim) {
69       auto shape = tensor_value.sizes();
70       VmapSymDimVector shapeVec(1, maybe_layer->batchSize());
71       shapeVec.reserve(shape.size() + 1);
72       shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
73 
74       // not taken care of with binary batch rule, which assumes at least one input is batched
75       tensor_value = tensor_value.expand_symint(shapeVec);
76       tensor_bdim = 0;
77     } else if (randomness == RandomnessType::Same && !tensor_bdim && !other_bdim) {
78 
79       // avoids unnecessary checks and batch rule assuming output is batched
80       return Func(tensor_value, other_value, std::forward<T>(extra_args)...);
81     }
82     auto res = _binary_pointwise_batch_rule<F, Func, T...>(
83       tensor_value, tensor_bdim, other_value, other_bdim,
84       std::forward<T>(extra_args)...);
85     return makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
86   }
87 };
88 
89 #define BINARY_RANDOM_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\
90     BinaryRandomPointwiseBatchRuleHelper<\
91       decltype(&fn),\
92       &fn,\
93       c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
94 
95 template <typename M, M Meth, typename... ExtraArgs>
binary_pointwise_inplace_batch_rule(Tensor & tensor,std::optional<int64_t> tensor_batch_dim,const Tensor & other,std::optional<int64_t> other_batch_dim,ExtraArgs...extra_args)96 void binary_pointwise_inplace_batch_rule(
97     Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
98     const Tensor& other, std::optional<int64_t> other_batch_dim,
99     ExtraArgs... extra_args) {
100   if (!tensor_batch_dim && other_batch_dim) {
101     vmapIncompatibleInplaceError("inplace arithmetic");
102   }
103 
104   // compute max logical rank
105   auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
106   auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
107   auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
108 
109   auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
110   auto other_ = moveBatchDimToFront(other, other_batch_dim);
111 
112   // If the dimensions aren't aligned, we need to line them up.
113   // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
114   // Note that only tensors that have a batch dim need to be modified.
115   // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
116   tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
117   other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
118 
119   (tensor_.*Meth)(other_, std::forward<ExtraArgs>(extra_args)...);
120 }
121 
122 template <typename F, F Func>
comparison_pointwise_batch_rule(const Tensor & tensor,std::optional<int64_t> tensor_batch_dim,const Tensor & other,std::optional<int64_t> other_batch_dim)123 std::tuple<Tensor, std::optional<int64_t>> comparison_pointwise_batch_rule(
124     const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
125     const Tensor& other, std::optional<int64_t> other_batch_dim) {
126   // compute max logical rank
127   auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
128   auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
129   auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
130 
131   auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
132   auto other_ = moveBatchDimToFront(other, other_batch_dim);
133 
134   // If the dimensions aren't aligned, we need to line them up.
135   // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
136   // Note that only tensors that have a batch dim need to be modified.
137   // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
138   tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
139   other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
140 
141   auto result = Func(tensor_, other_);
142   return std::make_tuple( std::move(result), 0 );
143 }
144 
where_self_batch_rule(const Tensor & condition,std::optional<int64_t> condition_bdim,const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)145 static std::tuple<Tensor, std::optional<int64_t>> where_self_batch_rule(
146     const Tensor& condition, std::optional<int64_t> condition_bdim,
147     const Tensor& self, std::optional<int64_t> self_bdim, const Tensor& other, std::optional<int64_t> other_bdim) {
148   auto condition_logical_rank = rankWithoutBatchDim(condition, condition_bdim);
149   auto tensor_logical_rank = rankWithoutBatchDim(self, self_bdim);
150   auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
151   auto max_logical_rank = std::max({tensor_logical_rank, other_logical_rank, condition_logical_rank});
152 
153   auto condition_ = moveBatchDimToFront(condition, condition_bdim);
154   auto self_ = moveBatchDimToFront(self, self_bdim);
155   auto other_ = moveBatchDimToFront(other, other_bdim);
156 
157   condition_ = maybePadToLogicalRank(condition_, condition_bdim, max_logical_rank);
158   self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank);
159   other_ = maybePadToLogicalRank(other_, other_bdim, max_logical_rank);
160   return std::make_tuple(at::where(condition_, self_, other_), 0);
161 }
162 
gelu_backward_batch_rule(const Tensor & grad_out,std::optional<int64_t> grad_out_bdim,const Tensor & input,std::optional<int64_t> input_bdim,c10::string_view approximate)163 static std::tuple<Tensor, std::optional<int64_t>> gelu_backward_batch_rule(
164     const Tensor& grad_out, std::optional<int64_t> grad_out_bdim, const Tensor& input, std::optional<int64_t> input_bdim,
165     c10::string_view approximate) {
166 
167   // repeat the preprocessing from _binary_pointwise_batch_rule
168   const auto tensor_other = _binary_pointwise_helper(grad_out, grad_out_bdim, input, input_bdim);
169   auto grad_out_ = std::get<0>(tensor_other);
170   auto input_ = std::get<1>(tensor_other);
171 
172   // gelu_backward doesn't broadcast well so we need to insist all inputs have a bdim
173   const auto batch_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
174   grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), batch_size);
175   input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
176 
177   return std::make_tuple(at::gelu_backward(grad_out_, input_, approximate), 0);
178 }
179 
masked_select_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & mask,std::optional<int64_t> mask_bdim)180 static std::tuple<Tensor, std::optional<int64_t>> masked_select_batch_rule(
181     const Tensor& self, std::optional<int64_t> self_bdim,
182     const Tensor& mask, std::optional<int64_t> mask_bdim) {
183   TORCH_CHECK(!mask_bdim.has_value(),
184       "vmap: Attempted to vmap over `mask` in torch.masked_select(self, mask) ",
185       "We cannot support this because for each batch this would return a ",
186       "differently shaped Tensor. "
187       "Please voice your support in https://github.com/pytorch/functorch/issues/256");
188   auto self_ = moveBatchDimToFront(self, self_bdim);
189   const auto batch_size = self_.size(0);
190   const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
191   const auto max_logical_rank = std::max(self_logical_rank, mask.dim());
192   self_ = maybePadToLogicalRank(self_, 0, max_logical_rank);
193 
194   // masked_select returns a 1D tensor, so we have to reshape it into 2D
195   const auto result = at::masked_select(self_, mask).view({ batch_size, -1 });
196   return std::make_tuple(result, 0);
197 }
198 
masked_select_backward_batch_rule(const Tensor & grad,std::optional<int64_t> grad_bdim,const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & mask,std::optional<int64_t> mask_bdim)199 static std::tuple<Tensor, std::optional<int64_t>> masked_select_backward_batch_rule(
200     const Tensor& grad, std::optional<int64_t> grad_bdim,
201     const Tensor& self, std::optional<int64_t> self_bdim,
202     const Tensor& mask, std::optional<int64_t> mask_bdim) {
203   TORCH_CHECK(!mask_bdim.has_value(),
204       "vmap: Attempted to vmap over `mask` in torch.masked_select_backward(grad, self, mask) ",
205       "We cannot support this because for each batch this would return a ",
206       "differently shaped Tensor. "
207       "Please voice your support in https://github.com/pytorch/functorch/issues/256");
208   auto self_ = moveBatchDimToFront(self, self_bdim);
209   auto grad_ = moveBatchDimToFront(grad, grad_bdim);
210 
211   const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
212   const auto max_logical_rank = std::max(self_logical_rank, mask.dim());
213 
214   self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank);
215 
216   const auto batch_size = get_bdim_size2(grad, grad_bdim, self, self_bdim);
217   self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
218   grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), batch_size);
219 
220   const auto result = at::masked_select_backward(grad_, self_.contiguous(), mask);
221   return std::make_tuple(result, 0);
222 }
223 
cdist_backward_batch_rule(const Tensor & grad,std::optional<int64_t> grad_bdim,const Tensor & x1,std::optional<int64_t> x1_bdim,const Tensor & x2,std::optional<int64_t> x2_bdim,const double p,const Tensor & cdist,std::optional<int64_t> cdist_bdim)224 static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule(
225     const Tensor& grad, std::optional<int64_t> grad_bdim,
226     const Tensor& x1, std::optional<int64_t> x1_bdim,
227     const Tensor& x2, std::optional<int64_t> x2_bdim,
228     const double p,
229     const Tensor& cdist, std::optional<int64_t> cdist_bdim) {
230 
231   auto x1_ = x1;
232   if (cdist_bdim && !x1_bdim) {
233     // We need to make sure that x1 has batch dim if cdist has one
234     // otherwise, we get
235     // RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
236     // but expected shape compatible with [4, 5]
237     auto bs = cdist.size(*cdist_bdim);
238     x1_ = ensure_has_bdim(x1, false, bs);
239     x1_ = x1_.contiguous();
240     x1_bdim = 0;
241   }
242 
243   // We need to apply the same preprocessing on x1 and x2 as in the forward pass
244   // _binary_pointwise_batch_rule
245   auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim);
246   x1_ = std::get<0>(x12);
247   auto x2_ = std::get<1>(x12);
248 
249   auto grad_ = moveBatchDimToFront(grad, grad_bdim);
250   if ((x1_bdim || x2_bdim) && !grad_bdim) {
251     // We need to make sure that grad has batch dim if x1 or x2 have one
252     // Probably, there is an assumption on the strides.
253     // Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
254     auto bs = get_bdim_size2(x1_, 0, x2_, 0);
255     grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs);
256     grad_ = grad_.contiguous();
257   }
258 
259   auto out = at::_cdist_backward(grad_, x1_, x2_, p, cdist);
260 
261   std::optional<int64_t> out_bdim = std::nullopt;
262   if (x1_bdim || x2_bdim) {
263     out_bdim = 0;
264   }
265 
266   return std::make_tuple(out, out_bdim);
267 }
268 
fill__Tensor_batch_rule(Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)269 static void fill__Tensor_batch_rule(
270     Tensor& self,
271     std::optional<int64_t> self_bdim,
272     const Tensor& other,
273     std::optional<int64_t> other_bdim) {
274   if (!other_bdim.has_value()) {
275     // Optimization: fill_ is faster than the other path which does
276     // reshaping + copy_
277     self.fill_(other);
278     return;
279   }
280   if (!self_bdim) {
281     vmapIncompatibleInplaceError("fill_");
282   }
283   auto self_and_other = _binary_pointwise_helper(
284       self, self_bdim, other, other_bdim, /*do_type_promotion*/false);
285   std::get<0>(self_and_other).copy_(std::get<1>(self_and_other));
286 }
287 
log_sigmoid_backward_batch_rule(Tensor & grad,std::optional<int64_t> grad_bdim,Tensor & self,std::optional<int64_t> self_bdim,Tensor & buffer,std::optional<int64_t> buffer_bdim)288 static std::tuple<Tensor, std::optional<int64_t>> log_sigmoid_backward_batch_rule(
289   Tensor& grad, std::optional<int64_t> grad_bdim,
290   Tensor& self, std::optional<int64_t> self_bdim,
291   Tensor& buffer, std::optional<int64_t> buffer_bdim) {
292   // NB: This emulates handle_pointwise_ops except we ignore the last argument, buffer
293   // when any of the inputs are on cuda.
294   // We do this because on cuda, buffer is a dummy tensor always of logical rank 1 and
295   // it becomes an issue when the rest of the inputs are scalar
296   int64_t out_logical_rank = std::max(rankWithoutBatchDim(grad, grad_bdim), rankWithoutBatchDim(self, self_bdim));
297   if (!grad.is_cuda() && !self.is_cuda() && !buffer.is_cuda()) {
298     out_logical_rank = std::max(out_logical_rank, rankWithoutBatchDim(buffer, buffer_bdim));
299   }
300   Tensor out_grad = maybePadToLogicalRank(moveBatchDimToFront(grad, grad_bdim), grad_bdim, out_logical_rank);
301   Tensor out_self = maybePadToLogicalRank(moveBatchDimToFront(self, self_bdim), self_bdim, out_logical_rank);
302   Tensor out_buffer = maybePadToLogicalRank(moveBatchDimToFront(buffer, buffer_bdim), buffer_bdim, out_logical_rank);
303   return std::make_tuple(at::log_sigmoid_backward(out_grad, out_self, out_buffer), 0);
304 }
305 
binomial_wrapper(const Tensor & count,const Tensor & prob,std::optional<Generator> gen)306 static Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, std::optional<Generator> gen) {
307   return at::binomial(count, prob.contiguous(), std::move(gen)); // Bug in PyTorch, prob shouldn't need to be contiguous
308 }
309 
TORCH_LIBRARY_IMPL(aten,FuncTorchVmapMode,m)310 TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
311   #define BINARY_RANDOM_POINTWISE(op) \
312     m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
313   #define BINARY_RANDOM_POINTWISE2(op, overload) \
314     m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
315 
316   BINARY_RANDOM_POINTWISE2(normal, Tensor_Tensor);
317   m.impl("binomial", BINARY_RANDOM_POINTWISE_BATCH_RULE(at::functorch::binomial_wrapper));
318 }
319 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)320 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
321 #define BINARY_POINTWISE2(op, overload) \
322   VMAP_SUPPORT2(op, overload, BINARY_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
323 #define BINARY_POINTWISE(op) \
324   VMAP_SUPPORT(op, BINARY_POINTWISE_BATCH_RULE(ATEN_FN(op)));
325 #define UNARY_POINTWISE2(op, overload) \
326   VMAP_SUPPORT2(op, overload, BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
327 #define UNARY_POINTWISE(op) \
328   VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
329 #define UNARY_SCALAR_POINTWISE2(op, overload) \
330   VMAP_SUPPORT(op, overload, SCALAR_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
331 
332 #define BINARY_SCALAR_2(op, tensor_tensor, tensor_scalar) \
333   BINARY_POINTWISE2(op, tensor_tensor);\
334   UNARY_POINTWISE2(op, tensor_scalar);
335 
336 // For all 3 combinations of Tensor x Tensor, Tensor x Scalar, Scalar x Tensor
337 #define BINARY_SCALAR_3(op, tensor_tensor, tensor_scalar, scalar_tensor) \
338   BINARY_POINTWISE2(op, tensor_tensor);\
339   UNARY_POINTWISE2(op, tensor_scalar);\
340   POINTWISE_BOXED(op.scalar_tensor);
341 
342 #define BINARY_SCALAR_3_Tensor(op, tensor_scalar, scalar_tensor) \
343   BINARY_POINTWISE(op);\
344   UNARY_POINTWISE2(op, tensor_scalar);\
345   POINTWISE_BOXED(op.scalar_tensor);
346 
347   // Batching rule registrations start
348   POINTWISE_BOXED(__ilshift__.Tensor);
349   POINTWISE_BOXED(__ilshift__.Scalar);
350   POINTWISE_BOXED(__irshift__.Tensor)
351   POINTWISE_BOXED(__irshift__.Scalar)
352   BINARY_SCALAR_2(__lshift__, Tensor, Scalar);
353   BINARY_SCALAR_2(__rshift__, Tensor, Scalar);
354 
355   BINARY_SCALAR_2(add, Tensor, Scalar);
356   POINTWISE_BOXED(addcdiv);
357   POINTWISE_BOXED(addcmul);
358   BINARY_POINTWISE(atan2);
359   BINARY_SCALAR_2(bitwise_and, Tensor, Scalar);
360   BINARY_POINTWISE2(bitwise_and_, Tensor);
361   POINTWISE_BOXED(bitwise_and_.Scalar);
362   POINTWISE_BOXED(bitwise_and.Scalar_Tensor);
363   BINARY_SCALAR_2(bitwise_or, Tensor, Scalar);
364   BINARY_POINTWISE2(bitwise_or_, Tensor);
365   POINTWISE_BOXED(bitwise_or_.Scalar);
366   POINTWISE_BOXED(bitwise_or.Scalar_Tensor);
367   BINARY_SCALAR_2(bitwise_xor, Tensor, Scalar);
368   BINARY_POINTWISE2(bitwise_xor_, Tensor);
369   POINTWISE_BOXED(bitwise_xor_.Scalar);
370   POINTWISE_BOXED(bitwise_xor.Scalar_Tensor);
371   BINARY_SCALAR_3(bitwise_left_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
372   POINTWISE_BOXED(bitwise_left_shift_.Tensor_Scalar);
373   POINTWISE_BOXED(bitwise_left_shift_.Tensor);
374   BINARY_SCALAR_3(bitwise_right_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
375   POINTWISE_BOXED(bitwise_right_shift_.Tensor_Scalar);
376   POINTWISE_BOXED(bitwise_right_shift_.Tensor);
377 
378   UNARY_POINTWISE(clamp);
379   POINTWISE_BOXED(clamp.Tensor);
380   BINARY_POINTWISE2(clamp_min, Tensor);
381   UNARY_POINTWISE(clamp_min);
382   POINTWISE_BOXED(clamp_min_);
383   BINARY_POINTWISE2(clamp_max, Tensor);
384   UNARY_POINTWISE(clamp_max);
385   POINTWISE_BOXED(clamp_max_);
386   BINARY_POINTWISE(complex);
387 
388   VARIADIC_BDIMS_BOXED(_euclidean_dist);
389   // Implementation note: _binary_pointwise_helper performs a dtype promotion if args are scalars,
390   // but cdist can't work with scalars, at least 2d tensors.
391   BINARY_POINTWISE(_cdist_forward);
392   VMAP_SUPPORT(_cdist_backward, cdist_backward_batch_rule);
393 
394   BINARY_SCALAR_2(copysign, Tensor, Scalar);
395   POINTWISE_BOXED(copysign_.Tensor);
396   POINTWISE_BOXED(copysign_.Scalar);
397   BINARY_SCALAR_2(div, Tensor, Scalar);
398   BINARY_SCALAR_2(div, Tensor_mode, Scalar_mode);
399 
400   BINARY_POINTWISE(floor_divide);
401   UNARY_POINTWISE2(floor_divide, Scalar);
402 
403   BINARY_POINTWISE(fmax);
404   BINARY_POINTWISE(fmin);
405   BINARY_SCALAR_2(fmod, Tensor, Scalar);
406   POINTWISE_BOXED(frexp.Tensor);
407   BINARY_POINTWISE(heaviside);
408   BINARY_POINTWISE(hypot);
409   BINARY_POINTWISE(gcd);
410   BINARY_POINTWISE(igamma);
411   BINARY_POINTWISE(igammac);
412   BINARY_POINTWISE(logaddexp);
413   BINARY_POINTWISE(logaddexp2);
414   POINTWISE_BOXED(lerp.Scalar);
415   POINTWISE_BOXED(lerp.Tensor);
416   BINARY_POINTWISE(lcm);
417   POINTWISE_BOXED(log_sigmoid_forward);
418   BINARY_POINTWISE(maximum);
419   BINARY_POINTWISE(minimum);
420 
421   BINARY_SCALAR_2(mul, Tensor, Scalar);
422   BINARY_POINTWISE(nextafter);
423   BINARY_SCALAR_3(pow, Tensor_Tensor, Tensor_Scalar, Scalar);
424   POINTWISE_BOXED2(pow_, Scalar);
425   BINARY_POINTWISE(polar);
426   POINTWISE_BOXED(polygamma);
427   BINARY_SCALAR_2(sub, Tensor, Scalar);
428   BINARY_SCALAR_3(remainder, Tensor, Scalar, Scalar_Tensor);
429   BINARY_POINTWISE(rrelu_with_noise);
430   BINARY_SCALAR_2(rsub, Tensor, Scalar);
431 
432   BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar);
433   BINARY_SCALAR_3_Tensor(special_zeta, other_scalar, self_scalar);
434 
435   VMAP_SUPPORT2(where, self, where_self_batch_rule);
436 
437   BINARY_SCALAR_3(xlogy, Tensor, Scalar_Other, Scalar_Self);
438 
439   POINTWISE_BOXED(elu_backward);
440   BINARY_POINTWISE(hardsigmoid_backward);
441   BINARY_POINTWISE(hardtanh_backward);
442   BINARY_POINTWISE(hardshrink_backward);
443   BINARY_POINTWISE(hardswish_backward);
444   BINARY_POINTWISE(_prelu_kernel);
445   VARIADIC_BDIMS_BOXED(_prelu_kernel_backward);
446   BINARY_POINTWISE(leaky_relu_backward);
447   BINARY_POINTWISE(logit_backward);
448   VMAP_SUPPORT(log_sigmoid_backward, log_sigmoid_backward_batch_rule);
449   VMAP_SUPPORT(gelu_backward, gelu_backward_batch_rule);
450   BINARY_POINTWISE(sigmoid_backward);
451   POINTWISE_BOXED(softplus_backward);
452   BINARY_POINTWISE(softshrink_backward);
453   BINARY_POINTWISE(tanh_backward);
454   BINARY_POINTWISE(threshold_backward);
455   BINARY_POINTWISE(silu_backward);
456 
457   using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
458   using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
459   using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const;
460   using TensorInplaceModeT = Tensor& (Tensor::*)(const Tensor&, std::optional<c10::string_view>) const;
461   using ScalarInplaceT = Tensor& (Tensor::*)(const Scalar&) const;
462   using CopyT = Tensor& (Tensor::*)(const Tensor&, bool) const;
463 
464   POINTWISE_BOXED(add_.Tensor); // just testing
465   POINTWISE_BOXED(atan2_);
466   POINTWISE_BOXED(gcd_);
467   POINTWISE_BOXED(lcm_);
468   VMAP_SUPPORT2(add_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarScalarInplaceT, &Tensor::add_, const Scalar&, const Scalar&>));
469   VMAP_SUPPORT2(sub_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorScalarInplaceT, &Tensor::sub_, const Scalar&>));
470   VMAP_SUPPORT2(sub_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarScalarInplaceT, &Tensor::sub_, const Scalar&, const Scalar&>));
471   VMAP_SUPPORT2(mul_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::mul_>));
472   VMAP_SUPPORT2(mul_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::mul_, const Scalar&>));
473   VMAP_SUPPORT2(div_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::div_>));
474   VMAP_SUPPORT2(div_, Tensor_mode, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceModeT, &Tensor::div_, std::optional<c10::string_view>>));
475   VMAP_SUPPORT2(div_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::div_, const Scalar&>));
476   VMAP_SUPPORT2(clamp_min_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_min_>));
477   VMAP_SUPPORT2(clamp_max_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_max_>));
478   VMAP_SUPPORT2(masked_fill_, Scalar, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorScalarInplaceT, &Tensor::masked_fill_, const Scalar&>));
479   VMAP_SUPPORT(copy_, SINGLE_ARG(binary_pointwise_inplace_batch_rule<CopyT, &Tensor::copy_, bool>));
480 
481 #define COMPARISON_POINTWISE(op) \
482   VMAP_SUPPORT2(op, Tensor, \
483       SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>)); \
484   UNARY_POINTWISE2(op, Scalar)
485 
486   COMPARISON_POINTWISE(eq);
487   COMPARISON_POINTWISE(gt);
488   COMPARISON_POINTWISE(ge);
489   COMPARISON_POINTWISE(le);
490   COMPARISON_POINTWISE(lt);
491   COMPARISON_POINTWISE(ne);
492 
493 #undef COMPARISON_POINTWISE
494 #undef BINARY_POINTWISE2
495 #undef BINARY_POINTWISE
496 #undef UNARY_POINTWISE2
497 #undef UNARY_POINTWISE
498 #undef UNARY_SCALAR_POINTWISE2
499 #undef BINARY_SCALAR_3
500 
501 #define LOGICAL_COMPARISON_POINTWISE(op) \
502   VMAP_SUPPORT(op, \
503       SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN(op)), &ATEN_FN(op)>)); \
504   VMAP_SUPPORT(op ## _, \
505       SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor:: op ## _ >));
506 
507   LOGICAL_COMPARISON_POINTWISE(logical_and);
508   LOGICAL_COMPARISON_POINTWISE(logical_or);
509   LOGICAL_COMPARISON_POINTWISE(logical_xor);
510 
511 #undef SINGLE_ARG
512 #undef LOGICAL_COMPARISON_POINTWISE
513   VMAP_SUPPORT(masked_select, masked_select_batch_rule);
514   VMAP_SUPPORT(masked_select_backward, masked_select_backward_batch_rule);
515 
516   VMAP_SUPPORT2(fill_, Tensor, fill__Tensor_batch_rule);
517 }
518 
519 } // namespace at::functorch
520