1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorIndexing.h>
6 #include <ATen/TensorMeta.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <ATen/native/Resize.h>
11 #include <c10/util/SmallBuffer.h>
12 #include <ATen/TensorSubclassLikeUtils.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/cross_entropy_loss_native.h>
19 #include <ATen/ops/empty.h>
20 #include <ATen/ops/log_softmax.h>
21 #include <ATen/ops/nll_loss.h>
22 #include <ATen/ops/nll_loss2d.h>
23 #include <ATen/ops/nll_loss_backward_native.h>
24 #include <ATen/ops/nll_loss_forward.h>
25 #include <ATen/ops/nll_loss_forward_native.h>
26 #include <ATen/ops/nll_loss_native.h>
27 #include <ATen/ops/nll_loss_nd.h>
28 #include <ATen/ops/nll_loss_nd_native.h>
29 #endif
30
31 #include <c10/core/TensorOptions.h>
32 #include <c10/util/irange.h>
33
34 #include <utility>
35
36 namespace at::meta {
TORCH_META_FUNC(nll_loss_forward)37 TORCH_META_FUNC(nll_loss_forward)
38 (const Tensor& self,
39 const Tensor& target,
40 const OptionalTensorRef weight_opt,
41 int64_t reduction,
42 int64_t ignore_index) {
43 const Tensor& weight = weight_opt.getTensorRef();
44
45 TORCH_CHECK(
46 self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
47 TORCH_CHECK(
48 target.dim() <= 1,
49 "0D or 1D target tensor expected, multi-target not supported");
50
51 auto no_batch_dim = self.dim() == 1 && target.dim() == 0;
52 TORCH_CHECK(
53 no_batch_dim || (self.size(0) == target.size(0)),
54 "size mismatch (got input: ",
55 self.sizes(),
56 ", target: ",
57 target.sizes(),
58 ")")
59
60 const auto n_classes = self.size(-1);
61
62 TORCH_CHECK(
63 !weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
64 "weight tensor should be defined either for all ",
65 n_classes,
66 " classes or no classes"
67 " but got weight tensor of shape: ",
68 weight.sizes());
69
70 const auto n_dims = self.dim();
71 const auto batch_size = self.size(0);
72
73 if (reduction == Reduction::None && n_dims == 2) {
74 set_output_raw_strided(0, {batch_size}, {}, self.options());
75 } else {
76 // produce scalar output when reducing or input is 1d
77 set_output_raw_strided(0, {}, {}, self.options());
78 }
79
80 set_output_raw_strided(1, {}, {}, self.options());
81 }
82
TORCH_META_FUNC(nll_loss_backward)83 TORCH_META_FUNC(nll_loss_backward)
84 (const Tensor& grad_output,
85 const Tensor& self,
86 const Tensor& target,
87 OptionalTensorRef weight_opt,
88 int64_t reduction,
89 int64_t ignore_index,
90 const Tensor& total_weight) {
91 TORCH_CHECK(
92 self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
93 TORCH_CHECK(
94 target.dim() <= 1,
95 "0D or 1D target tensor expected, multi-target not supported");
96
97 auto no_batch_dim = self.dim() == 1 && target.dim() == 0;
98 TORCH_CHECK(
99 no_batch_dim || (self.size(0) == target.size(0)),
100 "size mismatch (got input: ",
101 self.sizes(),
102 ", target: ",
103 target.sizes(),
104 ")")
105 TORCH_CHECK(
106 total_weight.numel() == 1,
107 "expected total_weight to be a single element tensor, got: ",
108 total_weight.sizes(),
109 " (",
110 total_weight.numel(),
111 " elements)");
112
113 const auto& weight = weight_opt.getTensorRef();
114
115 TORCH_CHECK(
116 !weight.defined() || weight.numel() == self.size(-1),
117 "weight tensor should be defined either for all or no classes");
118
119 const auto n_dims = self.dim();
120
121 if (reduction == Reduction::None && n_dims == 2) {
122 const auto batch_size = self.size(0);
123 check_dim_size(grad_output, 1, 0, batch_size);
124 } else {
125 TORCH_CHECK(
126 grad_output.dim() <= 1 && grad_output.numel() == 1,
127 "Expected a single element grad_output tensor, but got: ",
128 grad_output.sizes());
129 }
130
131 set_output_raw_strided(0, self.sizes(), {}, self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));
132 }
133 } // namespace at::meta
134
135 namespace at::native {
136
137 namespace {
138
139 // Returns a contiguous tensor if the source tensor
140 // is defined. Otherwise returns the undefined
141 // source tensor unmodified.
optional_contiguous(const Tensor & source)142 inline Tensor optional_contiguous(const Tensor& source) {
143 return source.defined() ? source.contiguous() : source;
144 }
145
146 // Returns the address of the first element of a tensor
147 // or nullptr if the tensor is undefined.
148 template <typename scalar_t>
optional_data(const Tensor & source)149 inline scalar_t* optional_data(const Tensor& source) {
150 if constexpr (std::is_const<scalar_t>::value) {
151 return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
152 } else {
153 return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
154 }
155 }
156
157 template <typename scalar_t, typename target_t>
nll_loss_out_frame(const Tensor & output,const Tensor & total_weight,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index)158 static void nll_loss_out_frame(
159 const Tensor& output,
160 const Tensor& total_weight,
161 const Tensor& input,
162 const Tensor& target,
163 const Tensor& weight,
164 int64_t reduction,
165 int64_t ignore_index) {
166 const auto n_dims = input.dim();
167 const auto n_classes = input.size(-1);
168
169 scalar_t* total_weight_data = total_weight.data_ptr<scalar_t>();
170 *total_weight_data = 0;
171
172 auto weight_contiguous = optional_contiguous(weight);
173 const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
174
175 if (reduction == Reduction::None && n_dims == 2) {
176 const auto batch_size = input.size(0);
177 at::native::resize_output(output, {batch_size});
178
179 auto input_acc = input.accessor<const scalar_t, 2>();
180 auto target_acc = target.accessor<const target_t, 1>();
181 auto output_acc = output.accessor<scalar_t, 1>();
182
183 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
184 for (const auto i : c10::irange(start, end)) {
185 const auto cur_target = target_acc[i];
186
187 if (cur_target == ignore_index) {
188 output_acc[i] = 0;
189 continue;
190 }
191
192 TORCH_CHECK_INDEX(
193 cur_target >= 0 && cur_target < n_classes,
194 "Target ",
195 cur_target,
196 " is out of bounds.");
197
198 scalar_t cur_weight = weight_data != nullptr ? weight_data[cur_target]
199 : static_cast<scalar_t>(1);
200 output_acc[i] = -input_acc[i][cur_target] * cur_weight;
201 }
202 });
203
204 return;
205 }
206
207 // produce scalar outputs for the reduction case
208 at::native::resize_output(output, {});
209
210 if (target.numel() == 0) {
211 // Here target (and input) have zero elements
212 // Mean reduction on empty tensors produces NaN. See the discussion in
213 // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
214 if (reduction == Reduction::Mean) {
215 output.fill_(std::numeric_limits<double>::quiet_NaN());
216 } else {
217 output.zero_();
218 }
219 total_weight.zero_();
220 return;
221 }
222
223 auto input_contiguous = input.contiguous();
224 auto target_contiguous = target.contiguous();
225
226 const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
227 const target_t* target_data = target_contiguous.const_data_ptr<target_t>();
228
229 const int64_t ndim = input.dim();
230 const int64_t batch_size = ndim == 1 ? 1 : input.size(0);
231
232 constexpr int64_t cascade_sum_num_levels = 8;
233 const int64_t level_power =
234 std::max(int64_t(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
235 const int64_t level_step = (1 << level_power);
236 const int64_t level_mask = level_step - 1;
237
238 int64_t num_ignored = 0;
239
240 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
241 scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
242 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
243 scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
244 for (const auto b : c10::irange(batch_size)) {
245 const int64_t cur_target = target_data[b];
246 if (cur_target == ignore_index) {
247 ++num_ignored;
248 continue;
249 }
250
251 TORCH_CHECK_INDEX(
252 cur_target >= 0 && cur_target < n_classes,
253 "Target ",
254 cur_target,
255 " is out of bounds.");
256
257 const auto data = input_data[b * n_classes + cur_target];
258 if (weight_data) {
259 const scalar_t weight_val = weight_data[cur_target];
260 loss_partial_sums[0] -= data * weight_val;
261 weight_partial_sums[0] += weight_val;
262 } else {
263 loss_partial_sums[0] -= data;
264 }
265
266 for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
267 const auto mask = (level_mask << (j * level_power));
268 if (C10_LIKELY((b & mask) != 0)) {
269 break;
270 }
271
272 weight_partial_sums[j + 1] += weight_partial_sums[j];
273 loss_partial_sums[j + 1] += loss_partial_sums[j];
274
275 weight_partial_sums[j] = 0;
276 loss_partial_sums[j] = 0;
277 }
278 }
279
280 const scalar_t total_weight_val = !weight_data ?
281 static_cast<scalar_t>(batch_size - num_ignored) :
282 std::accumulate(std::begin(weight_partial_sums),
283 std::end(weight_partial_sums),
284 scalar_t{0});
285
286 scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
287 std::end(loss_partial_sums),
288 scalar_t{0});
289
290 if (reduction == Reduction::Mean) {
291 output_val /= total_weight_val;
292 }
293
294 // write result to output tensors
295 *output.data_ptr<scalar_t>() = output_val;
296 *total_weight_data = total_weight_val;
297 }
298
nll_loss_forward_out_cpu_template(const Tensor & output,const Tensor & total_weight,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index)299 void nll_loss_forward_out_cpu_template(
300 const Tensor& output,
301 const Tensor& total_weight,
302 const Tensor& input,
303 const Tensor& target,
304 const Tensor& weight,
305 int64_t reduction,
306 int64_t ignore_index) {
307 AT_DISPATCH_FLOATING_TYPES_AND2(
308 ScalarType::BFloat16,
309 ScalarType::Half,
310 input.scalar_type(),
311 "nll_loss_out_frame",
312 [&] {
313 if (target.scalar_type() == kByte) {
314 nll_loss_out_frame<scalar_t, uint8_t>(
315 output,
316 total_weight,
317 input,
318 target,
319 weight,
320 reduction,
321 ignore_index);
322 } else {
323 // assumed to be int64
324 nll_loss_out_frame<scalar_t, int64_t>(
325 output,
326 total_weight,
327 input,
328 target,
329 weight,
330 reduction,
331 ignore_index);
332 }
333 });
334 }
335
336 template <typename scalar_t, typename target_t>
nll_loss_backward_out_frame(const Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)337 static void nll_loss_backward_out_frame(
338 const Tensor& grad_input,
339 const Tensor& grad_output,
340 const Tensor& input,
341 const Tensor& target,
342 const Tensor& weight,
343 int64_t reduction,
344 int64_t ignore_index,
345 const Tensor& total_weight) {
346 const auto n_dims = input.dim();
347 const auto n_classes = input.size(-1);
348
349 auto target_ = target;
350 if (target.dim() == 0) {
351 target_ = target.unsqueeze(0);
352 }
353 auto target_acc = target_.accessor<const target_t, 1>();
354
355 auto weight_contiguous = optional_contiguous(weight);
356 const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
357
358 if (reduction == Reduction::None && n_dims == 2) {
359 const auto batch_size = input.size(0);
360 auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
361 auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
362 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
363 for (const auto i : c10::irange(start, end)) {
364 auto cur_target = target_acc[i];
365 if (cur_target == ignore_index) {
366 continue;
367 }
368 const scalar_t w =
369 weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
370 grad_input_acc[i][cur_target] = -w * grad_output_acc[i];
371 }
372 });
373 return;
374 }
375
376 const scalar_t total_weight_value = *total_weight.const_data_ptr<scalar_t>();
377
378 const scalar_t grad_output_value = *grad_output.const_data_ptr<scalar_t>();
379
380 if (input.dim() == 1) {
381 auto grad_input_acc = grad_input.accessor<scalar_t, 1>();
382
383 const auto t = target_acc[0];
384 if (t != ignore_index) {
385 TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
386 const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
387 : grad_output_value);
388 grad_input_acc[t] = weight_data != nullptr ? weight_data[t] * grad
389 : grad;
390 }
391 } else if (input.dim() == 2) {
392 auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
393 const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
394 : grad_output_value);
395
396 const auto batch_size = input.size(0);
397
398 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
399 for (const auto i : c10::irange(start, end)) {
400 const auto t = target_acc[i];
401 if (t != ignore_index) {
402 TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
403 grad_input_acc[i][t] = weight_data != nullptr ? weight_data[t] * grad
404 : grad;
405 }
406 }
407 });
408 }
409 }
410
nll_loss_backward_out_cpu_template(const Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)411 void nll_loss_backward_out_cpu_template(
412 const Tensor& grad_input,
413 const Tensor& grad_output,
414 const Tensor& input,
415 const Tensor& target,
416 const Tensor& weight,
417 int64_t reduction,
418 int64_t ignore_index,
419 const Tensor& total_weight) {
420 grad_input.zero_();
421
422 AT_DISPATCH_FLOATING_TYPES_AND2(
423 ScalarType::BFloat16,
424 ScalarType::Half,
425 input.scalar_type(),
426 "nll_loss_backward_out_frame",
427 [&] {
428 if (target.scalar_type() == kByte) {
429 nll_loss_backward_out_frame<scalar_t, uint8_t>(
430 grad_input,
431 grad_output,
432 input,
433 target,
434 weight,
435 reduction,
436 ignore_index,
437 total_weight);
438 } else {
439 // assumed to be uint64
440 nll_loss_backward_out_frame<scalar_t, int64_t>(
441 grad_input,
442 grad_output,
443 input,
444 target,
445 weight,
446 reduction,
447 ignore_index,
448 total_weight);
449 }
450 });
451 }
452
453 } // namespace
454
TORCH_IMPL_FUNC(nll_loss_forward_out_cpu)455 TORCH_IMPL_FUNC(nll_loss_forward_out_cpu)
456 (const Tensor& self,
457 const Tensor& target,
458 const OptionalTensorRef weight_opt,
459 int64_t reduction,
460 int64_t ignore_index,
461 const Tensor& output,
462 const Tensor& total_weight) {
463 const Tensor& weight = weight_opt.getTensorRef();
464 nll_loss_forward_out_cpu_template(
465 output, total_weight, self, target, weight, reduction, ignore_index);
466 }
467
TORCH_IMPL_FUNC(nll_loss_backward_out_cpu)468 TORCH_IMPL_FUNC(nll_loss_backward_out_cpu)
469 (const Tensor& grad_output,
470 const Tensor& self,
471 const Tensor& target,
472 OptionalTensorRef weight_opt,
473 int64_t reduction,
474 int64_t ignore_index,
475 const Tensor& total_weight,
476 const Tensor& grad_input
477 ) {
478 const Tensor& weight = weight_opt.getTensorRef();
479 nll_loss_backward_out_cpu_template(
480 grad_input,
481 grad_output,
482 self,
483 target,
484 weight,
485 reduction,
486 ignore_index,
487 total_weight);
488 }
489
cross_entropy_loss_prob_target(const Tensor & self,const Tensor & target_,const Tensor & weight,int64_t reduction,double label_smoothing)490 static Tensor cross_entropy_loss_prob_target(
491 const Tensor& self,
492 const Tensor& target_,
493 const Tensor& weight,
494 int64_t reduction,
495 double label_smoothing) {
496 const auto class_dim = self.dim() == 1 ? 0 : 1;
497 const auto n_classes = self.size(class_dim);
498 TORCH_CHECK(
499 !weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
500 "cross_entropy: weight tensor should be defined either for all ",
501 n_classes,
502 " classes or no classes"
503 " but got weight tensor of shape: ",
504 weight.sizes());
505
506 auto input = at::log_softmax(self, class_dim, self.scalar_type());
507 Tensor target;
508
509 if (label_smoothing > 0.0) {
510 TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing);
511 target = target_ * (1 - label_smoothing) + label_smoothing / n_classes;
512 } else {
513 target = target_;
514 }
515
516 if (weight.defined()) {
517 // Expand weight to the correct number of dims for broadcasting with input / target
518 Tensor weight_ = weight;
519 if (input.dim() > 1) {
520 auto weight_broadcast_shape = SmallBuffer<int64_t, 5>(input.dim());
521 std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1);
522 weight_broadcast_shape[1] = weight.size(0);
523 weight_ = weight.view(weight_broadcast_shape);
524 }
525
526 switch (reduction) {
527 case Reduction::Mean:
528 if (input.sym_numel()==0){
529 return -(input * target * weight_).sum().fill_(std::numeric_limits<double>::quiet_NaN());
530 } else {
531 return -(input * target * weight_).sum() / (input.sym_numel() / n_classes);
532 }
533 case Reduction::Sum:
534 return -(input * target * weight_).sum();
535 case Reduction::None:
536 return -(input * target * weight_).sum(class_dim);
537 default:
538 TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
539 }
540 } else {
541 switch (reduction) {
542 case Reduction::Mean:
543 if (input.sym_numel()==0){
544 return -(input * target).sum().fill_(std::numeric_limits<double>::quiet_NaN());
545 } else {
546 return -(input * target).sum() / (input.sym_numel() / n_classes);
547 }
548 case Reduction::Sum:
549 return -(input * target).sum();
550 case Reduction::None:
551 return -(input * target).sum(class_dim);
552 default:
553 TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
554 }
555 }
556 }
557
cross_entropy_loss_label_smoothing(const Tensor & self,const Tensor & target,const Tensor & weight,int64_t reduction,c10::SymInt ignore_index,double label_smoothing)558 static Tensor cross_entropy_loss_label_smoothing(
559 const Tensor& self,
560 const Tensor& target,
561 const Tensor& weight,
562 int64_t reduction,
563 c10::SymInt ignore_index,
564 double label_smoothing) {
565 auto class_dim = self.dim() == 1 ? 0 : 1;
566 auto input = at::log_softmax(self, class_dim, self.scalar_type());
567 auto nllloss = at::nll_loss_nd_symint(input, target, weight, reduction, ignore_index);
568
569 auto n_classes = input.sym_size(class_dim);
570
571 Tensor smooth_loss;
572 if (weight.defined()) {
573 // Expand weight to the correct number of dims for broadcasting with input / target
574 auto weight_broadcast_shape = SmallBuffer<int64_t, 5>(input.dim());
575 std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1);
576 weight_broadcast_shape[class_dim] = weight.size(0);
577 Tensor weight_ = weight.view(weight_broadcast_shape);
578
579 smooth_loss = -(input * weight_).sum(class_dim);
580 } else {
581 smooth_loss = -input.sum(class_dim);
582 }
583
584 auto ignore_mask = target == std::move(ignore_index);
585 smooth_loss.masked_fill_(ignore_mask, 0.0);
586
587 Tensor ret;
588 switch (reduction) {
589 case Reduction::Mean:
590 if (weight.defined()) {
591 if (isTensorSubclassLike(weight)){
592 // we will collect weights from 0 index which is always valid
593 // and mask them out if they are ignored
594 auto filtered_target = target.masked_fill(ignore_mask, 0);
595 auto tgt_weights = weight.gather(0, filtered_target.flatten());
596 auto weight_sum =
597 tgt_weights.masked_fill_(ignore_mask.flatten(), 0).sum();
598 ret = smooth_loss.sum() / weight_sum;
599 } else {
600 // TODO: This code can path can be removed if #61309 is resolved
601 // loss is normalized by the weights to be consistent with
602 // nll_loss_nd
603 ret = smooth_loss.sum() /
604 weight.gather(0, target.masked_select(~ignore_mask).flatten())
605 .sum();
606 }
607 } else {
608 auto true_mask = ~ignore_mask;
609 ret = smooth_loss.sum()/ true_mask.sum();
610 }
611 break;
612 case Reduction::Sum:
613 ret = smooth_loss.sum();
614 break;
615 case Reduction::None:
616 ret = smooth_loss;
617 break;
618 default:
619 TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
620 }
621 return (1 - label_smoothing) * nllloss + ret * (label_smoothing / n_classes);
622 }
623
cross_entropy_loss_symint(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction,c10::SymInt ignore_index,double label_smoothing)624 Tensor cross_entropy_loss_symint(
625 const Tensor& self,
626 const Tensor& target,
627 const std::optional<Tensor>& weight,
628 int64_t reduction,
629 c10::SymInt ignore_index,
630 double label_smoothing) {
631 Tensor ret;
632 if (self.sym_sizes() == target.sym_sizes()) {
633 // Assume soft targets when input and target shapes are the same
634 TORCH_CHECK(at::isFloatingType(target.scalar_type()),
635 "Expected floating point type for target with class probabilities, got ", target.scalar_type());
636 TORCH_CHECK(ignore_index < 0, "ignore_index is not supported for floating point target");
637
638 // See [Note: hacky wrapper removal for optional tensor]
639 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight);
640 const Tensor& weight_ = *weight_maybe_owned;
641 ret = cross_entropy_loss_prob_target(self, target, weight_, reduction, label_smoothing);
642 } else if (label_smoothing > 0.0) {
643 TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing);
644
645 // See [Note: hacky wrapper removal for optional tensor]
646 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight);
647 const Tensor& weight_ = *weight_maybe_owned;
648 ret = cross_entropy_loss_label_smoothing(self, target, weight_, reduction, std::move(ignore_index), label_smoothing);
649 } else {
650 auto class_dim = self.dim() == 1 ? 0 : 1;
651 ret = at::nll_loss_nd_symint(
652 at::log_softmax(self, class_dim, self.scalar_type()),
653 target,
654 weight,
655 reduction,
656 std::move(ignore_index));
657 }
658 return ret;
659 }
660
nll_loss_out(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,Tensor & output)661 Tensor & nll_loss_out(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) {
662 // See [Note: hacky wrapper removal for optional tensor]
663 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
664 const Tensor& weight = *weight_maybe_owned;
665
666 Tensor total_weight = at::empty({0}, self.options());
667 return std::get<0>(at::nll_loss_forward_out(output, total_weight, self, target, weight, reduction, ignore_index));
668 }
669
nll_loss_symint(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,c10::SymInt ignore_index)670 Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) {
671 // See [Note: hacky wrapper removal for optional tensor]
672 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
673 const Tensor& weight = *weight_maybe_owned;
674
675 return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
676 }
677
nll_loss_nd_symint(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction,c10::SymInt ignore_index)678 Tensor nll_loss_nd_symint(
679 const Tensor& self,
680 const Tensor& target,
681 const std::optional<Tensor>& weight,
682 int64_t reduction,
683 c10::SymInt ignore_index) {
684 if (self.dim() < 1) {
685 TORCH_CHECK_VALUE(
686 false, "Expected 1 or more dimensions (got ", self.dim(), ")");
687 }
688
689 if (self.dim() != 1 && self.sym_sizes()[0] != target.sym_sizes()[0]) {
690 TORCH_CHECK_VALUE(
691 false,
692 "Expected input batch_size (",
693 self.sym_sizes()[0],
694 ") to match target batch_size (",
695 target.sym_sizes()[0],
696 ").");
697 }
698
699 Tensor ret;
700 Tensor input_ = self;
701 Tensor target_ = target;
702 if (input_.dim() == 1 || input_.dim() == 2) {
703 ret = at::nll_loss_symint(input_, target_, weight, reduction, std::move(ignore_index));
704 } else if (input_.dim() == 4) {
705 ret = at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
706 } else {
707 // dim == 3 or dim > 4
708 auto n = input_.sym_sizes()[0];
709 auto c = input_.sym_sizes()[1];
710 auto out_size = input_.sym_sizes().slice(2).vec();
711 out_size.insert(out_size.begin(), n);
712 if (target_.sym_sizes().slice(1) != input_.sym_sizes().slice(2)) {
713 TORCH_CHECK(
714 false,
715 "Expected target size ",
716 SymIntArrayRef(out_size),
717 ", got ",
718 target_.sym_sizes());
719 }
720 input_ = input_.contiguous();
721 target_ = target_.contiguous();
722 // support empty batches, see #15870
723 if (input_.sym_numel() > 0) {
724 input_ = input_.view_symint({n, std::move(c), 1, -1});
725 } else {
726 input_ = input_.view_symint({n, std::move(c), 0, 0});
727 }
728 if (target_.sym_numel() > 0) {
729 target_ = target_.view_symint({std::move(n), 1, -1});
730 } else {
731 target_ = target_.view_symint({std::move(n), 0, 0});
732 }
733 if (reduction != Reduction::None) {
734 ret = at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
735 } else {
736 auto out =
737 at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
738 ret = out.view_symint(out_size);
739 }
740 }
741 return ret;
742 }
743
744 } // namespace at::native
745