xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 
3 #include <ATen/native/AmpKernels.h>
4 #include <cmath>
5 #include <ATen/DeviceGuard.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/OpMathType.h>
8 #include <ATen/core/Tensor.h>
9 #include <ATen/native/ForeachUtils.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/cpu/Loops.h>
12 #include <ATen/cpu/vec/vec.h>
13 #include <ATen/cpu/vec/functional.h>
14 
15 namespace at::native {
16 
17 namespace {
18 // Follow the implementations of CUDA.
19 // Multiplies each tensor in scaled_grads by inv_scale in-place.
20 // If any element of any tensor in scaled_grads is inf or NaN, sets found_inf
21 // to 1.0.
22 //
23 // Args:
24 // scaled_grads:  A TensorList of scaled gradient tensors.  May contain infs or
25 // NaNs. found_inf:  A single-element float tensor to which 1.0 will be written
26 // if any gradient contain infs/nans.
27 //             Pre-zeroing found_inf, if appropriate, is the responsibility of
28 //             the caller.
29 // inv_scale:  The inverse of the scale factor by which scaled_grads are
30 // currently multiplied.
_amp_foreach_non_finite_check_and_unscale_cpu_kernel(TensorList scaled_grads,at::Tensor & found_inf,const at::Tensor & inv_scale)31 void _amp_foreach_non_finite_check_and_unscale_cpu_kernel(
32     TensorList scaled_grads,
33     at::Tensor& found_inf,
34     const at::Tensor& inv_scale) {
35   if (scaled_grads.empty()) {
36     return;
37   }
38 
39   TORCH_CHECK(inv_scale.is_cpu(), "inv_scale must be a CPU tensor.");
40   TORCH_CHECK(found_inf.is_cpu(), "found_inf must be a CPU tensor.");
41   TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
42   TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
43   TORCH_CHECK(
44       inv_scale.scalar_type() == at::ScalarType::Float,
45       "inv_scale must be a float tensor.");
46   TORCH_CHECK(
47       found_inf.scalar_type() == at::ScalarType::Float,
48       "found_inf must be a float tensor.");
49 
50   // Ensures client code (GradScaler) filtered scaled_grads by dtype.
51   at::native::check_foreach_api_restrictions(scaled_grads);
52   for (const at::Tensor& t : scaled_grads) {
53     TORCH_CHECK(t.is_cpu(), "one of scaled_grads was not a CPU tensor.");
54     TORCH_CHECK(
55         t.layout() == at::kStrided,
56         "one of scaled_grads was not a strided tensor.");
57     auto iter = at::TensorIterator::unary_op(
58         const_cast<at::Tensor&>(t), t);
59     if (at::isReducedFloatingType(iter.dtype())) {
60       AT_DISPATCH_REDUCED_FLOATING_TYPES(
61       iter.dtype(),
62       "_amp_foreach_non_finite_check_and_unscale_cpu",
63       [&iter, &found_inf, &inv_scale] {
64           auto* found_inf_ptr = found_inf.data_ptr<float>();
65           auto* inv_scale_ptr = inv_scale.data_ptr<float>();
66 
67           using opmath_t = at::opmath_type<scalar_t>;
68 
69           at::native::cpu_kernel_vec(
70               iter,
71               [found_inf_ptr, inv_scale_ptr](scalar_t val_in) -> scalar_t {
72                 auto val = static_cast<opmath_t>(val_in);
73                 if (!std::isfinite(val)) {
74                   *found_inf_ptr = 1.f;
75                 }
76                 // Every thread accesses inv_scale, but it will hit in cache.
77                 const auto inv_scale_val = *inv_scale_ptr;
78                 return static_cast<scalar_t>(
79                     inv_scale_val == 1.f ? val : val * inv_scale_val);
80               },
81               [found_inf_ptr, inv_scale_ptr](Vectorized<scalar_t> val_vec) -> Vectorized<scalar_t>{
82                 auto [val_vec0, val_vec1] = convert_to_float<scalar_t>(val_vec);
83                 if (val_vec0.has_inf_nan() || val_vec1.has_inf_nan()) {
84                   *found_inf_ptr = 1.f;
85                 }
86                 // Every thread accesses inv_scale, but it will hit in cache.
87                 const auto inv_scale_val = *inv_scale_ptr;
88                 val_vec0 = inv_scale_val == 1.f ? val_vec0 : val_vec0 * Vectorized<opmath_t>(inv_scale_val);
89                 val_vec1 = inv_scale_val == 1.f ? val_vec1 : val_vec1 * Vectorized<opmath_t>(inv_scale_val);
90                 return convert_from_float<scalar_t>(val_vec0, val_vec1);
91               });
92       });
93     } else {
94       AT_DISPATCH_FLOATING_TYPES(
95         iter.dtype(),
96         "_amp_foreach_non_finite_check_and_unscale_cpu",
97         [&iter, &found_inf, &inv_scale] {
98           auto* found_inf_ptr = found_inf.data_ptr<float>();
99           auto* inv_scale_ptr = inv_scale.data_ptr<float>();
100           at::native::cpu_kernel_vec(
101               iter,
102               [found_inf_ptr, inv_scale_ptr](scalar_t val_in) -> scalar_t {
103                 if (!std::isfinite(val_in)) {
104                   *found_inf_ptr = 1.f;
105                 }
106                 // Every thread accesses inv_scale, but it will hit in cache.
107                 const auto inv_scale_val = *inv_scale_ptr;
108                 return static_cast<scalar_t>(
109                     inv_scale_val == 1.f ? val_in : val_in * inv_scale_val);
110               },
111               [found_inf_ptr, inv_scale_ptr](Vectorized<scalar_t> val_vec) -> Vectorized<scalar_t>{
112                 if (val_vec.has_inf_nan()) {
113                   *found_inf_ptr = 1.f;
114                 }
115                 // Every thread accesses inv_scale, but it will hit in cache.
116                 const auto inv_scale_val = *inv_scale_ptr;
117                 return inv_scale_val == 1.f ? val_vec : val_vec * Vectorized<scalar_t>(inv_scale_val);
118               });
119         });
120     }
121   }
122 }
123 
124 // _amp_update_scale_cpu updates the scale tensor in place.
125 //
126 // Args:
127 // current_scale:  A one-element float tensor containing the scale value.
128 // growth_tracker:  A one-element IntTensor containing the number of recent
129 // consecutive unskipped steps. found_inf:  A one-element float tensor. If > 0,
130 // indicates that infs/nans were found by the relevant
131 //             prior _amp_non_finite_check_and_unscale_cpu call, and 0 if no
132 //             infs/nans were found.
133 // growth_factor:  Multiplier if no infs/NaNs were found (typically slightly >
134 // 1). backoff_factor:  Multiplier if infs/NaNs were found (typically 0.5).
135 // growth_interval:  Number of consecutive unskipped steps that must occur for
136 // current_scale to be multiplied by
137 //                   growth_factor.
138 //
139 // Returns:
140 // current_scale
_amp_update_scale_cpu_kernel(at::Tensor & current_scale,at::Tensor & growth_tracker,const at::Tensor & found_inf,double growth_factor,double backoff_factor,int64_t growth_interval)141 at::Tensor& _amp_update_scale_cpu_kernel(
142     at::Tensor& current_scale,
143     at::Tensor& growth_tracker,
144     const at::Tensor& found_inf,
145     double growth_factor,
146     double backoff_factor,
147     int64_t growth_interval) {
148   TORCH_CHECK(growth_tracker.is_cpu(), "growth_tracker must be a CPU tensor.");
149   TORCH_CHECK(current_scale.is_cpu(), "current_scale must be a CPU tensor.");
150   TORCH_CHECK(found_inf.is_cpu(), "found_inf must be a CPU tensor.");
151   TORCH_CHECK(
152       growth_tracker.numel() == 1,
153       "growth_tracker must be a 1-element tensor.");
154   TORCH_CHECK(
155       current_scale.numel() == 1, "current_scale must be a 1-element tensor.");
156   TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
157   TORCH_CHECK(
158       growth_tracker.scalar_type() == at::ScalarType::Int,
159       "growth_tracker must be an int tensor.");
160   TORCH_CHECK(
161       current_scale.scalar_type() == at::ScalarType::Float,
162       "current_scale must be a float tensor.");
163   TORCH_CHECK(
164       found_inf.scalar_type() == at::ScalarType::Float,
165       "found_inf must be a float tensor.");
166 
167   float* current_scale_ptr = current_scale.data_ptr<float>();
168   int* growth_tracker_ptr = growth_tracker.data_ptr<int>();
169   float* found_inf_ptr = found_inf.data_ptr<float>();
170 
171   if (*found_inf_ptr) {
172     *current_scale_ptr = (*current_scale_ptr) * backoff_factor;
173     *growth_tracker_ptr = 0;
174   } else {
175     // Entering this branch means we just carried out a successful step,
176     // so growth_tracker is incremented before comparing to growth_interval.
177     auto successful = (*growth_tracker_ptr) + 1;
178     if (successful == growth_interval) {
179       auto new_scale = static_cast<float>((*current_scale_ptr) * growth_factor);
180       // Do not grow the scale past fp32 bounds to inf.
181       if (std::isfinite(new_scale)) {
182         *current_scale_ptr = new_scale;
183       }
184       *growth_tracker_ptr = 0;
185     } else {
186       *growth_tracker_ptr = successful;
187     }
188   }
189 
190   return current_scale;
191 }
192 
193 } // namespace
194 
195 REGISTER_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu_stub, &_amp_foreach_non_finite_check_and_unscale_cpu_kernel);
196 REGISTER_DISPATCH(_amp_update_scale_cpu_stub, &_amp_update_scale_cpu_kernel);
197 
198 } // namespace at::native
199