xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/LossCTC.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
2 // Licensed under the BSD-3-Clause license
3 // This is the CPU implementation of the Connectionist Temporal Loss.
4 // We mostly follow Graves.
5 // 1. Graves et al.: http://www.cs.toronto.edu/~graves/icml_2006.pdf
6 // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based.
7 // Graves et al. call the probabilities y, we use log_probs (also calling them inputs)
8 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
9 
10 #include <ATen/core/Tensor.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/Parallel.h>
13 #include <ATen/TensorIterator.h>
14 #include <ATen/TensorOperators.h>
15 #include <ATen/native/Fill.h>
16 #include <c10/util/irange.h>
17 #include <ATen/TensorSubclassLikeUtils.h>
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #include <ATen/NativeFunctions.h>
22 #else
23 #include <ATen/ops/_ctc_loss.h>
24 #include <ATen/ops/_ctc_loss_backward.h>
25 #include <ATen/ops/_ctc_loss_backward_native.h>
26 #include <ATen/ops/_ctc_loss_native.h>
27 #include <ATen/ops/_cudnn_ctc_loss.h>
28 #include <ATen/ops/_use_cudnn_ctc_loss.h>
29 #include <ATen/ops/ctc_loss_native.h>
30 #include <ATen/ops/empty.h>
31 #include <ATen/ops/empty_like.h>
32 #include <ATen/ops/full_like.h>
33 #include <ATen/ops/tensor.h>
34 #include <ATen/ops/where.h>
35 #include <ATen/ops/zeros.h>
36 #endif
37 
38 #include <type_traits>
39 #include <utility>
40 
41 namespace at::native {
42 
43 namespace {
44 
45 // this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1]) note that no bound-checking is done
46 template<typename target_t>
get_target_prime(target_t * target,int64_t offset,int64_t stride,int64_t idx,int64_t BLANK)47 static inline int64_t get_target_prime(target_t* target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK) {
48   if (idx % 2 == 0) {
49     return BLANK;
50   } else {
51     return target[offset + stride * (idx / 2)];
52   }
53 }
54 
55 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_allocate_outputs(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)56 std::tuple<Tensor, Tensor, size_t, std::vector<int64_t>> ctc_loss_allocate_outputs(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
57   // log_probs: input_len x batch_size x num_labels
58   // targets [int64]: batch_size x target_length OR sum(target_lengths)
59 
60   CheckedFrom c = "ctc_loss_allocate_outputs";
61   auto log_probs_arg = TensorArg(log_probs, "log_probs", 1);
62   auto targets_arg = TensorArg(targets, "targets", 2);
63   checkScalarType(c, targets_arg, target_scalar_type);
64   checkDim(c, log_probs_arg, 3);
65   checkDimRange(c, targets_arg, 1, 3);
66 
67   int64_t batch_size = log_probs.size(1);
68   int64_t num_labels = log_probs.size(2);
69   TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range");
70   TORCH_CHECK((int64_t) input_lengths.size() == batch_size, "input_lengths must be of size batch_size");
71   TORCH_CHECK((int64_t) target_lengths.size() == batch_size, "target_lengths must be of size batch_size");
72 
73   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
74   size_t tg_target_stride;
75   int64_t max_target_length = 0;
76   std::vector<int64_t> tg_batch_offsets(batch_size);
77   if (targets.dim() == 1) { // concatenated targets
78     int64_t pos = 0;
79     for (const auto i : c10::irange(batch_size)) {
80       TORCH_CHECK(target_lengths[i] >= 0,
81                   "Expected target_lengths to have value at least ", 0, ", but got value ", target_lengths[i],
82                   " (while checking arguments for ", c, ")");
83       tg_batch_offsets[i] = pos;
84       pos += target_lengths[i];
85       if (max_target_length < target_lengths[i])
86          max_target_length = target_lengths[i];
87     }
88     tg_target_stride = targets.stride(0);
89     checkSize(c, targets_arg, 0, pos);
90   }
91   else { // batch x max_target_length
92     // dim is 2
93     int64_t tg_batch_stride = targets.stride(0);
94     for (const auto i : c10::irange(batch_size)) {
95       TORCH_CHECK(target_lengths[i] >= 0,
96                   "Expected target_lengths to have value at least ", 0, ", but got value ", target_lengths[i],
97                   " (while checking arguments for ", c, ")");
98       tg_batch_offsets[i] = i * tg_batch_stride;
99       if (max_target_length < target_lengths[i])
100         max_target_length = target_lengths[i];
101     }
102     tg_target_stride = targets.stride(1);
103     checkSize(c, targets_arg, 0, batch_size);
104     TORCH_CHECK(targets.size(1) >= max_target_length,
105              "Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg,
106              " (while checking arguments for ", c, ")");
107   }
108   int64_t max_input_length = log_probs.size(0);
109   for (const auto b : c10::irange(batch_size)) {
110     TORCH_CHECK(input_lengths[b] >= 0,
111              "Expected input_lengths to have value at least ", 0, ", but got value ", input_lengths[b],
112              " (while checking arguments for ", c, ")");
113     TORCH_CHECK(input_lengths[b] <= max_input_length,
114              "Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b],
115              " (while checking arguments for ", c, ")");
116   }
117 
118   Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2*max_target_length+1}, log_probs.options());
119   Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
120 
121   return std::make_tuple(neg_log_likelihood, log_alpha, tg_target_stride, tg_batch_offsets);
122 }
123 
124 // This kernel is a relatively straightforward implementation of the alpha calculation in the forward backward algorithm (section 4.1).
125 // A (minor) twist is that we are using log-calculations to enhance numerical stability (log_probs and log_alpha).
126 // The function returns the loss and the alphas, the alphas are kept for the backward step. The wrapper (ctc_loss below) hides
127 // the alphas from the user by only returning the loss.
128 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_cpu_template(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)129 std::tuple<Tensor, Tensor> ctc_loss_cpu_template(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
130   // log_probs: input_len x batch_size x num_labels
131   // targets [int64]: batch_size x target_length OR sum(target_lengths)
132   constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
133   using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
134 
135   Tensor neg_log_likelihood, log_alpha;
136   size_t tg_target_stride;
137   std::vector<int64_t> tg_batch_offsets;
138 
139   if (targets.scalar_type() == kLong) {
140     std::tie(neg_log_likelihood, log_alpha,tg_target_stride, tg_batch_offsets) =
141         ctc_loss_allocate_outputs<scalar_t, kLong>(
142             log_probs, targets, input_lengths, target_lengths, BLANK);
143   } else {
144     std::tie(neg_log_likelihood, log_alpha, tg_target_stride, tg_batch_offsets) =
145         ctc_loss_allocate_outputs<scalar_t, kInt>(
146             log_probs, targets, input_lengths, target_lengths, BLANK);
147   }
148 
149   int64_t batch_size = log_probs.size(1);
150   auto lpp  = log_probs.permute({1,0,2});
151   auto log_probs_a_global = lpp.accessor<const scalar_t, 3>();
152   auto log_alpha_a_global = log_alpha.accessor<scalar_t, 3>();
153   auto targets_data = targets.const_data_ptr<target_t>();
154   auto neg_log_likelihood_a = neg_log_likelihood.accessor<scalar_t, 1>();
155 
156   // alpha calculation for the first row, the three equations for alpha_1 above eq (6)
157   // first the default
158   log_alpha.narrow(1, 0, 1).fill_(neginf);
159   at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
160     for (const auto b : c10::irange(start, end)) {
161       int64_t input_length = input_lengths[b];
162       int64_t target_length = target_lengths[b];
163       auto log_probs_a = log_probs_a_global[b];
164       auto log_alpha_a = log_alpha_a_global[b];
165       int64_t tg_batch_offset = tg_batch_offsets[b];
166 
167       if (input_length == 0) {
168         scalar_t log_likelihood = target_length == 0 ? 0 : neginf;
169         neg_log_likelihood_a[b] = -log_likelihood;
170         continue;
171       }
172 
173       // the first two items of alpha_t above eq (6)
174       log_alpha_a[0][0] = log_probs_a[0][BLANK];
175       if (target_length > 0)
176         log_alpha_a[0][1] = log_probs_a[0][get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];
177 
178       // now the loop over the inputs
179       for (const auto t : c10::irange(1, input_length)) {
180         for (const auto s : c10::irange(2*target_length+1)) {
181           auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
182           // this loop over s could be parallel/vectorized, too, but the required items are one index apart
183           // alternatively, one might consider moving s to the outer loop to cache current_target_prime more (but then it needs to be descending)
184           // for the cuda implementation, that gave a speed boost.
185           // This is eq (6) and (7), la1,2,3 are the three summands. We keep track of the maximum for the logsumexp calculation.
186 
187           scalar_t la1 = log_alpha_a[t-1][s];
188           scalar_t lamax = la1;
189           scalar_t la2, la3;
190           if (s > 0) {
191             la2 = log_alpha_a[t-1][s-1];
192             if (la2 > lamax)
193               lamax = la2;
194           } else {
195             la2 = neginf;
196           }
197           if ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) !=
198                           current_target_prime)) {
199             la3 = log_alpha_a[t-1][s-2];
200             if (la3 > lamax)
201               lamax = la3;
202           } else {
203             la3 = neginf;
204           }
205           if (lamax == neginf) // cannot do neginf-neginf
206             lamax = 0;
207           // this is the assignment of eq (6)
208           log_alpha_a[t][s] = std::log(std::exp(la1-lamax)+std::exp(la2-lamax)+std::exp(la3-lamax))+lamax + log_probs_a[t][current_target_prime];
209         }
210       }
211       // the likelihood is the sum of the last two alphas, eq (8), the loss is the negative log likelihood
212       if (target_length == 0) {
213         // if the target is empty then there is no preceding BLANK state and hence there is no path to merge
214         neg_log_likelihood_a[b] = -log_alpha_a[input_length-1][0];
215       } else {
216         scalar_t l1 = log_alpha_a[input_length-1][target_length*2];
217         scalar_t l2 = log_alpha_a[input_length-1][target_length*2-1];
218         scalar_t m = std::max(l1, l2);
219         m = ((m == neginf) ? 0 : m);
220         scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
221         neg_log_likelihood_a[b] = -log_likelihood;
222       }
223     }
224   });
225 
226   return std::make_tuple(neg_log_likelihood, log_alpha);
227 }
228 
229 // This is the backward. It consists of two phases:
230 // a) computing the beta analogous to the alphas in the forward (backward half of the forward-backward algorithm) (eq (10) and (11))
231 // b) collecting the per-activation characters for all s and wrapping the gradient (eq (16), the collection is the sum)
232 template<typename scalar_t, ScalarType target_scalar_type>
ctc_loss_backward_cpu_template(const Tensor & grad_out,const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)233 Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
234                                       const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
235   constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
236   using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
237   int64_t max_input_length = log_probs.size(0);
238   int64_t batch_size = log_probs.size(1);
239   int64_t num_labels = log_probs.size(2);
240   Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // at this point, this is log of empty sum
241 
242   // The admin bits. We don't do much checking and assume that the forward did.
243   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
244   int64_t tg_target_stride;
245   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
246   int64_t max_target_length;
247   std::vector<int64_t> tg_batch_offsets(batch_size);
248 
249   if (targets.dim() == 1) { // concatenated targets
250     int64_t pos = 0;
251     max_target_length = 0;
252     for (const auto i : c10::irange(batch_size)) {
253       tg_batch_offsets[i] = pos;
254       pos += target_lengths[i];
255       if (max_target_length < target_lengths[i])
256         max_target_length = target_lengths[i];
257     }
258     tg_target_stride = targets.stride(0);
259   }
260   else { // batch x max_target_length
261     // dim is 2
262     int64_t tg_batch_stride = targets.stride(0);
263     for (const auto i : c10::irange(batch_size)) {
264       tg_batch_offsets[i] = i * tg_batch_stride;
265     }
266     tg_target_stride = targets.stride(1);
267     max_target_length = targets.size(1);
268   }
269 
270   Tensor log_beta = at::empty_like(log_alpha, LEGACY_CONTIGUOUS_MEMORY_FORMAT);  // could be optimized to use only 2 rows
271   auto lpp  = log_probs.permute({1,0,2});
272   auto log_probs_a_global = lpp.accessor<const scalar_t, 3>();
273   auto log_alpha_a_global = log_alpha.accessor<const scalar_t, 3>();
274   auto log_beta_a_global = log_beta.accessor<scalar_t, 3>();
275   auto gp = grad.permute({1,0,2});
276   auto grad_a_global = gp.accessor<scalar_t, 3>();
277   auto targets_data = targets.const_data_ptr<target_t>();
278   auto grad_out_a = grad_out.accessor<const scalar_t, 1>();
279 
280   auto create_fill_iterator = [](const Tensor& tensor, IntArrayRef squash_dims) {
281     return TensorIteratorConfig()
282         .set_check_mem_overlap(false)  // Fill is idempotent, so overlap is okay
283         .check_all_same_dtype(false)
284         .add_output(tensor)
285         .resize_outputs(false)
286         .declare_static_shape(tensor.sizes(), squash_dims)
287         .build();
288   };
289   const auto fill_iter = create_fill_iterator(grad, /*squash_dims=*/1);
290   const auto fill_1d_iter = create_fill_iterator(grad, /*squash_dims=*/{0, 1});
291   const auto fill_log_beta_1d_iter = create_fill_iterator(log_beta, /*squash_dims=*/{0, 1});
292 
293   at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
294     TensorIterator fill_iter_local(fill_iter);
295     TensorIterator fill_1d_iter_local(fill_1d_iter);
296     TensorIterator fill_log_beta_1d_iter_local(fill_log_beta_1d_iter);
297 
298     for (const auto b : c10::irange(start, end)) {
299       scalar_t nll = neg_log_likelihood.accessor<scalar_t, 1>()[b];
300       auto grad_a = grad_a_global[b];
301       if (zero_infinity && nll == std::numeric_limits<scalar_t>::infinity()) {
302         // grad_batch.zero_();
303         fill_iter_local.unsafe_replace_operand(0, grad_a.data());
304         fill_stub(kCPU, fill_iter_local, 0);
305         continue;
306       }
307 
308       auto log_probs_a = log_probs_a_global[b];
309       auto log_alpha_a = log_alpha_a_global[b];
310       auto log_beta_a = log_beta_a_global[b];
311       int64_t input_length = input_lengths[b];
312       int64_t target_length = target_lengths[b];
313       int64_t tg_batch_offset = tg_batch_offsets[b];
314 
315       // the initialization of beta before eq (10)
316       // here we do the fill for each batch item separately, as the input lengths will differ, so the t in which
317       // we start varies
318       if (input_length > 0) {
319         // log_beta.select(0, b).select(1, input_length-1).fill_(neginf);
320         fill_log_beta_1d_iter_local.unsafe_replace_operand(
321             0, log_beta_a[input_length - 1].data());
322         fill_stub(kCPU, fill_log_beta_1d_iter_local, neginf);
323 
324         log_beta_a[input_length-1][2*target_length] = log_probs_a[input_length-1][BLANK];
325         grad_a[input_length-1][BLANK] = log_alpha_a[input_length-1][2*target_length] + log_beta_a[input_length-1][2*target_length];
326 
327         if (target_length > 0) {
328           auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 2*target_length-1, BLANK);
329           log_beta_a[input_length-1][2*target_length-1] = log_probs_a[input_length-1][current_target_prime];
330 
331           // the first two are a blank and a non-blank, so we know they are different and we don't need to do log+
332           grad_a[input_length-1][current_target_prime] = log_alpha_a[input_length-1][2*target_length-1] + log_beta_a[input_length-1][2*target_length-1];
333         }
334       }
335 
336       // now loop applying eq (10) / (11)
337       for (int64_t t=input_length-2; t>=0; t--) {
338         // this loop over s could be parallel/vectorized and doesn't really need to be descending...
339         // alternatively, one might consider moving s to the outer loop to cache current_target_prime more (but then it needs to be descending)
340         // for the cuda implementation, that gave a speed boost.
341         for (int64_t s=2*target_length; s>=0; s--) {
342           scalar_t lb1 = log_beta_a[t+1][s];
343           scalar_t lbmax = lb1;
344           scalar_t lb2, lb3;
345           auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
346           if (s < 2*target_length) {
347             lb2 = log_beta_a[t+1][s+1];
348             if (lb2 > lbmax)
349               lbmax = lb2;
350           } else {
351             lb2 = neginf;
352           }
353           if ((s < 2*target_length-1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s+2, BLANK) !=
354                                           current_target_prime)) {
355             lb3 = log_beta_a[t+1][s+2];
356             if (lb3 > lbmax)
357               lbmax = lb3;
358           } else {
359             lb3 = neginf;
360           }
361           if (lbmax == neginf)
362             lbmax = 0;
363 
364           log_beta_a[t][s] = std::log(std::exp(lb1-lbmax)+std::exp(lb2-lbmax)+std::exp(lb3-lbmax))+lbmax + log_probs_a[t][current_target_prime];
365           // one might check whether one can vectorize this better when done after the t-loop...
366           // now that we have beta, we fill in the sum of alpha*beta in eq (16)
367           // in contrast to the cuda implementation, we only parallelize over the batch, so we don't have a concurrency
368           // issue (several s can map to the same target character)
369           // collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
370           scalar_t log_alpha_beta =  log_alpha_a[t][s] + log_beta_a[t][s];
371           scalar_t &lcab = grad_a[t][current_target_prime];
372           if (lcab == neginf) {
373             lcab = log_alpha_beta;
374           } else {
375             scalar_t max = std::max(lcab, log_alpha_beta);
376             lcab = std::log(std::exp(lcab-max)+std::exp(log_alpha_beta-max))+max;
377           }
378         }
379       }
380 
381       // now grad has the sum of eq (16)
382       // now we wrap up the calculation by adding in the remaining items of eq (16)
383       // this could be a great target for further vectorization.
384       // grad is the output gradient, nll is the loss. Note that the likelihood -nll is the Z of eq (16)
385       scalar_t gr = grad_out_a[b];
386       for (const auto t : c10::irange(input_length)) { // or go for the full thing?
387         for (const auto c : c10::irange(num_labels)) {
388           scalar_t& res = grad_a[t][c];
389           scalar_t lp = log_probs_a[t][c];
390           res = (std::exp(lp)-std::exp(res + nll - lp)) * gr;
391         }
392       }
393 
394       // zero the remainder
395       for (auto l : c10::irange(input_length, max_input_length)) {
396         // grad_batch.select(0, l).zero_();
397         fill_1d_iter_local.unsafe_replace_operand(0, grad_a[l].data());
398         fill_stub(kCPU, fill_1d_iter_local, 0);
399       }
400     }
401   });
402   return grad;
403 }
404 
405 } // namespace
406 
ctc_loss_meta(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,bool zero_infinity)407 std::tuple<Tensor, Tensor> ctc_loss_meta(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
408   (void)zero_infinity; // only used for backwards
409   return AT_DISPATCH_FLOATING_TYPES(
410       log_probs.scalar_type(), "ctc_loss_meta", [&] {
411         Tensor neg_log_likelihood, log_alpha;
412         if (targets.scalar_type() == kLong) {
413           std::tie(neg_log_likelihood, log_alpha, std::ignore, std::ignore) =  ctc_loss_allocate_outputs<scalar_t, kLong>(
414               log_probs, targets, input_lengths, target_lengths, BLANK);
415         } else {
416           std::tie(neg_log_likelihood, log_alpha, std::ignore, std::ignore) = ctc_loss_allocate_outputs<scalar_t, kInt>(
417               log_probs, targets, input_lengths, target_lengths, BLANK);
418         }
419         return std::make_tuple(neg_log_likelihood, log_alpha);
420       });
421 }
422 
ctc_loss_cpu(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,bool zero_infinity)423 std::tuple<Tensor, Tensor> ctc_loss_cpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
424   (void)zero_infinity; // only used for backwards
425   return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_cpu", [&] {
426       if (targets.scalar_type() == kLong) {
427         return ctc_loss_cpu_template<scalar_t, kLong>(log_probs, targets, input_lengths, target_lengths, BLANK);
428       } else {
429         return ctc_loss_cpu_template<scalar_t, kInt>(log_probs, targets, input_lengths, target_lengths, BLANK);
430       }
431   });
432 }
433 
434 
ctc_loss_tensor(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK,bool zero_infinity)435 std::tuple<Tensor, Tensor> ctc_loss_tensor(const Tensor& log_probs, const Tensor& targets, const Tensor& input_lengths, const Tensor& target_lengths, int64_t BLANK, bool zero_infinity) {
436   TORCH_CHECK(isIntegralType(input_lengths.scalar_type(), /*includeBool=*/false), "input_lengths must be integral");
437   TORCH_CHECK(isIntegralType(target_lengths.scalar_type(), /*includeBool=*/false), "target_lengths must be integral");
438 
439   Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
440   Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
441   IntArrayRef il(ilc.const_data_ptr<int64_t>(), ilc.numel());
442   IntArrayRef tl(tlc.const_data_ptr<int64_t>(), tlc.numel());
443 
444   return at::_ctc_loss(log_probs, targets, il, tl, BLANK, zero_infinity);
445 }
446 
ctc_loss_backward_cpu(const Tensor & grad,const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)447 Tensor ctc_loss_backward_cpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
448                              const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
449   return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_backward_cpu", [&] {
450       if (targets.scalar_type() == kLong) {
451         return ctc_loss_backward_cpu_template<scalar_t,kLong>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
452       } else {
453         return ctc_loss_backward_cpu_template<scalar_t,kInt>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
454       }
455   });
456 }
457 
ctc_loss_backward_tensor(const Tensor & grad,const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,const Tensor & neg_log_likelihood,const Tensor & log_alpha,int64_t BLANK,bool zero_infinity)458 Tensor ctc_loss_backward_tensor(
459     const Tensor& grad,
460     const Tensor& log_probs,
461     const Tensor& targets,
462     const Tensor& input_lengths,
463     const Tensor& target_lengths,
464     const Tensor& neg_log_likelihood,
465     const Tensor& log_alpha,
466     int64_t BLANK,
467     bool zero_infinity) {
468   TORCH_CHECK(
469       isIntegralType(input_lengths.scalar_type(), /*includeBool=*/false),
470       "input_lengths must be integral");
471   TORCH_CHECK(isIntegralType(target_lengths.scalar_type(), /*includeBool=*/false), "target_lengths must be integral");
472 
473   Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
474   Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
475   IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
476   IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
477   return at::_ctc_loss_backward(grad, log_probs, targets, il, tl, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
478 }
479 
480 namespace {
481 
get_clamped_target_length(IntArrayRef target_lengths,const TensorOptions & options)482 Tensor get_clamped_target_length(
483     IntArrayRef target_lengths,
484     const TensorOptions& options) {
485   return at::tensor(target_lengths, options).clamp_min(1);
486 }
487 
get_clamped_target_length(const Tensor & target_lengths,const TensorOptions & options)488 Tensor get_clamped_target_length(
489     const Tensor & target_lengths,
490     const TensorOptions& options) {
491   return target_lengths.clamp_min(1);
492 }
493 
494 // this wrapper function dispatches to the native and cudnn implementations and hides the alpha/grad from the user (by just returning the loss)
495 // the gradient is implemented for _cudnn_ctc_loss (just in derivatives.yaml) and _ctc_loss and this function has automatic gradients
496 // it also handles the reduction if desired
497 template <typename LengthsType>
ctc_loss_impl(const Tensor & log_probs_,const Tensor & targets,LengthsType input_lengths,LengthsType target_lengths,int64_t BLANK,int64_t reduction,bool zero_infinity)498 Tensor ctc_loss_impl(const Tensor& log_probs_, const Tensor& targets, LengthsType input_lengths, LengthsType target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) {
499   auto is_batched = log_probs_.dim() == 3;
500   Tensor log_probs = is_batched ? log_probs_ : log_probs_.unsqueeze(1);
501   bool use_cudnn =
502       (log_probs.device().type() == at::kCUDA) &&
503       at::_use_cudnn_ctc_loss(
504           log_probs, targets, input_lengths, target_lengths, BLANK);
505 
506   Tensor res;
507   if (use_cudnn) {
508     // non-deterministic ctc loss on cudnn disabled due to inconsistent results
509     // see: https://github.com/pytorch/pytorch/issues/21680
510     res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, /*deterministic=*/true, zero_infinity));
511   } else {
512     // if the targets are on CPU (which you need for CuDNN, let's move them to
513     // GPU as a service for the user)
514     res = std::get<0>(at::_ctc_loss(
515         log_probs,
516         targets.to(log_probs.device(), kLong),
517         input_lengths,
518         target_lengths,
519         BLANK,
520         zero_infinity));
521     if (zero_infinity) {
522       res = at::where(res == Scalar(std::numeric_limits<double>::infinity()), at::zeros({}, res.options()), res);
523     }
524   }
525   if (reduction == at::Reduction::Mean) {
526     auto target_lengths_t = get_clamped_target_length(target_lengths, res.options());
527     return (res / target_lengths_t).mean();
528   } else if (reduction == at::Reduction::Sum) {
529     return res.sum();
530   }
531   return is_batched ? std::move(res) : res.squeeze(0);
532 }
533 
534 } // namespace
535 
ctc_loss(const Tensor & log_probs_,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,int64_t reduction,bool zero_infinity)536 Tensor ctc_loss(const Tensor& log_probs_, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) {
537   return ctc_loss_impl(log_probs_, targets, input_lengths, target_lengths, BLANK, reduction, zero_infinity);
538 }
539 
540 // Convenience function accepting Tensors
ctc_loss(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK,int64_t reduction,bool zero_infinity)541 Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, const Tensor& input_lengths, const Tensor& target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) {
542   // we don't want to convert to IntArrayRef if we can dispatch to cuDNN (this allows graph-capturable ctc_loss)
543   bool use_cudnn =
544       (log_probs.device().type() == at::kCUDA) &&
545       at::_use_cudnn_ctc_loss(
546           log_probs, targets, input_lengths, target_lengths, BLANK);
547   if (at::areAnyTensorSubclassLike(
548           {log_probs, targets, input_lengths, target_lengths}) || use_cudnn) {
549     // Composite Compliant path for TensorSubclasses
550     return ctc_loss_impl(log_probs, targets, input_lengths, target_lengths, BLANK, reduction, zero_infinity);
551   }
552   // Fast path (which accesses data_ptr) and less operator dispatches for
553   // regular tensors
554   TORCH_CHECK(isIntegralType(input_lengths.scalar_type(), /*includeBool=*/false), "input_lengths must be integral");
555   TORCH_CHECK(isIntegralType(target_lengths.scalar_type(), /*includeBool=*/false), "target_lengths must be integral");
556 
557   Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
558   Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
559   IntArrayRef il(ilc.const_data_ptr<int64_t>(), ilc.numel());
560   IntArrayRef tl(tlc.const_data_ptr<int64_t>(), tlc.numel());
561   return at::native::ctc_loss(log_probs, targets, il, tl, BLANK, reduction, zero_infinity);
562 }
563 
564 } // at::native
565