xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/LossMultiMargin.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/LossMulti.h>
6 #include <c10/util/irange.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/multi_margin_loss_backward_native.h>
14 #include <ATen/ops/multi_margin_loss_native.h>
15 #endif
16 
17 namespace at::native {
18 
19 namespace {
20 
21 template <typename scalar_t>
multi_margin_inner_sum_cpu(const scalar_t * input_data,const scalar_t * weight_data,const int p,const scalar_t margin,const int64_t dim,const int64_t target_idx)22 inline scalar_t multi_margin_inner_sum_cpu(
23     const scalar_t* input_data,
24     const scalar_t* weight_data,
25     const int p,
26     const scalar_t margin,
27     const int64_t dim,
28     const int64_t target_idx) {
29   const scalar_t input_target = input_data[target_idx];
30   scalar_t sum = 0;
31   for (const auto d : c10::irange(dim)) {
32     if (d == target_idx) {
33       continue;
34     }
35 
36     const scalar_t z = margin - input_target + input_data[d];
37     if (z > 0) {
38       scalar_t h = (p == 1) ? z : z * z;
39       if (weight_data != nullptr) {
40         h *= weight_data[target_idx];
41       }
42       sum += h;
43     }
44   }
45 
46   sum /= dim;
47   return sum;
48 }
49 
target_index_checked(const int64_t * target_data,const int64_t index,const int64_t dim)50 inline int64_t target_index_checked(
51     const int64_t* target_data,
52     const int64_t index,
53     const int64_t dim) {
54   const int64_t idx = target_data[index];
55   TORCH_CHECK(idx >= 0 && idx < dim, "target out of range");
56   return idx;
57 }
58 
59 template <typename scalar_t>
multi_margin_loss_cpu_kernel(Tensor & output,const scalar_t * input_data,const int64_t * target_data,const int p,scalar_t margin,const scalar_t * weight_data,const int64_t nframe,const int64_t dim,const int64_t reduction)60 static inline void multi_margin_loss_cpu_kernel(
61     Tensor& output,
62     const scalar_t* input_data,
63     const int64_t* target_data,
64     const int p,
65     scalar_t margin,
66     const scalar_t* weight_data,
67     const int64_t nframe,
68     const int64_t dim,
69     const int64_t reduction) {
70   using accscalar_t = at::acc_type<scalar_t, false>;
71 
72   // dim() != 0 check is for 1d input which produces a scalar output (that
73   // cannot be handled by TensorAccessor)
74   if (reduction == Reduction::None && output.dim() > 0) {
75     auto output_acc = output.accessor<scalar_t, 1>();
76     for (const auto t : c10::irange(nframe)) {
77       const auto idx = target_index_checked(target_data, t, dim);
78       auto sum = multi_margin_inner_sum_cpu(
79           input_data, weight_data, p, margin, dim, idx);
80       output_acc[t] = sum;
81       input_data += dim;
82     }
83   } else {
84     accscalar_t sum = 0;
85     auto output_acc = output.data_ptr<scalar_t>();
86     for (const auto t : c10::irange(nframe)) {
87       const auto idx = target_index_checked(target_data, t, dim);
88       sum += multi_margin_inner_sum_cpu(
89           input_data, weight_data, p, margin, dim, idx);
90       input_data += dim;
91     }
92     if (reduction == Reduction::Mean) {
93       sum /= nframe;
94     }
95     output_acc[0] = sum;
96   }
97 }
98 
multi_margin_loss_out_cpu_template(Tensor & output,const Tensor & input,const Tensor & target,int p,const Scalar & margin,const std::optional<Tensor> & weight,int64_t reduction)99 void multi_margin_loss_out_cpu_template(
100     Tensor& output,
101     const Tensor& input,
102     const Tensor& target,
103     int p,
104     const Scalar& margin,
105     const std::optional<Tensor>& weight,
106     int64_t reduction) {
107   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
108   int64_t nframe, dim;
109   const auto ndims = input.dim();
110 
111   TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
112 
113   multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);
114 
115   // produce a scalar output for 1d input
116   if (reduction == Reduction::None && target.dim() > 0) {
117     output.resize_({nframe});
118   } else {
119     output.resize_({});
120   }
121   if (input.numel() == 0) {
122     return;
123   }
124 
125   auto input_contiguous = input.contiguous();
126   auto target_contiguous = target.contiguous();
127   Tensor weight_contiguous;
128   if (weight && weight->defined()) {
129     weight_contiguous = weight->contiguous();
130   }
131 
132   AT_DISPATCH_FLOATING_TYPES(
133       input.scalar_type(), "multi_margin_loss_cpu_kernel", [&] {
134         auto input_data = input_contiguous.const_data_ptr<scalar_t>();
135         auto target_data = target_contiguous.const_data_ptr<int64_t>();
136         auto weight_data =
137             weight_contiguous.defined() ? weight_contiguous.const_data_ptr<scalar_t>() : nullptr;
138         multi_margin_loss_cpu_kernel<scalar_t>(
139             output,
140             input_data,
141             target_data,
142             p,
143             margin.to<scalar_t>(),
144             weight_data,
145             nframe,
146             dim,
147             reduction);
148       });
149 }
150 
151 template <typename scalar_t>
multi_margin_loss_backward_cpu_kernel(scalar_t * grad_input_data,const Tensor & grad_output,const scalar_t * input_data,const int64_t * target_data,int p,scalar_t margin,scalar_t g,const scalar_t * weight_data,int64_t nframe,int64_t dim,int64_t reduction)152 static void multi_margin_loss_backward_cpu_kernel(
153     scalar_t* grad_input_data,
154     const Tensor& grad_output,
155     const scalar_t* input_data,
156     const int64_t* target_data,
157     int p,
158     scalar_t margin,
159     scalar_t g,
160     const scalar_t* weight_data,
161     int64_t nframe,
162     int64_t dim,
163     int64_t reduction) {
164   scalar_t* grad_input_row_data = grad_input_data;
165   for (const auto t : c10::irange(nframe)) {
166     int64_t target_idx = target_index_checked(target_data, t, dim);
167     scalar_t input_target = input_data[target_idx];
168     scalar_t grad_input_target = 0;
169     for (const auto d : c10::irange(dim)) {
170       scalar_t z = margin - input_target + input_data[d];
171       if (d == target_idx) {
172         continue;
173       }
174 
175       if (z > 0) {
176         scalar_t h = (p == 1) ? g : 2 * g * z;
177         if (weight_data != nullptr) {
178           h *= weight_data[target_idx];
179         }
180         grad_input_target -= h;
181         grad_input_row_data[d] = h;
182       } else {
183         grad_input_row_data[d] = 0;
184       }
185     }
186     grad_input_row_data[target_idx] = grad_input_target;
187 
188     input_data += dim;
189     grad_input_row_data += dim;
190   }
191 
192   if (reduction != Reduction::None || grad_output.dim() == 0) {
193     assert(
194         reduction != Reduction::None || grad_output.dim() > 0 ||
195         nframe == 1); // check 1d scalar fallback-case
196     const auto d = *grad_output.const_data_ptr<scalar_t>();
197     for (int64_t t = 0; t < nframe * dim; t++) {
198       grad_input_data[t] *= d;
199     }
200   } else {
201     auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
202     for (const auto t : c10::irange(nframe)) {
203       for (const auto d : c10::irange(dim)) {
204         grad_input_data[t * dim + d] *= grad_output_acc[t];
205       }
206     }
207   }
208 }
209 
multi_margin_loss_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,int p,const Scalar & margin,const Tensor & weight,int64_t reduction)210 void multi_margin_loss_backward_out_cpu_template(
211     Tensor& grad_input,
212     const Tensor& grad_output,
213     const Tensor& input,
214     const Tensor& target,
215     int p,
216     const Scalar& margin,
217     const Tensor& weight,
218     int64_t reduction) {
219   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
220   int64_t nframe, dim;
221   const auto ndims = input.dim();
222 
223   TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
224 
225   multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);
226   grad_input.resize_as_(input);
227   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
228 
229   if (input.numel() == 0) {
230     return;
231   }
232 
233   auto input_contiguous = input.contiguous();
234   auto target_contiguous = target.contiguous();
235   auto weight_contiguous = weight.contiguous();
236   AT_DISPATCH_FLOATING_TYPES(
237       input.scalar_type(), "multi_margin_loss_backward_cpu_kernel", [&] {
238         auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
239         auto input_data = input_contiguous.const_data_ptr<scalar_t>();
240         auto target_data = target_contiguous.const_data_ptr<int64_t>();
241         auto weight_data = weight_contiguous.defined()
242             ? weight_contiguous.const_data_ptr<scalar_t>()
243             : nullptr;
244         scalar_t g = reduction == Reduction::Mean
245             ? static_cast<scalar_t>(1. / (nframe * dim))
246             : static_cast<scalar_t>(1. / dim);
247         multi_margin_loss_backward_cpu_kernel<scalar_t>(
248             grad_input_data,
249             grad_output,
250             input_data,
251             target_data,
252             p,
253             margin.to<scalar_t>(),
254             g,
255             weight_data,
256             nframe,
257             dim,
258             reduction);
259       });
260 }
261 
262 } // namespace
263 
multi_margin_loss_cpu(const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weight,int64_t reduction)264 Tensor multi_margin_loss_cpu(
265     const Tensor& input,
266     const Tensor& target,
267     const Scalar& p,
268     const Scalar& margin,
269     const std::optional<Tensor>& weight,
270     int64_t reduction) {
271   auto output = at::empty({0}, input.options());
272   multi_margin_loss_out_cpu_template(
273       output, input, target, p.toInt(), margin, weight, reduction);
274   return output;
275 }
276 
multi_margin_loss_cpu_out(const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weight,int64_t reduction,Tensor & output)277 Tensor& multi_margin_loss_cpu_out(const Tensor& input,
278     const Tensor& target,
279     const Scalar& p,
280     const Scalar& margin,
281     const std::optional<Tensor>& weight,
282     int64_t reduction,
283     Tensor& output) {
284   multi_margin_loss_out_cpu_template(
285       output, input, target, p.toInt(), margin, weight, reduction);
286   return output;
287 }
288 
multi_margin_loss_cpu_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weight_opt,int64_t reduction)289 Tensor multi_margin_loss_cpu_backward(
290     const Tensor& grad_output,
291     const Tensor& input,
292     const Tensor& target,
293     const Scalar& p,
294     const Scalar& margin, const std::optional<Tensor>& weight_opt,
295     int64_t reduction) {
296   // See [Note: hacky wrapper removal for optional tensor]
297   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
298   const Tensor& weight = *weight_maybe_owned;
299 
300   auto grad_input = at::empty({0}, input.options());
301   multi_margin_loss_backward_out_cpu_template(
302       grad_input,
303       grad_output,
304       input,
305       target,
306       p.toInt(),
307       margin,
308       weight,
309       reduction);
310   return grad_input;
311 }
312 
multi_margin_loss_cpu_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & grad_input)313 Tensor& multi_margin_loss_cpu_backward_out(const Tensor& grad_output,
314     const Tensor& input,
315     const Tensor& target,
316     const Scalar& p,
317     const Scalar& margin, const std::optional<Tensor>& weight_opt,
318     int64_t reduction,
319     Tensor& grad_input) {
320   // See [Note: hacky wrapper removal for optional tensor]
321   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
322   const Tensor& weight = *weight_maybe_owned;
323 
324   multi_margin_loss_backward_out_cpu_template(
325       grad_input,
326       grad_output,
327       input,
328       target,
329       p.toInt(),
330       margin,
331       weight,
332       reduction);
333   return grad_input;
334 }
335 
336 } // namespace at::native
337