xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Loss.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/Reduction.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorIterator.h>
6 #include <ATen/TensorMeta.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/native/BinaryOps.h>
9 #include <ATen/native/PointwiseOps.h>
10 #include <ATen/native/cpu/Loops.h>
11 #include <c10/util/Exception.h>
12 #include <ATen/TensorSubclassLikeUtils.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/binary_cross_entropy_backward_native.h>
19 #include <ATen/ops/binary_cross_entropy_native.h>
20 #include <ATen/ops/binary_cross_entropy_with_logits_native.h>
21 #include <ATen/ops/clamp_min.h>
22 #include <ATen/ops/cosine_embedding_loss_native.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/empty_like.h>
25 #include <ATen/ops/exp.h>
26 #include <ATen/ops/hinge_embedding_loss_native.h>
27 #include <ATen/ops/huber_loss_backward.h>
28 #include <ATen/ops/huber_loss_backward_native.h>
29 #include <ATen/ops/huber_loss_native.h>
30 #include <ATen/ops/kl_div_native.h>
31 #include <ATen/ops/l1_loss_native.h>
32 #include <ATen/ops/log.h>
33 #include <ATen/ops/log_sigmoid.h>
34 #include <ATen/ops/margin_ranking_loss_native.h>
35 #include <ATen/ops/mean.h>
36 #include <ATen/ops/min.h>
37 #include <ATen/ops/mse_loss_backward.h>
38 #include <ATen/ops/mse_loss_backward_native.h>
39 #include <ATen/ops/mse_loss_meta.h>
40 #include <ATen/ops/mse_loss_native.h>
41 #include <ATen/ops/mul.h>
42 #include <ATen/ops/neg.h>
43 #include <ATen/ops/pairwise_distance.h>
44 #include <ATen/ops/poisson_nll_loss_native.h>
45 #include <ATen/ops/smooth_l1_loss_backward.h>
46 #include <ATen/ops/smooth_l1_loss_backward_native.h>
47 #include <ATen/ops/smooth_l1_loss_meta.h>
48 #include <ATen/ops/smooth_l1_loss_native.h>
49 #include <ATen/ops/soft_margin_loss.h>
50 #include <ATen/ops/soft_margin_loss_backward.h>
51 #include <ATen/ops/soft_margin_loss_backward_native.h>
52 #include <ATen/ops/soft_margin_loss_native.h>
53 #include <ATen/ops/squeeze.h>
54 #include <ATen/ops/sum.h>
55 #include <ATen/ops/triplet_margin_loss_native.h>
56 #include <ATen/ops/where.h>
57 #include <ATen/ops/xlogy.h>
58 #include <ATen/ops/zeros_like.h>
59 #endif
60 
61 constexpr float EPSILON = 1e-12;
62 
63 namespace {
apply_loss_reduction(const at::Tensor & unreduced,int64_t reduction)64   static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
65     if (reduction == at::Reduction::Mean) {
66       return unreduced.mean();
67     } else if (reduction == at::Reduction::Sum) {
68       return unreduced.sum();
69     }
70     return unreduced;
71   }
72 }
73 
74 namespace at::meta {
75 
TORCH_META_FUNC(smooth_l1_loss)76 TORCH_META_FUNC(smooth_l1_loss)
77 (const Tensor& input, const Tensor& target, const int64_t reduction, double beta) {
78   TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.")
79   // TODO: Reduce this extra TensorIterator construction for Reduction::Mean & Sum.
80   // We do another TensorIterator construction in the IMPL for the two cases.
81   build_borrowing_binary_op(maybe_get_output(), input, target);
82   if (reduction == Reduction::None) {
83     return;
84   }
85 
86   TORCH_INTERNAL_ASSERT(reduction == Reduction::Mean || reduction == Reduction::Sum);
87   maybe_get_output().resize_({});
88 }
89 
TORCH_META_FUNC(mse_loss)90 TORCH_META_FUNC(mse_loss)
91 (const Tensor& input, const Tensor& target, const int64_t reduction) {
92   build_borrowing_binary_op(maybe_get_output(), input, target);
93   if (reduction == Reduction::None) {
94     return;
95   }
96 
97   TORCH_INTERNAL_ASSERT(reduction == Reduction::Mean || reduction == Reduction::Sum);
98   maybe_get_output().resize_({});
99 }
100 
101 } // namespace at::meta
102 
103 namespace at::native {
104 
105 DEFINE_DISPATCH(smooth_l1_stub);
106 DEFINE_DISPATCH(smooth_l1_backward_stub);
107 DEFINE_DISPATCH(huber_stub);
108 DEFINE_DISPATCH(huber_backward_stub);
109 DEFINE_DISPATCH(mse_stub);
110 DEFINE_DISPATCH(mse_backward_stub);
111 
TORCH_IMPL_FUNC(smooth_l1_loss_out)112 TORCH_IMPL_FUNC(smooth_l1_loss_out)
113 (const Tensor& input, const Tensor& target, int64_t reduction, double beta, const Tensor& result) {
114   if (reduction != Reduction::None) {
115     Tensor loss;
116     auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
117     smooth_l1_stub(iter.device_type(), iter, beta);
118     if (reduction == Reduction::Mean) {
119       at::mean_out(const_cast<Tensor&>(result), iter.output(), IntArrayRef{});
120     } else {
121       at::sum_out(const_cast<Tensor&>(result), iter.output(), IntArrayRef{});
122     }
123   } else {
124     smooth_l1_stub(device_type(), *this, beta);
125   }
126 }
127 
TORCH_IMPL_FUNC(mse_loss_out)128 TORCH_IMPL_FUNC(mse_loss_out)
129 (const Tensor& input, const Tensor& target, int64_t reduction, const Tensor& result) {
130   if (reduction != Reduction::None) {
131     Tensor loss;
132     auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
133     mse_stub(iter.device_type(), iter);
134     if (reduction == Reduction::Mean) {
135       at::mean_out(const_cast<Tensor&>(result), iter.output(), IntArrayRef{});
136     } else {
137       at::sum_out(const_cast<Tensor&>(result), iter.output(), IntArrayRef{});
138     }
139   } else {
140     mse_stub(device_type(), *this);
141   }
142 }
143 
cosine_embedding_loss(const Tensor & input1,const Tensor & input2,const Tensor & target,double margin,int64_t reduction)144 Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) {
145   auto targ_dim = target.dim();
146   TORCH_CHECK(
147       targ_dim == 1 || targ_dim == 0,
148       "0D or 1D target tensor expected, multi-target not supported");
149   if (targ_dim == 1) {
150     TORCH_CHECK(
151         input1.dim() == 2 && input2.dim() == 2,
152         "1D target tensor expects 2D input tensors, but found inputs with sizes ",
153         input1.sizes(),
154         " and ",
155         input2.sizes(),
156         ".");
157   } else {
158     TORCH_CHECK(
159         input1.dim() == 1 && input2.dim() == 1,
160         "0D target tensor expects 1D input tensors, but found inputs with sizes ",
161         input1.sizes(),
162         " and ",
163         input2.sizes(),
164         ".");
165   }
166 
167   auto prod_sum = (input1 * input2).sum(targ_dim);
168   auto mag_square1 = (input1 * input1).sum(targ_dim) + EPSILON;
169   auto mag_square2 = (input2 * input2).sum(targ_dim) + EPSILON;
170   auto denom = (mag_square1 * mag_square2).sqrt_();
171   auto cos = prod_sum / denom;
172 
173   auto zeros = at::zeros_like(cos, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
174   auto pos = 1 - cos;
175   auto neg = (cos - margin).clamp_min_(0);
176   auto output_pos = at::where(target == 1, pos, zeros);
177   auto output_neg = at::where(target == -1, neg, zeros);
178   auto output = output_pos + output_neg;
179   return apply_loss_reduction(output, reduction);
180 }
181 
hinge_embedding_loss(const Tensor & self,const Tensor & target,double margin,int64_t reduction)182 Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double margin, int64_t reduction) {
183   auto zeros = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
184   auto margin_diff = (margin - self);
185   // For Composite Compliance,
186   // In Forward AD, if `margin_diff` is a CCT but its tangent isn't,
187   // using inplace clamp_min doesn't work because we end up writing
188   // the CCT in-place to the tangent
189   auto margin_clamp = (margin_diff._fw_grad(/*level*/ 0).defined() &&
190                        isTensorSubclassLike(margin_diff))
191       ? margin_diff.clamp_min(0)
192       : margin_diff.clamp_min_(0);
193   auto output_margin = at::where(target != 1, margin_clamp, zeros);
194   auto output_self = at::where(target != -1, self, zeros);
195   auto output = output_margin + output_self;
196   return apply_loss_reduction(output, reduction);
197 }
198 
triplet_margin_loss(const Tensor & anchor,const Tensor & positive,const Tensor & negative,double margin,double p,double eps,bool swap,int64_t reduction)199 Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const Tensor& negative, double margin,
200                            double p, double eps, bool swap, int64_t reduction) {
201   auto a_dim = anchor.dim();
202   auto p_dim = positive.dim();
203   auto n_dim = negative.dim();
204   TORCH_CHECK(
205       a_dim == p_dim && p_dim == n_dim,
206       "The anchor, positive, and negative tensors are expected to have "
207       "the same number of dimensions, but got: anchor ", a_dim, "D, "
208       "positive ", p_dim, "D, and negative ", n_dim, "D inputs")
209 
210   auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
211   auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
212   // The distance swap is described in the paper "Learning shallow
213   // convolutional feature descriptors with triplet losses" by V. Balntas, E.
214   // Riba et al.  If True, and if the positive example is closer to the
215   // negative example than the anchor is, swaps the positive example and the
216   // anchor in the loss computation.
217   if (swap) {
218     auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
219     dist_neg = at::min(dist_neg, dist_swap);
220   }
221   auto output = at::clamp_min(margin + dist_pos - dist_neg, 0);
222   return apply_loss_reduction(output, reduction);
223 }
224 
margin_ranking_loss(const Tensor & input1,const Tensor & input2,const Tensor & target,double margin,int64_t reduction)225 Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) {
226   auto unclamped_output = (-target * (input1 - input2) + margin);
227   // For Composite Compliance,
228   // In Forward AD, if `margin_diff` is a CCT but its tangent isn't,
229   // using inplace clamp_min doesn't work because we end up writing
230   // the CCT in-place to the tangent
231   auto output = (unclamped_output._fw_grad(/*level*/ 0).defined() &&
232                  isTensorSubclassLike(unclamped_output))
233       ? unclamped_output.clamp_min(0)
234       : unclamped_output.clamp_min_(0);
235   return apply_loss_reduction(output, reduction);
236 }
237 
kl_div(const Tensor & input,const Tensor & target,int64_t reduction,bool log_target)238 Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
239   TORCH_CHECK(!input.is_complex() && !target.is_complex(),
240               "kl_div: Complex inputs not supported.");
241   TORCH_CHECK(!at::isIntegralType(input.scalar_type(), /*include_bool*/true) &&
242               !at::isIntegralType(target.scalar_type(), /*include_bool*/true),
243               "kl_div: Integral inputs not supported.");
244   Tensor output;
245   if (log_target) {
246     output = at::exp(target) * (target - input);
247   } else {
248     output = at::xlogy(target, target) - target * input;
249   }
250   return apply_loss_reduction(output, reduction);
251 }
252 
binary_cross_entropy_cpu(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)253 Tensor binary_cross_entropy_cpu(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
254   // See [Note: hacky wrapper removal for optional tensor]
255   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
256   const Tensor& weight = *weight_maybe_owned;
257 
258     Tensor loss = at::empty_like(input);
259     return at::native::binary_cross_entropy_out_cpu(
260         input, target, weight, reduction, loss);
261 }
262 
binary_cross_entropy_out_cpu(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & loss)263 Tensor& binary_cross_entropy_out_cpu(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& loss) {
264   // See [Note: hacky wrapper removal for optional tensor]
265   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
266   const Tensor& weight = *weight_maybe_owned;
267 
268     Tensor loss_squeezed = at::squeeze(loss);
269 
270     auto iter = TensorIteratorConfig()
271       .add_output(loss_squeezed)
272       .add_owned_const_input(at::squeeze(input))
273       .add_owned_const_input(at::squeeze(target))
274       .build();
275 
276     AT_DISPATCH_FLOATING_TYPES_AND2(
277         ScalarType::Half,
278         ScalarType::BFloat16,
279         loss.scalar_type(),
280         "binary_cross_entropy",
281         [&] {
282           at::native::cpu_kernel(
283               iter, [](scalar_t input_val, scalar_t target_val) {
284                 TORCH_CHECK(
285                     (input_val >= 0) && (input_val <= 1),
286                     "all elements of input should be between 0 and 1");
287                 TORCH_CHECK(
288                     (target_val >= 0) && (target_val <= 1),
289                     "all elements of target should be between 0 and 1");
290 
291                 // Binary cross entropy tensor is defined by the equation:
292                 // L = -w (y ln(x) + (1-y) ln(1-x))
293                 return (target_val - scalar_t(1)) *
294                     std::max(scalar_t(std::log1p(-input_val)), scalar_t(-100)) -
295                     target_val *
296                     std::max(scalar_t(std::log(input_val)), scalar_t(-100));
297               });
298         });
299 
300     if (weight.defined()) {
301         loss.mul_(weight);
302     }
303     if (reduction != at::Reduction::None) {
304         Tensor loss_reduced = apply_loss_reduction(loss, reduction);
305         loss.resize_as_(loss_reduced).copy_(loss_reduced);
306     }
307     return loss;
308 }
309 
binary_cross_entropy_backward_cpu(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)310 Tensor binary_cross_entropy_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
311   // See [Note: hacky wrapper removal for optional tensor]
312   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
313   const Tensor& weight = *weight_maybe_owned;
314 
315     Tensor grad_input = at::empty_like(input);
316     return at::native::binary_cross_entropy_backward_out_cpu(
317         grad, input, target, weight, reduction, grad_input);
318 }
319 
binary_cross_entropy_backward_out_cpu(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & grad_input)320 Tensor& binary_cross_entropy_backward_out_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& grad_input) {
321   // See [Note: hacky wrapper removal for optional tensor]
322   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
323   const Tensor& weight = *weight_maybe_owned;
324 
325     Tensor grad_input_squeezed = at::squeeze(grad_input);
326 
327     auto iter = TensorIteratorConfig()
328       .add_output(grad_input_squeezed)
329       .add_owned_const_input(at::squeeze(grad))
330       .add_owned_const_input(at::squeeze(input))
331       .add_owned_const_input(at::squeeze(target))
332       .build();
333 
334     AT_DISPATCH_FLOATING_TYPES_AND2(
335         ScalarType::Half,
336         ScalarType::BFloat16,
337         grad_input.scalar_type(),
338         "binary_cross_entropy_backward",
339         [&] {
340           at::native::cpu_kernel(
341               iter,
342               [](scalar_t grad_val, scalar_t input_val, scalar_t target_val) {
343                 // The gradient is the partial derivative of BCELoss
344                 // with respect to x
345                 // d(L)/d(x) = -w (y - x) / (x - x^2)
346                 return grad_val * (input_val - target_val) /
347                     (scalar_t(std::max(
348                         (scalar_t(1) - input_val) * input_val,
349                         scalar_t(EPSILON))));
350               });
351         });
352 
353     if (weight.defined()) {
354         grad_input.mul_(weight);
355     }
356     if (reduction == at::Reduction::Mean) {
357         grad_input.div_(input.numel());
358     }
359     return grad_input;
360 }
361 
binary_cross_entropy_with_logits(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & pos_weight_opt,int64_t reduction)362 Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& pos_weight_opt, int64_t reduction) {
363   // See [Note: hacky wrapper removal for optional tensor]
364   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
365   const Tensor& weight = *weight_maybe_owned;
366   c10::MaybeOwned<Tensor> pos_weight_maybe_owned = at::borrow_from_optional_tensor(pos_weight_opt);
367   const Tensor& pos_weight = *pos_weight_maybe_owned;
368 
369   auto log_sigmoid_input = at::log_sigmoid(input);
370   if (pos_weight.defined()) {
371       // pos_weight need to be broadcasted, thus mul(target) is not inplace.
372       auto log_weight = (pos_weight - 1).mul(target).add_(1);
373       log_sigmoid_input.mul_(log_weight);
374   }
375 
376   Tensor loss = (1 - target).mul_(input).sub_(log_sigmoid_input);
377 
378   if (weight.defined()) {
379       loss.mul_(weight);
380   }
381 
382   return apply_loss_reduction(loss, reduction);
383 }
384 
poisson_nll_loss(const Tensor & input,const Tensor & target,const bool log_input,const bool full,const double eps,const int64_t reduction)385 Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction)
386 {
387     Tensor loss;
388     if (log_input) {
389         loss = at::exp(input) - target * input;
390     } else {
391         loss = input - target * at::log(input + eps);
392     }
393 
394     if (full) {
395         auto stirling_term = target * at::log(target) - target + 0.5 * at::log(2 * c10::pi<double> * target);
396         loss += stirling_term.masked_fill(target <= 1, 0);
397     }
398 
399     return apply_loss_reduction(loss, reduction);
400 }
401 
soft_margin_loss_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,Tensor & grad_input)402 Tensor& soft_margin_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, Tensor& grad_input) {
403   auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.;
404   auto z = at::exp(-target * input);
405   // inplace version of: grad_input = -norm * target * z / (1. + z) * grad_output;
406   at::mul_out(grad_input, target, z).mul_(-norm);
407   z.add_(1);
408   grad_input.div_(z).mul_(grad_output);
409   return grad_input;
410 }
411 
soft_margin_loss_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction)412 Tensor soft_margin_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
413   auto grad_input = at::empty({0}, input.options());
414   at::soft_margin_loss_backward_out(grad_input, grad_output, input, target, reduction);
415   return grad_input;
416 }
417 
soft_margin_loss_out(const Tensor & input,const Tensor & target,int64_t reduction,Tensor & output)418 Tensor& soft_margin_loss_out(const Tensor& input,
419     const Tensor& target,
420     int64_t reduction,
421     Tensor& output) {
422   // compute inplace variant of: output = at::log1p(at::exp(-input * target));
423   at::neg_out(output, input).mul_(target).exp_().log1p_();
424   if (reduction != Reduction::None) {
425     auto tmp = apply_loss_reduction(output, reduction);
426     output.resize_({});
427     output.copy_(tmp);
428   }
429   return output;
430 }
431 
soft_margin_loss(const Tensor & input,const Tensor & target,int64_t reduction)432 Tensor soft_margin_loss(
433     const Tensor& input,
434     const Tensor& target,
435     int64_t reduction) {
436   auto output = at::empty({0}, input.options());
437   at::soft_margin_loss_out(output, input, target, reduction);
438   return output;
439 }
440 
smooth_l1_loss_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double beta,Tensor & grad_input)441 Tensor& smooth_l1_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta, Tensor& grad_input) {
442   auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.;
443   auto iter = at::TensorIteratorConfig()
444     .add_output(grad_input)
445     .add_const_input(input)
446     .add_const_input(target)
447     .add_const_input(grad_output)
448     .promote_inputs_to_common_dtype(true)
449     .cast_common_dtype_to_outputs(true)
450     .enforce_safe_casting_to_output(true)
451     .build();
452   smooth_l1_backward_stub(iter.device_type(), iter, norm, beta);
453   return grad_input;
454 }
455 
smooth_l1_loss_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double beta)456 Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta) {
457   auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
458   return at::smooth_l1_loss_backward_out(grad_input, grad_output, input, target, reduction, beta);
459 }
460 
huber_loss(const Tensor & input,const Tensor & target,int64_t reduction,double delta)461 Tensor huber_loss(const Tensor& input, const Tensor& target, int64_t reduction, double delta) {
462   TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
463   Tensor loss = at::empty_like(input);
464   auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
465   huber_stub(iter.device_type(), iter, delta);
466   return apply_loss_reduction(loss, reduction);
467 }
468 
huber_loss_out(const Tensor & input,const Tensor & target,int64_t reduction,double delta,Tensor & result)469 Tensor& huber_loss_out(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& result) {
470   TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
471   auto iter = TensorIterator::borrowing_binary_op(result, input, target);
472   huber_stub(iter.device_type(), iter, delta);
473   if (reduction != Reduction::None) {
474     auto reduced = apply_loss_reduction(result, reduction);
475     result.resize_({});
476     result.copy_(reduced);
477   }
478   return result;
479 }
480 
huber_loss_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double delta)481 Tensor huber_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double delta) {
482   auto grad_input = at::zeros_like(input, MemoryFormat::Contiguous);
483   return at::huber_loss_backward_out(grad_input, grad_output, input, target, reduction, delta);
484 }
485 
huber_loss_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double delta,Tensor & grad_input)486 Tensor& huber_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& grad_input) {
487   auto norm = (reduction == Reduction::Mean) ? (1. / input.numel()) : 1.;
488   auto iter = at::TensorIteratorConfig()
489     .add_output(grad_input)
490     .add_const_input(input)
491     .add_const_input(target)
492     .add_const_input(grad_output)
493     .build();
494   huber_backward_stub(iter.device_type(), iter, norm, delta);
495   return grad_input;
496 }
497 
mse_loss_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction)498 Tensor mse_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
499   Tensor grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
500   return at::mse_loss_backward_out(grad_input, grad_output, input, target, reduction);
501 }
502 
mse_loss_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,Tensor & grad_input)503 Tensor& mse_loss_backward_out(const Tensor& grad_output,
504     const Tensor& input, const Tensor& target, int64_t reduction, Tensor& grad_input) {
505   auto norm = reduction == Reduction::Mean ? 2. / input.numel() : 2.;
506   auto iter = at::TensorIteratorConfig()
507     .add_output(grad_input)
508     .add_const_input(input)
509     .add_const_input(target)
510     .add_const_input(grad_output)
511     .build();
512   mse_backward_stub(iter.device_type(), iter, norm);
513   return grad_input;
514 }
515 
l1_loss(const Tensor & input,const Tensor & target,int64_t reduction)516 Tensor l1_loss(const Tensor& input, const Tensor& target, int64_t reduction) {
517   return apply_loss_reduction((input - target).abs(), reduction);
518 }
519 }  // namespace at::native
520