1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/cuda/Atomic.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/core/TensorAccessor.h>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <c10/cuda/CUDAException.h>
11 #include <c10/macros/Macros.h>
12 #include <ATen/native/IndexingUtils.h>
13 #include <ATen/native/Resize.h>
14 #include <ATen/native/cuda/block_reduce.cuh>
15
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/empty_like.h>
22 #include <ATen/ops/nll_loss2d_forward_native.h>
23 #include <ATen/ops/nll_loss2d_backward_native.h>
24 #endif
25
26 namespace at::native {
27
28 namespace {
29
30 // Returns a contiguous tensor if the source tensor
31 // is defined. Otherwise returns the undefined
32 // source tensor unmodified.
optional_contiguous(const Tensor & source)33 inline Tensor optional_contiguous(const Tensor& source) {
34 return source.defined() ? source.contiguous() : source;
35 }
36
37 // Returns the address of the first element of a tensor
38 // or nullptr if the tensor is undefined.
39 template <typename scalar_t>
optional_data(const Tensor & source)40 inline const scalar_t* optional_data(const Tensor& source) {
41 return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
42 }
43
44 using at::cuda::detail::CUDA_NUM_THREADS;
45 using at::cuda::detail::GET_BLOCKS;
46
47 // TODO(crcrpar): Think about introducing `canUse32BitIndexMath` and choose int or int64_t for `target`.
48 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)49 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
50 __global__ void nll_loss2d_forward_no_reduce_kernel(
51 int64_t n_threads,
52 PackedTensorAccessor64<scalar_t, 4> input,
53 PackedTensorAccessor64<int64_t, 3> target,
54 PackedTensorAccessor64<scalar_t, 3> output,
55 const scalar_t* weight,
56 int64_t ignore_index
57 ) {
58 int64_t batch_size = input.size(0);
59 int64_t n_classes = input.size(1);
60 int64_t H = input.size(2);
61 int64_t W = input.size(3);
62
63 CUDA_KERNEL_LOOP(index, n_threads) {
64 const int64_t b = index % batch_size;
65 const int64_t h = (index / batch_size) % H;
66 const int64_t w = (index / (batch_size * H)) % W;
67
68 int64_t cur_target = target[b][h][w];
69 if (cur_target == ignore_index) {
70 output[b][h][w] = static_cast<scalar_t>(0);
71 continue;
72 }
73 CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes);
74 scalar_t value = input[b][cur_target][h][w];
75 scalar_t cur_weight = weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1);
76 output[b][h][w] = -value * cur_weight;
77 }
78 }
79
80 template <typename scalar_t, typename accscalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)81 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
82 __global__ void nll_loss2d_forward_kernel(
83 scalar_t* output,
84 scalar_t* total_weight,
85 const scalar_t* input,
86 const int64_t* target,
87 const scalar_t* weight,
88 int n_classes,
89 int map_nelem,
90 int blocks_per_sample,
91 int64_t ignore_index) {
92
93 scalar_t cur_weight;
94 accscalar_t input_sum = 0;
95 accscalar_t acc_weight = 0;
96
97 index_t sample = blockIdx.x / blocks_per_sample;
98 index_t toffset = sample * map_nelem;
99 index_t ioffset = sample * map_nelem * n_classes;
100 int step = blockDim.x * blocks_per_sample;
101 for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
102 i < map_nelem;
103 i += step) {
104 index_t t = target[toffset + i];
105 if (t != ignore_index) {
106 CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
107 cur_weight = weight != nullptr ? weight[t] : static_cast<scalar_t>(1);
108 const auto input_index = ioffset + i + map_nelem * t;
109 CUDA_KERNEL_ASSERT(input_index >= 0);
110 input_sum -= input[input_index] * cur_weight;
111 acc_weight += cur_weight;
112 }
113 }
114
115 __shared__ accscalar_t acc_weight_smem[CUDA_NUM_THREADS];
116 __shared__ accscalar_t input_sum_smem[CUDA_NUM_THREADS];
117
118 auto acc_weight_ = cuda_utils::BlockReduceSum(acc_weight, acc_weight_smem);
119 auto input_sum_ = cuda_utils::BlockReduceSum(input_sum, input_sum_smem);
120
121 if (threadIdx.x == 0) {
122 gpuAtomicAdd(total_weight, static_cast<scalar_t>(acc_weight_));
123 gpuAtomicAdd(output, static_cast<scalar_t>(input_sum_));
124 }
125 }
126
127 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)128 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
129 __global__ void nll_loss2d_forward_size_average_kernel(
130 scalar_t* output,
131 const scalar_t* total_weight
132 ) {
133 *output /= *total_weight;
134 }
135
136 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)137 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
138 __global__ void nll_loss2d_backward_no_reduce_kernel(
139 int64_t n_threads,
140 PackedTensorAccessor64<int64_t, 3> target,
141 PackedTensorAccessor64<scalar_t, 3> grad_output,
142 PackedTensorAccessor64<scalar_t, 4> grad_input,
143 const scalar_t* weight,
144 int64_t ignore_index
145 ) {
146 int64_t batch_size = target.size(0);
147 int64_t H = target.size(1);
148 int64_t W = target.size(2);
149
150 CUDA_KERNEL_LOOP(index, n_threads) {
151 const int64_t b = index % batch_size;
152 const int64_t h = (index / batch_size) % H;
153 const int64_t w = (index / (batch_size * H)) % W;
154
155 int64_t cur_target = target[b][h][w];
156 if (cur_target == ignore_index) {
157 continue;
158 }
159 scalar_t value = -(weight != nullptr ? weight[cur_target] : static_cast<scalar_t>(1));
160 grad_input[b][cur_target][h][w] = value * grad_output[b][h][w];
161 }
162 }
163
164 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)165 C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
166 __global__ void nll_loss2d_backward_kernel(
167 scalar_t* grad_input,
168 const scalar_t* grad_output,
169 const int64_t* target,
170 const scalar_t* weights,
171 const scalar_t* total_weight,
172 bool size_average,
173 int n_classes,
174 int map_nelem,
175 int blocks_per_sample,
176 int64_t ignore_index
177 ) {
178 const auto grad = -(size_average ? *grad_output / *total_weight
179 : *grad_output);
180
181 const int sample = blockIdx.x / blocks_per_sample;
182 const int step = blockDim.x * blocks_per_sample;
183
184 const int toffset = sample * map_nelem;
185 const auto* const target_thread = target + toffset;
186
187 const int ioffset = sample * map_nelem * n_classes;
188 auto* const grad_input_thread = grad_input + ioffset;
189
190 for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
191 i < map_nelem;
192 i += step) {
193 const int64_t t = target_thread[i];
194 if (t != ignore_index) {
195 CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
196 const auto grad_input_index = i + map_nelem * t;
197 CUDA_KERNEL_ASSERT(grad_input_index >= 0);
198 grad_input_thread[i + map_nelem * t] = weights != nullptr ? weights[t] * grad
199 : grad;
200 }
201 }
202 }
203
check_inputs_nll_loss2d(const Tensor & input,const Tensor & target,const Tensor & weight)204 void check_inputs_nll_loss2d(
205 const Tensor& input,
206 const Tensor& target,
207 const Tensor& weight) {
208 TORCH_CHECK(
209 target.dim() == 3,
210 "only batches of spatial targets supported (3D tensors)"
211 " but got targets of size: : ",
212 target.sizes());
213 TORCH_CHECK(
214 input.dim() == 4,
215 "only batches of spatial inputs supported (4D tensors), "
216 "but got input of size: ",
217 input.sizes());
218 TORCH_CHECK(
219 !weight.defined() || weight.numel() == input.size(1),
220 "weight tensor should be defined either for all or no classes");
221
222 TORCH_CHECK(
223 input.size(0) == target.size(0) && input.size(2) == target.size(1) &&
224 input.size(3) == target.size(2),
225 "input and target batch or spatial sizes don't match: target ",
226 target.sizes(),
227 ", input ",
228 input.sizes());
229 }
230
nll_loss2d_forward_out_cuda_template(Tensor & output,Tensor & total_weight,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index)231 void nll_loss2d_forward_out_cuda_template(
232 Tensor& output,
233 Tensor& total_weight,
234 const Tensor& input,
235 const Tensor& target,
236 const std::optional<Tensor>& weight_opt,
237 int64_t reduction,
238 int64_t ignore_index) {
239 // See Note [Writing Nondeterministic Operations]
240 // Nondeterministic because of atomicAdd usage in 'sum' or 'mean' reductions.
241 if (reduction != at::Reduction::None) {
242 at::globalContext().alertNotDeterministic("nll_loss2d_forward_out_cuda_template");
243 }
244
245 // See [Note: hacky wrapper removal for optional tensor]
246 c10::MaybeOwned<Tensor> weight_maybe_owned =
247 at::borrow_from_optional_tensor(weight_opt);
248 const Tensor& weight = *weight_maybe_owned;
249
250 check_inputs_nll_loss2d(input, target, weight);
251 total_weight.resize_({});
252
253 if (reduction == at::Reduction::None) {
254 int64_t batch_size = input.size(0);
255 int64_t H = input.size(2);
256 int64_t W = input.size(3);
257 int64_t count = batch_size * H * W;
258
259 at::native::resize_output(output, {batch_size, H, W});
260 if (count == 0) {
261 // This guards from unnecessary operations and launching CUDA kernel with
262 // 0 blocks.
263 return;
264 }
265 auto weight_ = optional_contiguous(weight);
266 AT_DISPATCH_FLOATING_TYPES_AND2(
267 at::ScalarType::Half,
268 at::ScalarType::BFloat16,
269 input.scalar_type(),
270 "nll_loss2d_forward_no_reduce_kernel",
271 [&] {
272 nll_loss2d_forward_no_reduce_kernel<scalar_t>
273 <<<GET_BLOCKS(count),
274 CUDA_NUM_THREADS,
275 0,
276 at::cuda::getCurrentCUDAStream()>>>(
277 count,
278 input.packed_accessor64<scalar_t, 4>(),
279 target.packed_accessor64<int64_t, 3>(),
280 output.packed_accessor64<scalar_t, 3>(),
281 optional_data<scalar_t>(weight_),
282 ignore_index);
283 C10_CUDA_KERNEL_LAUNCH_CHECK();
284 });
285 return;
286 }
287
288 // produce scalar outputs for the reduction case
289 at::native::resize_output(output, {});
290
291 if (target.numel() == 0) {
292 // Here target (and input) have zero elements
293 // Mean reduction on empty tensors produces NaN. See the discussion in
294 // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
295 if (reduction == Reduction::Mean) {
296 output.fill_(std::numeric_limits<double>::quiet_NaN());
297 } else {
298 output.zero_();
299 }
300 total_weight.zero_();
301 return;
302 }
303
304 auto input_ = input.contiguous();
305 auto weight_ = optional_contiguous(weight);
306 auto target_ = target.contiguous();
307
308 output.zero_();
309 total_weight.zero_();
310
311 auto batch_size = target.size(0);
312 int64_t map_nelem = target.numel() / batch_size;
313 int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
314 blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
315 int total_blocks = blocks_per_sample * batch_size;
316
317 AT_DISPATCH_FLOATING_TYPES_AND2(
318 at::ScalarType::Half,
319 at::ScalarType::BFloat16,
320 input.scalar_type(),
321 "nll_loss2d_forward_kernel",
322 [&] {
323 using accscalar_t = acc_type<scalar_t, true>;
324 AT_DISPATCH_INDEX_TYPES(
325 at::native::canUse32BitIndexMath(input_, INT_MAX) ? ScalarType::Int : ScalarType::Long,
326 "nll_loss2d_forward_launcher", [&] {
327 nll_loss2d_forward_kernel<scalar_t, accscalar_t, index_t>
328 <<<total_blocks,
329 CUDA_NUM_THREADS,
330 0,
331 at::cuda::getCurrentCUDAStream()>>>(
332 output.mutable_data_ptr<scalar_t>(),
333 total_weight.mutable_data_ptr<scalar_t>(),
334 input_.const_data_ptr<scalar_t>(),
335 target_.const_data_ptr<int64_t>(),
336 optional_data<scalar_t>(weight_),
337 input_.size(1),
338 input_.size(2) * input_.size(3),
339 blocks_per_sample,
340 ignore_index);
341 C10_CUDA_KERNEL_LAUNCH_CHECK();
342 // Divide by total_weight
343 if (reduction == at::Reduction::Mean) {
344 nll_loss2d_forward_size_average_kernel<scalar_t>
345 <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
346 output.mutable_data_ptr<scalar_t>(),
347 total_weight.const_data_ptr<scalar_t>());
348 C10_CUDA_KERNEL_LAUNCH_CHECK();
349 }
350 });
351 });
352 }
353
nll_loss2d_backward_out_cuda_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)354 void nll_loss2d_backward_out_cuda_template(
355 Tensor& grad_input,
356 const Tensor& grad_output,
357 const Tensor& input,
358 const Tensor& target,
359 const std::optional<Tensor>& weight_opt,
360 int64_t reduction,
361 int64_t ignore_index,
362 const Tensor& total_weight) {
363 // See [Note: hacky wrapper removal for optional tensor]
364 c10::MaybeOwned<Tensor> weight_maybe_owned =
365 at::borrow_from_optional_tensor(weight_opt);
366 const Tensor& weight = *weight_maybe_owned;
367
368 check_inputs_nll_loss2d(input, target, weight);
369 grad_input.resize_as_(input);
370 grad_input.zero_();
371 TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
372 TORCH_CHECK(
373 total_weight.numel() == 1,
374 "expected total_weight to be a single element tensor, got: ",
375 total_weight.sizes(),
376 " (",
377 total_weight.numel(),
378 " elements)");
379
380
381 if (reduction == at::Reduction::None) {
382 TORCH_CHECK(
383 grad_output.dim() == 3,
384 "grad_output must have same dimension as target (3) but got dimension: ",
385 grad_output.sizes());
386 TORCH_CHECK(
387 grad_output.size(0) == target.size(0) &&
388 grad_output.size(1) == target.size(1) &&
389 grad_output.size(2) == target.size(2),
390 "grad_output sizes don't match target sizes: target ",
391 target.sizes(),
392 ", grad_output ",
393 grad_output.sizes())
394 int64_t batch_size = input.size(0);
395 int64_t H = input.size(2);
396 int64_t W = input.size(3);
397 int64_t count = batch_size * H * W;
398
399 if (count == 0) {
400 // This guards from unnecessary operations and launching CUDA kernel with
401 // 0 blocks.
402 return;
403 }
404 auto weight_ = optional_contiguous(weight);
405 AT_DISPATCH_FLOATING_TYPES_AND2(
406 at::ScalarType::Half,
407 at::ScalarType::BFloat16,
408 input.scalar_type(),
409 "nll_loss2d_backward_no_reduce_kernel",
410 [&] {
411 nll_loss2d_backward_no_reduce_kernel<scalar_t>
412 <<<GET_BLOCKS(count),
413 CUDA_NUM_THREADS,
414 0,
415 at::cuda::getCurrentCUDAStream()>>>(
416 count,
417 target.packed_accessor64<int64_t, 3>(),
418 grad_output.packed_accessor64<scalar_t, 3>(),
419 grad_input.packed_accessor64<scalar_t, 4>(),
420 optional_data<scalar_t>(weight_),
421 ignore_index);
422 C10_CUDA_KERNEL_LAUNCH_CHECK();
423 });
424 return;
425 }
426
427 int64_t batch_size = target.size(0);
428 auto target_numel = target.numel();
429 if (batch_size != 0 && target_numel != 0) {
430 // This guards from unnecessary operations and launching CUDA kernel with 1
431 // blocks.
432 auto target_ = target.contiguous();
433 auto weight_ = optional_contiguous(weight);
434
435 int64_t map_nelem = target_numel / batch_size;
436 int blocks_per_sample = GET_BLOCKS(map_nelem) / 128;
437 blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
438 int total_blocks = blocks_per_sample * batch_size;
439
440 AT_DISPATCH_FLOATING_TYPES_AND2(
441 at::ScalarType::Half,
442 at::ScalarType::BFloat16,
443 input.scalar_type(),
444 "nll_loss2d_backward_kernel",
445 [&] {
446 nll_loss2d_backward_kernel<scalar_t>
447 <<<total_blocks,
448 CUDA_NUM_THREADS,
449 0,
450 at::cuda::getCurrentCUDAStream()>>>(
451 grad_input.mutable_data_ptr<scalar_t>(),
452 grad_output.const_data_ptr<scalar_t>(),
453 target_.const_data_ptr<int64_t>(),
454 optional_data<scalar_t>(weight_),
455 total_weight.const_data_ptr<scalar_t>(),
456 reduction == at::Reduction::Mean,
457 input.size(1),
458 map_nelem,
459 blocks_per_sample,
460 ignore_index);
461 C10_CUDA_KERNEL_LAUNCH_CHECK();
462 });
463 }
464 }
465 } // namespace
466
nll_loss2d_forward_out_cuda(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,Tensor & output,Tensor & total_weight)467 std::tuple<Tensor&, Tensor&> nll_loss2d_forward_out_cuda(
468 const Tensor& self,
469 const Tensor& target,
470 const std::optional<Tensor>& weight_opt,
471 int64_t reduction,
472 int64_t ignore_index,
473 Tensor& output,
474 Tensor& total_weight) {
475 nll_loss2d_forward_out_cuda_template(
476 output, total_weight, self, target, weight_opt, reduction, ignore_index);
477 return std::tuple<Tensor&, Tensor&>(output, total_weight);
478 }
479
nll_loss2d_forward_cuda(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index)480 std::tuple<Tensor, Tensor> nll_loss2d_forward_cuda(
481 const Tensor& self,
482 const Tensor& target,
483 const std::optional<Tensor>& weight_opt,
484 int64_t reduction,
485 int64_t ignore_index) {
486 auto output = at::empty({0}, self.options());
487 auto total_weight = at::empty({0}, self.options());
488 nll_loss2d_forward_out_cuda_template(
489 output, total_weight, self, target, weight_opt, reduction, ignore_index);
490 return std::make_tuple(output, total_weight);
491 }
492
nll_loss2d_backward_out_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight,Tensor & grad_input)493 Tensor& nll_loss2d_backward_out_cuda(
494 const Tensor& grad_output,
495 const Tensor& self,
496 const Tensor& target,
497 const std::optional<Tensor>& weight_opt,
498 int64_t reduction,
499 int64_t ignore_index,
500 const Tensor& total_weight,
501 Tensor& grad_input) {
502 nll_loss2d_backward_out_cuda_template(
503 grad_input,
504 grad_output,
505 self,
506 target,
507 weight_opt,
508 reduction,
509 ignore_index,
510 total_weight);
511 return grad_input;
512 }
513
nll_loss2d_backward_cuda(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)514 Tensor nll_loss2d_backward_cuda(
515 const Tensor& grad_output,
516 const Tensor& self,
517 const Tensor& target,
518 const std::optional<Tensor>& weight_opt,
519 int64_t reduction,
520 int64_t ignore_index,
521 const Tensor& total_weight) {
522 auto grad_input = at::empty_like(self);
523 nll_loss2d_backward_out_cuda_template(
524 grad_input,
525 grad_output,
526 self,
527 target,
528 weight_opt,
529 reduction,
530 ignore_index,
531 total_weight);
532 return grad_input;
533 }
534
535 } // namespace at::native
536