xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/NLLLoss2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/cuda/Atomic.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/core/TensorAccessor.h>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <c10/cuda/CUDAException.h>
11 #include <c10/macros/Macros.h>
12 #include <ATen/native/IndexingUtils.h>
13 #include <ATen/native/Resize.h>
14 #include <ATen/native/cuda/block_reduce.cuh>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/empty_like.h>
22 #include <ATen/ops/nll_loss2d_forward_native.h>
23 #include <ATen/ops/nll_loss2d_backward_native.h>
24 #endif
25 
26 namespace at::native {
27 
28 namespace {
29 
30 // Returns a contiguous tensor if the source tensor
31 // is defined. Otherwise returns the undefined
32 // source tensor unmodified.
optional_contiguous(const Tensor & source)33 inline Tensor optional_contiguous(const Tensor& source) {
34   return source.defined() ? source.contiguous() : source;
35 }
36 
37 // Returns the address of the first element of a tensor
38 // or nullptr if the tensor is undefined.
39 template <typename scalar_t>
optional_data(const Tensor & source)40 inline const scalar_t* optional_data(const Tensor& source) {
41   return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
42 }
43 
44 using at::cuda::detail::CUDA_NUM_THREADS;
45 using at::cuda::detail::GET_BLOCKS;
46 
47 // TODO(crcrpar): Think about introducing `canUse32BitIndexMath` and choose int or int64_t for `target`.
48 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)49 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
50 __global__ void nll_loss2d_forward_no_reduce_kernel(
51   int64_t n_threads,
52   PackedTensorAccessor64<scalar_t, 4> input,
53   PackedTensorAccessor64<int64_t, 3> target,
54   PackedTensorAccessor64<scalar_t, 3> output,
55   const scalar_t* weight,
56   int64_t ignore_index
57 ) {
58   int64_t batch_size = input.size(0);
59   int64_t n_classes = input.size(1);
60   int64_t H = input.size(2);
61   int64_t W = input.size(3);
62 
63   CUDA_KERNEL_LOOP(index, n_threads) {
64     const int64_t b = index % batch_size;
65     const int64_t h = (index / batch_size) % H;
66     const int64_t w = (index / (batch_size * H)) % W;
67 
68     int64_t cur_target = target[b][h][w];
69     if (cur_target == ignore_index) {
70       output[b][h][w] = static_cast<scalar_t>(0);
71       continue;
72     }
73     CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes);
74     scalar_t value = input[b][cur_target][h][w];
75     scalar_t cur_weight = weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1);
76     output[b][h][w] = -value * cur_weight;
77   }
78 }
79 
80 template <typename scalar_t, typename accscalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)81 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
82 __global__ void nll_loss2d_forward_kernel(
83   scalar_t* output,
84   scalar_t* total_weight,
85   const scalar_t* input,
86   const int64_t* target,
87   const scalar_t* weight,
88   int n_classes,
89   int map_nelem,
90   int blocks_per_sample,
91   int64_t ignore_index) {
92 
93   scalar_t cur_weight;
94   accscalar_t input_sum = 0;
95   accscalar_t acc_weight = 0;
96 
97   index_t sample = blockIdx.x / blocks_per_sample;
98   index_t toffset = sample * map_nelem;
99   index_t ioffset = sample * map_nelem * n_classes;
100   int step = blockDim.x * blocks_per_sample;
101   for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
102        i < map_nelem;
103        i += step) {
104     index_t t = target[toffset + i];
105     if (t != ignore_index) {
106       CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
107       cur_weight = weight != nullptr ? weight[t] : static_cast<scalar_t>(1);
108       const auto input_index = ioffset + i + map_nelem * t;
109       CUDA_KERNEL_ASSERT(input_index >= 0);
110       input_sum -= input[input_index] * cur_weight;
111       acc_weight += cur_weight;
112     }
113   }
114 
115   __shared__ accscalar_t acc_weight_smem[CUDA_NUM_THREADS];
116   __shared__ accscalar_t input_sum_smem[CUDA_NUM_THREADS];
117 
118   auto acc_weight_ = cuda_utils::BlockReduceSum(acc_weight, acc_weight_smem);
119   auto input_sum_ = cuda_utils::BlockReduceSum(input_sum, input_sum_smem);
120 
121   if (threadIdx.x == 0) {
122     gpuAtomicAdd(total_weight, static_cast<scalar_t>(acc_weight_));
123     gpuAtomicAdd(output, static_cast<scalar_t>(input_sum_));
124   }
125 }
126 
127 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)128 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
129 __global__ void nll_loss2d_forward_size_average_kernel(
130   scalar_t* output,
131   const scalar_t* total_weight
132 ) {
133   *output /= *total_weight;
134 }
135 
136 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)137 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
138 __global__ void nll_loss2d_backward_no_reduce_kernel(
139   int64_t n_threads,
140   PackedTensorAccessor64<int64_t, 3> target,
141   PackedTensorAccessor64<scalar_t, 3> grad_output,
142   PackedTensorAccessor64<scalar_t, 4> grad_input,
143   const scalar_t* weight,
144   int64_t ignore_index
145 ) {
146   int64_t batch_size = target.size(0);
147   int64_t H = target.size(1);
148   int64_t W = target.size(2);
149 
150   CUDA_KERNEL_LOOP(index, n_threads) {
151     const int64_t b = index % batch_size;
152     const int64_t h = (index / batch_size) % H;
153     const int64_t w = (index / (batch_size * H)) % W;
154 
155     int64_t cur_target = target[b][h][w];
156     if (cur_target == ignore_index) {
157       continue;
158     }
159     scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1));
160     grad_input[b][cur_target][h][w] = value * grad_output[b][h][w];
161   }
162 }
163 
164 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)165 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
166 __global__ void nll_loss2d_backward_kernel(
167   scalar_t* grad_input,
168   const scalar_t* grad_output,
169   const int64_t* target,
170   const scalar_t* weights,
171   const scalar_t* total_weight,
172   bool size_average,
173   int n_classes,
174   int map_nelem,
175   int blocks_per_sample,
176   int64_t ignore_index
177 ) {
178   const auto grad = -(size_average ? *grad_output / *total_weight
179                                    : *grad_output);
180 
181   const int sample = blockIdx.x / blocks_per_sample;
182   const int step = blockDim.x * blocks_per_sample;
183 
184   const int toffset = sample * map_nelem;
185   const auto* const target_thread = target + toffset;
186 
187   const int ioffset = sample * map_nelem * n_classes;
188   auto* const grad_input_thread = grad_input + ioffset;
189 
190   for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
191        i < map_nelem;
192        i += step) {
193     const int64_t t = target_thread[i];
194     if (t != ignore_index) {
195       CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
196       const auto grad_input_index = i + map_nelem * t;
197       CUDA_KERNEL_ASSERT(grad_input_index >= 0);
198       grad_input_thread[i + map_nelem * t] = weights != nullptr ? weights[t] * grad
199                                                                 : grad;
200     }
201   }
202 }
203 
check_inputs_nll_loss2d(const Tensor & input,const Tensor & target,const Tensor & weight)204 void check_inputs_nll_loss2d(
205     const Tensor& input,
206     const Tensor& target,
207     const Tensor& weight) {
208   TORCH_CHECK(
209       target.dim() == 3,
210       "only batches of spatial targets supported (3D tensors)"
211       " but got targets of size: : ",
212       target.sizes());
213   TORCH_CHECK(
214       input.dim() == 4,
215       "only batches of spatial inputs supported (4D tensors), "
216       "but got input of size: ",
217       input.sizes());
218   TORCH_CHECK(
219       !weight.defined() || weight.numel() == input.size(1),
220       "weight tensor should be defined either for all or no classes");
221 
222   TORCH_CHECK(
223       input.size(0) == target.size(0) && input.size(2) == target.size(1) &&
224           input.size(3) == target.size(2),
225       "input and target batch or spatial sizes don't match: target ",
226       target.sizes(),
227       ", input ",
228       input.sizes());
229 }
230 
nll_loss2d_forward_out_cuda_template(Tensor & output,Tensor & total_weight,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index)231 void nll_loss2d_forward_out_cuda_template(
232     Tensor& output,
233     Tensor& total_weight,
234     const Tensor& input,
235     const Tensor& target,
236     const std::optional<Tensor>& weight_opt,
237     int64_t reduction,
238     int64_t ignore_index) {
239   // See Note [Writing Nondeterministic Operations]
240   // Nondeterministic because of atomicAdd usage in 'sum' or 'mean' reductions.
241   if (reduction != at::Reduction::None) {
242     at::globalContext().alertNotDeterministic("nll_loss2d_forward_out_cuda_template");
243   }
244 
245   // See [Note: hacky wrapper removal for optional tensor]
246   c10::MaybeOwned<Tensor> weight_maybe_owned =
247       at::borrow_from_optional_tensor(weight_opt);
248   const Tensor& weight = *weight_maybe_owned;
249 
250   check_inputs_nll_loss2d(input, target, weight);
251   total_weight.resize_({});
252 
253   if (reduction == at::Reduction::None) {
254     int64_t batch_size = input.size(0);
255     int64_t H = input.size(2);
256     int64_t W = input.size(3);
257     int64_t count = batch_size * H * W;
258 
259     at::native::resize_output(output, {batch_size, H, W});
260     if (count == 0) {
261       // This guards from unnecessary operations and launching CUDA kernel with
262       // 0 blocks.
263       return;
264     }
265     auto weight_ = optional_contiguous(weight);
266     AT_DISPATCH_FLOATING_TYPES_AND2(
267         at::ScalarType::Half,
268         at::ScalarType::BFloat16,
269         input.scalar_type(),
270         "nll_loss2d_forward_no_reduce_kernel",
271         [&] {
272           nll_loss2d_forward_no_reduce_kernel<scalar_t>
273               <<<GET_BLOCKS(count),
274                  CUDA_NUM_THREADS,
275                  0,
276                  at::cuda::getCurrentCUDAStream()>>>(
277                   count,
278                   input.packed_accessor64<scalar_t, 4>(),
279                   target.packed_accessor64<int64_t, 3>(),
280                   output.packed_accessor64<scalar_t, 3>(),
281                   optional_data<scalar_t>(weight_),
282                   ignore_index);
283           C10_CUDA_KERNEL_LAUNCH_CHECK();
284         });
285     return;
286   }
287 
288   // produce scalar outputs for the reduction case
289   at::native::resize_output(output, {});
290 
291   if (target.numel() == 0) {
292     // Here target (and input) have zero elements
293     // Mean reduction on empty tensors produces NaN. See the discussion in
294     // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
295     if (reduction == Reduction::Mean) {
296       output.fill_(std::numeric_limits<double>::quiet_NaN());
297     } else {
298       output.zero_();
299     }
300     total_weight.zero_();
301     return;
302   }
303 
304   auto input_ = input.contiguous();
305   auto weight_ = optional_contiguous(weight);
306   auto target_ = target.contiguous();
307 
308   output.zero_();
309   total_weight.zero_();
310 
311   auto batch_size = target.size(0);
312   int64_t map_nelem = target.numel() / batch_size;
313   int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
314   blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
315   int total_blocks = blocks_per_sample * batch_size;
316 
317   AT_DISPATCH_FLOATING_TYPES_AND2(
318       at::ScalarType::Half,
319       at::ScalarType::BFloat16,
320       input.scalar_type(),
321       "nll_loss2d_forward_kernel",
322       [&] {
323         using accscalar_t = acc_type<scalar_t, true>;
324     AT_DISPATCH_INDEX_TYPES(
325         at::native::canUse32BitIndexMath(input_, INT_MAX) ? ScalarType::Int : ScalarType::Long,
326         "nll_loss2d_forward_launcher", [&] {
327             nll_loss2d_forward_kernel<scalar_t, accscalar_t, index_t>
328                 <<<total_blocks,
329                   CUDA_NUM_THREADS,
330                   0,
331                   at::cuda::getCurrentCUDAStream()>>>(
332                     output.mutable_data_ptr<scalar_t>(),
333                     total_weight.mutable_data_ptr<scalar_t>(),
334                     input_.const_data_ptr<scalar_t>(),
335                     target_.const_data_ptr<int64_t>(),
336                     optional_data<scalar_t>(weight_),
337                     input_.size(1),
338                     input_.size(2) * input_.size(3),
339                     blocks_per_sample,
340                     ignore_index);
341             C10_CUDA_KERNEL_LAUNCH_CHECK();
342             // Divide by total_weight
343             if (reduction == at::Reduction::Mean) {
344               nll_loss2d_forward_size_average_kernel<scalar_t>
345                   <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
346                       output.mutable_data_ptr<scalar_t>(),
347                       total_weight.const_data_ptr<scalar_t>());
348               C10_CUDA_KERNEL_LAUNCH_CHECK();
349             }
350     });
351       });
352 }
353 
nll_loss2d_backward_out_cuda_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)354 void nll_loss2d_backward_out_cuda_template(
355     Tensor& grad_input,
356     const Tensor& grad_output,
357     const Tensor& input,
358     const Tensor& target,
359     const std::optional<Tensor>& weight_opt,
360     int64_t reduction,
361     int64_t ignore_index,
362     const Tensor& total_weight) {
363   // See [Note: hacky wrapper removal for optional tensor]
364   c10::MaybeOwned<Tensor> weight_maybe_owned =
365       at::borrow_from_optional_tensor(weight_opt);
366   const Tensor& weight = *weight_maybe_owned;
367 
368   check_inputs_nll_loss2d(input, target, weight);
369   grad_input.resize_as_(input);
370   grad_input.zero_();
371   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
372   TORCH_CHECK(
373       total_weight.numel() == 1,
374       "expected total_weight to be a single element tensor, got: ",
375       total_weight.sizes(),
376       " (",
377       total_weight.numel(),
378       " elements)");
379 
380 
381   if (reduction == at::Reduction::None) {
382     TORCH_CHECK(
383         grad_output.dim() == 3,
384         "grad_output must have same dimension as target (3) but got dimension: ",
385         grad_output.sizes());
386     TORCH_CHECK(
387         grad_output.size(0) == target.size(0) &&
388             grad_output.size(1) == target.size(1) &&
389             grad_output.size(2) == target.size(2),
390         "grad_output sizes don't match target sizes: target ",
391         target.sizes(),
392         ", grad_output ",
393         grad_output.sizes())
394     int64_t batch_size = input.size(0);
395     int64_t H = input.size(2);
396     int64_t W = input.size(3);
397     int64_t count = batch_size * H * W;
398 
399     if (count == 0) {
400       // This guards from unnecessary operations and launching CUDA kernel with
401       // 0 blocks.
402       return;
403     }
404     auto weight_ = optional_contiguous(weight);
405     AT_DISPATCH_FLOATING_TYPES_AND2(
406         at::ScalarType::Half,
407         at::ScalarType::BFloat16,
408         input.scalar_type(),
409         "nll_loss2d_backward_no_reduce_kernel",
410         [&] {
411           nll_loss2d_backward_no_reduce_kernel<scalar_t>
412               <<<GET_BLOCKS(count),
413                  CUDA_NUM_THREADS,
414                  0,
415                  at::cuda::getCurrentCUDAStream()>>>(
416                   count,
417                   target.packed_accessor64<int64_t, 3>(),
418                   grad_output.packed_accessor64<scalar_t, 3>(),
419                   grad_input.packed_accessor64<scalar_t, 4>(),
420                   optional_data<scalar_t>(weight_),
421                   ignore_index);
422           C10_CUDA_KERNEL_LAUNCH_CHECK();
423         });
424     return;
425   }
426 
427   int64_t batch_size = target.size(0);
428   auto target_numel = target.numel();
429   if (batch_size != 0 && target_numel != 0) {
430     // This guards from unnecessary operations and launching CUDA kernel with 1
431     // blocks.
432     auto target_ = target.contiguous();
433     auto weight_ = optional_contiguous(weight);
434 
435     int64_t map_nelem = target_numel / batch_size;
436     int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
437     blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
438     int total_blocks = blocks_per_sample * batch_size;
439 
440     AT_DISPATCH_FLOATING_TYPES_AND2(
441         at::ScalarType::Half,
442         at::ScalarType::BFloat16,
443         input.scalar_type(),
444         "nll_loss2d_backward_kernel",
445         [&] {
446           nll_loss2d_backward_kernel<scalar_t>
447               <<<total_blocks,
448                 CUDA_NUM_THREADS,
449                 0,
450                 at::cuda::getCurrentCUDAStream()>>>(
451                   grad_input.mutable_data_ptr<scalar_t>(),
452                   grad_output.const_data_ptr<scalar_t>(),
453                   target_.const_data_ptr<int64_t>(),
454                   optional_data<scalar_t>(weight_),
455                   total_weight.const_data_ptr<scalar_t>(),
456                   reduction == at::Reduction::Mean,
457                   input.size(1),
458                   map_nelem,
459                   blocks_per_sample,
460                   ignore_index);
461           C10_CUDA_KERNEL_LAUNCH_CHECK();
462         });
463   }
464 }
465 } // namespace
466 
nll_loss2d_forward_out_cuda(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,Tensor & output,Tensor & total_weight)467 std::tuple<Tensor&, Tensor&> nll_loss2d_forward_out_cuda(
468     const Tensor& self,
469     const Tensor& target,
470     const std::optional<Tensor>& weight_opt,
471     int64_t reduction,
472     int64_t ignore_index,
473     Tensor& output,
474     Tensor& total_weight) {
475   nll_loss2d_forward_out_cuda_template(
476       output, total_weight, self, target, weight_opt, reduction, ignore_index);
477   return std::tuple<Tensor&, Tensor&>(output, total_weight);
478 }
479 
nll_loss2d_forward_cuda(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index)480 std::tuple<Tensor, Tensor> nll_loss2d_forward_cuda(
481     const Tensor& self,
482     const Tensor& target,
483     const std::optional<Tensor>& weight_opt,
484     int64_t reduction,
485     int64_t ignore_index) {
486   auto output = at::empty({0}, self.options());
487   auto total_weight = at::empty({0}, self.options());
488   nll_loss2d_forward_out_cuda_template(
489       output, total_weight, self, target, weight_opt, reduction, ignore_index);
490   return std::make_tuple(output, total_weight);
491 }
492 
nll_loss2d_backward_out_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight,Tensor & grad_input)493 Tensor& nll_loss2d_backward_out_cuda(
494     const Tensor& grad_output,
495     const Tensor& self,
496     const Tensor& target,
497     const std::optional<Tensor>& weight_opt,
498     int64_t reduction,
499     int64_t ignore_index,
500     const Tensor& total_weight,
501     Tensor& grad_input) {
502   nll_loss2d_backward_out_cuda_template(
503       grad_input,
504       grad_output,
505       self,
506       target,
507       weight_opt,
508       reduction,
509       ignore_index,
510       total_weight);
511   return grad_input;
512 }
513 
nll_loss2d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)514 Tensor nll_loss2d_backward_cuda(
515     const Tensor& grad_output,
516     const Tensor& self,
517     const Tensor& target,
518     const std::optional<Tensor>& weight_opt,
519     int64_t reduction,
520     int64_t ignore_index,
521     const Tensor& total_weight) {
522   auto grad_input = at::empty_like(self);
523   nll_loss2d_backward_out_cuda_template(
524       grad_input,
525       grad_output,
526       self,
527       target,
528       weight_opt,
529       reduction,
530       ignore_index,
531       total_weight);
532   return grad_input;
533 }
534 
535 } // namespace at::native
536