1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/Operators.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11
12 #include <utility>
13
14 namespace at::functorch {
15
16 template <typename F, F Func, typename... ExtraArgs>
_binary_pointwise_batch_rule(const Tensor & tensor,std::optional<int64_t> tensor_batch_dim,const Tensor & other,std::optional<int64_t> other_batch_dim,ExtraArgs...extra_args)17 std::tuple<Tensor, std::optional<int64_t>> _binary_pointwise_batch_rule(
18 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
19 const Tensor& other, std::optional<int64_t> other_batch_dim,
20 ExtraArgs... extra_args) {
21
22 auto tensor_other = _binary_pointwise_helper(
23 tensor, tensor_batch_dim, other, other_batch_dim);
24 auto tensor_ = std::get<0>(tensor_other);
25 auto other_ = std::get<1>(tensor_other);
26
27 auto result = Func(tensor_, other_, std::forward<ExtraArgs>(extra_args)...);
28 return std::make_tuple(result, 0);
29 }
30
31 template <typename A, A a, typename C>
32 struct BinaryPointwiseBatchRuleHelper;
33
34 template <typename F, F Func, typename T1, typename T2, typename... T>
35 struct BinaryPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
applyat::functorch::BinaryPointwiseBatchRuleHelper36 static std::tuple<Tensor, std::optional<int64_t>> apply(
37 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
38 const Tensor& other, std::optional<int64_t> other_batch_dim,
39 T... extra_args) {
40 return _binary_pointwise_batch_rule<F, Func, T...>(
41 tensor, tensor_batch_dim, other, other_batch_dim,
42 std::forward<T>(extra_args)...);
43 }
44 };
45
46 #define BINARY_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\
47 BinaryPointwiseBatchRuleHelper<\
48 decltype(&fn),\
49 &fn,\
50 c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
51
52 template <typename A, A a, typename C>
53 struct BinaryRandomPointwiseBatchRuleHelper;
54
55 template <typename F, F Func, typename T1, typename T2, typename... T>
56 struct BinaryRandomPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
applyat::functorch::BinaryRandomPointwiseBatchRuleHelper57 static Tensor apply(const Tensor& tensor, const Tensor& other, T... extra_args) {
58 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
59 auto maybe_layer = maybeCurrentDynamicLayer();
60 auto cur_level = maybe_layer->layerId();
61 RandomnessType randomness = maybe_layer->randomness();
62
63 auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(tensor, cur_level);
64
65 auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level);
66
67 check_randomness(randomness, (tensor_bdim || other_bdim));
68 if (randomness == RandomnessType::Different && !tensor_bdim && !other_bdim) {
69 auto shape = tensor_value.sizes();
70 VmapSymDimVector shapeVec(1, maybe_layer->batchSize());
71 shapeVec.reserve(shape.size() + 1);
72 shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
73
74 // not taken care of with binary batch rule, which assumes at least one input is batched
75 tensor_value = tensor_value.expand_symint(shapeVec);
76 tensor_bdim = 0;
77 } else if (randomness == RandomnessType::Same && !tensor_bdim && !other_bdim) {
78
79 // avoids unnecessary checks and batch rule assuming output is batched
80 return Func(tensor_value, other_value, std::forward<T>(extra_args)...);
81 }
82 auto res = _binary_pointwise_batch_rule<F, Func, T...>(
83 tensor_value, tensor_bdim, other_value, other_bdim,
84 std::forward<T>(extra_args)...);
85 return makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
86 }
87 };
88
89 #define BINARY_RANDOM_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\
90 BinaryRandomPointwiseBatchRuleHelper<\
91 decltype(&fn),\
92 &fn,\
93 c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
94
95 template <typename M, M Meth, typename... ExtraArgs>
binary_pointwise_inplace_batch_rule(Tensor & tensor,std::optional<int64_t> tensor_batch_dim,const Tensor & other,std::optional<int64_t> other_batch_dim,ExtraArgs...extra_args)96 void binary_pointwise_inplace_batch_rule(
97 Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
98 const Tensor& other, std::optional<int64_t> other_batch_dim,
99 ExtraArgs... extra_args) {
100 if (!tensor_batch_dim && other_batch_dim) {
101 vmapIncompatibleInplaceError("inplace arithmetic");
102 }
103
104 // compute max logical rank
105 auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
106 auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
107 auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
108
109 auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
110 auto other_ = moveBatchDimToFront(other, other_batch_dim);
111
112 // If the dimensions aren't aligned, we need to line them up.
113 // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
114 // Note that only tensors that have a batch dim need to be modified.
115 // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
116 tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
117 other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
118
119 (tensor_.*Meth)(other_, std::forward<ExtraArgs>(extra_args)...);
120 }
121
122 template <typename F, F Func>
comparison_pointwise_batch_rule(const Tensor & tensor,std::optional<int64_t> tensor_batch_dim,const Tensor & other,std::optional<int64_t> other_batch_dim)123 std::tuple<Tensor, std::optional<int64_t>> comparison_pointwise_batch_rule(
124 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
125 const Tensor& other, std::optional<int64_t> other_batch_dim) {
126 // compute max logical rank
127 auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
128 auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
129 auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
130
131 auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
132 auto other_ = moveBatchDimToFront(other, other_batch_dim);
133
134 // If the dimensions aren't aligned, we need to line them up.
135 // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
136 // Note that only tensors that have a batch dim need to be modified.
137 // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
138 tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
139 other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
140
141 auto result = Func(tensor_, other_);
142 return std::make_tuple( std::move(result), 0 );
143 }
144
where_self_batch_rule(const Tensor & condition,std::optional<int64_t> condition_bdim,const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)145 static std::tuple<Tensor, std::optional<int64_t>> where_self_batch_rule(
146 const Tensor& condition, std::optional<int64_t> condition_bdim,
147 const Tensor& self, std::optional<int64_t> self_bdim, const Tensor& other, std::optional<int64_t> other_bdim) {
148 auto condition_logical_rank = rankWithoutBatchDim(condition, condition_bdim);
149 auto tensor_logical_rank = rankWithoutBatchDim(self, self_bdim);
150 auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
151 auto max_logical_rank = std::max({tensor_logical_rank, other_logical_rank, condition_logical_rank});
152
153 auto condition_ = moveBatchDimToFront(condition, condition_bdim);
154 auto self_ = moveBatchDimToFront(self, self_bdim);
155 auto other_ = moveBatchDimToFront(other, other_bdim);
156
157 condition_ = maybePadToLogicalRank(condition_, condition_bdim, max_logical_rank);
158 self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank);
159 other_ = maybePadToLogicalRank(other_, other_bdim, max_logical_rank);
160 return std::make_tuple(at::where(condition_, self_, other_), 0);
161 }
162
gelu_backward_batch_rule(const Tensor & grad_out,std::optional<int64_t> grad_out_bdim,const Tensor & input,std::optional<int64_t> input_bdim,c10::string_view approximate)163 static std::tuple<Tensor, std::optional<int64_t>> gelu_backward_batch_rule(
164 const Tensor& grad_out, std::optional<int64_t> grad_out_bdim, const Tensor& input, std::optional<int64_t> input_bdim,
165 c10::string_view approximate) {
166
167 // repeat the preprocessing from _binary_pointwise_batch_rule
168 const auto tensor_other = _binary_pointwise_helper(grad_out, grad_out_bdim, input, input_bdim);
169 auto grad_out_ = std::get<0>(tensor_other);
170 auto input_ = std::get<1>(tensor_other);
171
172 // gelu_backward doesn't broadcast well so we need to insist all inputs have a bdim
173 const auto batch_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
174 grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), batch_size);
175 input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
176
177 return std::make_tuple(at::gelu_backward(grad_out_, input_, approximate), 0);
178 }
179
masked_select_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & mask,std::optional<int64_t> mask_bdim)180 static std::tuple<Tensor, std::optional<int64_t>> masked_select_batch_rule(
181 const Tensor& self, std::optional<int64_t> self_bdim,
182 const Tensor& mask, std::optional<int64_t> mask_bdim) {
183 TORCH_CHECK(!mask_bdim.has_value(),
184 "vmap: Attempted to vmap over `mask` in torch.masked_select(self, mask) ",
185 "We cannot support this because for each batch this would return a ",
186 "differently shaped Tensor. "
187 "Please voice your support in https://github.com/pytorch/functorch/issues/256");
188 auto self_ = moveBatchDimToFront(self, self_bdim);
189 const auto batch_size = self_.size(0);
190 const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
191 const auto max_logical_rank = std::max(self_logical_rank, mask.dim());
192 self_ = maybePadToLogicalRank(self_, 0, max_logical_rank);
193
194 // masked_select returns a 1D tensor, so we have to reshape it into 2D
195 const auto result = at::masked_select(self_, mask).view({ batch_size, -1 });
196 return std::make_tuple(result, 0);
197 }
198
masked_select_backward_batch_rule(const Tensor & grad,std::optional<int64_t> grad_bdim,const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & mask,std::optional<int64_t> mask_bdim)199 static std::tuple<Tensor, std::optional<int64_t>> masked_select_backward_batch_rule(
200 const Tensor& grad, std::optional<int64_t> grad_bdim,
201 const Tensor& self, std::optional<int64_t> self_bdim,
202 const Tensor& mask, std::optional<int64_t> mask_bdim) {
203 TORCH_CHECK(!mask_bdim.has_value(),
204 "vmap: Attempted to vmap over `mask` in torch.masked_select_backward(grad, self, mask) ",
205 "We cannot support this because for each batch this would return a ",
206 "differently shaped Tensor. "
207 "Please voice your support in https://github.com/pytorch/functorch/issues/256");
208 auto self_ = moveBatchDimToFront(self, self_bdim);
209 auto grad_ = moveBatchDimToFront(grad, grad_bdim);
210
211 const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
212 const auto max_logical_rank = std::max(self_logical_rank, mask.dim());
213
214 self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank);
215
216 const auto batch_size = get_bdim_size2(grad, grad_bdim, self, self_bdim);
217 self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
218 grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), batch_size);
219
220 const auto result = at::masked_select_backward(grad_, self_.contiguous(), mask);
221 return std::make_tuple(result, 0);
222 }
223
cdist_backward_batch_rule(const Tensor & grad,std::optional<int64_t> grad_bdim,const Tensor & x1,std::optional<int64_t> x1_bdim,const Tensor & x2,std::optional<int64_t> x2_bdim,const double p,const Tensor & cdist,std::optional<int64_t> cdist_bdim)224 static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule(
225 const Tensor& grad, std::optional<int64_t> grad_bdim,
226 const Tensor& x1, std::optional<int64_t> x1_bdim,
227 const Tensor& x2, std::optional<int64_t> x2_bdim,
228 const double p,
229 const Tensor& cdist, std::optional<int64_t> cdist_bdim) {
230
231 auto x1_ = x1;
232 if (cdist_bdim && !x1_bdim) {
233 // We need to make sure that x1 has batch dim if cdist has one
234 // otherwise, we get
235 // RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
236 // but expected shape compatible with [4, 5]
237 auto bs = cdist.size(*cdist_bdim);
238 x1_ = ensure_has_bdim(x1, false, bs);
239 x1_ = x1_.contiguous();
240 x1_bdim = 0;
241 }
242
243 // We need to apply the same preprocessing on x1 and x2 as in the forward pass
244 // _binary_pointwise_batch_rule
245 auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim);
246 x1_ = std::get<0>(x12);
247 auto x2_ = std::get<1>(x12);
248
249 auto grad_ = moveBatchDimToFront(grad, grad_bdim);
250 if ((x1_bdim || x2_bdim) && !grad_bdim) {
251 // We need to make sure that grad has batch dim if x1 or x2 have one
252 // Probably, there is an assumption on the strides.
253 // Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
254 auto bs = get_bdim_size2(x1_, 0, x2_, 0);
255 grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs);
256 grad_ = grad_.contiguous();
257 }
258
259 auto out = at::_cdist_backward(grad_, x1_, x2_, p, cdist);
260
261 std::optional<int64_t> out_bdim = std::nullopt;
262 if (x1_bdim || x2_bdim) {
263 out_bdim = 0;
264 }
265
266 return std::make_tuple(out, out_bdim);
267 }
268
fill__Tensor_batch_rule(Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)269 static void fill__Tensor_batch_rule(
270 Tensor& self,
271 std::optional<int64_t> self_bdim,
272 const Tensor& other,
273 std::optional<int64_t> other_bdim) {
274 if (!other_bdim.has_value()) {
275 // Optimization: fill_ is faster than the other path which does
276 // reshaping + copy_
277 self.fill_(other);
278 return;
279 }
280 if (!self_bdim) {
281 vmapIncompatibleInplaceError("fill_");
282 }
283 auto self_and_other = _binary_pointwise_helper(
284 self, self_bdim, other, other_bdim, /*do_type_promotion*/false);
285 std::get<0>(self_and_other).copy_(std::get<1>(self_and_other));
286 }
287
log_sigmoid_backward_batch_rule(Tensor & grad,std::optional<int64_t> grad_bdim,Tensor & self,std::optional<int64_t> self_bdim,Tensor & buffer,std::optional<int64_t> buffer_bdim)288 static std::tuple<Tensor, std::optional<int64_t>> log_sigmoid_backward_batch_rule(
289 Tensor& grad, std::optional<int64_t> grad_bdim,
290 Tensor& self, std::optional<int64_t> self_bdim,
291 Tensor& buffer, std::optional<int64_t> buffer_bdim) {
292 // NB: This emulates handle_pointwise_ops except we ignore the last argument, buffer
293 // when any of the inputs are on cuda.
294 // We do this because on cuda, buffer is a dummy tensor always of logical rank 1 and
295 // it becomes an issue when the rest of the inputs are scalar
296 int64_t out_logical_rank = std::max(rankWithoutBatchDim(grad, grad_bdim), rankWithoutBatchDim(self, self_bdim));
297 if (!grad.is_cuda() && !self.is_cuda() && !buffer.is_cuda()) {
298 out_logical_rank = std::max(out_logical_rank, rankWithoutBatchDim(buffer, buffer_bdim));
299 }
300 Tensor out_grad = maybePadToLogicalRank(moveBatchDimToFront(grad, grad_bdim), grad_bdim, out_logical_rank);
301 Tensor out_self = maybePadToLogicalRank(moveBatchDimToFront(self, self_bdim), self_bdim, out_logical_rank);
302 Tensor out_buffer = maybePadToLogicalRank(moveBatchDimToFront(buffer, buffer_bdim), buffer_bdim, out_logical_rank);
303 return std::make_tuple(at::log_sigmoid_backward(out_grad, out_self, out_buffer), 0);
304 }
305
binomial_wrapper(const Tensor & count,const Tensor & prob,std::optional<Generator> gen)306 static Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, std::optional<Generator> gen) {
307 return at::binomial(count, prob.contiguous(), std::move(gen)); // Bug in PyTorch, prob shouldn't need to be contiguous
308 }
309
TORCH_LIBRARY_IMPL(aten,FuncTorchVmapMode,m)310 TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
311 #define BINARY_RANDOM_POINTWISE(op) \
312 m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
313 #define BINARY_RANDOM_POINTWISE2(op, overload) \
314 m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
315
316 BINARY_RANDOM_POINTWISE2(normal, Tensor_Tensor);
317 m.impl("binomial", BINARY_RANDOM_POINTWISE_BATCH_RULE(at::functorch::binomial_wrapper));
318 }
319
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)320 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
321 #define BINARY_POINTWISE2(op, overload) \
322 VMAP_SUPPORT2(op, overload, BINARY_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
323 #define BINARY_POINTWISE(op) \
324 VMAP_SUPPORT(op, BINARY_POINTWISE_BATCH_RULE(ATEN_FN(op)));
325 #define UNARY_POINTWISE2(op, overload) \
326 VMAP_SUPPORT2(op, overload, BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
327 #define UNARY_POINTWISE(op) \
328 VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
329 #define UNARY_SCALAR_POINTWISE2(op, overload) \
330 VMAP_SUPPORT(op, overload, SCALAR_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
331
332 #define BINARY_SCALAR_2(op, tensor_tensor, tensor_scalar) \
333 BINARY_POINTWISE2(op, tensor_tensor);\
334 UNARY_POINTWISE2(op, tensor_scalar);
335
336 // For all 3 combinations of Tensor x Tensor, Tensor x Scalar, Scalar x Tensor
337 #define BINARY_SCALAR_3(op, tensor_tensor, tensor_scalar, scalar_tensor) \
338 BINARY_POINTWISE2(op, tensor_tensor);\
339 UNARY_POINTWISE2(op, tensor_scalar);\
340 POINTWISE_BOXED(op.scalar_tensor);
341
342 #define BINARY_SCALAR_3_Tensor(op, tensor_scalar, scalar_tensor) \
343 BINARY_POINTWISE(op);\
344 UNARY_POINTWISE2(op, tensor_scalar);\
345 POINTWISE_BOXED(op.scalar_tensor);
346
347 // Batching rule registrations start
348 POINTWISE_BOXED(__ilshift__.Tensor);
349 POINTWISE_BOXED(__ilshift__.Scalar);
350 POINTWISE_BOXED(__irshift__.Tensor)
351 POINTWISE_BOXED(__irshift__.Scalar)
352 BINARY_SCALAR_2(__lshift__, Tensor, Scalar);
353 BINARY_SCALAR_2(__rshift__, Tensor, Scalar);
354
355 BINARY_SCALAR_2(add, Tensor, Scalar);
356 POINTWISE_BOXED(addcdiv);
357 POINTWISE_BOXED(addcmul);
358 BINARY_POINTWISE(atan2);
359 BINARY_SCALAR_2(bitwise_and, Tensor, Scalar);
360 BINARY_POINTWISE2(bitwise_and_, Tensor);
361 POINTWISE_BOXED(bitwise_and_.Scalar);
362 POINTWISE_BOXED(bitwise_and.Scalar_Tensor);
363 BINARY_SCALAR_2(bitwise_or, Tensor, Scalar);
364 BINARY_POINTWISE2(bitwise_or_, Tensor);
365 POINTWISE_BOXED(bitwise_or_.Scalar);
366 POINTWISE_BOXED(bitwise_or.Scalar_Tensor);
367 BINARY_SCALAR_2(bitwise_xor, Tensor, Scalar);
368 BINARY_POINTWISE2(bitwise_xor_, Tensor);
369 POINTWISE_BOXED(bitwise_xor_.Scalar);
370 POINTWISE_BOXED(bitwise_xor.Scalar_Tensor);
371 BINARY_SCALAR_3(bitwise_left_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
372 POINTWISE_BOXED(bitwise_left_shift_.Tensor_Scalar);
373 POINTWISE_BOXED(bitwise_left_shift_.Tensor);
374 BINARY_SCALAR_3(bitwise_right_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
375 POINTWISE_BOXED(bitwise_right_shift_.Tensor_Scalar);
376 POINTWISE_BOXED(bitwise_right_shift_.Tensor);
377
378 UNARY_POINTWISE(clamp);
379 POINTWISE_BOXED(clamp.Tensor);
380 BINARY_POINTWISE2(clamp_min, Tensor);
381 UNARY_POINTWISE(clamp_min);
382 POINTWISE_BOXED(clamp_min_);
383 BINARY_POINTWISE2(clamp_max, Tensor);
384 UNARY_POINTWISE(clamp_max);
385 POINTWISE_BOXED(clamp_max_);
386 BINARY_POINTWISE(complex);
387
388 VARIADIC_BDIMS_BOXED(_euclidean_dist);
389 // Implementation note: _binary_pointwise_helper performs a dtype promotion if args are scalars,
390 // but cdist can't work with scalars, at least 2d tensors.
391 BINARY_POINTWISE(_cdist_forward);
392 VMAP_SUPPORT(_cdist_backward, cdist_backward_batch_rule);
393
394 BINARY_SCALAR_2(copysign, Tensor, Scalar);
395 POINTWISE_BOXED(copysign_.Tensor);
396 POINTWISE_BOXED(copysign_.Scalar);
397 BINARY_SCALAR_2(div, Tensor, Scalar);
398 BINARY_SCALAR_2(div, Tensor_mode, Scalar_mode);
399
400 BINARY_POINTWISE(floor_divide);
401 UNARY_POINTWISE2(floor_divide, Scalar);
402
403 BINARY_POINTWISE(fmax);
404 BINARY_POINTWISE(fmin);
405 BINARY_SCALAR_2(fmod, Tensor, Scalar);
406 POINTWISE_BOXED(frexp.Tensor);
407 BINARY_POINTWISE(heaviside);
408 BINARY_POINTWISE(hypot);
409 BINARY_POINTWISE(gcd);
410 BINARY_POINTWISE(igamma);
411 BINARY_POINTWISE(igammac);
412 BINARY_POINTWISE(logaddexp);
413 BINARY_POINTWISE(logaddexp2);
414 POINTWISE_BOXED(lerp.Scalar);
415 POINTWISE_BOXED(lerp.Tensor);
416 BINARY_POINTWISE(lcm);
417 POINTWISE_BOXED(log_sigmoid_forward);
418 BINARY_POINTWISE(maximum);
419 BINARY_POINTWISE(minimum);
420
421 BINARY_SCALAR_2(mul, Tensor, Scalar);
422 BINARY_POINTWISE(nextafter);
423 BINARY_SCALAR_3(pow, Tensor_Tensor, Tensor_Scalar, Scalar);
424 POINTWISE_BOXED2(pow_, Scalar);
425 BINARY_POINTWISE(polar);
426 POINTWISE_BOXED(polygamma);
427 BINARY_SCALAR_2(sub, Tensor, Scalar);
428 BINARY_SCALAR_3(remainder, Tensor, Scalar, Scalar_Tensor);
429 BINARY_POINTWISE(rrelu_with_noise);
430 BINARY_SCALAR_2(rsub, Tensor, Scalar);
431
432 BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar);
433 BINARY_SCALAR_3_Tensor(special_zeta, other_scalar, self_scalar);
434
435 VMAP_SUPPORT2(where, self, where_self_batch_rule);
436
437 BINARY_SCALAR_3(xlogy, Tensor, Scalar_Other, Scalar_Self);
438
439 POINTWISE_BOXED(elu_backward);
440 BINARY_POINTWISE(hardsigmoid_backward);
441 BINARY_POINTWISE(hardtanh_backward);
442 BINARY_POINTWISE(hardshrink_backward);
443 BINARY_POINTWISE(hardswish_backward);
444 BINARY_POINTWISE(_prelu_kernel);
445 VARIADIC_BDIMS_BOXED(_prelu_kernel_backward);
446 BINARY_POINTWISE(leaky_relu_backward);
447 BINARY_POINTWISE(logit_backward);
448 VMAP_SUPPORT(log_sigmoid_backward, log_sigmoid_backward_batch_rule);
449 VMAP_SUPPORT(gelu_backward, gelu_backward_batch_rule);
450 BINARY_POINTWISE(sigmoid_backward);
451 POINTWISE_BOXED(softplus_backward);
452 BINARY_POINTWISE(softshrink_backward);
453 BINARY_POINTWISE(tanh_backward);
454 BINARY_POINTWISE(threshold_backward);
455 BINARY_POINTWISE(silu_backward);
456
457 using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
458 using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
459 using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const;
460 using TensorInplaceModeT = Tensor& (Tensor::*)(const Tensor&, std::optional<c10::string_view>) const;
461 using ScalarInplaceT = Tensor& (Tensor::*)(const Scalar&) const;
462 using CopyT = Tensor& (Tensor::*)(const Tensor&, bool) const;
463
464 POINTWISE_BOXED(add_.Tensor); // just testing
465 POINTWISE_BOXED(atan2_);
466 POINTWISE_BOXED(gcd_);
467 POINTWISE_BOXED(lcm_);
468 VMAP_SUPPORT2(add_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarScalarInplaceT, &Tensor::add_, const Scalar&, const Scalar&>));
469 VMAP_SUPPORT2(sub_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorScalarInplaceT, &Tensor::sub_, const Scalar&>));
470 VMAP_SUPPORT2(sub_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarScalarInplaceT, &Tensor::sub_, const Scalar&, const Scalar&>));
471 VMAP_SUPPORT2(mul_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::mul_>));
472 VMAP_SUPPORT2(mul_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::mul_, const Scalar&>));
473 VMAP_SUPPORT2(div_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::div_>));
474 VMAP_SUPPORT2(div_, Tensor_mode, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceModeT, &Tensor::div_, std::optional<c10::string_view>>));
475 VMAP_SUPPORT2(div_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::div_, const Scalar&>));
476 VMAP_SUPPORT2(clamp_min_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_min_>));
477 VMAP_SUPPORT2(clamp_max_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_max_>));
478 VMAP_SUPPORT2(masked_fill_, Scalar, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorScalarInplaceT, &Tensor::masked_fill_, const Scalar&>));
479 VMAP_SUPPORT(copy_, SINGLE_ARG(binary_pointwise_inplace_batch_rule<CopyT, &Tensor::copy_, bool>));
480
481 #define COMPARISON_POINTWISE(op) \
482 VMAP_SUPPORT2(op, Tensor, \
483 SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>)); \
484 UNARY_POINTWISE2(op, Scalar)
485
486 COMPARISON_POINTWISE(eq);
487 COMPARISON_POINTWISE(gt);
488 COMPARISON_POINTWISE(ge);
489 COMPARISON_POINTWISE(le);
490 COMPARISON_POINTWISE(lt);
491 COMPARISON_POINTWISE(ne);
492
493 #undef COMPARISON_POINTWISE
494 #undef BINARY_POINTWISE2
495 #undef BINARY_POINTWISE
496 #undef UNARY_POINTWISE2
497 #undef UNARY_POINTWISE
498 #undef UNARY_SCALAR_POINTWISE2
499 #undef BINARY_SCALAR_3
500
501 #define LOGICAL_COMPARISON_POINTWISE(op) \
502 VMAP_SUPPORT(op, \
503 SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN(op)), &ATEN_FN(op)>)); \
504 VMAP_SUPPORT(op ## _, \
505 SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor:: op ## _ >));
506
507 LOGICAL_COMPARISON_POINTWISE(logical_and);
508 LOGICAL_COMPARISON_POINTWISE(logical_or);
509 LOGICAL_COMPARISON_POINTWISE(logical_xor);
510
511 #undef SINGLE_ARG
512 #undef LOGICAL_COMPARISON_POINTWISE
513 VMAP_SUPPORT(masked_select, masked_select_batch_rule);
514 VMAP_SUPPORT(masked_select_backward, masked_select_backward_batch_rule);
515
516 VMAP_SUPPORT2(fill_, Tensor, fill__Tensor_batch_rule);
517 }
518
519 } // namespace at::functorch
520