1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/LossMulti.h>
6 #include <c10/util/irange.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/multi_margin_loss_backward_native.h>
14 #include <ATen/ops/multi_margin_loss_native.h>
15 #endif
16
17 namespace at::native {
18
19 namespace {
20
21 template <typename scalar_t>
multi_margin_inner_sum_cpu(const scalar_t * input_data,const scalar_t * weight_data,const int p,const scalar_t margin,const int64_t dim,const int64_t target_idx)22 inline scalar_t multi_margin_inner_sum_cpu(
23 const scalar_t* input_data,
24 const scalar_t* weight_data,
25 const int p,
26 const scalar_t margin,
27 const int64_t dim,
28 const int64_t target_idx) {
29 const scalar_t input_target = input_data[target_idx];
30 scalar_t sum = 0;
31 for (const auto d : c10::irange(dim)) {
32 if (d == target_idx) {
33 continue;
34 }
35
36 const scalar_t z = margin - input_target + input_data[d];
37 if (z > 0) {
38 scalar_t h = (p == 1) ? z : z * z;
39 if (weight_data != nullptr) {
40 h *= weight_data[target_idx];
41 }
42 sum += h;
43 }
44 }
45
46 sum /= dim;
47 return sum;
48 }
49
target_index_checked(const int64_t * target_data,const int64_t index,const int64_t dim)50 inline int64_t target_index_checked(
51 const int64_t* target_data,
52 const int64_t index,
53 const int64_t dim) {
54 const int64_t idx = target_data[index];
55 TORCH_CHECK(idx >= 0 && idx < dim, "target out of range");
56 return idx;
57 }
58
59 template <typename scalar_t>
multi_margin_loss_cpu_kernel(Tensor & output,const scalar_t * input_data,const int64_t * target_data,const int p,scalar_t margin,const scalar_t * weight_data,const int64_t nframe,const int64_t dim,const int64_t reduction)60 static inline void multi_margin_loss_cpu_kernel(
61 Tensor& output,
62 const scalar_t* input_data,
63 const int64_t* target_data,
64 const int p,
65 scalar_t margin,
66 const scalar_t* weight_data,
67 const int64_t nframe,
68 const int64_t dim,
69 const int64_t reduction) {
70 using accscalar_t = at::acc_type<scalar_t, false>;
71
72 // dim() != 0 check is for 1d input which produces a scalar output (that
73 // cannot be handled by TensorAccessor)
74 if (reduction == Reduction::None && output.dim() > 0) {
75 auto output_acc = output.accessor<scalar_t, 1>();
76 for (const auto t : c10::irange(nframe)) {
77 const auto idx = target_index_checked(target_data, t, dim);
78 auto sum = multi_margin_inner_sum_cpu(
79 input_data, weight_data, p, margin, dim, idx);
80 output_acc[t] = sum;
81 input_data += dim;
82 }
83 } else {
84 accscalar_t sum = 0;
85 auto output_acc = output.data_ptr<scalar_t>();
86 for (const auto t : c10::irange(nframe)) {
87 const auto idx = target_index_checked(target_data, t, dim);
88 sum += multi_margin_inner_sum_cpu(
89 input_data, weight_data, p, margin, dim, idx);
90 input_data += dim;
91 }
92 if (reduction == Reduction::Mean) {
93 sum /= nframe;
94 }
95 output_acc[0] = sum;
96 }
97 }
98
multi_margin_loss_out_cpu_template(Tensor & output,const Tensor & input,const Tensor & target,int p,const Scalar & margin,const std::optional<Tensor> & weight,int64_t reduction)99 void multi_margin_loss_out_cpu_template(
100 Tensor& output,
101 const Tensor& input,
102 const Tensor& target,
103 int p,
104 const Scalar& margin,
105 const std::optional<Tensor>& weight,
106 int64_t reduction) {
107 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
108 int64_t nframe, dim;
109 const auto ndims = input.dim();
110
111 TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
112
113 multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);
114
115 // produce a scalar output for 1d input
116 if (reduction == Reduction::None && target.dim() > 0) {
117 output.resize_({nframe});
118 } else {
119 output.resize_({});
120 }
121 if (input.numel() == 0) {
122 return;
123 }
124
125 auto input_contiguous = input.contiguous();
126 auto target_contiguous = target.contiguous();
127 Tensor weight_contiguous;
128 if (weight && weight->defined()) {
129 weight_contiguous = weight->contiguous();
130 }
131
132 AT_DISPATCH_FLOATING_TYPES(
133 input.scalar_type(), "multi_margin_loss_cpu_kernel", [&] {
134 auto input_data = input_contiguous.const_data_ptr<scalar_t>();
135 auto target_data = target_contiguous.const_data_ptr<int64_t>();
136 auto weight_data =
137 weight_contiguous.defined() ? weight_contiguous.const_data_ptr<scalar_t>() : nullptr;
138 multi_margin_loss_cpu_kernel<scalar_t>(
139 output,
140 input_data,
141 target_data,
142 p,
143 margin.to<scalar_t>(),
144 weight_data,
145 nframe,
146 dim,
147 reduction);
148 });
149 }
150
151 template <typename scalar_t>
multi_margin_loss_backward_cpu_kernel(scalar_t * grad_input_data,const Tensor & grad_output,const scalar_t * input_data,const int64_t * target_data,int p,scalar_t margin,scalar_t g,const scalar_t * weight_data,int64_t nframe,int64_t dim,int64_t reduction)152 static void multi_margin_loss_backward_cpu_kernel(
153 scalar_t* grad_input_data,
154 const Tensor& grad_output,
155 const scalar_t* input_data,
156 const int64_t* target_data,
157 int p,
158 scalar_t margin,
159 scalar_t g,
160 const scalar_t* weight_data,
161 int64_t nframe,
162 int64_t dim,
163 int64_t reduction) {
164 scalar_t* grad_input_row_data = grad_input_data;
165 for (const auto t : c10::irange(nframe)) {
166 int64_t target_idx = target_index_checked(target_data, t, dim);
167 scalar_t input_target = input_data[target_idx];
168 scalar_t grad_input_target = 0;
169 for (const auto d : c10::irange(dim)) {
170 scalar_t z = margin - input_target + input_data[d];
171 if (d == target_idx) {
172 continue;
173 }
174
175 if (z > 0) {
176 scalar_t h = (p == 1) ? g : 2 * g * z;
177 if (weight_data != nullptr) {
178 h *= weight_data[target_idx];
179 }
180 grad_input_target -= h;
181 grad_input_row_data[d] = h;
182 } else {
183 grad_input_row_data[d] = 0;
184 }
185 }
186 grad_input_row_data[target_idx] = grad_input_target;
187
188 input_data += dim;
189 grad_input_row_data += dim;
190 }
191
192 if (reduction != Reduction::None || grad_output.dim() == 0) {
193 assert(
194 reduction != Reduction::None || grad_output.dim() > 0 ||
195 nframe == 1); // check 1d scalar fallback-case
196 const auto d = *grad_output.const_data_ptr<scalar_t>();
197 for (int64_t t = 0; t < nframe * dim; t++) {
198 grad_input_data[t] *= d;
199 }
200 } else {
201 auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
202 for (const auto t : c10::irange(nframe)) {
203 for (const auto d : c10::irange(dim)) {
204 grad_input_data[t * dim + d] *= grad_output_acc[t];
205 }
206 }
207 }
208 }
209
multi_margin_loss_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,int p,const Scalar & margin,const Tensor & weight,int64_t reduction)210 void multi_margin_loss_backward_out_cpu_template(
211 Tensor& grad_input,
212 const Tensor& grad_output,
213 const Tensor& input,
214 const Tensor& target,
215 int p,
216 const Scalar& margin,
217 const Tensor& weight,
218 int64_t reduction) {
219 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
220 int64_t nframe, dim;
221 const auto ndims = input.dim();
222
223 TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
224
225 multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);
226 grad_input.resize_as_(input);
227 TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
228
229 if (input.numel() == 0) {
230 return;
231 }
232
233 auto input_contiguous = input.contiguous();
234 auto target_contiguous = target.contiguous();
235 auto weight_contiguous = weight.contiguous();
236 AT_DISPATCH_FLOATING_TYPES(
237 input.scalar_type(), "multi_margin_loss_backward_cpu_kernel", [&] {
238 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
239 auto input_data = input_contiguous.const_data_ptr<scalar_t>();
240 auto target_data = target_contiguous.const_data_ptr<int64_t>();
241 auto weight_data = weight_contiguous.defined()
242 ? weight_contiguous.const_data_ptr<scalar_t>()
243 : nullptr;
244 scalar_t g = reduction == Reduction::Mean
245 ? static_cast<scalar_t>(1. / (nframe * dim))
246 : static_cast<scalar_t>(1. / dim);
247 multi_margin_loss_backward_cpu_kernel<scalar_t>(
248 grad_input_data,
249 grad_output,
250 input_data,
251 target_data,
252 p,
253 margin.to<scalar_t>(),
254 g,
255 weight_data,
256 nframe,
257 dim,
258 reduction);
259 });
260 }
261
262 } // namespace
263
multi_margin_loss_cpu(const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weight,int64_t reduction)264 Tensor multi_margin_loss_cpu(
265 const Tensor& input,
266 const Tensor& target,
267 const Scalar& p,
268 const Scalar& margin,
269 const std::optional<Tensor>& weight,
270 int64_t reduction) {
271 auto output = at::empty({0}, input.options());
272 multi_margin_loss_out_cpu_template(
273 output, input, target, p.toInt(), margin, weight, reduction);
274 return output;
275 }
276
multi_margin_loss_cpu_out(const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weight,int64_t reduction,Tensor & output)277 Tensor& multi_margin_loss_cpu_out(const Tensor& input,
278 const Tensor& target,
279 const Scalar& p,
280 const Scalar& margin,
281 const std::optional<Tensor>& weight,
282 int64_t reduction,
283 Tensor& output) {
284 multi_margin_loss_out_cpu_template(
285 output, input, target, p.toInt(), margin, weight, reduction);
286 return output;
287 }
288
multi_margin_loss_cpu_backward(const Tensor & grad_output,const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weight_opt,int64_t reduction)289 Tensor multi_margin_loss_cpu_backward(
290 const Tensor& grad_output,
291 const Tensor& input,
292 const Tensor& target,
293 const Scalar& p,
294 const Scalar& margin, const std::optional<Tensor>& weight_opt,
295 int64_t reduction) {
296 // See [Note: hacky wrapper removal for optional tensor]
297 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
298 const Tensor& weight = *weight_maybe_owned;
299
300 auto grad_input = at::empty({0}, input.options());
301 multi_margin_loss_backward_out_cpu_template(
302 grad_input,
303 grad_output,
304 input,
305 target,
306 p.toInt(),
307 margin,
308 weight,
309 reduction);
310 return grad_input;
311 }
312
multi_margin_loss_cpu_backward_out(const Tensor & grad_output,const Tensor & input,const Tensor & target,const Scalar & p,const Scalar & margin,const std::optional<Tensor> & weight_opt,int64_t reduction,Tensor & grad_input)313 Tensor& multi_margin_loss_cpu_backward_out(const Tensor& grad_output,
314 const Tensor& input,
315 const Tensor& target,
316 const Scalar& p,
317 const Scalar& margin, const std::optional<Tensor>& weight_opt,
318 int64_t reduction,
319 Tensor& grad_input) {
320 // See [Note: hacky wrapper removal for optional tensor]
321 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
322 const Tensor& weight = *weight_maybe_owned;
323
324 multi_margin_loss_backward_out_cpu_template(
325 grad_input,
326 grad_output,
327 input,
328 target,
329 p.toInt(),
330 margin,
331 weight,
332 reduction);
333 return grad_input;
334 }
335
336 } // namespace at::native
337