1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/ATen.h>
8 #include <ATen/functorch/DynamicLayer.h>
9 #include <ATen/functorch/BatchRulesHelper.h>
10
11 #include <utility>
12
13 // This file contains batching rules for random operations. These are different
14 // from our regular batching rules: regular batching rules get registered to the
15 // FuncTorchBatched key, but batching rules for random operations get
16 // registered to FuncTorchVmapMode. This is because we need to interpose on
17 // random operations even if they're not on a BatchedTensor.
18
19 namespace at::functorch {
20
21 template <typename F, F Func, typename... ExtraArgs>
random_batching_rule(SymIntArrayRef shape,ExtraArgs...extra_args)22 Tensor random_batching_rule(SymIntArrayRef shape, ExtraArgs... extra_args) {
23 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
24 auto maybe_layer = maybeCurrentDynamicLayer();
25 c10::SmallVector<SymInt> shapeVec(1, maybe_layer->batchSize());
26 shapeVec.reserve(shape.size() + 1);
27 shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
28 RandomnessType randomness = maybe_layer->randomness();
29 check_randomness(randomness);
30 if (randomness == RandomnessType::Different) {
31 return makeBatched(Func(shapeVec, std::forward<ExtraArgs>(extra_args)...), 0, maybe_layer->layerId());
32 } else {
33 return Func(shape, std::forward<ExtraArgs>(extra_args)...);
34 }
35 }
36
37 template <typename F, F Func, typename... ExtraArgs>
random_inplace_batching_rule(Tensor & self,ExtraArgs...extra_args)38 Tensor& random_inplace_batching_rule(Tensor& self, ExtraArgs... extra_args) {
39 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
40 auto maybe_layer = maybeCurrentDynamicLayer();
41 const auto cur_level = maybe_layer->layerId();
42 auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
43 self_value = moveBatchDimToFront(self_value, self_bdim);
44 RandomnessType randomness = maybe_layer->randomness();
45 check_randomness(randomness);
46 TORCH_CHECK(
47 !(randomness == RandomnessType::Different && !self_bdim),
48 "vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. ",
49 "If this is necessary for your usage, please file an issue with functorch.");
50 if (randomness == RandomnessType::Same && self_bdim) {
51 auto intermediate = empty(self.sizes(), self.options());
52 Func(intermediate, std::forward<ExtraArgs>(extra_args)...);
53 self.copy_(intermediate); // batching should make this just work out...
54 return self;
55 } else {
56 Func(self_value, std::forward<ExtraArgs>(extra_args)...);
57 return self;
58 }
59 }
60
bernoulli_inplace_Tensor_batching_rule(Tensor & self,const Tensor & p_,std::optional<Generator> gen)61 static Tensor& bernoulli_inplace_Tensor_batching_rule(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
62 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
63 auto maybe_layer = maybeCurrentDynamicLayer();
64 auto cur_level = maybe_layer->layerId();
65 RandomnessType randomness = maybe_layer->randomness();
66
67 auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
68
69 auto [other_value, other_bdim] = unwrapTensorAtLevel(p_, cur_level);
70
71 check_randomness(randomness, other_bdim.has_value());
72
73 if (!self_bdim && other_bdim) {
74 vmapIncompatibleInplaceError("inplace bernoulli");
75 }
76
77 // compute max logical rank
78 auto self_logical_rank = rankWithoutBatchDim(self_value, self_bdim);
79 auto other_logical_rank = rankWithoutBatchDim(other_value, other_bdim);
80 auto max_logical_rank = std::max(self_logical_rank, other_logical_rank);
81
82 auto self_ = moveBatchDimToFront(self_value, self_bdim);
83 auto other_ = moveBatchDimToFront(other_value, other_bdim);
84
85 // If the dimensions aren't aligned, we need to line them up.
86 // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
87 // Note that only tensors that have a batch dim need to be modified.
88 // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
89 self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank);
90 other_ = maybePadToLogicalRank(other_, other_bdim, max_logical_rank);
91 TORCH_CHECK(
92 !(randomness == RandomnessType::Different && !self_bdim),
93 "vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. ",
94 "If this is necessary for your usage, please file an issue with functorch.");
95 if (randomness == RandomnessType::Same && self_bdim) {
96 auto intermediate = empty(self.sizes(), self.options());
97 intermediate.bernoulli_(other_, std::move(gen));
98 self.copy_(intermediate); // batching should make this just work out...
99 return self;
100 } else {
101 self_.bernoulli_(other_, std::move(gen));
102 return self;
103 }
104 }
105
106 template <typename F, F Func, typename... ExtraArgs>
randperm_batching_rule(int64_t n,ExtraArgs...extra_args)107 Tensor randperm_batching_rule(int64_t n, ExtraArgs... extra_args) {
108 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
109 auto maybe_layer = maybeCurrentDynamicLayer();
110 auto const batch_size = maybe_layer->batchSize();
111 RandomnessType randomness = maybe_layer->randomness();
112 check_randomness(randomness);
113 if (randomness == RandomnessType::Different) {
114 std::vector<at::Tensor> stackedList(batch_size.guard_int(__FILE__, __LINE__));
115 for (int64_t idx = 0; idx < batch_size; ++idx) {
116 // since this is done in a loop, need to pass by reference for generator to update
117 stackedList[idx] = Func(n, extra_args...);
118 }
119 return makeBatched(at::stack(stackedList), 0, maybe_layer->layerId());
120 } else {
121 return Func(n, std::forward<ExtraArgs>(extra_args)...);
122 }
123 }
124
125 template <typename F, F Func, typename... ExtraArgs>
unary_pointwise_random_batch_rule(const Tensor & tensor,ExtraArgs...extra_args)126 Tensor unary_pointwise_random_batch_rule(const Tensor& tensor, ExtraArgs... extra_args) {
127 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
128 auto maybe_layer = maybeCurrentDynamicLayer();
129 const auto cur_level = maybe_layer->layerId();
130
131 auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(tensor, cur_level);
132 tensor_value = moveBatchDimToFront(tensor_value, tensor_bdim);
133
134 RandomnessType randomness = maybe_layer->randomness();
135 check_randomness(randomness, tensor_bdim.has_value());
136 auto shape = tensor_value.sizes();
137 VmapSymDimVector shapeVec(1, maybe_layer->batchSize());
138 shapeVec.reserve(shape.size() + 1);
139 shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
140
141 if (randomness == RandomnessType::Different && !tensor_bdim) {
142 tensor_value = tensor_value.expand_symint(shapeVec);
143 }
144 auto out = Func(tensor_value, std::forward<ExtraArgs>(extra_args)...);
145 if (randomness == RandomnessType::Same && !tensor_bdim) {
146 return out;
147 }
148 return makeBatched(out, 0, cur_level);
149 }
150
151 template<typename F, F Func, typename... ExtraArgs>
tensor_like_random_batch_rule(const Tensor & self,ExtraArgs...extra_args)152 Tensor tensor_like_random_batch_rule(const Tensor& self, ExtraArgs... extra_args) {
153 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
154 auto maybe_layer = maybeCurrentDynamicLayer();
155 const auto cur_level = maybe_layer->layerId();
156 RandomnessType randomness = maybe_layer->randomness();
157 check_randomness(randomness);
158
159 auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(self, cur_level);
160 tensor_value = moveBatchDimToFront(tensor_value, tensor_bdim);
161
162 if (randomness == RandomnessType::Same && tensor_bdim) {
163 tensor_value = tensor_value[0];
164 } else if (randomness == RandomnessType::Different && !tensor_bdim) {
165 auto shape = tensor_value.sizes();
166 VmapSymDimVector shapeVec(1, maybe_layer->batchSize());
167 shapeVec.reserve(shape.size() + 1);
168 shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
169 tensor_value = tensor_value.expand_symint(shapeVec);
170 }
171
172 auto res = Func(tensor_value, std::forward<ExtraArgs>(extra_args)...);
173 return (randomness == RandomnessType::Same) ? res : makeBatched(res, 0, cur_level);
174 }
175
native_dropout_batching_rule(const Tensor & tensor,double p,std::optional<bool> train)176 static std::tuple<Tensor,Tensor> native_dropout_batching_rule(const Tensor& tensor, double p, std::optional<bool> train) {
177 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
178 auto maybe_layer = maybeCurrentDynamicLayer();
179 const auto cur_level = maybe_layer->layerId();
180 RandomnessType randomness = maybe_layer->randomness();
181
182 auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(tensor, cur_level);
183 tensor_value = moveBatchDimToFront(tensor_value, tensor_bdim);
184
185 if (!train.has_value() || *train) {
186 check_randomness(randomness); // if we are in eval mode, we don't use about randomness
187 }
188
189 if ((train.has_value() && !*train) ||
190 randomness == RandomnessType::Different) {
191 if (!tensor_bdim) {
192 // if tensor is unbatched, add batch dim before
193 // calling dropout.
194 auto shape = tensor_value.sizes();
195 VmapSymDimVector shapeVec(1, maybe_layer->batchSize());
196 shapeVec.reserve(shape.size() + 1);
197 shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
198 tensor_value = tensor_value.expand_symint(shapeVec);
199 }
200 auto [output, mask] = at::native_dropout(tensor_value, p, train);
201 return std::make_tuple(
202 makeBatched(output, 0, cur_level),
203 makeBatched(mask, 0, cur_level));
204 }
205
206 // repeated code from the CPU kernel since the CUDA one doesn't call bernoulli_ explicitly
207 double p1m = 1. - p;
208 // Check for probability of zero to avoid divide by zero and NaN results
209 double scale = p1m == 0 ? 0. : 1. / p1m;
210 Tensor mask = at::empty_like(tensor, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
211 mask.bernoulli_(p1m);
212 const auto output = tensor.mul(mask).mul_(scale);
213 return std::make_tuple(output, mask);
214 }
215
multinomial_batching_rule(const Tensor & self,const int64_t num_samples,const bool replacement,const std::optional<Generator> generator)216 static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_samples, const bool replacement, const std::optional<Generator> generator) {
217 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
218 auto maybe_layer = maybeCurrentDynamicLayer();
219 const auto cur_level = maybe_layer->layerId();
220
221 auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
222 self_value = moveBatchDimToFront(self_value, self_bdim);
223
224 RandomnessType randomness = maybe_layer->randomness();
225 check_randomness(randomness, self_bdim.has_value());
226
227 if (randomness == RandomnessType::Different) {
228 // 1D cases: S -> BS -> multinomial(BS)
229 // BS -> multinomial(BS)
230 //
231 // 2D cases: MS -> BMS -> (BM)S -> multinomial((BM)S) -> (BM)S -> BMS
232 // BMS -> (BM)S -> multinomial((BM)S) -> (BM)S -> BMS
233 const auto is_2D_case = rankWithoutBatchDim(self_value, self_bdim) == 2;
234 if (!self_bdim.has_value()) {
235 self_value = ensure_has_bdim(self_value, self_bdim.has_value(), maybe_layer->batchSize());
236 }
237 if (is_2D_case) {
238 self_value = reshape_dim_into(0, 0, self_value);
239 }
240 auto out = multinomial(self_value, num_samples, replacement, generator);
241 if (is_2D_case) {
242 out = reshape_dim_outof_symint(0, maybe_layer->batchSize(), out);
243 }
244 return makeBatched(out, 0, cur_level);;
245 }
246
247 TORCH_INTERNAL_ASSERT(randomness == RandomnessType::Same); // check_randomness eliminates error randomness
248 TORCH_INTERNAL_ASSERT(!self_bdim.has_value()); // check_randomness eliminates same randomness with batched input
249 // Must be same randomness with unbatched input
250 // 1D case: S -> multinomial(S) -> S
251 // 2D case: MS -> multinomial(MS) -> MS
252 return multinomial(self_value, num_samples, replacement, generator);
253 }
254
255 template <typename A, A a, typename C>
256 struct RandomBatchRuleHelper;
257
258 template <typename F, F Func, typename T1, typename... T>
259 struct RandomBatchRuleHelper<F, Func, typelist<T1, T...>> {
applyat::functorch::RandomBatchRuleHelper260 static Tensor apply(SymIntArrayRef shape, T... extra_args) {
261 return random_batching_rule<F, Func, T...>(shape, std::forward<T>(extra_args)...);
262 }
263 };
264
265 template <typename F, F Func, typename... T>
rand_int_wrapper(SymIntArrayRef shape,c10::SymInt high,T...extra_args)266 Tensor rand_int_wrapper(SymIntArrayRef shape, c10::SymInt high, T... extra_args) {
267 return Func(high, shape, std::forward<T>(extra_args)...);
268 }
269
270 template <typename A, A a, typename C>
271 struct RandomInplaceBatchRuleHelper;
272
273 template <typename F, F Func, typename T1, typename... T>
274 struct RandomInplaceBatchRuleHelper<F, Func, typelist<T1, T...>> {
applyat::functorch::RandomInplaceBatchRuleHelper275 static Tensor& apply(Tensor& self, T... extra_args) {
276 return random_inplace_batching_rule<F, Func, T...>(self, std::forward<T>(extra_args)...);
277 }
278 };
279
280 template <typename A, A a, typename C>
281 struct RandIntBatchRuleHelper;
282
283 template <typename F, F Func, typename T1, typename T2, typename... T>
284 struct RandIntBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
applyat::functorch::RandIntBatchRuleHelper285 static Tensor apply(c10::SymInt high, SymIntArrayRef shape, T... extra_args) {
286 return random_batching_rule<decltype(&rand_int_wrapper<F, Func, T...>),
287 &rand_int_wrapper<F, Func, T...>,
288 c10::SymInt, T...>(shape, std::move(high), std::forward<T>(extra_args)...);
289 }
290 };
291
292 template <typename F, F Func, typename T0, typename T1, typename... T>
rand_int_low_wrapper(SymIntArrayRef shape,T0 scalar0,T1 scalar1,T...extra_args)293 Tensor rand_int_low_wrapper(SymIntArrayRef shape, T0 scalar0, T1 scalar1, T... extra_args) {
294 return Func(scalar0, scalar1, shape, std::forward<T>(extra_args)...);
295 }
296
297 template <typename A, A a, typename C>
298 struct RandTwoLeadingScalarsBatchRuleHelper;
299
300 template <typename F, F Func, typename T0, typename T1, typename T2, typename... T>
301 struct RandTwoLeadingScalarsBatchRuleHelper<F, Func, typelist<T0, T1, T2, T...>> {
applyat::functorch::RandTwoLeadingScalarsBatchRuleHelper302 static Tensor apply(T0 scalar0, T1 scalar1, SymIntArrayRef shape, T... extra_args) {
303 return random_batching_rule<decltype(&rand_int_low_wrapper<F, Func, T0, T1, T...>),
304 &rand_int_low_wrapper<F, Func, T0, T1, T...>,
305 T0, T1, T...>(shape, scalar0, scalar1, std::forward<T>(extra_args)...);
306 }
307 };
308
309 template <typename A, A a, typename C>
310 struct RandpermBatchRuleHelper;
311
312 template <typename F, F Func, typename T1, typename... T>
313 struct RandpermBatchRuleHelper<F, Func, typelist<T1, T...>> {
applyat::functorch::RandpermBatchRuleHelper314 static Tensor apply(int64_t n, T... extra_args) {
315 return randperm_batching_rule<F, Func, T...>(n, std::forward<T>(extra_args)...);
316 }
317 };
318
319 template <typename A, A a, typename C>
320 struct UnaryPointwiseRandomBatchRule;
321
322 template <typename F, F Func, typename A0, typename... T>
323 struct UnaryPointwiseRandomBatchRule<F, Func, typelist<A0, T...>> {
applyat::functorch::UnaryPointwiseRandomBatchRule324 static Tensor apply(const Tensor& tensor, T... extra_args) {
325 return unary_pointwise_random_batch_rule<F, Func, T...>(tensor, std::forward<T>(extra_args)...);
326 }
327 };
328
329 template <typename A, A a, typename C>
330 struct NormalPointwiseBatchRule;
331
332 template <typename F, F Func, typename A0, typename... T>
333 struct NormalPointwiseBatchRule<F, Func, typelist<A0, T...>> {
applyat::functorch::NormalPointwiseBatchRule334 static Tensor apply(const Tensor& tensor, T... extra_args) {
335 return unary_pointwise_random_batch_rule<F, Func, T...>(tensor, std::forward<T>(extra_args)...);
336 }
337 };
338
339 template<typename F, F Func, typename... T>
normal_wrapper(const Tensor & tensor,double scalar,T...extra_args)340 Tensor normal_wrapper(const Tensor& tensor, double scalar, T... extra_args) {
341 return Func(scalar, tensor, extra_args...);
342 }
343
344 template <typename A, A a, typename C>
345 struct UnaryPointwiseRandomLeadingFloatBatchRule;
346
347 template <typename F, F Func, typename A0, typename A1, typename... T>
348 struct UnaryPointwiseRandomLeadingFloatBatchRule<F, Func, typelist<A0, A1, T...>> {
applyat::functorch::UnaryPointwiseRandomLeadingFloatBatchRule349 static Tensor apply(double scalar, const Tensor& tensor, T... extra_args) {
350 return unary_pointwise_random_batch_rule<decltype(&normal_wrapper<F, Func, T...>),
351 &normal_wrapper<F, Func, T...>, double,
352 T...>(tensor, scalar, std::forward<T>(extra_args)...);
353 }
354 };
355
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)356 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
357 #define RANDOM_INPLACE_BATCH_RULE2(op, overload) \
358 m.impl(#op"."#overload, SINGLE_ARG(\
359 RandomInplaceBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
360 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
361
362 RANDOM_INPLACE_BATCH_RULE2(bernoulli_, float);
363
364 #undef RANDOM_INPLACE_BATCH_RULE2
365 }
366
TORCH_LIBRARY_IMPL(aten,FuncTorchVmapMode,m)367 TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
368 #define RANDOM_BATCH_RULE(op) \
369 m.impl(#op, SINGLE_ARG(\
370 RandomBatchRuleHelper<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
371 c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
372
373 #define RANDOM_BATCH_RULE2(op, overload) \
374 m.impl(#op"."#overload, SINGLE_ARG(\
375 RandomBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
376 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
377
378 #define RANDOM_INPLACE_BATCH_RULE(op) \
379 m.impl(#op, SINGLE_ARG(\
380 RandomInplaceBatchRuleHelper<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
381 c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
382
383 #define RANDOM_INPLACE_BATCH_RULE2(op, overload) \
384 m.impl(#op"."#overload, SINGLE_ARG(\
385 RandomInplaceBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
386 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
387
388 #define RANDINT_BATCH_RULE(op) \
389 m.impl(#op, SINGLE_ARG(\
390 RandIntBatchRuleHelper<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
391 c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
392
393 #define RANDINT_BATCH_RULE2(op, overload) \
394 m.impl(#op"."#overload, SINGLE_ARG(\
395 RandIntBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
396 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
397
398 #define RAND_TWO_LEADING_SCALARS_BATCH_RULE(op, overload) \
399 m.impl(#op"."#overload, SINGLE_ARG(\
400 RandTwoLeadingScalarsBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
401 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
402 #define RANDPERM_BATCH_RULE(op) \
403 m.impl(#op, SINGLE_ARG(\
404 RandpermBatchRuleHelper<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
405 c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
406
407 #define RANDPERM_BATCH_RULE2(op, overload) \
408 m.impl(#op"."#overload, SINGLE_ARG(\
409 RandpermBatchRuleHelper<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
410 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
411
412 #define UNARY_POINTWISE_RANDOM(op) \
413 m.impl(#op, SINGLE_ARG(\
414 UnaryPointwiseRandomBatchRule<decltype(&ATEN_FN(op)), &ATEN_FN(op), \
415 c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
416
417 #define UNARY_POINTWISE_RANDOM2(op, overload) \
418 m.impl(#op"."#overload, SINGLE_ARG(\
419 UnaryPointwiseRandomBatchRule<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
420 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
421
422 #define UNARY_POINTWISE_RANDOM_LEADING_FLOAT(op, overload) \
423 m.impl(#op"."#overload, SINGLE_ARG(\
424 UnaryPointwiseRandomLeadingFloatBatchRule<decltype(&ATEN_FN2(op, overload)), &ATEN_FN2(op, overload), \
425 c10::guts::function_traits<decltype(ATEN_FN2(op, overload))>::parameter_types>::apply))
426
427 RANDOM_BATCH_RULE(randn);
428 RANDOM_BATCH_RULE2(randn, generator);
429 RANDOM_BATCH_RULE2(randn, generator_with_names);
430 RANDOM_BATCH_RULE2(randn, names);
431
432 RANDOM_BATCH_RULE(rand);
433 RANDOM_BATCH_RULE2(rand, generator);
434 RANDOM_BATCH_RULE2(rand, generator_with_names);
435 RANDOM_BATCH_RULE2(rand, names);
436
437 RANDOM_INPLACE_BATCH_RULE(random_);
438 RANDOM_INPLACE_BATCH_RULE2(random_, from);
439 RANDOM_INPLACE_BATCH_RULE2(random_, to);
440
441 RANDOM_INPLACE_BATCH_RULE(cauchy_);
442 RANDOM_INPLACE_BATCH_RULE(exponential_);
443 RANDOM_INPLACE_BATCH_RULE(geometric_);
444 RANDOM_INPLACE_BATCH_RULE(log_normal_);
445 RANDOM_INPLACE_BATCH_RULE(normal_);
446 RANDOM_INPLACE_BATCH_RULE(uniform_);
447
448 RANDINT_BATCH_RULE(randint);
449 RANDINT_BATCH_RULE2(randint, generator);
450 RAND_TWO_LEADING_SCALARS_BATCH_RULE(randint, low);
451 RAND_TWO_LEADING_SCALARS_BATCH_RULE(randint, low_generator);
452
453 m.impl("bernoulli_.Tensor", at::functorch::bernoulli_inplace_Tensor_batching_rule);
454 RANDOM_INPLACE_BATCH_RULE2(bernoulli_, float);
455 UNARY_POINTWISE_RANDOM2(bernoulli, p);
456
457 RANDPERM_BATCH_RULE(randperm);
458 RANDPERM_BATCH_RULE2(randperm, generator);
459
460 RAND_TWO_LEADING_SCALARS_BATCH_RULE(normal, float_float);
461 UNARY_POINTWISE_RANDOM2(normal, Tensor_float);
462 UNARY_POINTWISE_RANDOM_LEADING_FLOAT(normal, float_Tensor);
463
464 m.impl("native_dropout", native_dropout_batching_rule); // needs special casing because cuda version doesn't call bernoulli
465
466 UNARY_POINTWISE_RANDOM(_standard_gamma);
467 UNARY_POINTWISE_RANDOM(_sample_dirichlet);
468 m.impl("multinomial", multinomial_batching_rule);
469 UNARY_POINTWISE_RANDOM(poisson);
470 UNARY_POINTWISE_RANDOM(bernoulli);
471
472 #define TENSOR_LIKE_COMMON_ARG_TYPES std::optional<ScalarType>, std::optional<Layout>, std::optional<Device>, std::optional<bool>, std::optional<MemoryFormat>
473 m.impl("randint_like", tensor_like_random_batch_rule<decltype(&ATEN_FN(randint_like)), &ATEN_FN(randint_like), int64_t, TENSOR_LIKE_COMMON_ARG_TYPES>);
474 m.impl("randint_like.low_dtype", tensor_like_random_batch_rule<\
475 decltype(&ATEN_FN2(randint_like, low_dtype)), &ATEN_FN2(randint_like, low_dtype), int64_t, int64_t, TENSOR_LIKE_COMMON_ARG_TYPES>);
476 m.impl("rand_like", tensor_like_random_batch_rule<decltype(&ATEN_FN(rand_like)), &ATEN_FN(rand_like), TENSOR_LIKE_COMMON_ARG_TYPES>);
477 m.impl("randn_like", tensor_like_random_batch_rule<decltype(&ATEN_FN(randn_like)), &ATEN_FN(randn_like), TENSOR_LIKE_COMMON_ARG_TYPES>);
478
479 #undef RANDOM_BATCH_RULE
480 #undef RANDOM_BATCH_RULE2
481 #undef RANDOM_INPLACE_BATCH_RULE
482 #undef RANDOM_INPLACE_BATCH_RULE2
483 #undef RANDINT_BATCH_RULE
484 #undef RANDINT_BATCH_RULE2
485 #undef RAND_TWO_LEADING_SCALARS_BATCH_RULE
486 #undef RANDPERM_BATCH_RULE
487 #undef RANDPERM_BATCH_RULE2
488 #undef UNARY_POINTWISE_RANDOM
489 #undef UNARY_POINTWISE_RANDOM2
490 #undef UNARY_POINTWISE_RANDOM_LEADING_FLOAT
491 #undef TENSOR_LIKE_COMMON_ARG_TYPES
492 }
493
494 } // namespace at::functorch
495