xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/utils/rnn.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <torch/types.h>
5 
6 #include <utility>
7 
8 namespace torch {
9 namespace nn {
10 namespace utils {
11 namespace rnn {
12 
invert_permutation(const Tensor & permutation)13 inline Tensor invert_permutation(const Tensor& permutation) {
14   if (!permutation.defined()) {
15     return torch::Tensor();
16   }
17   Tensor output =
18       torch::empty_like(permutation, torch::MemoryFormat::Contiguous);
19   output.scatter_(
20       0,
21       permutation,
22       torch::arange(0, permutation.numel(), permutation.device()));
23   return output;
24 }
25 
26 /// Holds the data and list of `batch_sizes` of a packed sequence.
27 ///
28 /// All RNN modules accept packed sequences as inputs.
29 ///
30 /// Note:
31 ///     Instances of this class should never be created manually. They are meant
32 ///     to be instantiated by functions like `pack_padded_sequence`.
33 ///
34 ///     Batch sizes represent the number elements at each sequence step in
35 ///     the batch, not the varying sequence lengths passed to
36 ///     `pack_padded_sequence`.  For instance, given data ``abc`` and ``x``
37 ///     the :class:`PackedSequence` would contain data ``axbc`` with
38 ///     ``batch_sizes=[2,1,1]``.
39 ///
40 /// Attributes:
41 ///     data (Tensor): Tensor containing packed sequence
42 ///     batch_sizes (Tensor): Tensor of integers holding
43 ///         information about the batch size at each sequence step
44 ///     sorted_indices (Tensor, optional): Tensor of integers holding how this
45 ///         :class:`PackedSequence` is constructed from sequences.
46 ///     unsorted_indices (Tensor, optional): Tensor of integers holding how this
47 ///         to recover the original sequences with correct order.
48 ///
49 /// .. note::
50 ///     `data` can be on arbitrary device and of arbitrary dtype.
51 ///     `sorted_indices` and `unsorted_indices` must be ``torch::kInt64``
52 ///     tensors on the same device as `data`.
53 ///
54 ///     However, `batch_sizes` should always be a CPU ``torch::kInt64`` tensor.
55 ///
56 ///     This invariant is maintained throughout `PackedSequence` class,
57 ///     and all functions that construct a `PackedSequence` in libtorch
58 ///     (i.e., they only pass in tensors conforming to this constraint).
59 class PackedSequence {
60  public:
61   explicit PackedSequence(
62       Tensor data,
63       Tensor batch_sizes,
64       Tensor sorted_indices = {},
65       Tensor unsorted_indices = {}) {
66     // NB: if unsorted_indices is provided, it should be the inverse permutation
67     // to sorted_indices. Don't assert it here because the PackedSequence ctor
68     // should only be used internally.
69     if (!unsorted_indices.defined()) {
70       unsorted_indices = invert_permutation(sorted_indices);
71     }
72     TORCH_CHECK(
73         batch_sizes.device().type() == kCPU,
74         "batch_sizes should always be on CPU. "
75         "Instances of PackedSequence should never be created manually. "
76         "They should be instantiated by functions like pack_sequence "
77         "and pack_padded_sequences in nn::utils::rnn. "
78         "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence");
79     data_ = std::move(data);
80     batch_sizes_ = std::move(batch_sizes);
81     sorted_indices_ = std::move(sorted_indices);
82     unsorted_indices_ = std::move(unsorted_indices);
83   }
84 
data()85   const Tensor& data() const {
86     return data_;
87   }
88 
batch_sizes()89   const Tensor& batch_sizes() const {
90     return batch_sizes_;
91   }
92 
sorted_indices()93   const Tensor& sorted_indices() const {
94     return sorted_indices_;
95   }
96 
unsorted_indices()97   const Tensor& unsorted_indices() const {
98     return unsorted_indices_;
99   }
100 
pin_memory()101   PackedSequence pin_memory() const {
102     // Why not convert `batch_sizes`?
103     // See NOTE [ device and dtype of a PackedSequence ]
104     return PackedSequence(
105         data_.pin_memory(),
106         batch_sizes_,
107         sorted_indices_.defined() ? sorted_indices_.pin_memory() : Tensor(),
108         unsorted_indices_.defined() ? unsorted_indices_.pin_memory()
109                                     : Tensor());
110   }
111 
to(TensorOptions options)112   PackedSequence to(TensorOptions options) const {
113     // Performs dtype and/or device conversion on `data_`.
114     //
115     // If the ``data_`` Tensor already has the correct `torch::Dtype`
116     // and `torch::Device`, then ``self`` is returned.
117     // Otherwise, returns a copy with the desired configuration.
118 
119     // Why not convert `batch_sizes`?
120     // See NOTE [ device and dtype of a PackedSequence ]
121     Tensor data = data_.to(options);
122     if (data.is_same(data_)) {
123       return *this;
124     } else {
125       // Does not forward device or dtype args, device is set from data.device()
126       Tensor sorted_indices = sorted_indices_.defined()
127           ? sorted_indices_.to(
128                 options.device(data.device()).dtype(sorted_indices_.dtype()))
129           : Tensor();
130       Tensor unsorted_indices = unsorted_indices_.defined()
131           ? unsorted_indices_.to(
132                 options.device(data.device()).dtype(unsorted_indices_.dtype()))
133           : Tensor();
134       return PackedSequence(
135           std::move(data),
136           batch_sizes_,
137           std::move(sorted_indices),
138           std::move(unsorted_indices));
139     }
140   }
141 
cuda()142   PackedSequence cuda() const {
143     return to(kCUDA);
144   }
145 
cpu()146   PackedSequence cpu() const {
147     return to(kCPU);
148   }
149 
150   /// Returns true if `data_` stored on a gpu
is_cuda()151   bool is_cuda() const {
152     return data_.is_cuda();
153   }
154 
155   /// Returns true if `data_` stored on in pinned memory
is_pinned()156   bool is_pinned() const {
157     return data_.is_pinned();
158   }
159 
160  private:
161   Tensor data_;
162   Tensor batch_sizes_;
163   Tensor sorted_indices_;
164   Tensor unsorted_indices_;
165 };
166 
167 /// Packs a Tensor containing padded sequences of variable length.
168 ///
169 /// `input` can be of size ``T x B x *`` where `T` is the length of the
170 /// longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
171 /// ``*`` is any number of dimensions (including 0). If ``batch_first`` is
172 /// ``true``, ``B x T x *`` `input` is expected.
173 ///
174 /// For unsorted sequences, use `enforce_sorted = false`. If `enforce_sorted` is
175 /// ``true``, the sequences should be sorted by length in a decreasing order,
176 /// i.e.
177 /// ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the
178 /// shortest one.
179 ///
180 /// Note:
181 ///     This function accepts any input that has at least two dimensions. You
182 ///     can apply it to pack the labels, and use the output of the RNN with
183 ///     them to compute the loss directly. A Tensor can be retrieved from
184 ///     a `PackedSequence` object by calling its ``.data()`` function.
185 ///
186 /// Arguments:
187 ///     input (Tensor): padded batch of variable length sequences.
188 ///     lengths (Tensor): list of sequences lengths of each batch element.
189 ///     batch_first (bool, optional): if ``true``, the input is expected in ``B
190 ///     x T x *``
191 ///         format. Default: ``false``.
192 ///     enforce_sorted (bool, optional): if ``true``, the input is expected to
193 ///         contain sequences sorted by length in a decreasing order. If
194 ///         ``false``, this condition is not checked. Default: ``true``.
195 ///
196 /// Returns:
197 ///     a `PackedSequence` object
198 inline PackedSequence pack_padded_sequence(
199     Tensor input,
200     Tensor lengths,
201     bool batch_first = false,
202     bool enforce_sorted = true) {
203   lengths = lengths.to(kInt64);
204   Tensor sorted_indices;
205   if (enforce_sorted) {
206     sorted_indices = Tensor();
207   } else {
208     std::tie(lengths, sorted_indices) =
209         torch::sort(lengths, /*dim=*/-1, /*descending=*/true);
210     sorted_indices = sorted_indices.to(input.device());
211     int64_t batch_dim = batch_first ? 0 : 1;
212     input = input.index_select(batch_dim, sorted_indices);
213   }
214 
215   auto [data, batch_sizes] =
216       torch::_pack_padded_sequence(input, lengths, batch_first);
217   return PackedSequence(
218       std::move(data), std::move(batch_sizes), std::move(sorted_indices), {});
219 }
220 
221 /// Pads a packed batch of variable length sequences.
222 ///
223 /// It is an inverse operation to `pack_padded_sequence`.
224 ///
225 /// The returned Tensor's data will be of size ``T x B x *``, where `T` is the
226 /// length of the longest sequence and `B` is the batch size. If ``batch_first``
227 /// is true, the data will be transposed into ``B x T x *`` format.
228 ///
229 /// Batch elements will be ordered decreasingly by their length.
230 ///
231 /// Arguments:
232 ///     sequence (PackedSequence): batch to pad
233 ///     batch_first (bool, optional): if ``true``, the output will be in ``B x T
234 ///     x *``
235 ///         format.
236 ///     padding_value (double, optional): values for padded elements.
237 ///     total_length (int64_t, optional): if specified, the output will be
238 ///     padded to
239 ///         have length `total_length`. This method will throw error
240 ///         if `total_length` is less than the max sequence length in
241 ///         `sequence`.
242 ///
243 /// Returns:
244 ///     Tuple of Tensor containing the padded sequence, and a Tensor
245 ///     containing the list of lengths of each sequence in the batch.
246 inline std::tuple<Tensor, Tensor> pad_packed_sequence(
247     PackedSequence sequence,
248     bool batch_first = false,
249     double padding_value = 0.0,
250     std::optional<int64_t> total_length = torch::nullopt) {
251   int64_t max_seq_length = sequence.batch_sizes().size(0);
252   if (total_length.has_value()) {
253     int64_t total_length_val = total_length.value();
254     TORCH_CHECK(
255         total_length_val >= max_seq_length,
256         "Expected total_length to be at least the length "
257         "of the longest sequence in input, but got "
258         "total_length=",
259         total_length_val,
260         " and max sequence length being ",
261         max_seq_length);
262     max_seq_length = total_length_val;
263   }
264   auto [padded_output, lengths] = torch::_pad_packed_sequence(
265       sequence.data(),
266       sequence.batch_sizes(),
267       batch_first,
268       padding_value,
269       max_seq_length);
270   const Tensor& unsorted_indices = sequence.unsorted_indices();
271   if (unsorted_indices.defined()) {
272     int64_t batch_dim = batch_first ? 0 : 1;
273     return std::make_tuple(
274         padded_output.index_select(batch_dim, unsorted_indices),
275         lengths.index({unsorted_indices.cpu()}));
276   }
277   return std::make_tuple(padded_output, lengths);
278 }
279 
280 /// Pad a list of variable length Tensors with ``padding_value``
281 ///
282 /// ``pad_sequence`` stacks a list of Tensors along a new dimension,
283 /// and pads them to equal length. For example, if the input is list of
284 /// sequences with size ``L x *`` and if batch_first is false, and ``T x B x *``
285 /// otherwise.
286 ///
287 /// `B` is batch size. It is equal to the number of elements in ``sequences``.
288 /// `T` is length of the longest sequence.
289 /// `L` is length of the sequence.
290 /// `*` is any number of trailing dimensions, including none.
291 ///
292 /// Note:
293 ///     This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
294 ///     where `T` is the length of the longest sequence. This function assumes
295 ///     trailing dimensions and type of all the Tensors in sequences are same.
296 ///
297 /// Arguments:
298 ///     sequences (torch::ArrayRef<Tensor>): list of variable length sequences.
299 ///     batch_first (bool, optional): output will be in ``B x T x *`` if true,
300 ///     or in
301 ///         ``T x B x *`` otherwise
302 ///     padding_value (double, optional): value for padded elements. Default: 0.
303 ///     padding_side (str, optional): the side to pad the sequences on. Default:
304 ///         "right".
305 ///
306 /// Returns:
307 ///     Tensor of size ``T x B x *`` if `batch_first` is ``false``.
308 ///     Tensor of size ``B x T x *`` otherwise
309 inline Tensor pad_sequence(
310     ArrayRef<Tensor> sequences,
311     bool batch_first = false,
312     double padding_value = 0,
313     c10::string_view padding_side = "right") {
314   return at::pad_sequence(sequences, batch_first, padding_value, padding_side);
315 }
316 
317 /// Packs a list of variable length Tensors
318 ///
319 /// ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
320 /// the length of a sequence and `*` is any number of trailing dimensions,
321 /// including zero.
322 ///
323 /// For unsorted sequences, use `enforce_sorted = false`. If ``enforce_sorted``
324 /// is ``true``, the sequences should be sorted in the order of decreasing
325 /// length.
326 ///
327 ///
328 /// Arguments:
329 ///     sequences (torch::ArrayRef<Tensor>): A list of sequences of decreasing
330 ///     length. enforce_sorted (bool, optional): if ``true``, checks that the
331 ///     input
332 ///         contains sequences sorted by length in a decreasing order. If
333 ///         ``false``, this condition is not checked. Default: ``true``.
334 ///
335 /// Returns:
336 ///     a `PackedSequence` object
337 inline PackedSequence pack_sequence(
338     ArrayRef<Tensor> sequences,
339     bool enforce_sorted = true) {
340   Tensor lengths = torch::empty({(int64_t)sequences.size()}, kInt64);
341   for (const auto i : c10::irange(sequences.size())) {
342     lengths[i] = sequences[i].size(0);
343   }
344   return pack_padded_sequence(
345       at::pad_sequence(sequences),
346       std::move(lengths),
347       /*batch_first=*/false,
348       /*enforce_sorted=*/enforce_sorted);
349 }
350 
351 } // namespace rnn
352 } // namespace utils
353 } // namespace nn
354 } // namespace torch
355