1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/native/LossMulti.h>
7 #include <c10/util/irange.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/empty.h>
14 #include <ATen/ops/multilabel_margin_loss_backward_native.h>
15 #include <ATen/ops/multilabel_margin_loss_forward.h>
16 #include <ATen/ops/multilabel_margin_loss_forward_native.h>
17 #include <ATen/ops/multilabel_margin_loss_native.h>
18 #include <ATen/ops/zeros_like.h>
19 #endif
20
21 namespace at::native {
22
23 namespace {
24
25 template <typename scalar_t>
multilabel_margin_loss_forward_inner_sum_cpu(const scalar_t * input_data,const int64_t * target_data,scalar_t * is_target_data,int64_t dim)26 inline scalar_t multilabel_margin_loss_forward_inner_sum_cpu(
27 const scalar_t* input_data,
28 const int64_t* target_data,
29 scalar_t* is_target_data,
30 int64_t dim) {
31 using accscalar_t = at::acc_type<scalar_t, false>;
32 accscalar_t sum = 0;
33 for (const auto ddt : c10::irange(dim)) {
34 int64_t target_idx = target_data[ddt];
35 if (target_idx < 0) {
36 break;
37 }
38 is_target_data[target_idx] = 1;
39 }
40 for (const auto dt : c10::irange(dim)) {
41 int64_t target_idx = target_data[dt];
42 if (target_idx < 0) {
43 break;
44 }
45
46 scalar_t input_target = input_data[target_idx];
47 for (const auto d : c10::irange(dim)) {
48 if (!is_target_data[d]) {
49 scalar_t z = 1 - input_target + input_data[d];
50 if (z > 0) {
51 sum += z;
52 }
53 }
54 }
55 }
56
57 return sum;
58 }
59
60 template <typename scalar_t>
multilabel_margin_loss_forward_out_frame(const Tensor & input_contiguous,const Tensor & target_contiguous,Tensor & output,Tensor & is_target,int64_t reduction,int64_t nframe,int64_t dim)61 static void multilabel_margin_loss_forward_out_frame(
62 const Tensor& input_contiguous,
63 const Tensor& target_contiguous,
64 Tensor& output,
65 Tensor& is_target,
66 int64_t reduction,
67 int64_t nframe,
68 int64_t dim) {
69 using accscalar_t = at::acc_type<scalar_t, false>;
70 const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
71 const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
72 scalar_t* is_target_data = is_target.data_ptr<scalar_t>();
73
74 if (reduction != Reduction::None || output.dim() == 0) {
75 scalar_t* output_data = output.data_ptr<scalar_t>();
76
77 accscalar_t sum = 0;
78
79 for (C10_UNUSED const auto t : c10::irange(nframe)) {
80 sum += multilabel_margin_loss_forward_inner_sum_cpu(
81 input_data, target_data, is_target_data, dim);
82
83 input_data += dim;
84 target_data += dim;
85 is_target_data += dim;
86 }
87
88 sum /= dim;
89 if (reduction == Reduction::Mean) {
90 sum /= nframe;
91 }
92
93 *output_data = sum; // write scalar output value
94 } else {
95 auto output_acc = output.accessor<scalar_t, 1>();
96
97 for (const auto t : c10::irange(nframe)) {
98 scalar_t sum = multilabel_margin_loss_forward_inner_sum_cpu(
99 input_data, target_data, is_target_data, dim);
100
101 sum /= dim;
102 output_acc[t] = sum;
103
104 input_data += dim;
105 target_data += dim;
106 is_target_data += dim;
107 }
108 }
109 }
110
multilabel_margin_loss_forward_out_cpu_template(const Tensor & input,const Tensor & target,Tensor & output,Tensor & is_target,int64_t reduction)111 static void multilabel_margin_loss_forward_out_cpu_template(
112 const Tensor& input,
113 const Tensor& target,
114 Tensor& output,
115 Tensor& is_target,
116 int64_t reduction) {
117 #ifndef STRIP_ERROR_MESSAGES
118 auto target_arg = TensorArg(target, "target", 2);
119 #endif
120 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
121 int64_t nframe, dim;
122 const int64_t ndims = input.dim();
123 multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
124
125 // special case target.dim() <= 1: produce scalar output for scalar inputs
126 // even if reduction == Reduction::None
127 if (reduction != Reduction::None || target.dim() <= 1) {
128 output.resize_({});
129 } else {
130 output.resize_({nframe});
131 }
132
133 is_target.resize_as_(target);
134 TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous");
135 is_target.zero_();
136
137 if (input.numel() == 0) {
138 return;
139 }
140
141 TORCH_CHECK(
142 target.min().item<int64_t>() >= -1, target_arg, " is out of range");
143 TORCH_CHECK(
144 target.max().item<int64_t>() < dim, target_arg, " is out of range");
145
146 auto input_contiguous = input.contiguous();
147 auto target_contiguous = target.contiguous();
148
149 AT_DISPATCH_FLOATING_TYPES(
150 input.scalar_type(), "multilabel_margin_loss_forward_out_frame", [&] {
151 multilabel_margin_loss_forward_out_frame<scalar_t>(
152 input_contiguous, target_contiguous, output, is_target, reduction, nframe, dim);
153 });
154 }
155
156 template <typename scalar_t>
multilabel_margin_loss_backward_out_frame(Tensor & grad_input,const Tensor & grad_output,const Tensor & input_contiguous,const Tensor & target_contiguous,int64_t reduction,const Tensor & is_target_contiguous,int64_t nframe,int64_t dim)157 static void multilabel_margin_loss_backward_out_frame(
158 Tensor& grad_input,
159 const Tensor& grad_output,
160 const Tensor& input_contiguous,
161 const Tensor& target_contiguous,
162 int64_t reduction,
163 const Tensor& is_target_contiguous,
164 int64_t nframe,
165 int64_t dim) {
166 #ifndef STRIP_ERROR_MESSAGES
167 auto is_target_arg = TensorArg(is_target_contiguous, "is_target", 5);
168 #endif
169
170 TORCH_CHECK(
171 is_target_contiguous.min().item<scalar_t>() >= 0, is_target_arg, " is out of range");
172 TORCH_CHECK(
173 is_target_contiguous.max().item<scalar_t>() <= 1, is_target_arg, " is out of range");
174
175 const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
176 const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
177 const scalar_t* is_target_data = is_target_contiguous.const_data_ptr<scalar_t>();
178 scalar_t g = static_cast<scalar_t>(
179 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
180 reduction == Reduction::Mean ? 1. / (nframe * dim) : 1. / dim);
181
182 scalar_t* grad_input_row_data = grad_input.mutable_data_ptr<scalar_t>();
183 for (C10_UNUSED const auto t : c10::irange(nframe)) {
184 for (const auto dt : c10::irange(dim)) {
185 int64_t target_idx = target_data[dt];
186 if (target_idx < 0) {
187 break;
188 }
189
190 scalar_t input_target = input_data[target_idx];
191 for (const auto d : c10::irange(dim)) {
192 if (!is_target_data[d]) {
193 scalar_t z = 1 - input_target + input_data[d];
194 if (z > 0) {
195 grad_input_row_data[target_idx] -= g;
196 grad_input_row_data[d] += g;
197 }
198 }
199 }
200 }
201 input_data += dim;
202 target_data += dim;
203 is_target_data += dim;
204 grad_input_row_data += dim;
205 }
206
207 scalar_t* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
208 if (reduction != Reduction::None || grad_output.dim() == 0) {
209 assert(
210 reduction != Reduction::None || grad_output.dim() > 0 || nframe == 1);
211 const auto d = *grad_output.const_data_ptr<scalar_t>();
212 for (int64_t t = 0; t < nframe * dim; t++) {
213 grad_input_data[t] *= d;
214 }
215 } else {
216 check_dim_size(grad_output, 1, 0, nframe);
217 auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
218 for (const auto t : c10::irange(nframe)) {
219 for (const auto d : c10::irange(dim)) {
220 grad_input_data[t * dim + d] *= grad_output_acc[t];
221 }
222 }
223 }
224 }
225
multilabel_margin_loss_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,const Tensor & is_target)226 static void multilabel_margin_loss_backward_out_cpu_template(
227 Tensor& grad_input,
228 const Tensor& grad_output,
229 const Tensor& input,
230 const Tensor& target,
231 int64_t reduction,
232 const Tensor& is_target) {
233 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
234 int64_t nframe, dim;
235 CheckedFrom c = "multilabel_margin_loss_backward_cpu_template";
236 auto target_arg = TensorArg(target, "target", 3);
237 auto is_target_arg = TensorArg(is_target, "is_target", 5);
238 const int64_t ndims = input.dim();
239
240 multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
241 checkSameSize(c, target_arg, is_target_arg);
242
243 grad_input.resize_as_(input);
244 if (grad_input.numel() == 0) {
245 return;
246 }
247
248 TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
249 grad_input.zero_();
250
251 TORCH_CHECK(
252 target.min().item<int64_t>() >= -1, target_arg, " is out of range");
253 TORCH_CHECK(
254 target.max().item<int64_t>() < dim, target_arg, " is out of range");
255
256 auto input_contiguous = input.contiguous();
257 auto target_contiguous = target.contiguous();
258 auto is_target_contiguous = is_target.contiguous();
259
260 AT_DISPATCH_FLOATING_TYPES(
261 input.scalar_type(), "multilabel_margin_loss_backward_out_frame", [&] {
262 multilabel_margin_loss_backward_out_frame<scalar_t>(
263 grad_input,
264 grad_output,
265 input_contiguous,
266 target_contiguous,
267 reduction,
268 is_target_contiguous,
269 nframe,
270 dim);
271 });
272 }
273
274 } // namespace
275
multilabel_margin_loss_forward_out_cpu(const Tensor & self,const Tensor & target,int64_t reduction,Tensor & output,Tensor & is_target)276 std::tuple<Tensor&, Tensor&> multilabel_margin_loss_forward_out_cpu(const Tensor& self,
277 const Tensor& target,
278 int64_t reduction,
279 Tensor& output,
280 Tensor& is_target) {
281 multilabel_margin_loss_forward_out_cpu_template(
282 self, target, output, is_target, reduction);
283 return std::tuple<Tensor&, Tensor&>(output, is_target);
284 }
285
multilabel_margin_loss_forward_cpu(const Tensor & self,const Tensor & target,int64_t reduction)286 std::tuple<Tensor, Tensor> multilabel_margin_loss_forward_cpu(
287 const Tensor& self,
288 const Tensor& target,
289 int64_t reduction) {
290 auto output = at::empty({0}, self.options());
291 auto is_target = at::empty({0}, self.options());
292 at::native::multilabel_margin_loss_forward_out_cpu(
293 self, target, reduction, output, is_target);
294 return std::make_tuple(output, is_target);
295 }
296
multilabel_margin_loss_backward_cpu_out(const Tensor & grad_output,const Tensor & self,const Tensor & target,int64_t reduction,const Tensor & is_target,Tensor & grad_input)297 Tensor& multilabel_margin_loss_backward_cpu_out(const Tensor& grad_output,
298 const Tensor& self,
299 const Tensor& target,
300 int64_t reduction,
301 const Tensor& is_target,
302 Tensor& grad_input) {
303 multilabel_margin_loss_backward_out_cpu_template(
304 grad_input, grad_output, self, target, reduction, is_target);
305 return grad_input;
306 }
307
multilabel_margin_loss_backward_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & target,int64_t reduction,const Tensor & is_target)308 Tensor multilabel_margin_loss_backward_cpu(
309 const Tensor& grad_output,
310 const Tensor& self,
311 const Tensor& target,
312 int64_t reduction,
313 const Tensor& is_target) {
314 auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
315 at::native::multilabel_margin_loss_backward_cpu_out(
316 grad_output, self, target, reduction, is_target, grad_input);
317 return grad_input;
318 }
319
multilabel_margin_loss_out(const Tensor & self,const Tensor & target,int64_t reduction,Tensor & output)320 Tensor & multilabel_margin_loss_out(const Tensor & self, const Tensor & target, int64_t reduction, Tensor & output) {
321 Tensor is_target = at::empty({0}, self.options());
322 return std::get<0>(at::multilabel_margin_loss_forward_out(output, is_target, self, target, reduction));
323 }
324
multilabel_margin_loss(const Tensor & self,const Tensor & target,int64_t reduction)325 Tensor multilabel_margin_loss(const Tensor & self, const Tensor & target, int64_t reduction) {
326 return std::get<0>(at::multilabel_margin_loss_forward(self, target, reduction));
327 }
328
329 } // namespace at::native
330