1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/Reduction.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorIterator.h>
6 #include <ATen/TensorMeta.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/native/BinaryOps.h>
9 #include <ATen/native/PointwiseOps.h>
10 #include <ATen/native/cpu/Loops.h>
11 #include <c10/util/Exception.h>
12 #include <ATen/TensorSubclassLikeUtils.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/binary_cross_entropy_backward_native.h>
19 #include <ATen/ops/binary_cross_entropy_native.h>
20 #include <ATen/ops/binary_cross_entropy_with_logits_native.h>
21 #include <ATen/ops/clamp_min.h>
22 #include <ATen/ops/cosine_embedding_loss_native.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/empty_like.h>
25 #include <ATen/ops/exp.h>
26 #include <ATen/ops/hinge_embedding_loss_native.h>
27 #include <ATen/ops/huber_loss_backward.h>
28 #include <ATen/ops/huber_loss_backward_native.h>
29 #include <ATen/ops/huber_loss_native.h>
30 #include <ATen/ops/kl_div_native.h>
31 #include <ATen/ops/l1_loss_native.h>
32 #include <ATen/ops/log.h>
33 #include <ATen/ops/log_sigmoid.h>
34 #include <ATen/ops/margin_ranking_loss_native.h>
35 #include <ATen/ops/mean.h>
36 #include <ATen/ops/min.h>
37 #include <ATen/ops/mse_loss_backward.h>
38 #include <ATen/ops/mse_loss_backward_native.h>
39 #include <ATen/ops/mse_loss_meta.h>
40 #include <ATen/ops/mse_loss_native.h>
41 #include <ATen/ops/mul.h>
42 #include <ATen/ops/neg.h>
43 #include <ATen/ops/pairwise_distance.h>
44 #include <ATen/ops/poisson_nll_loss_native.h>
45 #include <ATen/ops/smooth_l1_loss_backward.h>
46 #include <ATen/ops/smooth_l1_loss_backward_native.h>
47 #include <ATen/ops/smooth_l1_loss_meta.h>
48 #include <ATen/ops/smooth_l1_loss_native.h>
49 #include <ATen/ops/soft_margin_loss.h>
50 #include <ATen/ops/soft_margin_loss_backward.h>
51 #include <ATen/ops/soft_margin_loss_backward_native.h>
52 #include <ATen/ops/soft_margin_loss_native.h>
53 #include <ATen/ops/squeeze.h>
54 #include <ATen/ops/sum.h>
55 #include <ATen/ops/triplet_margin_loss_native.h>
56 #include <ATen/ops/where.h>
57 #include <ATen/ops/xlogy.h>
58 #include <ATen/ops/zeros_like.h>
59 #endif
60
61 constexpr float EPSILON = 1e-12;
62
63 namespace {
apply_loss_reduction(const at::Tensor & unreduced,int64_t reduction)64 static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
65 if (reduction == at::Reduction::Mean) {
66 return unreduced.mean();
67 } else if (reduction == at::Reduction::Sum) {
68 return unreduced.sum();
69 }
70 return unreduced;
71 }
72 }
73
74 namespace at::meta {
75
TORCH_META_FUNC(smooth_l1_loss)76 TORCH_META_FUNC(smooth_l1_loss)
77 (const Tensor& input, const Tensor& target, const int64_t reduction, double beta) {
78 TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.")
79 // TODO: Reduce this extra TensorIterator construction for Reduction::Mean & Sum.
80 // We do another TensorIterator construction in the IMPL for the two cases.
81 build_borrowing_binary_op(maybe_get_output(), input, target);
82 if (reduction == Reduction::None) {
83 return;
84 }
85
86 TORCH_INTERNAL_ASSERT(reduction == Reduction::Mean || reduction == Reduction::Sum);
87 maybe_get_output().resize_({});
88 }
89
TORCH_META_FUNC(mse_loss)90 TORCH_META_FUNC(mse_loss)
91 (const Tensor& input, const Tensor& target, const int64_t reduction) {
92 build_borrowing_binary_op(maybe_get_output(), input, target);
93 if (reduction == Reduction::None) {
94 return;
95 }
96
97 TORCH_INTERNAL_ASSERT(reduction == Reduction::Mean || reduction == Reduction::Sum);
98 maybe_get_output().resize_({});
99 }
100
101 } // namespace at::meta
102
103 namespace at::native {
104
105 DEFINE_DISPATCH(smooth_l1_stub);
106 DEFINE_DISPATCH(smooth_l1_backward_stub);
107 DEFINE_DISPATCH(huber_stub);
108 DEFINE_DISPATCH(huber_backward_stub);
109 DEFINE_DISPATCH(mse_stub);
110 DEFINE_DISPATCH(mse_backward_stub);
111
TORCH_IMPL_FUNC(smooth_l1_loss_out)112 TORCH_IMPL_FUNC(smooth_l1_loss_out)
113 (const Tensor& input, const Tensor& target, int64_t reduction, double beta, const Tensor& result) {
114 if (reduction != Reduction::None) {
115 Tensor loss;
116 auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
117 smooth_l1_stub(iter.device_type(), iter, beta);
118 if (reduction == Reduction::Mean) {
119 at::mean_out(const_cast<Tensor&>(result), iter.output(), IntArrayRef{});
120 } else {
121 at::sum_out(const_cast<Tensor&>(result), iter.output(), IntArrayRef{});
122 }
123 } else {
124 smooth_l1_stub(device_type(), *this, beta);
125 }
126 }
127
TORCH_IMPL_FUNC(mse_loss_out)128 TORCH_IMPL_FUNC(mse_loss_out)
129 (const Tensor& input, const Tensor& target, int64_t reduction, const Tensor& result) {
130 if (reduction != Reduction::None) {
131 Tensor loss;
132 auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
133 mse_stub(iter.device_type(), iter);
134 if (reduction == Reduction::Mean) {
135 at::mean_out(const_cast<Tensor&>(result), iter.output(), IntArrayRef{});
136 } else {
137 at::sum_out(const_cast<Tensor&>(result), iter.output(), IntArrayRef{});
138 }
139 } else {
140 mse_stub(device_type(), *this);
141 }
142 }
143
cosine_embedding_loss(const Tensor & input1,const Tensor & input2,const Tensor & target,double margin,int64_t reduction)144 Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) {
145 auto targ_dim = target.dim();
146 TORCH_CHECK(
147 targ_dim == 1 || targ_dim == 0,
148 "0D or 1D target tensor expected, multi-target not supported");
149 if (targ_dim == 1) {
150 TORCH_CHECK(
151 input1.dim() == 2 && input2.dim() == 2,
152 "1D target tensor expects 2D input tensors, but found inputs with sizes ",
153 input1.sizes(),
154 " and ",
155 input2.sizes(),
156 ".");
157 } else {
158 TORCH_CHECK(
159 input1.dim() == 1 && input2.dim() == 1,
160 "0D target tensor expects 1D input tensors, but found inputs with sizes ",
161 input1.sizes(),
162 " and ",
163 input2.sizes(),
164 ".");
165 }
166
167 auto prod_sum = (input1 * input2).sum(targ_dim);
168 auto mag_square1 = (input1 * input1).sum(targ_dim) + EPSILON;
169 auto mag_square2 = (input2 * input2).sum(targ_dim) + EPSILON;
170 auto denom = (mag_square1 * mag_square2).sqrt_();
171 auto cos = prod_sum / denom;
172
173 auto zeros = at::zeros_like(cos, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
174 auto pos = 1 - cos;
175 auto neg = (cos - margin).clamp_min_(0);
176 auto output_pos = at::where(target == 1, pos, zeros);
177 auto output_neg = at::where(target == -1, neg, zeros);
178 auto output = output_pos + output_neg;
179 return apply_loss_reduction(output, reduction);
180 }
181
hinge_embedding_loss(const Tensor & self,const Tensor & target,double margin,int64_t reduction)182 Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double margin, int64_t reduction) {
183 auto zeros = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
184 auto margin_diff = (margin - self);
185 // For Composite Compliance,
186 // In Forward AD, if `margin_diff` is a CCT but its tangent isn't,
187 // using inplace clamp_min doesn't work because we end up writing
188 // the CCT in-place to the tangent
189 auto margin_clamp = (margin_diff._fw_grad(/*level*/ 0).defined() &&
190 isTensorSubclassLike(margin_diff))
191 ? margin_diff.clamp_min(0)
192 : margin_diff.clamp_min_(0);
193 auto output_margin = at::where(target != 1, margin_clamp, zeros);
194 auto output_self = at::where(target != -1, self, zeros);
195 auto output = output_margin + output_self;
196 return apply_loss_reduction(output, reduction);
197 }
198
triplet_margin_loss(const Tensor & anchor,const Tensor & positive,const Tensor & negative,double margin,double p,double eps,bool swap,int64_t reduction)199 Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const Tensor& negative, double margin,
200 double p, double eps, bool swap, int64_t reduction) {
201 auto a_dim = anchor.dim();
202 auto p_dim = positive.dim();
203 auto n_dim = negative.dim();
204 TORCH_CHECK(
205 a_dim == p_dim && p_dim == n_dim,
206 "The anchor, positive, and negative tensors are expected to have "
207 "the same number of dimensions, but got: anchor ", a_dim, "D, "
208 "positive ", p_dim, "D, and negative ", n_dim, "D inputs")
209
210 auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
211 auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
212 // The distance swap is described in the paper "Learning shallow
213 // convolutional feature descriptors with triplet losses" by V. Balntas, E.
214 // Riba et al. If True, and if the positive example is closer to the
215 // negative example than the anchor is, swaps the positive example and the
216 // anchor in the loss computation.
217 if (swap) {
218 auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
219 dist_neg = at::min(dist_neg, dist_swap);
220 }
221 auto output = at::clamp_min(margin + dist_pos - dist_neg, 0);
222 return apply_loss_reduction(output, reduction);
223 }
224
margin_ranking_loss(const Tensor & input1,const Tensor & input2,const Tensor & target,double margin,int64_t reduction)225 Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) {
226 auto unclamped_output = (-target * (input1 - input2) + margin);
227 // For Composite Compliance,
228 // In Forward AD, if `margin_diff` is a CCT but its tangent isn't,
229 // using inplace clamp_min doesn't work because we end up writing
230 // the CCT in-place to the tangent
231 auto output = (unclamped_output._fw_grad(/*level*/ 0).defined() &&
232 isTensorSubclassLike(unclamped_output))
233 ? unclamped_output.clamp_min(0)
234 : unclamped_output.clamp_min_(0);
235 return apply_loss_reduction(output, reduction);
236 }
237
kl_div(const Tensor & input,const Tensor & target,int64_t reduction,bool log_target)238 Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
239 TORCH_CHECK(!input.is_complex() && !target.is_complex(),
240 "kl_div: Complex inputs not supported.");
241 TORCH_CHECK(!at::isIntegralType(input.scalar_type(), /*include_bool*/true) &&
242 !at::isIntegralType(target.scalar_type(), /*include_bool*/true),
243 "kl_div: Integral inputs not supported.");
244 Tensor output;
245 if (log_target) {
246 output = at::exp(target) * (target - input);
247 } else {
248 output = at::xlogy(target, target) - target * input;
249 }
250 return apply_loss_reduction(output, reduction);
251 }
252
binary_cross_entropy_cpu(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)253 Tensor binary_cross_entropy_cpu(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
254 // See [Note: hacky wrapper removal for optional tensor]
255 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
256 const Tensor& weight = *weight_maybe_owned;
257
258 Tensor loss = at::empty_like(input);
259 return at::native::binary_cross_entropy_out_cpu(
260 input, target, weight, reduction, loss);
261 }
262
binary_cross_entropy_out_cpu(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & loss)263 Tensor& binary_cross_entropy_out_cpu(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& loss) {
264 // See [Note: hacky wrapper removal for optional tensor]
265 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
266 const Tensor& weight = *weight_maybe_owned;
267
268 Tensor loss_squeezed = at::squeeze(loss);
269
270 auto iter = TensorIteratorConfig()
271 .add_output(loss_squeezed)
272 .add_owned_const_input(at::squeeze(input))
273 .add_owned_const_input(at::squeeze(target))
274 .build();
275
276 AT_DISPATCH_FLOATING_TYPES_AND2(
277 ScalarType::Half,
278 ScalarType::BFloat16,
279 loss.scalar_type(),
280 "binary_cross_entropy",
281 [&] {
282 at::native::cpu_kernel(
283 iter, [](scalar_t input_val, scalar_t target_val) {
284 TORCH_CHECK(
285 (input_val >= 0) && (input_val <= 1),
286 "all elements of input should be between 0 and 1");
287 TORCH_CHECK(
288 (target_val >= 0) && (target_val <= 1),
289 "all elements of target should be between 0 and 1");
290
291 // Binary cross entropy tensor is defined by the equation:
292 // L = -w (y ln(x) + (1-y) ln(1-x))
293 return (target_val - scalar_t(1)) *
294 std::max(scalar_t(std::log1p(-input_val)), scalar_t(-100)) -
295 target_val *
296 std::max(scalar_t(std::log(input_val)), scalar_t(-100));
297 });
298 });
299
300 if (weight.defined()) {
301 loss.mul_(weight);
302 }
303 if (reduction != at::Reduction::None) {
304 Tensor loss_reduced = apply_loss_reduction(loss, reduction);
305 loss.resize_as_(loss_reduced).copy_(loss_reduced);
306 }
307 return loss;
308 }
309
binary_cross_entropy_backward_cpu(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction)310 Tensor binary_cross_entropy_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
311 // See [Note: hacky wrapper removal for optional tensor]
312 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
313 const Tensor& weight = *weight_maybe_owned;
314
315 Tensor grad_input = at::empty_like(input);
316 return at::native::binary_cross_entropy_backward_out_cpu(
317 grad, input, target, weight, reduction, grad_input);
318 }
319
binary_cross_entropy_backward_out_cpu(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & grad_input)320 Tensor& binary_cross_entropy_backward_out_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& grad_input) {
321 // See [Note: hacky wrapper removal for optional tensor]
322 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
323 const Tensor& weight = *weight_maybe_owned;
324
325 Tensor grad_input_squeezed = at::squeeze(grad_input);
326
327 auto iter = TensorIteratorConfig()
328 .add_output(grad_input_squeezed)
329 .add_owned_const_input(at::squeeze(grad))
330 .add_owned_const_input(at::squeeze(input))
331 .add_owned_const_input(at::squeeze(target))
332 .build();
333
334 AT_DISPATCH_FLOATING_TYPES_AND2(
335 ScalarType::Half,
336 ScalarType::BFloat16,
337 grad_input.scalar_type(),
338 "binary_cross_entropy_backward",
339 [&] {
340 at::native::cpu_kernel(
341 iter,
342 [](scalar_t grad_val, scalar_t input_val, scalar_t target_val) {
343 // The gradient is the partial derivative of BCELoss
344 // with respect to x
345 // d(L)/d(x) = -w (y - x) / (x - x^2)
346 return grad_val * (input_val - target_val) /
347 (scalar_t(std::max(
348 (scalar_t(1) - input_val) * input_val,
349 scalar_t(EPSILON))));
350 });
351 });
352
353 if (weight.defined()) {
354 grad_input.mul_(weight);
355 }
356 if (reduction == at::Reduction::Mean) {
357 grad_input.div_(input.numel());
358 }
359 return grad_input;
360 }
361
binary_cross_entropy_with_logits(const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & pos_weight_opt,int64_t reduction)362 Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& pos_weight_opt, int64_t reduction) {
363 // See [Note: hacky wrapper removal for optional tensor]
364 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
365 const Tensor& weight = *weight_maybe_owned;
366 c10::MaybeOwned<Tensor> pos_weight_maybe_owned = at::borrow_from_optional_tensor(pos_weight_opt);
367 const Tensor& pos_weight = *pos_weight_maybe_owned;
368
369 auto log_sigmoid_input = at::log_sigmoid(input);
370 if (pos_weight.defined()) {
371 // pos_weight need to be broadcasted, thus mul(target) is not inplace.
372 auto log_weight = (pos_weight - 1).mul(target).add_(1);
373 log_sigmoid_input.mul_(log_weight);
374 }
375
376 Tensor loss = (1 - target).mul_(input).sub_(log_sigmoid_input);
377
378 if (weight.defined()) {
379 loss.mul_(weight);
380 }
381
382 return apply_loss_reduction(loss, reduction);
383 }
384
poisson_nll_loss(const Tensor & input,const Tensor & target,const bool log_input,const bool full,const double eps,const int64_t reduction)385 Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction)
386 {
387 Tensor loss;
388 if (log_input) {
389 loss = at::exp(input) - target * input;
390 } else {
391 loss = input - target * at::log(input + eps);
392 }
393
394 if (full) {
395 auto stirling_term = target * at::log(target) - target + 0.5 * at::log(2 * c10::pi<double> * target);
396 loss += stirling_term.masked_fill(target <= 1, 0);
397 }
398
399 return apply_loss_reduction(loss, reduction);
400 }
401
soft_margin_loss_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,Tensor & grad_input)402 Tensor& soft_margin_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, Tensor& grad_input) {
403 auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.;
404 auto z = at::exp(-target * input);
405 // inplace version of: grad_input = -norm * target * z / (1. + z) * grad_output;
406 at::mul_out(grad_input, target, z).mul_(-norm);
407 z.add_(1);
408 grad_input.div_(z).mul_(grad_output);
409 return grad_input;
410 }
411
soft_margin_loss_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction)412 Tensor soft_margin_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
413 auto grad_input = at::empty({0}, input.options());
414 at::soft_margin_loss_backward_out(grad_input, grad_output, input, target, reduction);
415 return grad_input;
416 }
417
soft_margin_loss_out(const Tensor & input,const Tensor & target,int64_t reduction,Tensor & output)418 Tensor& soft_margin_loss_out(const Tensor& input,
419 const Tensor& target,
420 int64_t reduction,
421 Tensor& output) {
422 // compute inplace variant of: output = at::log1p(at::exp(-input * target));
423 at::neg_out(output, input).mul_(target).exp_().log1p_();
424 if (reduction != Reduction::None) {
425 auto tmp = apply_loss_reduction(output, reduction);
426 output.resize_({});
427 output.copy_(tmp);
428 }
429 return output;
430 }
431
soft_margin_loss(const Tensor & input,const Tensor & target,int64_t reduction)432 Tensor soft_margin_loss(
433 const Tensor& input,
434 const Tensor& target,
435 int64_t reduction) {
436 auto output = at::empty({0}, input.options());
437 at::soft_margin_loss_out(output, input, target, reduction);
438 return output;
439 }
440
smooth_l1_loss_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double beta,Tensor & grad_input)441 Tensor& smooth_l1_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta, Tensor& grad_input) {
442 auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.;
443 auto iter = at::TensorIteratorConfig()
444 .add_output(grad_input)
445 .add_const_input(input)
446 .add_const_input(target)
447 .add_const_input(grad_output)
448 .promote_inputs_to_common_dtype(true)
449 .cast_common_dtype_to_outputs(true)
450 .enforce_safe_casting_to_output(true)
451 .build();
452 smooth_l1_backward_stub(iter.device_type(), iter, norm, beta);
453 return grad_input;
454 }
455
smooth_l1_loss_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double beta)456 Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta) {
457 auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
458 return at::smooth_l1_loss_backward_out(grad_input, grad_output, input, target, reduction, beta);
459 }
460
huber_loss(const Tensor & input,const Tensor & target,int64_t reduction,double delta)461 Tensor huber_loss(const Tensor& input, const Tensor& target, int64_t reduction, double delta) {
462 TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
463 Tensor loss = at::empty_like(input);
464 auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
465 huber_stub(iter.device_type(), iter, delta);
466 return apply_loss_reduction(loss, reduction);
467 }
468
huber_loss_out(const Tensor & input,const Tensor & target,int64_t reduction,double delta,Tensor & result)469 Tensor& huber_loss_out(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& result) {
470 TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
471 auto iter = TensorIterator::borrowing_binary_op(result, input, target);
472 huber_stub(iter.device_type(), iter, delta);
473 if (reduction != Reduction::None) {
474 auto reduced = apply_loss_reduction(result, reduction);
475 result.resize_({});
476 result.copy_(reduced);
477 }
478 return result;
479 }
480
huber_loss_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double delta)481 Tensor huber_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double delta) {
482 auto grad_input = at::zeros_like(input, MemoryFormat::Contiguous);
483 return at::huber_loss_backward_out(grad_input, grad_output, input, target, reduction, delta);
484 }
485
huber_loss_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double delta,Tensor & grad_input)486 Tensor& huber_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& grad_input) {
487 auto norm = (reduction == Reduction::Mean) ? (1. / input.numel()) : 1.;
488 auto iter = at::TensorIteratorConfig()
489 .add_output(grad_input)
490 .add_const_input(input)
491 .add_const_input(target)
492 .add_const_input(grad_output)
493 .build();
494 huber_backward_stub(iter.device_type(), iter, norm, delta);
495 return grad_input;
496 }
497
mse_loss_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction)498 Tensor mse_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
499 Tensor grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
500 return at::mse_loss_backward_out(grad_input, grad_output, input, target, reduction);
501 }
502
mse_loss_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,Tensor & grad_input)503 Tensor& mse_loss_backward_out(const Tensor& grad_output,
504 const Tensor& input, const Tensor& target, int64_t reduction, Tensor& grad_input) {
505 auto norm = reduction == Reduction::Mean ? 2. / input.numel() : 2.;
506 auto iter = at::TensorIteratorConfig()
507 .add_output(grad_input)
508 .add_const_input(input)
509 .add_const_input(target)
510 .add_const_input(grad_output)
511 .build();
512 mse_backward_stub(iter.device_type(), iter, norm);
513 return grad_input;
514 }
515
l1_loss(const Tensor & input,const Tensor & target,int64_t reduction)516 Tensor l1_loss(const Tensor& input, const Tensor& target, int64_t reduction) {
517 return apply_loss_reduction((input - target).abs(), reduction);
518 }
519 } // namespace at::native
520