1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2
3 #include <ATen/native/AmpKernels.h>
4 #include <cmath>
5 #include <ATen/DeviceGuard.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/OpMathType.h>
8 #include <ATen/core/Tensor.h>
9 #include <ATen/native/ForeachUtils.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/cpu/Loops.h>
12 #include <ATen/cpu/vec/vec.h>
13 #include <ATen/cpu/vec/functional.h>
14
15 namespace at::native {
16
17 namespace {
18 // Follow the implementations of CUDA.
19 // Multiplies each tensor in scaled_grads by inv_scale in-place.
20 // If any element of any tensor in scaled_grads is inf or NaN, sets found_inf
21 // to 1.0.
22 //
23 // Args:
24 // scaled_grads: A TensorList of scaled gradient tensors. May contain infs or
25 // NaNs. found_inf: A single-element float tensor to which 1.0 will be written
26 // if any gradient contain infs/nans.
27 // Pre-zeroing found_inf, if appropriate, is the responsibility of
28 // the caller.
29 // inv_scale: The inverse of the scale factor by which scaled_grads are
30 // currently multiplied.
_amp_foreach_non_finite_check_and_unscale_cpu_kernel(TensorList scaled_grads,at::Tensor & found_inf,const at::Tensor & inv_scale)31 void _amp_foreach_non_finite_check_and_unscale_cpu_kernel(
32 TensorList scaled_grads,
33 at::Tensor& found_inf,
34 const at::Tensor& inv_scale) {
35 if (scaled_grads.empty()) {
36 return;
37 }
38
39 TORCH_CHECK(inv_scale.is_cpu(), "inv_scale must be a CPU tensor.");
40 TORCH_CHECK(found_inf.is_cpu(), "found_inf must be a CPU tensor.");
41 TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
42 TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
43 TORCH_CHECK(
44 inv_scale.scalar_type() == at::ScalarType::Float,
45 "inv_scale must be a float tensor.");
46 TORCH_CHECK(
47 found_inf.scalar_type() == at::ScalarType::Float,
48 "found_inf must be a float tensor.");
49
50 // Ensures client code (GradScaler) filtered scaled_grads by dtype.
51 at::native::check_foreach_api_restrictions(scaled_grads);
52 for (const at::Tensor& t : scaled_grads) {
53 TORCH_CHECK(t.is_cpu(), "one of scaled_grads was not a CPU tensor.");
54 TORCH_CHECK(
55 t.layout() == at::kStrided,
56 "one of scaled_grads was not a strided tensor.");
57 auto iter = at::TensorIterator::unary_op(
58 const_cast<at::Tensor&>(t), t);
59 if (at::isReducedFloatingType(iter.dtype())) {
60 AT_DISPATCH_REDUCED_FLOATING_TYPES(
61 iter.dtype(),
62 "_amp_foreach_non_finite_check_and_unscale_cpu",
63 [&iter, &found_inf, &inv_scale] {
64 auto* found_inf_ptr = found_inf.data_ptr<float>();
65 auto* inv_scale_ptr = inv_scale.data_ptr<float>();
66
67 using opmath_t = at::opmath_type<scalar_t>;
68
69 at::native::cpu_kernel_vec(
70 iter,
71 [found_inf_ptr, inv_scale_ptr](scalar_t val_in) -> scalar_t {
72 auto val = static_cast<opmath_t>(val_in);
73 if (!std::isfinite(val)) {
74 *found_inf_ptr = 1.f;
75 }
76 // Every thread accesses inv_scale, but it will hit in cache.
77 const auto inv_scale_val = *inv_scale_ptr;
78 return static_cast<scalar_t>(
79 inv_scale_val == 1.f ? val : val * inv_scale_val);
80 },
81 [found_inf_ptr, inv_scale_ptr](Vectorized<scalar_t> val_vec) -> Vectorized<scalar_t>{
82 auto [val_vec0, val_vec1] = convert_to_float<scalar_t>(val_vec);
83 if (val_vec0.has_inf_nan() || val_vec1.has_inf_nan()) {
84 *found_inf_ptr = 1.f;
85 }
86 // Every thread accesses inv_scale, but it will hit in cache.
87 const auto inv_scale_val = *inv_scale_ptr;
88 val_vec0 = inv_scale_val == 1.f ? val_vec0 : val_vec0 * Vectorized<opmath_t>(inv_scale_val);
89 val_vec1 = inv_scale_val == 1.f ? val_vec1 : val_vec1 * Vectorized<opmath_t>(inv_scale_val);
90 return convert_from_float<scalar_t>(val_vec0, val_vec1);
91 });
92 });
93 } else {
94 AT_DISPATCH_FLOATING_TYPES(
95 iter.dtype(),
96 "_amp_foreach_non_finite_check_and_unscale_cpu",
97 [&iter, &found_inf, &inv_scale] {
98 auto* found_inf_ptr = found_inf.data_ptr<float>();
99 auto* inv_scale_ptr = inv_scale.data_ptr<float>();
100 at::native::cpu_kernel_vec(
101 iter,
102 [found_inf_ptr, inv_scale_ptr](scalar_t val_in) -> scalar_t {
103 if (!std::isfinite(val_in)) {
104 *found_inf_ptr = 1.f;
105 }
106 // Every thread accesses inv_scale, but it will hit in cache.
107 const auto inv_scale_val = *inv_scale_ptr;
108 return static_cast<scalar_t>(
109 inv_scale_val == 1.f ? val_in : val_in * inv_scale_val);
110 },
111 [found_inf_ptr, inv_scale_ptr](Vectorized<scalar_t> val_vec) -> Vectorized<scalar_t>{
112 if (val_vec.has_inf_nan()) {
113 *found_inf_ptr = 1.f;
114 }
115 // Every thread accesses inv_scale, but it will hit in cache.
116 const auto inv_scale_val = *inv_scale_ptr;
117 return inv_scale_val == 1.f ? val_vec : val_vec * Vectorized<scalar_t>(inv_scale_val);
118 });
119 });
120 }
121 }
122 }
123
124 // _amp_update_scale_cpu updates the scale tensor in place.
125 //
126 // Args:
127 // current_scale: A one-element float tensor containing the scale value.
128 // growth_tracker: A one-element IntTensor containing the number of recent
129 // consecutive unskipped steps. found_inf: A one-element float tensor. If > 0,
130 // indicates that infs/nans were found by the relevant
131 // prior _amp_non_finite_check_and_unscale_cpu call, and 0 if no
132 // infs/nans were found.
133 // growth_factor: Multiplier if no infs/NaNs were found (typically slightly >
134 // 1). backoff_factor: Multiplier if infs/NaNs were found (typically 0.5).
135 // growth_interval: Number of consecutive unskipped steps that must occur for
136 // current_scale to be multiplied by
137 // growth_factor.
138 //
139 // Returns:
140 // current_scale
_amp_update_scale_cpu_kernel(at::Tensor & current_scale,at::Tensor & growth_tracker,const at::Tensor & found_inf,double growth_factor,double backoff_factor,int64_t growth_interval)141 at::Tensor& _amp_update_scale_cpu_kernel(
142 at::Tensor& current_scale,
143 at::Tensor& growth_tracker,
144 const at::Tensor& found_inf,
145 double growth_factor,
146 double backoff_factor,
147 int64_t growth_interval) {
148 TORCH_CHECK(growth_tracker.is_cpu(), "growth_tracker must be a CPU tensor.");
149 TORCH_CHECK(current_scale.is_cpu(), "current_scale must be a CPU tensor.");
150 TORCH_CHECK(found_inf.is_cpu(), "found_inf must be a CPU tensor.");
151 TORCH_CHECK(
152 growth_tracker.numel() == 1,
153 "growth_tracker must be a 1-element tensor.");
154 TORCH_CHECK(
155 current_scale.numel() == 1, "current_scale must be a 1-element tensor.");
156 TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
157 TORCH_CHECK(
158 growth_tracker.scalar_type() == at::ScalarType::Int,
159 "growth_tracker must be an int tensor.");
160 TORCH_CHECK(
161 current_scale.scalar_type() == at::ScalarType::Float,
162 "current_scale must be a float tensor.");
163 TORCH_CHECK(
164 found_inf.scalar_type() == at::ScalarType::Float,
165 "found_inf must be a float tensor.");
166
167 float* current_scale_ptr = current_scale.data_ptr<float>();
168 int* growth_tracker_ptr = growth_tracker.data_ptr<int>();
169 float* found_inf_ptr = found_inf.data_ptr<float>();
170
171 if (*found_inf_ptr) {
172 *current_scale_ptr = (*current_scale_ptr) * backoff_factor;
173 *growth_tracker_ptr = 0;
174 } else {
175 // Entering this branch means we just carried out a successful step,
176 // so growth_tracker is incremented before comparing to growth_interval.
177 auto successful = (*growth_tracker_ptr) + 1;
178 if (successful == growth_interval) {
179 auto new_scale = static_cast<float>((*current_scale_ptr) * growth_factor);
180 // Do not grow the scale past fp32 bounds to inf.
181 if (std::isfinite(new_scale)) {
182 *current_scale_ptr = new_scale;
183 }
184 *growth_tracker_ptr = 0;
185 } else {
186 *growth_tracker_ptr = successful;
187 }
188 }
189
190 return current_scale;
191 }
192
193 } // namespace
194
195 REGISTER_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu_stub, &_amp_foreach_non_finite_check_and_unscale_cpu_kernel);
196 REGISTER_DISPATCH(_amp_update_scale_cpu_stub, &_amp_update_scale_cpu_kernel);
197
198 } // namespace at::native
199