1 // Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
2 // Licensed under the BSD-3-Clause license
3 // This is the CPU implementation of the Connectionist Temporal Loss.
4 // We mostly follow Graves.
5 // 1. Graves et al.: http://www.cs.toronto.edu/~graves/icml_2006.pdf
6 // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based.
7 // Graves et al. call the probabilities y, we use log_probs (also calling them inputs)
8 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
9
10 #include <ATen/core/Tensor.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/Parallel.h>
13 #include <ATen/TensorIterator.h>
14 #include <ATen/TensorOperators.h>
15 #include <ATen/native/Fill.h>
16 #include <c10/util/irange.h>
17 #include <ATen/TensorSubclassLikeUtils.h>
18
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #include <ATen/NativeFunctions.h>
22 #else
23 #include <ATen/ops/_ctc_loss.h>
24 #include <ATen/ops/_ctc_loss_backward.h>
25 #include <ATen/ops/_ctc_loss_backward_native.h>
26 #include <ATen/ops/_ctc_loss_native.h>
27 #include <ATen/ops/_cudnn_ctc_loss.h>
28 #include <ATen/ops/_use_cudnn_ctc_loss.h>
29 #include <ATen/ops/ctc_loss_native.h>
30 #include <ATen/ops/empty.h>
31 #include <ATen/ops/empty_like.h>
32 #include <ATen/ops/full_like.h>
33 #include <ATen/ops/tensor.h>
34 #include <ATen/ops/where.h>
35 #include <ATen/ops/zeros.h>
36 #endif
37
38 #include <type_traits>
39 #include <utility>
40
41 namespace at::native {
42
43 namespace {
44
45 // this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done
46 template<typename target_t>
get_target_prime(target_t * target,int64_t offset,int64_t stride,int64_t idx,int64_t BLANK)47 static inline int64_t get_target_prime(target_t* target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) {
48 if (idx % 2 == 0) {
49 return BLANK;
50 } else {
51 return target[offset + stride * (idx / 2)];
52 }
53 }
54
55 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_allocate_outputs(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)56 std::tuple<Tensor, Tensor, size_t, std::vector<int64_t>> ctc_loss_allocate_outputs(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
57 // log_probs: input_len x batch_size x num_labels
58 // targets [int64]: batch_size x target_length OR sum(target_lengths)
59
60 CheckedFrom c = "ctc_loss_allocate_outputs";
61 auto log_probs_arg = TensorArg(log_probs, "log_probs", 1);
62 auto targets_arg = TensorArg(targets, "targets", 2);
63 checkScalarType(c, targets_arg, target_scalar_type);
64 checkDim(c, log_probs_arg, 3);
65 checkDimRange(c, targets_arg, 1, 3);
66
67 int64_t batch_size = log_probs.size(1);
68 int64_t num_labels = log_probs.size(2);
69 TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range");
70 TORCH_CHECK((int64_t) input_lengths.size() == batch_size, "input_lengths must be of size batch_size");
71 TORCH_CHECK((int64_t) target_lengths.size() == batch_size, "target_lengths must be of size batch_size");
72
73 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
74 size_t tg_target_stride;
75 int64_t max_target_length = 0;
76 std::vector<int64_t> tg_batch_offsets(batch_size);
77 if (targets.dim() == 1) { // concatenated targets
78 int64_t pos = 0;
79 for (const auto i : c10::irange(batch_size)) {
80 TORCH_CHECK(target_lengths[i] >= 0,
81 "Expected target_lengths to have value at least ", 0, ", but got value ", target_lengths[i],
82 " (while checking arguments for ", c, ")");
83 tg_batch_offsets[i] = pos;
84 pos += target_lengths[i];
85 if (max_target_length < target_lengths[i])
86 max_target_length = target_lengths[i];
87 }
88 tg_target_stride = targets.stride(0);
89 checkSize(c, targets_arg, 0, pos);
90 }
91 else { // batch x max_target_length
92 // dim is 2
93 int64_t tg_batch_stride = targets.stride(0);
94 for (const auto i : c10::irange(batch_size)) {
95 TORCH_CHECK(target_lengths[i] >= 0,
96 "Expected target_lengths to have value at least ", 0, ", but got value ", target_lengths[i],
97 " (while checking arguments for ", c, ")");
98 tg_batch_offsets[i] = i * tg_batch_stride;
99 if (max_target_length < target_lengths[i])
100 max_target_length = target_lengths[i];
101 }
102 tg_target_stride = targets.stride(1);
103 checkSize(c, targets_arg, 0, batch_size);
104 TORCH_CHECK(targets.size(1) >= max_target_length,
105 "Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg,
106 " (while checking arguments for ", c, ")");
107 }
108 int64_t max_input_length = log_probs.size(0);
109 for (const auto b : c10::irange(batch_size)) {
110 TORCH_CHECK(input_lengths[b] >= 0,
111 "Expected input_lengths to have value at least ", 0, ", but got value ", input_lengths[b],
112 " (while checking arguments for ", c, ")");
113 TORCH_CHECK(input_lengths[b] <= max_input_length,
114 "Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b],
115 " (while checking arguments for ", c, ")");
116 }
117
118 Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2*max_target_length+1}, log_probs.options());
119 Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
120
121 return std::make_tuple(neg_log_likelihood, log_alpha, tg_target_stride, tg_batch_offsets);
122 }
123
124 // This kernel is a relatively straightforward implementation of the alpha calculation in the forward backward algorithm (section 4.1).
125 // A (minor) twist is that we are using log-calculations to enhance numerical stability (log_probs and log_alpha).
126 // The function returns the loss and the alphas, the alphas are kept for the backward step. The wrapper (ctc_loss below) hides
127 // the alphas from the user by only returning the loss.
128 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_cpu_template(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)129 std::tuple<Tensor, Tensor> ctc_loss_cpu_template(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
130 // log_probs: input_len x batch_size x num_labels
131 // targets [int64]: batch_size x target_length OR sum(target_lengths)
132 constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
133 using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
134
135 Tensor neg_log_likelihood, log_alpha;
136 size_t tg_target_stride;
137 std::vector<int64_t> tg_batch_offsets;
138
139 if (targets.scalar_type() == kLong) {
140 std::tie(neg_log_likelihood, log_alpha,tg_target_stride, tg_batch_offsets) =
141 ctc_loss_allocate_outputs<scalar_t, kLong>(
142 log_probs, targets, input_lengths, target_lengths, BLANK);
143 } else {
144 std::tie(neg_log_likelihood, log_alpha, tg_target_stride, tg_batch_offsets) =
145 ctc_loss_allocate_outputs<scalar_t, kInt>(
146 log_probs, targets, input_lengths, target_lengths, BLANK);
147 }
148
149 int64_t batch_size = log_probs.size(1);
150 auto lpp = log_probs.permute({1,0,2});
151 auto log_probs_a_global = lpp.accessor<const scalar_t, 3>();
152 auto log_alpha_a_global = log_alpha.accessor<scalar_t, 3>();
153 auto targets_data = targets.const_data_ptr<target_t>();
154 auto neg_log_likelihood_a = neg_log_likelihood.accessor<scalar_t, 1>();
155
156 // alpha calculation for the first row, the three equations for alpha_1 above eq (6)
157 // first the default
158 log_alpha.narrow(1, 0, 1).fill_(neginf);
159 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
160 for (const auto b : c10::irange(start, end)) {
161 int64_t input_length = input_lengths[b];
162 int64_t target_length = target_lengths[b];
163 auto log_probs_a = log_probs_a_global[b];
164 auto log_alpha_a = log_alpha_a_global[b];
165 int64_t tg_batch_offset = tg_batch_offsets[b];
166
167 if (input_length == 0) {
168 scalar_t log_likelihood = target_length == 0 ? 0 : neginf;
169 neg_log_likelihood_a[b] = -log_likelihood;
170 continue;
171 }
172
173 // the first two items of alpha_t above eq (6)
174 log_alpha_a[0][0] = log_probs_a[0][BLANK];
175 if (target_length > 0)
176 log_alpha_a[0][1] = log_probs_a[0][get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];
177
178 // now the loop over the inputs
179 for (const auto t : c10::irange(1, input_length)) {
180 for (const auto s : c10::irange(2*target_length+1)) {
181 auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
182 // this loop over s could be parallel/vectorized, too, but the required items are one index apart
183 // alternatively, one might consider moving s to the outer loop to cache current_target_prime more (but then it needs to be descending)
184 // for the cuda implementation, that gave a speed boost.
185 // This is eq (6) and (7), la1,2,3 are the three summands. We keep track of the maximum for the logsumexp calculation.
186
187 scalar_t la1 = log_alpha_a[t-1][s];
188 scalar_t lamax = la1;
189 scalar_t la2, la3;
190 if (s > 0) {
191 la2 = log_alpha_a[t-1][s-1];
192 if (la2 > lamax)
193 lamax = la2;
194 } else {
195 la2 = neginf;
196 }
197 if ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) !=
198 current_target_prime)) {
199 la3 = log_alpha_a[t-1][s-2];
200 if (la3 > lamax)
201 lamax = la3;
202 } else {
203 la3 = neginf;
204 }
205 if (lamax == neginf) // cannot do neginf-neginf
206 lamax = 0;
207 // this is the assignment of eq (6)
208 log_alpha_a[t][s] = std::log(std::exp(la1-lamax)+std::exp(la2-lamax)+std::exp(la3-lamax))+lamax + log_probs_a[t][current_target_prime];
209 }
210 }
211 // the likelihood is the sum of the last two alphas, eq (8), the loss is the negative log likelihood
212 if (target_length == 0) {
213 // if the target is empty then there is no preceding BLANK state and hence there is no path to merge
214 neg_log_likelihood_a[b] = -log_alpha_a[input_length-1][0];
215 } else {
216 scalar_t l1 = log_alpha_a[input_length-1][target_length*2];
217 scalar_t l2 = log_alpha_a[input_length-1][target_length*2-1];
218 scalar_t m = std::max(l1, l2);
219 m = ((m == neginf) ? 0 : m);
220 scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
221 neg_log_likelihood_a[b] = -log_likelihood;
222 }
223 }
224 });
225
226 return std::make_tuple(neg_log_likelihood, log_alpha);
227 }
228
229 // This is the backward. It consists of two phases:
230 // a) computing the beta analogous to the alphas in the forward (backward half of the forward-backward algorithm) (eq (10) and (11))
231 // b) collecting the per-activation characters for all s and wrapping the gradient (eq (16), the collection is the sum)
232 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_backward_cpu_template(const Tensor & grad_out,const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)233 Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
234 const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
235 constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
236 using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
237 int64_t max_input_length = log_probs.size(0);
238 int64_t batch_size = log_probs.size(1);
239 int64_t num_labels = log_probs.size(2);
240 Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // at this point, this is log of empty sum
241
242 // The admin bits. We don't do much checking and assume that the forward did.
243 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
244 int64_t tg_target_stride;
245 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
246 int64_t max_target_length;
247 std::vector<int64_t> tg_batch_offsets(batch_size);
248
249 if (targets.dim() == 1) { // concatenated targets
250 int64_t pos = 0;
251 max_target_length = 0;
252 for (const auto i : c10::irange(batch_size)) {
253 tg_batch_offsets[i] = pos;
254 pos += target_lengths[i];
255 if (max_target_length < target_lengths[i])
256 max_target_length = target_lengths[i];
257 }
258 tg_target_stride = targets.stride(0);
259 }
260 else { // batch x max_target_length
261 // dim is 2
262 int64_t tg_batch_stride = targets.stride(0);
263 for (const auto i : c10::irange(batch_size)) {
264 tg_batch_offsets[i] = i * tg_batch_stride;
265 }
266 tg_target_stride = targets.stride(1);
267 max_target_length = targets.size(1);
268 }
269
270 Tensor log_beta = at::empty_like(log_alpha, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // could be optimized to use only 2 rows
271 auto lpp = log_probs.permute({1,0,2});
272 auto log_probs_a_global = lpp.accessor<const scalar_t, 3>();
273 auto log_alpha_a_global = log_alpha.accessor<const scalar_t, 3>();
274 auto log_beta_a_global = log_beta.accessor<scalar_t, 3>();
275 auto gp = grad.permute({1,0,2});
276 auto grad_a_global = gp.accessor<scalar_t, 3>();
277 auto targets_data = targets.const_data_ptr<target_t>();
278 auto grad_out_a = grad_out.accessor<const scalar_t, 1>();
279
280 auto create_fill_iterator = [](const Tensor& tensor, IntArrayRef squash_dims) {
281 return TensorIteratorConfig()
282 .set_check_mem_overlap(false) // Fill is idempotent, so overlap is okay
283 .check_all_same_dtype(false)
284 .add_output(tensor)
285 .resize_outputs(false)
286 .declare_static_shape(tensor.sizes(), squash_dims)
287 .build();
288 };
289 const auto fill_iter = create_fill_iterator(grad, /*squash_dims=*/1);
290 const auto fill_1d_iter = create_fill_iterator(grad, /*squash_dims=*/{0, 1});
291 const auto fill_log_beta_1d_iter = create_fill_iterator(log_beta, /*squash_dims=*/{0, 1});
292
293 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
294 TensorIterator fill_iter_local(fill_iter);
295 TensorIterator fill_1d_iter_local(fill_1d_iter);
296 TensorIterator fill_log_beta_1d_iter_local(fill_log_beta_1d_iter);
297
298 for (const auto b : c10::irange(start, end)) {
299 scalar_t nll = neg_log_likelihood.accessor<scalar_t, 1>()[b];
300 auto grad_a = grad_a_global[b];
301 if (zero_infinity && nll == std::numeric_limits<scalar_t>::infinity()) {
302 // grad_batch.zero_();
303 fill_iter_local.unsafe_replace_operand(0, grad_a.data());
304 fill_stub(kCPU, fill_iter_local, 0);
305 continue;
306 }
307
308 auto log_probs_a = log_probs_a_global[b];
309 auto log_alpha_a = log_alpha_a_global[b];
310 auto log_beta_a = log_beta_a_global[b];
311 int64_t input_length = input_lengths[b];
312 int64_t target_length = target_lengths[b];
313 int64_t tg_batch_offset = tg_batch_offsets[b];
314
315 // the initialization of beta before eq (10)
316 // here we do the fill for each batch item separately, as the input lengths will differ, so the t in which
317 // we start varies
318 if (input_length > 0) {
319 // log_beta.select(0, b).select(1, input_length-1).fill_(neginf);
320 fill_log_beta_1d_iter_local.unsafe_replace_operand(
321 0, log_beta_a[input_length - 1].data());
322 fill_stub(kCPU, fill_log_beta_1d_iter_local, neginf);
323
324 log_beta_a[input_length-1][2*target_length] = log_probs_a[input_length-1][BLANK];
325 grad_a[input_length-1][BLANK] = log_alpha_a[input_length-1][2*target_length] + log_beta_a[input_length-1][2*target_length];
326
327 if (target_length > 0) {
328 auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 2*target_length-1, BLANK);
329 log_beta_a[input_length-1][2*target_length-1] = log_probs_a[input_length-1][current_target_prime];
330
331 // the first two are a blank and a non-blank, so we know they are different and we don't need to do log+
332 grad_a[input_length-1][current_target_prime] = log_alpha_a[input_length-1][2*target_length-1] + log_beta_a[input_length-1][2*target_length-1];
333 }
334 }
335
336 // now loop applying eq (10) / (11)
337 for (int64_t t=input_length-2; t>=0; t--) {
338 // this loop over s could be parallel/vectorized and doesn't really need to be descending...
339 // alternatively, one might consider moving s to the outer loop to cache current_target_prime more (but then it needs to be descending)
340 // for the cuda implementation, that gave a speed boost.
341 for (int64_t s=2*target_length; s>=0; s--) {
342 scalar_t lb1 = log_beta_a[t+1][s];
343 scalar_t lbmax = lb1;
344 scalar_t lb2, lb3;
345 auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
346 if (s < 2*target_length) {
347 lb2 = log_beta_a[t+1][s+1];
348 if (lb2 > lbmax)
349 lbmax = lb2;
350 } else {
351 lb2 = neginf;
352 }
353 if ((s < 2*target_length-1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s+2, BLANK) !=
354 current_target_prime)) {
355 lb3 = log_beta_a[t+1][s+2];
356 if (lb3 > lbmax)
357 lbmax = lb3;
358 } else {
359 lb3 = neginf;
360 }
361 if (lbmax == neginf)
362 lbmax = 0;
363
364 log_beta_a[t][s] = std::log(std::exp(lb1-lbmax)+std::exp(lb2-lbmax)+std::exp(lb3-lbmax))+lbmax + log_probs_a[t][current_target_prime];
365 // one might check whether one can vectorize this better when done after the t-loop...
366 // now that we have beta, we fill in the sum of alpha*beta in eq (16)
367 // in contrast to the cuda implementation, we only parallelize over the batch, so we don't have a concurrency
368 // issue (several s can map to the same target character)
369 // collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
370 scalar_t log_alpha_beta = log_alpha_a[t][s] + log_beta_a[t][s];
371 scalar_t &lcab = grad_a[t][current_target_prime];
372 if (lcab == neginf) {
373 lcab = log_alpha_beta;
374 } else {
375 scalar_t max = std::max(lcab, log_alpha_beta);
376 lcab = std::log(std::exp(lcab-max)+std::exp(log_alpha_beta-max))+max;
377 }
378 }
379 }
380
381 // now grad has the sum of eq (16)
382 // now we wrap up the calculation by adding in the remaining items of eq (16)
383 // this could be a great target for further vectorization.
384 // grad is the output gradient, nll is the loss. Note that the likelihood -nll is the Z of eq (16)
385 scalar_t gr = grad_out_a[b];
386 for (const auto t : c10::irange(input_length)) { // or go for the full thing?
387 for (const auto c : c10::irange(num_labels)) {
388 scalar_t& res = grad_a[t][c];
389 scalar_t lp = log_probs_a[t][c];
390 res = (std::exp(lp)-std::exp(res + nll - lp)) * gr;
391 }
392 }
393
394 // zero the remainder
395 for (auto l : c10::irange(input_length, max_input_length)) {
396 // grad_batch.select(0, l).zero_();
397 fill_1d_iter_local.unsafe_replace_operand(0, grad_a[l].data());
398 fill_stub(kCPU, fill_1d_iter_local, 0);
399 }
400 }
401 });
402 return grad;
403 }
404
405 } // namespace
406
ctc_loss_meta(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,bool zero_infinity)407 std::tuple<Tensor, Tensor> ctc_loss_meta(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
408 (void)zero_infinity; // only used for backwards
409 return AT_DISPATCH_FLOATING_TYPES(
410 log_probs.scalar_type(), "ctc_loss_meta", [&] {
411 Tensor neg_log_likelihood, log_alpha;
412 if (targets.scalar_type() == kLong) {
413 std::tie(neg_log_likelihood, log_alpha, std::ignore, std::ignore) = ctc_loss_allocate_outputs<scalar_t, kLong>(
414 log_probs, targets, input_lengths, target_lengths, BLANK);
415 } else {
416 std::tie(neg_log_likelihood, log_alpha, std::ignore, std::ignore) = ctc_loss_allocate_outputs<scalar_t, kInt>(
417 log_probs, targets, input_lengths, target_lengths, BLANK);
418 }
419 return std::make_tuple(neg_log_likelihood, log_alpha);
420 });
421 }
422
ctc_loss_cpu(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,bool zero_infinity)423 std::tuple<Tensor, Tensor> ctc_loss_cpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
424 (void)zero_infinity; // only used for backwards
425 return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_cpu", [&] {
426 if (targets.scalar_type() == kLong) {
427 return ctc_loss_cpu_template<scalar_t, kLong>(log_probs, targets, input_lengths, target_lengths, BLANK);
428 } else {
429 return ctc_loss_cpu_template<scalar_t, kInt>(log_probs, targets, input_lengths, target_lengths, BLANK);
430 }
431 });
432 }
433
434
ctc_loss_tensor(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK,bool zero_infinity)435 std::tuple<Tensor, Tensor> ctc_loss_tensor(const Tensor& log_probs, const Tensor& targets, const Tensor& input_lengths, const Tensor& target_lengths, int64_t BLANK, bool zero_infinity) {
436 TORCH_CHECK(isIntegralType(input_lengths.scalar_type(), /*includeBool=*/false), "input_lengths must be integral");
437 TORCH_CHECK(isIntegralType(target_lengths.scalar_type(), /*includeBool=*/false), "target_lengths must be integral");
438
439 Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
440 Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
441 IntArrayRef il(ilc.const_data_ptr<int64_t>(), ilc.numel());
442 IntArrayRef tl(tlc.const_data_ptr<int64_t>(), tlc.numel());
443
444 return at::_ctc_loss(log_probs, targets, il, tl, BLANK, zero_infinity);
445 }
446
ctc_loss_backward_cpu(const Tensor & grad,const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)447 Tensor ctc_loss_backward_cpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
448 const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
449 return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_backward_cpu", [&] {
450 if (targets.scalar_type() == kLong) {
451 return ctc_loss_backward_cpu_template<scalar_t,kLong>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
452 } else {
453 return ctc_loss_backward_cpu_template<scalar_t,kInt>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
454 }
455 });
456 }
457
ctc_loss_backward_tensor(const Tensor & grad,const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)458 Tensor ctc_loss_backward_tensor(
459 const Tensor& grad,
460 const Tensor& log_probs,
461 const Tensor& targets,
462 const Tensor& input_lengths,
463 const Tensor& target_lengths,
464 const Tensor& neg_log_likelihood,
465 const Tensor& log_alpha,
466 int64_t BLANK,
467 bool zero_infinity) {
468 TORCH_CHECK(
469 isIntegralType(input_lengths.scalar_type(), /*includeBool=*/false),
470 "input_lengths must be integral");
471 TORCH_CHECK(isIntegralType(target_lengths.scalar_type(), /*includeBool=*/false), "target_lengths must be integral");
472
473 Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
474 Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
475 IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
476 IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
477 return at::_ctc_loss_backward(grad, log_probs, targets, il, tl, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
478 }
479
480 namespace {
481
get_clamped_target_length(IntArrayRef target_lengths,const TensorOptions & options)482 Tensor get_clamped_target_length(
483 IntArrayRef target_lengths,
484 const TensorOptions& options) {
485 return at::tensor(target_lengths, options).clamp_min(1);
486 }
487
get_clamped_target_length(const Tensor & target_lengths,const TensorOptions & options)488 Tensor get_clamped_target_length(
489 const Tensor & target_lengths,
490 const TensorOptions& options) {
491 return target_lengths.clamp_min(1);
492 }
493
494 // this wrapper function dispatches to the native and cudnn implementations and hides the alpha/grad from the user (by just returning the loss)
495 // the gradient is implemented for _cudnn_ctc_loss (just in derivatives.yaml) and _ctc_loss and this function has automatic gradients
496 // it also handles the reduction if desired
497 template <typename LengthsType>
ctc_loss_impl(const Tensor & log_probs_,const Tensor & targets,LengthsType input_lengths,LengthsType target_lengths,int64_t BLANK,int64_t reduction,bool zero_infinity)498 Tensor ctc_loss_impl(const Tensor& log_probs_, const Tensor& targets, LengthsType input_lengths, LengthsType target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) {
499 auto is_batched = log_probs_.dim() == 3;
500 Tensor log_probs = is_batched ? log_probs_ : log_probs_.unsqueeze(1);
501 bool use_cudnn =
502 (log_probs.device().type() == at::kCUDA) &&
503 at::_use_cudnn_ctc_loss(
504 log_probs, targets, input_lengths, target_lengths, BLANK);
505
506 Tensor res;
507 if (use_cudnn) {
508 // non-deterministic ctc loss on cudnn disabled due to inconsistent results
509 // see: https://github.com/pytorch/pytorch/issues/21680
510 res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, /*deterministic=*/true, zero_infinity));
511 } else {
512 // if the targets are on CPU (which you need for CuDNN, let's move them to
513 // GPU as a service for the user)
514 res = std::get<0>(at::_ctc_loss(
515 log_probs,
516 targets.to(log_probs.device(), kLong),
517 input_lengths,
518 target_lengths,
519 BLANK,
520 zero_infinity));
521 if (zero_infinity) {
522 res = at::where(res == Scalar(std::numeric_limits<double>::infinity()), at::zeros({}, res.options()), res);
523 }
524 }
525 if (reduction == at::Reduction::Mean) {
526 auto target_lengths_t = get_clamped_target_length(target_lengths, res.options());
527 return (res / target_lengths_t).mean();
528 } else if (reduction == at::Reduction::Sum) {
529 return res.sum();
530 }
531 return is_batched ? std::move(res) : res.squeeze(0);
532 }
533
534 } // namespace
535
ctc_loss(const Tensor & log_probs_,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,int64_t reduction,bool zero_infinity)536 Tensor ctc_loss(const Tensor& log_probs_, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) {
537 return ctc_loss_impl(log_probs_, targets, input_lengths, target_lengths, BLANK, reduction, zero_infinity);
538 }
539
540 // Convenience function accepting Tensors
ctc_loss(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK,int64_t reduction,bool zero_infinity)541 Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, const Tensor& input_lengths, const Tensor& target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) {
542 // we don't want to convert to IntArrayRef if we can dispatch to cuDNN (this allows graph-capturable ctc_loss)
543 bool use_cudnn =
544 (log_probs.device().type() == at::kCUDA) &&
545 at::_use_cudnn_ctc_loss(
546 log_probs, targets, input_lengths, target_lengths, BLANK);
547 if (at::areAnyTensorSubclassLike(
548 {log_probs, targets, input_lengths, target_lengths}) || use_cudnn) {
549 // Composite Compliant path for TensorSubclasses
550 return ctc_loss_impl(log_probs, targets, input_lengths, target_lengths, BLANK, reduction, zero_infinity);
551 }
552 // Fast path (which accesses data_ptr) and less operator dispatches for
553 // regular tensors
554 TORCH_CHECK(isIntegralType(input_lengths.scalar_type(), /*includeBool=*/false), "input_lengths must be integral");
555 TORCH_CHECK(isIntegralType(target_lengths.scalar_type(), /*includeBool=*/false), "target_lengths must be integral");
556
557 Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
558 Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
559 IntArrayRef il(ilc.const_data_ptr<int64_t>(), ilc.numel());
560 IntArrayRef tl(tlc.const_data_ptr<int64_t>(), tlc.numel());
561 return at::native::ctc_loss(log_probs, targets, il, tl, BLANK, reduction, zero_infinity);
562 }
563
564 } // at::native
565