1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/cpu/utils.h>
6 #include <ATen/native/Resize.h>
7 #include <c10/util/irange.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/empty.h>
14 #include <ATen/ops/nll_loss2d_backward_native.h>
15 #include <ATen/ops/nll_loss2d_forward.h>
16 #include <ATen/ops/nll_loss2d_forward_native.h>
17 #include <ATen/ops/nll_loss2d_native.h>
18 #include <ATen/ops/zeros_like.h>
19
20 #include <utility>
21 #endif
22
23 namespace at::native {
24
25 namespace {
26
27 // Returns a contiguous tensor if the source tensor
28 // is defined. Otherwise returns the undefined
29 // source tensor unmodified.
optional_contiguous(const Tensor & source)30 inline Tensor optional_contiguous(const Tensor& source) {
31 return source.defined() ? source.contiguous() : source;
32 }
33
34 // Returns the address of the first element of a tensor
35 // or nullptr if the tensor is undefined.
36 template <typename scalar_t>
optional_data(const Tensor & source)37 inline scalar_t* optional_data(const Tensor& source) {
38 if constexpr (std::is_const<scalar_t>::value) {
39 return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
40 } else {
41 return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
42 }
43 }
44
check_inputs_nll_loss2d(const Tensor & input,const Tensor & target,const Tensor & weight)45 inline void check_inputs_nll_loss2d(
46 const Tensor& input,
47 const Tensor& target,
48 const Tensor& weight) {
49 TORCH_CHECK(
50 target.dim() == 3,
51 "only batches of spatial targets supported (3D tensors)"
52 " but got targets of dimension: ",
53 target.dim());
54 TORCH_CHECK(
55 input.dim() == 4,
56 "only batches of spatial inputs supported (4D tensors), "
57 "but got input of dimension: ",
58 input.dim());
59 TORCH_CHECK(
60 !weight.defined() || weight.numel() == input.size(1),
61 "weight tensor should be defined either for all or no classes");
62
63 const int64_t input0 = input.size(0);
64 const int64_t input2 = input.size(2);
65 const int64_t input3 = input.size(3);
66 const int64_t target0 = target.size(0);
67 const int64_t target1 = target.size(1);
68 const int64_t target2 = target.size(2);
69 TORCH_CHECK(
70 input0 == target0 && input2 == target1 && input3 == target2,
71 "size mismatch (got input: ",
72 input.sizes(),
73 " , target: ",
74 target.sizes());
75 }
76
check_gradout_shape_nll_loss2d(const Tensor & grad_output,const Tensor & target)77 inline void check_gradout_shape_nll_loss2d(
78 const Tensor& grad_output,
79 const Tensor& target) {
80 TORCH_CHECK(
81 grad_output.dim() == 3,
82 "grad_output must have same dimension as target (3) but got dimension: ",
83 grad_output.sizes());
84
85 const int64_t grad_output0 = grad_output.size(0);
86 const int64_t grad_output1 = grad_output.size(1);
87 const int64_t grad_output2 = grad_output.size(2);
88 const int64_t target0 = target.size(0);
89 const int64_t target1 = target.size(1);
90 const int64_t target2 = target.size(2);
91 TORCH_CHECK(
92 grad_output0 == target0 && grad_output1 == target1 &&
93 grad_output2 == target2,
94 "size mismatch (got grad_output: ",
95 grad_output.sizes(),
96 " target: ",
97 target.sizes());
98 }
99
100
101 template <typename scalar_t>
nll_loss2d_forward_out_frame(Tensor & output,Tensor & total_weight,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index)102 static void nll_loss2d_forward_out_frame(
103 Tensor& output,
104 Tensor& total_weight,
105 const Tensor& input,
106 const Tensor& target,
107 const Tensor& weight,
108 int64_t reduction,
109 int64_t ignore_index) {
110 const int64_t n_classes = input.size(1);
111
112 scalar_t* total_weight_data = total_weight.data_ptr<scalar_t>();
113 *total_weight_data = 0;
114
115 auto weight_contiguous = optional_contiguous(weight);
116 const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
117
118 if (reduction == Reduction::None) {
119 const int64_t batch_size = input.size(0);
120 const int64_t H = input.size(2);
121 const int64_t W = input.size(3);
122
123 at::native::resize_output(output, {batch_size, H, W});
124 auto input_acc = input.accessor<const scalar_t, 4>();
125 auto output_acc = output.accessor<scalar_t, 3>();
126 auto target_acc = target.accessor<const int64_t, 3>();
127
128 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
129 for (const auto b : c10::irange(start, end)) {
130 for (const auto h : c10::irange(H)) {
131 for (const auto w : c10::irange(W)) {
132 const int64_t cur_target = (int64_t)target_acc[b][h][w];
133
134 if (cur_target == ignore_index) {
135 output_acc[b][h][w] = static_cast<scalar_t>(0);
136 continue;
137 }
138
139 TORCH_CHECK_INDEX(
140 cur_target >= 0 && cur_target < n_classes,
141 "Target ",
142 cur_target,
143 " is out of bounds.");
144
145 // load optional weight value
146 const scalar_t cur_weight = weight_data != nullptr
147 ? weight_data[cur_target]
148 : static_cast<scalar_t>(1);
149 output_acc[b][h][w] = -input_acc[b][cur_target][h][w] * cur_weight;
150 }
151 }
152 }
153 });
154
155 return;
156 }
157
158 // produce scalar outputs for the reduction case
159 at::native::resize_output(output, {});
160
161 if (target.numel() == 0) {
162 // Here target (and input) have zero elements
163 // Mean reduction on empty tensors produces NaN. See the discussion in
164 // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
165 if (reduction == Reduction::Mean) {
166 output.fill_(std::numeric_limits<double>::quiet_NaN());
167 } else {
168 output.zero_();
169 }
170 total_weight.zero_();
171 return;
172 }
173
174 auto input_contiguous = input.contiguous();
175 auto target_contiguous = target.contiguous();
176
177 const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
178 const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
179
180 const int64_t batch_size = input.size(0);
181 const int64_t map_size = input.size(2) * input.size(3);
182 const int64_t sample_size = map_size * n_classes;
183 const int64_t numiter = batch_size * map_size;
184
185 constexpr int64_t cascade_sum_num_levels = 8;
186 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
187 scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
188 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
189 scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
190 const int64_t level_power =
191 std::max(int64_t(4), utils::CeilLog2(numiter) / cascade_sum_num_levels);
192 const int64_t level_step = (1 << level_power);
193 const int64_t level_mask = level_step - 1;
194
195 int64_t num_ignored = 0;
196 for (const auto b : c10::irange(batch_size)) {
197 for (const auto elem : c10::irange(map_size)) {
198 const int64_t cur_target = target_data[b * map_size + elem];
199 if (cur_target == ignore_index) {
200 ++num_ignored;
201 continue;
202 }
203
204 TORCH_CHECK_INDEX(
205 cur_target >= 0 && cur_target < n_classes,
206 "Target ",
207 cur_target,
208 " is out of bounds.");
209
210 const auto data = input_data[b * sample_size + cur_target * map_size + elem];
211 if (weight_data) {
212 const scalar_t weight_val = weight_data[cur_target];
213 loss_partial_sums[0] -= data * weight_val;
214 weight_partial_sums[0] += weight_val;
215 } else {
216 loss_partial_sums[0] -= data;
217 }
218
219 const int64_t linear_idx = b * map_size + elem;
220 for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
221 const auto mask = (level_mask << (j * level_power));
222 if (C10_LIKELY((linear_idx & mask) != 0)) {
223 break;
224 }
225
226 weight_partial_sums[j + 1] += weight_partial_sums[j];
227 loss_partial_sums[j + 1] += loss_partial_sums[j];
228
229 weight_partial_sums[j] = 0;
230 loss_partial_sums[j] = 0;
231 }
232 }
233 }
234
235
236 const scalar_t total_weight_val = !weight_data ?
237 static_cast<scalar_t>(numiter - num_ignored) :
238 std::accumulate(std::begin(weight_partial_sums),
239 std::end(weight_partial_sums),
240 scalar_t{0});
241
242 scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
243 std::end(loss_partial_sums),
244 scalar_t{0});
245
246 if (reduction == Reduction::Mean) {
247 output_val /= total_weight_val;
248 }
249
250 *total_weight_data = total_weight_val;
251 *output.data_ptr<scalar_t>() = output_val;
252 }
253
nll_loss2d_forward_out_cpu_template(Tensor & output,Tensor & total_weight,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index)254 void nll_loss2d_forward_out_cpu_template(
255 Tensor& output,
256 Tensor& total_weight,
257 const Tensor& input,
258 const Tensor& target,
259 const Tensor& weight,
260 int64_t reduction,
261 int64_t ignore_index) {
262 check_inputs_nll_loss2d(input, target, weight);
263 total_weight.resize_({});
264
265 AT_DISPATCH_FLOATING_TYPES_AND2(
266 ScalarType::BFloat16,
267 ScalarType::Half,
268 input.scalar_type(),
269 "nll_loss2d_forward_out_frame",
270 [&] {
271 nll_loss2d_forward_out_frame<scalar_t>(
272 output,
273 total_weight,
274 input,
275 target,
276 weight,
277 reduction,
278 ignore_index);
279 });
280 }
281
282 template <typename scalar_t>
nll_loss2d_backward_out_frame(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)283 static void nll_loss2d_backward_out_frame(
284 Tensor& grad_input,
285 const Tensor& grad_output,
286 const Tensor& input,
287 const Tensor& target,
288 const Tensor& weight,
289 int64_t reduction,
290 int64_t ignore_index,
291 const Tensor& total_weight) {
292 auto weight_contiguous = optional_contiguous(weight);
293 const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
294
295 if (reduction == at::Reduction::None) {
296 check_gradout_shape_nll_loss2d(grad_output, target);
297
298 const int64_t batch_size = input.size(0);
299 const int64_t H = input.size(2);
300 const int64_t W = input.size(3);
301
302 auto grad_input_acc = grad_input.accessor<scalar_t, 4>();
303 auto grad_output_acc = grad_output.accessor<const scalar_t, 3>();
304 auto target_acc = target.accessor<const int64_t, 3>();
305
306 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
307 for (const auto b : c10::irange(start, end)) {
308 for (const auto h : c10::irange(H)) {
309 for (const auto w : c10::irange(W)) {
310 const int64_t cur_target = target_acc[b][h][w];
311 if (cur_target == ignore_index) {
312 continue;
313 }
314 const scalar_t value =
315 -(weight_data ? weight_data[cur_target]
316 : static_cast<scalar_t>(1));
317 const scalar_t grad_output_value = grad_output_acc[b][h][w];
318 grad_input_acc[b][cur_target][h][w] = value * grad_output_value;
319 }
320 }
321 }
322 });
323
324 return;
325 }
326
327 const scalar_t total_weight_value = *total_weight.const_data_ptr<scalar_t>();
328
329 TORCH_CHECK(
330 grad_output.dim() <= 1 && grad_output.numel() == 1,
331 "Expected a single element grad_output tensor, but got: ",
332 grad_output.sizes());
333
334 const scalar_t grad_output_value = *grad_output.const_data_ptr<scalar_t>();
335
336 const auto target_contiguous = target.contiguous();
337 const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
338
339 scalar_t* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
340
341 const int64_t batch_size = input.size(0);
342 const int64_t n_classes = input.size(1);
343 const int64_t map_size = input.size(2) * input.size(3);
344 const int64_t sample_size = map_size * n_classes;
345
346 const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
347 : grad_output_value);
348
349 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
350 for (const auto b : c10::irange(start, end)) {
351 for (const auto elem : c10::irange(map_size)) {
352 const int64_t t = target_data[b * map_size + elem];
353
354 if (t != ignore_index) {
355 TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
356
357 const int64_t index = b * sample_size + t * map_size + elem;
358 grad_input_data[index] = weight_data != nullptr ? weight_data[t] * grad
359 : grad;
360 }
361 }
362 }
363 });
364 }
365
nll_loss2d_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)366 void nll_loss2d_backward_out_cpu_template(
367 Tensor& grad_input,
368 const Tensor& grad_output,
369 const Tensor& input,
370 const Tensor& target,
371 const Tensor& weight,
372 int64_t reduction,
373 int64_t ignore_index,
374 const Tensor& total_weight) {
375 check_inputs_nll_loss2d(input, target, weight);
376 grad_input.resize_as_(input);
377 grad_input.zero_();
378 TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
379 TORCH_CHECK(
380 total_weight.numel() == 1,
381 "expected total_weight to be a single element tensor, got: ",
382 total_weight.sizes(),
383 " (",
384 total_weight.numel(),
385 " elements)");
386
387 AT_DISPATCH_FLOATING_TYPES_AND2(
388 ScalarType::BFloat16,
389 ScalarType::Half,
390 input.scalar_type(),
391 "nll_loss2d_backward_out_frame",
392 [&] {
393 nll_loss2d_backward_out_frame<scalar_t>(
394 grad_input,
395 grad_output,
396 input,
397 target,
398 weight,
399 reduction,
400 ignore_index,
401 total_weight);
402 });
403 }
404
405 } // namespace
406
nll_loss2d_forward_out_cpu(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,Tensor & output,Tensor & total_weight)407 std::tuple<Tensor&, Tensor&> nll_loss2d_forward_out_cpu(const Tensor& self,
408 const Tensor& target, const std::optional<Tensor>& weight_opt,
409 int64_t reduction,
410 int64_t ignore_index,
411 Tensor& output,
412 Tensor& total_weight) {
413 // See [Note: hacky wrapper removal for optional tensor]
414 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
415 const Tensor& weight = *weight_maybe_owned;
416
417 nll_loss2d_forward_out_cpu_template(
418 output, total_weight, self, target, weight, reduction, ignore_index);
419 return std::tuple<Tensor&, Tensor&>(output, total_weight);
420 }
421
nll_loss2d_forward_cpu(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index)422 std::tuple<Tensor, Tensor> nll_loss2d_forward_cpu(
423 const Tensor& self,
424 const Tensor& target, const std::optional<Tensor>& weight_opt,
425 int64_t reduction,
426 int64_t ignore_index) {
427 // See [Note: hacky wrapper removal for optional tensor]
428 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
429 const Tensor& weight = *weight_maybe_owned;
430
431 auto output = at::empty({0}, self.options());
432 auto total_weight = at::empty({0}, self.options());
433 at::native::nll_loss2d_forward_out_cpu(
434 self, target, weight, reduction, ignore_index, output, total_weight);
435 return std::make_tuple(output, total_weight);
436 }
437
nll_loss2d_backward_out_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight,Tensor & grad_input)438 Tensor& nll_loss2d_backward_out_cpu(const Tensor& grad_output,
439 const Tensor& self,
440 const Tensor& target, const std::optional<Tensor>& weight_opt,
441 int64_t reduction,
442 int64_t ignore_index,
443 const Tensor& total_weight,
444 Tensor& grad_input) {
445 // See [Note: hacky wrapper removal for optional tensor]
446 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
447 const Tensor& weight = *weight_maybe_owned;
448
449 nll_loss2d_backward_out_cpu_template(
450 grad_input,
451 grad_output,
452 self,
453 target,
454 weight,
455 reduction,
456 ignore_index,
457 total_weight);
458 return grad_input;
459 }
460
nll_loss2d_backward_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)461 Tensor nll_loss2d_backward_cpu(
462 const Tensor& grad_output,
463 const Tensor& self,
464 const Tensor& target, const std::optional<Tensor>& weight_opt,
465 int64_t reduction,
466 int64_t ignore_index,
467 const Tensor& total_weight) {
468 // See [Note: hacky wrapper removal for optional tensor]
469 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
470 const Tensor& weight = *weight_maybe_owned;
471
472 auto grad_input = at::zeros_like(self);
473 at::native::nll_loss2d_backward_out_cpu(
474 grad_output,
475 self,
476 target,
477 weight,
478 reduction,
479 ignore_index,
480 total_weight,
481 grad_input);
482 return grad_input;
483 }
484
nll_loss2d_out(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,Tensor & output)485 Tensor & nll_loss2d_out(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) {
486 // See [Note: hacky wrapper removal for optional tensor]
487 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
488 const Tensor& weight = *weight_maybe_owned;
489
490 Tensor total_weight = at::empty({0}, self.options());
491 return std::get<0>(at::nll_loss2d_forward_out(output, total_weight, self, target, weight, reduction, ignore_index));
492 }
493
nll_loss2d_symint(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,c10::SymInt ignore_index)494 Tensor nll_loss2d_symint(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) {
495 // See [Note: hacky wrapper removal for optional tensor]
496 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
497 const Tensor& weight = *weight_maybe_owned;
498
499 return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
500 }
501
502 } // namespace at::native
503