xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Dropout.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/NamedTensorUtils.h>
4 #include <ATen/TensorOperators.h>
5 #include <c10/util/irange.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/alpha_dropout_native.h>
12 #include <ATen/ops/dropout_native.h>
13 #include <ATen/ops/empty_like.h>
14 #include <ATen/ops/feature_alpha_dropout_native.h>
15 #include <ATen/ops/feature_dropout_native.h>
16 #include <ATen/ops/native_dropout.h>
17 #include <ATen/ops/native_dropout_backward_native.h>
18 #include <ATen/ops/native_dropout_native.h>
19 #include <ATen/ops/ones_like.h>
20 #include <ATen/ops/zeros.h>
21 #endif
22 
23 namespace at::native {
24 
25 namespace {
26 
27 template<bool inplace>
28 using Ctype = typename std::conditional<inplace, Tensor&, Tensor>::type;
29 
make_feature_noise(const Tensor & input)30 Tensor make_feature_noise(const Tensor& input) {
31   auto input_sizes = input.sym_sizes();
32   TORCH_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input");
33   c10::SymDimVector sizes;
34   sizes.reserve(input.dim());
35   sizes.push_back(input_sizes[0]);
36   sizes.push_back(input_sizes[1]);
37   for (C10_UNUSED const auto i : c10::irange(2, input.dim())) {
38     sizes.push_back(1);
39   }
40   return input.new_empty_symint(sizes);
41 }
42 
is_fused_kernel_acceptable(const Tensor & input,double p)43 bool is_fused_kernel_acceptable(const Tensor& input, double p) {
44   return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.is_privateuseone()) && p > 0 && p < 1 && input.sym_numel() > 0;
45 }
46 
47 // NB: sure, we could have used different overloads here, but I would feel insecure
48 // knowing that this dispatch depends only on the constness of the references
49 template<bool inplace>
multiply(Tensor & input,const Tensor & noise)50 Tensor& multiply(Tensor& input, const Tensor& noise) {
51   static_assert(inplace, "Wrong multiply overload triggered in Dropout.cpp");
52   return input.mul_(noise);
53 }
54 
55 template<bool inplace>
multiply(const Tensor & input,const Tensor & noise)56 Tensor multiply(const Tensor& input, const Tensor& noise) {
57   static_assert(!inplace, "Wrong multiply overload triggered in Dropout.cpp");
58   return input.mul(noise);
59 }
60 
61 template<bool feature_dropout, bool alpha_dropout, bool inplace, typename T>
_dropout_impl(T & input,double p,bool train)62 Ctype<inplace> _dropout_impl(T& input, double p, bool train) {
63   TORCH_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p);
64   if (p == 0 || !train || input.sym_numel() == 0) {
65     return input;
66   }
67 
68   if (p == 1) {
69     return multiply<inplace>(input, at::zeros({}, input.options()));
70   }
71 
72   at::Tensor b; // used for alpha_dropout only
73   auto noise = feature_dropout ? make_feature_noise(input) : at::empty_like(input);
74   noise.bernoulli_(1 - p);
75   if (alpha_dropout) {
76     constexpr double alpha = 1.7580993408473766;
77     double a = 1. / std::sqrt((alpha * alpha * p + 1) * (1 - p));
78     b = noise.add(-1).mul_(alpha * a).add_(alpha * a * p);
79     noise.mul_(a);
80   } else {
81     noise.div_(1 - p);
82   }
83 
84   if (!alpha_dropout) {
85     return multiply<inplace>(input, noise);
86   } else {
87     return multiply<inplace>(input, noise).add_(b);
88   }
89 }
90 
91 #define ALIAS_SPECIALIZATION(ALIAS_NAME, IS_FEATURE, IS_ALPHA)                      \
92 template <bool inplace, typename... Args>                                           \
93 Ctype<inplace> ALIAS_NAME(Args&&... args) {                                         \
94   return _dropout_impl<IS_FEATURE, IS_ALPHA, inplace>(std::forward<Args>(args)...); \
95 }
96 
97 ALIAS_SPECIALIZATION(_dropout,               false, false)
98 ALIAS_SPECIALIZATION(_feature_dropout,       true,  false)
99 ALIAS_SPECIALIZATION(_alpha_dropout,         false, true )
100 ALIAS_SPECIALIZATION(_feature_alpha_dropout, true,  true )
101 
102 } // anonymous namespace
103 
104 std::tuple<Tensor,Tensor>
native_dropout_cpu(const Tensor & input,double p,std::optional<bool> train)105 native_dropout_cpu(const Tensor& input, double p, std::optional<bool> train) {
106   if (input.numel() == 0) {
107     return std::make_tuple(input, at::empty_like(input, input.options()));
108   }
109 
110   Tensor mask;
111   Tensor output;
112 
113   if (!train.has_value() || *train) {
114     double p1m = 1. - p;
115     // Check for probability of zero to avoid divide by zero and NaN results
116     double scale = p1m == 0 ? 0. : 1. / p1m;
117     mask = at::empty_like(input, input.options().dtype(c10::CppTypeToScalarType<bool>::value));
118     mask.bernoulli_(p1m);
119     output = input.mul(mask).mul_(scale);
120   } else {
121     mask = at::ones_like(input, input.options().dtype(c10::CppTypeToScalarType<bool>::value));
122     output = input.clone();
123   }
124   return std::make_tuple(output, mask);
125 }
126 
native_dropout_backward(const Tensor & grad,const Tensor & mask,double scale)127 Tensor native_dropout_backward(const Tensor& grad, const Tensor& mask, double scale) {
128   Tensor result = grad * mask * scale;
129   return result;
130 }
131 
dropout(const Tensor & input,double p,bool train)132 Tensor dropout(const Tensor& input, double p, bool train) {
133   auto result = [&]() {
134     NoNamesGuard guard;
135     // TODO: we can remove this is_nested() code smell in the future
136     //       if we find a way to support _dropout for nested tensor
137     //       e.g. make it an op (at::_dropout) to use dispatcher?
138     if (input.is_nested() || (train && is_fused_kernel_acceptable(input, p))) {
139       return std::get<0>(at::native_dropout(input, p, train));
140     }
141     return _dropout<false>(input, p, train);
142   }();
143   namedinference::propagate_names(result, input);
144   return result;
145 }
146 
dropout_(Tensor & input,double p,bool train)147 Tensor& dropout_(Tensor& input, double p, bool train) {
148   return _dropout<true>(input, p, train);
149 }
150 
feature_dropout(const Tensor & input,double p,bool train)151 Tensor feature_dropout(const Tensor& input, double p, bool train) {
152   return _feature_dropout<false>(input, p, train);
153 }
154 
feature_dropout_(Tensor & input,double p,bool train)155 Tensor& feature_dropout_(Tensor& input, double p, bool train) {
156   return _feature_dropout<true>(input, p, train);
157 }
158 
alpha_dropout(const Tensor & input,double p,bool train)159 Tensor alpha_dropout(const Tensor& input, double p, bool train) {
160   return _alpha_dropout<false>(input, p, train);
161 }
162 
alpha_dropout_(Tensor & input,double p,bool train)163 Tensor& alpha_dropout_(Tensor& input, double p, bool train) {
164   return _alpha_dropout<true>(input, p, train);
165 }
166 
feature_alpha_dropout(const Tensor & input,double p,bool train)167 Tensor feature_alpha_dropout(const Tensor& input, double p, bool train) {
168   return _feature_alpha_dropout<false>(input, p, train);
169 }
170 
feature_alpha_dropout_(Tensor & input,double p,bool train)171 Tensor& feature_alpha_dropout_(Tensor& input, double p, bool train) {
172   return _feature_alpha_dropout<true>(input, p, train);
173 }
174 
175 } // namespace at::native
176