xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesRandomness.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/ATen.h>
8 #include <ATen/functorch/DynamicLayer.h>
9 #include <ATen/functorch/BatchRulesHelper.h>
10 
11 #include <utility>
12 
13 // This file contains batching rules for random operations. These are different
14 // from our regular batching rules: regular batching rules get registered to the
15 // FuncTorchBatched key, but batching rules for random operations get
16 // registered to FuncTorchVmapMode. This is because we need to interpose on
17 // random operations even if they're not on a BatchedTensor.
18 
19 namespace at::functorch {
20 
21 template <typename F, F Func, typename... ExtraArgs>
random_batching_rule(SymIntArrayRef shape,ExtraArgs...extra_args)22 Tensor random_batching_rule(SymIntArrayRef shape, ExtraArgs... extra_args) {
23   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
24   auto maybe_layer = maybeCurrentDynamicLayer();
25   c10::SmallVector<SymInt> shapeVec(1, maybe_layer->batchSize());
26   shapeVec.reserve(shape.size() + 1);
27   shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
28   RandomnessType randomness = maybe_layer->randomness();
29   check_randomness(randomness);
30   if (randomness == RandomnessType::Different) {
31     return makeBatched(Func(shapeVec, std::forward<ExtraArgs>(extra_args)...), 0, maybe_layer->layerId());
32   } else {
33     return Func(shape, std::forward<ExtraArgs>(extra_args)...);
34   }
35 }
36 
37 template <typename F, F Func, typename... ExtraArgs>
random_inplace_batching_rule(Tensor & self,ExtraArgs...extra_args)38 Tensor& random_inplace_batching_rule(Tensor& self, ExtraArgs... extra_args) {
39   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
40   auto maybe_layer = maybeCurrentDynamicLayer();
41   const auto cur_level = maybe_layer->layerId();
42   auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
43   self_value = moveBatchDimToFront(self_value, self_bdim);
44   RandomnessType randomness = maybe_layer->randomness();
45   check_randomness(randomness);
46   TORCH_CHECK(
47     !(randomness == RandomnessType::Different && !self_bdim),
48     "vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. ",
49     "If this is necessary for your usage, please file an issue with functorch.");
50   if (randomness == RandomnessType::Same && self_bdim) {
51     auto intermediate = empty(self.sizes(), self.options());
52     Func(intermediate, std::forward<ExtraArgs>(extra_args)...);
53     self.copy_(intermediate); // batching should make this just work out...
54     return self;
55   } else {
56     Func(self_value, std::forward<ExtraArgs>(extra_args)...);
57     return self;
58   }
59 }
60 
bernoulli_inplace_Tensor_batching_rule(Tensor & self,const Tensor & p_,std::optional<Generator> gen)61 static Tensor& bernoulli_inplace_Tensor_batching_rule(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
62   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
63   auto maybe_layer = maybeCurrentDynamicLayer();
64   auto cur_level = maybe_layer->layerId();
65   RandomnessType randomness = maybe_layer->randomness();
66 
67   auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
68 
69   auto [other_value, other_bdim] = unwrapTensorAtLevel(p_, cur_level);
70 
71   check_randomness(randomness, other_bdim.has_value());
72 
73   if (!self_bdim && other_bdim) {
74     vmapIncompatibleInplaceError("inplace bernoulli");
75   }
76 
77   // compute max logical rank
78   auto self_logical_rank = rankWithoutBatchDim(self_value, self_bdim);
79   auto other_logical_rank = rankWithoutBatchDim(other_value, other_bdim);
80   auto max_logical_rank = std::max(self_logical_rank, other_logical_rank);
81 
82   auto self_ = moveBatchDimToFront(self_value, self_bdim);
83   auto other_ = moveBatchDimToFront(other_value, other_bdim);
84 
85   // If the dimensions aren't aligned, we need to line them up.
86   // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
87   // Note that only tensors that have a batch dim need to be modified.
88   // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
89   self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank);
90   other_ = maybePadToLogicalRank(other_, other_bdim, max_logical_rank);
91   TORCH_CHECK(
92     !(randomness == RandomnessType::Different && !self_bdim),
93     "vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. ",
94     "If this is necessary for your usage, please file an issue with functorch.");
95   if (randomness == RandomnessType::Same && self_bdim) {
96     auto intermediate = empty(self.sizes(), self.options());
97     intermediate.bernoulli_(other_, std::move(gen));
98     self.copy_(intermediate); // batching should make this just work out...
99     return self;
100   } else {
101     self_.bernoulli_(other_, std::move(gen));
102     return self;
103   }
104 }
105 
106 template <typename F, F Func, typename... ExtraArgs>
randperm_batching_rule(int64_t n,ExtraArgs...extra_args)107 Tensor randperm_batching_rule(int64_t n, ExtraArgs... extra_args) {
108   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
109   auto maybe_layer = maybeCurrentDynamicLayer();
110   auto const batch_size = maybe_layer->batchSize();
111   RandomnessType randomness = maybe_layer->randomness();
112   check_randomness(randomness);
113   if (randomness == RandomnessType::Different) {
114     std::vector<at::Tensor> stackedList(batch_size.guard_int(__FILE__, __LINE__));
115     for (int64_t idx = 0; idx < batch_size; ++idx) {
116       // since this is done in a loop, need to pass by reference for generator to update
117       stackedList[idx] = Func(n, extra_args...);
118     }
119     return makeBatched(at::stack(stackedList), 0, maybe_layer->layerId());
120   } else {
121     return Func(n, std::forward<ExtraArgs>(extra_args)...);
122   }
123 }
124 
125 template <typename F, F Func, typename... ExtraArgs>
unary_pointwise_random_batch_rule(const Tensor & tensor,ExtraArgs...extra_args)126 Tensor unary_pointwise_random_batch_rule(const Tensor& tensor, ExtraArgs... extra_args) {
127   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
128   auto maybe_layer = maybeCurrentDynamicLayer();
129   const auto cur_level = maybe_layer->layerId();
130 
131   auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(tensor, cur_level);
132   tensor_value = moveBatchDimToFront(tensor_value, tensor_bdim);
133 
134   RandomnessType randomness = maybe_layer->randomness();
135   check_randomness(randomness, tensor_bdim.has_value());
136   auto shape = tensor_value.sizes();
137   VmapSymDimVector shapeVec(1, maybe_layer->batchSize());
138   shapeVec.reserve(shape.size() + 1);
139   shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
140 
141   if (randomness == RandomnessType::Different && !tensor_bdim) {
142     tensor_value = tensor_value.expand_symint(shapeVec);
143   }
144   auto out = Func(tensor_value, std::forward<ExtraArgs>(extra_args)...);
145   if (randomness == RandomnessType::Same && !tensor_bdim) {
146     return out;
147   }
148   return makeBatched(out, 0, cur_level);
149 }
150 
151 template<typename F, F Func, typename... ExtraArgs>
tensor_like_random_batch_rule(const Tensor & self,ExtraArgs...extra_args)152 Tensor tensor_like_random_batch_rule(const Tensor& self, ExtraArgs... extra_args) {
153   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
154   auto maybe_layer = maybeCurrentDynamicLayer();
155   const auto cur_level = maybe_layer->layerId();
156   RandomnessType randomness = maybe_layer->randomness();
157   check_randomness(randomness);
158 
159   auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(self, cur_level);
160   tensor_value = moveBatchDimToFront(tensor_value, tensor_bdim);
161 
162   if (randomness == RandomnessType::Same && tensor_bdim) {
163     tensor_value = tensor_value[0];
164   } else if (randomness == RandomnessType::Different && !tensor_bdim) {
165     auto shape = tensor_value.sizes();
166     VmapSymDimVector shapeVec(1, maybe_layer->batchSize());
167     shapeVec.reserve(shape.size() + 1);
168     shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
169     tensor_value = tensor_value.expand_symint(shapeVec);
170   }
171 
172   auto res = Func(tensor_value, std::forward<ExtraArgs>(extra_args)...);
173   return (randomness == RandomnessType::Same) ? res : makeBatched(res, 0, cur_level);
174 }
175 
native_dropout_batching_rule(const Tensor & tensor,double p,std::optional<bool> train)176 static std::tuple<Tensor,Tensor> native_dropout_batching_rule(const Tensor& tensor, double p, std::optional<bool> train) {
177   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
178   auto maybe_layer = maybeCurrentDynamicLayer();
179   const auto cur_level = maybe_layer->layerId();
180   RandomnessType randomness = maybe_layer->randomness();
181 
182   auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(tensor, cur_level);
183   tensor_value = moveBatchDimToFront(tensor_value, tensor_bdim);
184 
185   if (!train.has_value() || *train) {
186     check_randomness(randomness); // if we are in eval mode, we don't use about randomness
187   }
188 
189   if ((train.has_value() && !*train) ||
190       randomness == RandomnessType::Different) {
191     if (!tensor_bdim) {
192       // if tensor is unbatched, add batch dim before
193       // calling dropout.
194       auto shape = tensor_value.sizes();
195       VmapSymDimVector shapeVec(1, maybe_layer->batchSize());
196       shapeVec.reserve(shape.size() + 1);
197       shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
198       tensor_value = tensor_value.expand_symint(shapeVec);
199     }
200     auto [output, mask] = at::native_dropout(tensor_value, p, train);
201     return std::make_tuple(
202         makeBatched(output, 0, cur_level),
203         makeBatched(mask, 0, cur_level));
204   }
205 
206   // repeated code from the CPU kernel since the CUDA one doesn't call bernoulli_ explicitly
207   double p1m = 1. - p;
208   // Check for probability of zero to avoid divide by zero and NaN results
209   double scale = p1m == 0 ? 0. : 1. / p1m;
210   Tensor mask = at::empty_like(tensor, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
211   mask.bernoulli_(p1m);
212   const auto output = tensor.mul(mask).mul_(scale);
213   return std::make_tuple(output, mask);
214 }
215 
multinomial_batching_rule(const Tensor & self,const int64_t num_samples,const bool replacement,const std::optional<Generator> generator)216 static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_samples, const bool replacement, const std::optional<Generator> generator) {
217   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
218   auto maybe_layer = maybeCurrentDynamicLayer();
219   const auto cur_level = maybe_layer->layerId();
220 
221   auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
222   self_value = moveBatchDimToFront(self_value, self_bdim);
223 
224   RandomnessType randomness = maybe_layer->randomness();
225   check_randomness(randomness, self_bdim.has_value());
226 
227   if (randomness == RandomnessType::Different) {
228     // 1D cases: S -> BS -> multinomial(BS)
229     //           BS -> multinomial(BS)
230     //
231     // 2D cases: MS -> BMS -> (BM)S -> multinomial((BM)S) -> (BM)S -> BMS
232     //           BMS -> (BM)S -> multinomial((BM)S) -> (BM)S -> BMS
233     const auto is_2D_case = rankWithoutBatchDim(self_value, self_bdim) == 2;
234     if (!self_bdim.has_value()) {
235       self_value = ensure_has_bdim(self_value, self_bdim.has_value(), maybe_layer->batchSize());
236     }
237     if (is_2D_case) {
238       self_value = reshape_dim_into(0, 0, self_value);
239     }
240     auto out = multinomial(self_value, num_samples, replacement, generator);
241     if (is_2D_case) {
242       out = reshape_dim_outof_symint(0, maybe_layer->batchSize(), out);
243     }
244     return makeBatched(out, 0, cur_level);;
245   }
246 
247   TORCH_INTERNAL_ASSERT(randomness == RandomnessType::Same); // check_randomness eliminates error randomness
248   TORCH_INTERNAL_ASSERT(!self_bdim.has_value()); // check_randomness eliminates same randomness with batched input
249   // Must be same randomness with unbatched input
250   // 1D case: S -> multinomial(S) -> S
251   // 2D case: MS -> multinomial(MS) -> MS
252   return multinomial(self_value, num_samples, replacement, generator);
253 }
254 
255 template <typename A, A a, typename C>
256 struct RandomBatchRuleHelper;
257 
258 template <typename F, F Func, typename T1, typename... T>
259 struct RandomBatchRuleHelper<F, Func, typelist<T1, T...>> {
applyat::functorch::RandomBatchRuleHelper260   static Tensor apply(SymIntArrayRef shape, T... extra_args) {
261     return random_batching_rule<F, Func, T...>(shape, std::forward<T>(extra_args)...);
262   }
263 };
264 
265 template <typename F, F Func, typename... T>
rand_int_wrapper(SymIntArrayRef shape,c10::SymInt high,T...extra_args)266 Tensor rand_int_wrapper(SymIntArrayRef shape, c10::SymInt high, T... extra_args) {
267   return Func(high, shape, std::forward<T>(extra_args)...);
268 }
269 
270 template <typename A, A a, typename C>
271 struct RandomInplaceBatchRuleHelper;
272 
273 template <typename F, F Func, typename T1, typename... T>
274 struct RandomInplaceBatchRuleHelper<F, Func, typelist<T1, T...>> {
applyat::functorch::RandomInplaceBatchRuleHelper275   static Tensor& apply(Tensor& self, T... extra_args) {
276     return random_inplace_batching_rule<F, Func, T...>(self, std::forward<T>(extra_args)...);
277   }
278 };
279 
280 template <typename A, A a, typename C>
281 struct RandIntBatchRuleHelper;
282 
283 template <typename F, F Func, typename T1, typename T2, typename... T>
284 struct RandIntBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
applyat::functorch::RandIntBatchRuleHelper285   static Tensor apply(c10::SymInt high, SymIntArrayRef shape, T... extra_args) {
286     return random_batching_rule<decltype(&rand_int_wrapper<F, Func, T...>),
287                                 &rand_int_wrapper<F, Func, T...>,
288                                 c10::SymInt, T...>(shape, std::move(high), std::forward<T>(extra_args)...);
289   }
290 };
291 
292 template <typename F, F Func, typename T0, typename T1, typename... T>
rand_int_low_wrapper(SymIntArrayRef shape,T0 scalar0,T1 scalar1,T...extra_args)293 Tensor rand_int_low_wrapper(SymIntArrayRef shape, T0 scalar0, T1 scalar1, T... extra_args) {
294   return Func(scalar0, scalar1, shape, std::forward<T>(extra_args)...);
295 }
296 
297 template <typename A, A a, typename C>
298 struct RandTwoLeadingScalarsBatchRuleHelper;
299 
300 template <typename F, F Func, typename T0, typename T1, typename T2, typename... T>
301 struct RandTwoLeadingScalarsBatchRuleHelper<F, Func, typelist<T0, T1, T2, T...>> {
applyat::functorch::RandTwoLeadingScalarsBatchRuleHelper302   static Tensor apply(T0 scalar0, T1 scalar1, SymIntArrayRef shape, T... extra_args) {
303     return random_batching_rule<decltype(&rand_int_low_wrapper<F, Func, T0, T1, T...>),
304                                 &rand_int_low_wrapper<F, Func, T0, T1, T...>,
305                                 T0, T1, T...>(shape, scalar0, scalar1, std::forward<T>(extra_args)...);
306   }
307 };
308 
309 template <typename A, A a, typename C>
310 struct RandpermBatchRuleHelper;
311 
312 template <typename F, F Func, typename T1, typename... T>
313 struct RandpermBatchRuleHelper<F, Func, typelist<T1, T...>> {
applyat::functorch::RandpermBatchRuleHelper314   static Tensor apply(int64_t n, T... extra_args) {
315     return randperm_batching_rule<F, Func, T...>(n, std::forward<T>(extra_args)...);
316   }
317 };
318 
319 template <typename A, A a, typename C>
320 struct UnaryPointwiseRandomBatchRule;
321 
322 template <typename F, F Func, typename A0, typename... T>
323 struct UnaryPointwiseRandomBatchRule<F, Func, typelist<A0, T...>> {
applyat::functorch::UnaryPointwiseRandomBatchRule324   static Tensor apply(const Tensor& tensor, T... extra_args) {
325     return unary_pointwise_random_batch_rule<F, Func, T...>(tensor, std::forward<T>(extra_args)...);
326   }
327 };
328 
329 template <typename A, A a, typename C>
330 struct NormalPointwiseBatchRule;
331 
332 template <typename F, F Func, typename A0, typename... T>
333 struct NormalPointwiseBatchRule<F, Func, typelist<A0, T...>> {
applyat::functorch::NormalPointwiseBatchRule334   static Tensor apply(const Tensor& tensor, T... extra_args) {
335     return unary_pointwise_random_batch_rule<F, Func, T...>(tensor, std::forward<T>(extra_args)...);
336   }
337 };
338 
339 template<typename F, F Func, typename... T>
normal_wrapper(const Tensor & tensor,double scalar,T...extra_args)340 Tensor normal_wrapper(const Tensor& tensor, double scalar, T... extra_args) {
341   return Func(scalar, tensor, extra_args...);
342 }
343 
344 template <typename A, A a, typename C>
345 struct UnaryPointwiseRandomLeadingFloatBatchRule;
346 
347 template <typename F, F Func, typename A0, typename A1, typename... T>
348 struct UnaryPointwiseRandomLeadingFloatBatchRule<F, Func, typelist<A0, A1, T...>> {
applyat::functorch::UnaryPointwiseRandomLeadingFloatBatchRule349   static Tensor apply(double scalar, const Tensor& tensor, T... extra_args) {
350     return unary_pointwise_random_batch_rule<decltype(&normal_wrapper<F, Func, T...>),
351                                          &normal_wrapper<F, Func, T...>, double,
352                                          T...>(tensor, scalar, std::forward<T>(extra_args)...);
353   }
354 };
355 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)356 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
357   #define RANDOM_INPLACE_BATCH_RULE2(op, overload) \
358     m.impl(#op"."#overload, SINGLE_ARG(\
359       RandomInplaceBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
360                             c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
361 
362   RANDOM_INPLACE_BATCH_RULE2(bernoulli_, float);
363 
364   #undef RANDOM_INPLACE_BATCH_RULE2
365 }
366 
TORCH_LIBRARY_IMPL(aten,FuncTorchVmapMode,m)367 TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
368   #define RANDOM_BATCH_RULE(op) \
369     m.impl(#op, SINGLE_ARG(\
370       RandomBatchRuleHelper<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
371                             c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
372 
373   #define RANDOM_BATCH_RULE2(op, overload) \
374     m.impl(#op"."#overload, SINGLE_ARG(\
375       RandomBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
376                             c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
377 
378   #define RANDOM_INPLACE_BATCH_RULE(op) \
379     m.impl(#op, SINGLE_ARG(\
380       RandomInplaceBatchRuleHelper<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
381                             c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
382 
383   #define RANDOM_INPLACE_BATCH_RULE2(op, overload) \
384     m.impl(#op"."#overload, SINGLE_ARG(\
385       RandomInplaceBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
386                             c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
387 
388   #define RANDINT_BATCH_RULE(op) \
389     m.impl(#op, SINGLE_ARG(\
390       RandIntBatchRuleHelper<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
391                              c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
392 
393   #define RANDINT_BATCH_RULE2(op, overload) \
394     m.impl(#op"."#overload, SINGLE_ARG(\
395       RandIntBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
396                             c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
397 
398   #define RAND_TWO_LEADING_SCALARS_BATCH_RULE(op, overload) \
399     m.impl(#op"."#overload, SINGLE_ARG(\
400       RandTwoLeadingScalarsBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
401                                 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
402   #define RANDPERM_BATCH_RULE(op) \
403     m.impl(#op, SINGLE_ARG(\
404       RandpermBatchRuleHelper<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
405                             c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
406 
407   #define RANDPERM_BATCH_RULE2(op, overload) \
408     m.impl(#op"."#overload, SINGLE_ARG(\
409       RandpermBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
410                             c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
411 
412   #define UNARY_POINTWISE_RANDOM(op) \
413     m.impl(#op, SINGLE_ARG(\
414       UnaryPointwiseRandomBatchRule<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
415                                     c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
416 
417   #define UNARY_POINTWISE_RANDOM2(op, overload) \
418     m.impl(#op"."#overload, SINGLE_ARG(\
419       UnaryPointwiseRandomBatchRule<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
420                                     c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
421 
422   #define UNARY_POINTWISE_RANDOM_LEADING_FLOAT(op, overload) \
423     m.impl(#op"."#overload, SINGLE_ARG(\
424       UnaryPointwiseRandomLeadingFloatBatchRule<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
425                                                 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
426 
427   RANDOM_BATCH_RULE(randn);
428   RANDOM_BATCH_RULE2(randn, generator);
429   RANDOM_BATCH_RULE2(randn, generator_with_names);
430   RANDOM_BATCH_RULE2(randn, names);
431 
432   RANDOM_BATCH_RULE(rand);
433   RANDOM_BATCH_RULE2(rand, generator);
434   RANDOM_BATCH_RULE2(rand, generator_with_names);
435   RANDOM_BATCH_RULE2(rand, names);
436 
437   RANDOM_INPLACE_BATCH_RULE(random_);
438   RANDOM_INPLACE_BATCH_RULE2(random_, from);
439   RANDOM_INPLACE_BATCH_RULE2(random_, to);
440 
441   RANDOM_INPLACE_BATCH_RULE(cauchy_);
442   RANDOM_INPLACE_BATCH_RULE(exponential_);
443   RANDOM_INPLACE_BATCH_RULE(geometric_);
444   RANDOM_INPLACE_BATCH_RULE(log_normal_);
445   RANDOM_INPLACE_BATCH_RULE(normal_);
446   RANDOM_INPLACE_BATCH_RULE(uniform_);
447 
448   RANDINT_BATCH_RULE(randint);
449   RANDINT_BATCH_RULE2(randint, generator);
450   RAND_TWO_LEADING_SCALARS_BATCH_RULE(randint, low);
451   RAND_TWO_LEADING_SCALARS_BATCH_RULE(randint, low_generator);
452 
453   m.impl("bernoulli_.Tensor", at::functorch::bernoulli_inplace_Tensor_batching_rule);
454   RANDOM_INPLACE_BATCH_RULE2(bernoulli_, float);
455   UNARY_POINTWISE_RANDOM2(bernoulli, p);
456 
457   RANDPERM_BATCH_RULE(randperm);
458   RANDPERM_BATCH_RULE2(randperm, generator);
459 
460   RAND_TWO_LEADING_SCALARS_BATCH_RULE(normal, float_float);
461   UNARY_POINTWISE_RANDOM2(normal, Tensor_float);
462   UNARY_POINTWISE_RANDOM_LEADING_FLOAT(normal, float_Tensor);
463 
464   m.impl("native_dropout", native_dropout_batching_rule); // needs special casing because cuda version doesn't call bernoulli
465 
466   UNARY_POINTWISE_RANDOM(_standard_gamma);
467   UNARY_POINTWISE_RANDOM(_sample_dirichlet);
468   m.impl("multinomial", multinomial_batching_rule);
469   UNARY_POINTWISE_RANDOM(poisson);
470   UNARY_POINTWISE_RANDOM(bernoulli);
471 
472   #define TENSOR_LIKE_COMMON_ARG_TYPES std::optional<ScalarType>, std::optional<Layout>, std::optional<Device>, std::optional<bool>, std::optional<MemoryFormat>
473   m.impl("randint_like", tensor_like_random_batch_rule<decltype(&ATEN_FN(randint_like)), &ATEN_FN(randint_like), int64_t, TENSOR_LIKE_COMMON_ARG_TYPES>);
474   m.impl("randint_like.low_dtype", tensor_like_random_batch_rule<\
475     decltype(&ATEN_FN2(randint_like, low_dtype)), &ATEN_FN2(randint_like, low_dtype), int64_t, int64_t, TENSOR_LIKE_COMMON_ARG_TYPES>);
476   m.impl("rand_like", tensor_like_random_batch_rule<decltype(&ATEN_FN(rand_like)), &ATEN_FN(rand_like), TENSOR_LIKE_COMMON_ARG_TYPES>);
477   m.impl("randn_like", tensor_like_random_batch_rule<decltype(&ATEN_FN(randn_like)), &ATEN_FN(randn_like), TENSOR_LIKE_COMMON_ARG_TYPES>);
478 
479   #undef RANDOM_BATCH_RULE
480   #undef RANDOM_BATCH_RULE2
481   #undef RANDOM_INPLACE_BATCH_RULE
482   #undef RANDOM_INPLACE_BATCH_RULE2
483   #undef RANDINT_BATCH_RULE
484   #undef RANDINT_BATCH_RULE2
485   #undef RAND_TWO_LEADING_SCALARS_BATCH_RULE
486   #undef RANDPERM_BATCH_RULE
487   #undef RANDPERM_BATCH_RULE2
488   #undef UNARY_POINTWISE_RANDOM
489   #undef UNARY_POINTWISE_RANDOM2
490   #undef UNARY_POINTWISE_RANDOM_LEADING_FLOAT
491   #undef TENSOR_LIKE_COMMON_ARG_TYPES
492 }
493 
494 } // namespace at::functorch
495