1 #pragma once
2
3 #include <c10/util/irange.h>
4 #include <torch/types.h>
5
6 #include <utility>
7
8 namespace torch {
9 namespace nn {
10 namespace utils {
11 namespace rnn {
12
invert_permutation(const Tensor & permutation)13 inline Tensor invert_permutation(const Tensor& permutation) {
14 if (!permutation.defined()) {
15 return torch::Tensor();
16 }
17 Tensor output =
18 torch::empty_like(permutation, torch::MemoryFormat::Contiguous);
19 output.scatter_(
20 0,
21 permutation,
22 torch::arange(0, permutation.numel(), permutation.device()));
23 return output;
24 }
25
26 /// Holds the data and list of `batch_sizes` of a packed sequence.
27 ///
28 /// All RNN modules accept packed sequences as inputs.
29 ///
30 /// Note:
31 /// Instances of this class should never be created manually. They are meant
32 /// to be instantiated by functions like `pack_padded_sequence`.
33 ///
34 /// Batch sizes represent the number elements at each sequence step in
35 /// the batch, not the varying sequence lengths passed to
36 /// `pack_padded_sequence`. For instance, given data ``abc`` and ``x``
37 /// the :class:`PackedSequence` would contain data ``axbc`` with
38 /// ``batch_sizes=[2,1,1]``.
39 ///
40 /// Attributes:
41 /// data (Tensor): Tensor containing packed sequence
42 /// batch_sizes (Tensor): Tensor of integers holding
43 /// information about the batch size at each sequence step
44 /// sorted_indices (Tensor, optional): Tensor of integers holding how this
45 /// :class:`PackedSequence` is constructed from sequences.
46 /// unsorted_indices (Tensor, optional): Tensor of integers holding how this
47 /// to recover the original sequences with correct order.
48 ///
49 /// .. note::
50 /// `data` can be on arbitrary device and of arbitrary dtype.
51 /// `sorted_indices` and `unsorted_indices` must be ``torch::kInt64``
52 /// tensors on the same device as `data`.
53 ///
54 /// However, `batch_sizes` should always be a CPU ``torch::kInt64`` tensor.
55 ///
56 /// This invariant is maintained throughout `PackedSequence` class,
57 /// and all functions that construct a `PackedSequence` in libtorch
58 /// (i.e., they only pass in tensors conforming to this constraint).
59 class PackedSequence {
60 public:
61 explicit PackedSequence(
62 Tensor data,
63 Tensor batch_sizes,
64 Tensor sorted_indices = {},
65 Tensor unsorted_indices = {}) {
66 // NB: if unsorted_indices is provided, it should be the inverse permutation
67 // to sorted_indices. Don't assert it here because the PackedSequence ctor
68 // should only be used internally.
69 if (!unsorted_indices.defined()) {
70 unsorted_indices = invert_permutation(sorted_indices);
71 }
72 TORCH_CHECK(
73 batch_sizes.device().type() == kCPU,
74 "batch_sizes should always be on CPU. "
75 "Instances of PackedSequence should never be created manually. "
76 "They should be instantiated by functions like pack_sequence "
77 "and pack_padded_sequences in nn::utils::rnn. "
78 "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence");
79 data_ = std::move(data);
80 batch_sizes_ = std::move(batch_sizes);
81 sorted_indices_ = std::move(sorted_indices);
82 unsorted_indices_ = std::move(unsorted_indices);
83 }
84
data()85 const Tensor& data() const {
86 return data_;
87 }
88
batch_sizes()89 const Tensor& batch_sizes() const {
90 return batch_sizes_;
91 }
92
sorted_indices()93 const Tensor& sorted_indices() const {
94 return sorted_indices_;
95 }
96
unsorted_indices()97 const Tensor& unsorted_indices() const {
98 return unsorted_indices_;
99 }
100
pin_memory()101 PackedSequence pin_memory() const {
102 // Why not convert `batch_sizes`?
103 // See NOTE [ device and dtype of a PackedSequence ]
104 return PackedSequence(
105 data_.pin_memory(),
106 batch_sizes_,
107 sorted_indices_.defined() ? sorted_indices_.pin_memory() : Tensor(),
108 unsorted_indices_.defined() ? unsorted_indices_.pin_memory()
109 : Tensor());
110 }
111
to(TensorOptions options)112 PackedSequence to(TensorOptions options) const {
113 // Performs dtype and/or device conversion on `data_`.
114 //
115 // If the ``data_`` Tensor already has the correct `torch::Dtype`
116 // and `torch::Device`, then ``self`` is returned.
117 // Otherwise, returns a copy with the desired configuration.
118
119 // Why not convert `batch_sizes`?
120 // See NOTE [ device and dtype of a PackedSequence ]
121 Tensor data = data_.to(options);
122 if (data.is_same(data_)) {
123 return *this;
124 } else {
125 // Does not forward device or dtype args, device is set from data.device()
126 Tensor sorted_indices = sorted_indices_.defined()
127 ? sorted_indices_.to(
128 options.device(data.device()).dtype(sorted_indices_.dtype()))
129 : Tensor();
130 Tensor unsorted_indices = unsorted_indices_.defined()
131 ? unsorted_indices_.to(
132 options.device(data.device()).dtype(unsorted_indices_.dtype()))
133 : Tensor();
134 return PackedSequence(
135 std::move(data),
136 batch_sizes_,
137 std::move(sorted_indices),
138 std::move(unsorted_indices));
139 }
140 }
141
cuda()142 PackedSequence cuda() const {
143 return to(kCUDA);
144 }
145
cpu()146 PackedSequence cpu() const {
147 return to(kCPU);
148 }
149
150 /// Returns true if `data_` stored on a gpu
is_cuda()151 bool is_cuda() const {
152 return data_.is_cuda();
153 }
154
155 /// Returns true if `data_` stored on in pinned memory
is_pinned()156 bool is_pinned() const {
157 return data_.is_pinned();
158 }
159
160 private:
161 Tensor data_;
162 Tensor batch_sizes_;
163 Tensor sorted_indices_;
164 Tensor unsorted_indices_;
165 };
166
167 /// Packs a Tensor containing padded sequences of variable length.
168 ///
169 /// `input` can be of size ``T x B x *`` where `T` is the length of the
170 /// longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
171 /// ``*`` is any number of dimensions (including 0). If ``batch_first`` is
172 /// ``true``, ``B x T x *`` `input` is expected.
173 ///
174 /// For unsorted sequences, use `enforce_sorted = false`. If `enforce_sorted` is
175 /// ``true``, the sequences should be sorted by length in a decreasing order,
176 /// i.e.
177 /// ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the
178 /// shortest one.
179 ///
180 /// Note:
181 /// This function accepts any input that has at least two dimensions. You
182 /// can apply it to pack the labels, and use the output of the RNN with
183 /// them to compute the loss directly. A Tensor can be retrieved from
184 /// a `PackedSequence` object by calling its ``.data()`` function.
185 ///
186 /// Arguments:
187 /// input (Tensor): padded batch of variable length sequences.
188 /// lengths (Tensor): list of sequences lengths of each batch element.
189 /// batch_first (bool, optional): if ``true``, the input is expected in ``B
190 /// x T x *``
191 /// format. Default: ``false``.
192 /// enforce_sorted (bool, optional): if ``true``, the input is expected to
193 /// contain sequences sorted by length in a decreasing order. If
194 /// ``false``, this condition is not checked. Default: ``true``.
195 ///
196 /// Returns:
197 /// a `PackedSequence` object
198 inline PackedSequence pack_padded_sequence(
199 Tensor input,
200 Tensor lengths,
201 bool batch_first = false,
202 bool enforce_sorted = true) {
203 lengths = lengths.to(kInt64);
204 Tensor sorted_indices;
205 if (enforce_sorted) {
206 sorted_indices = Tensor();
207 } else {
208 std::tie(lengths, sorted_indices) =
209 torch::sort(lengths, /*dim=*/-1, /*descending=*/true);
210 sorted_indices = sorted_indices.to(input.device());
211 int64_t batch_dim = batch_first ? 0 : 1;
212 input = input.index_select(batch_dim, sorted_indices);
213 }
214
215 auto [data, batch_sizes] =
216 torch::_pack_padded_sequence(input, lengths, batch_first);
217 return PackedSequence(
218 std::move(data), std::move(batch_sizes), std::move(sorted_indices), {});
219 }
220
221 /// Pads a packed batch of variable length sequences.
222 ///
223 /// It is an inverse operation to `pack_padded_sequence`.
224 ///
225 /// The returned Tensor's data will be of size ``T x B x *``, where `T` is the
226 /// length of the longest sequence and `B` is the batch size. If ``batch_first``
227 /// is true, the data will be transposed into ``B x T x *`` format.
228 ///
229 /// Batch elements will be ordered decreasingly by their length.
230 ///
231 /// Arguments:
232 /// sequence (PackedSequence): batch to pad
233 /// batch_first (bool, optional): if ``true``, the output will be in ``B x T
234 /// x *``
235 /// format.
236 /// padding_value (double, optional): values for padded elements.
237 /// total_length (int64_t, optional): if specified, the output will be
238 /// padded to
239 /// have length `total_length`. This method will throw error
240 /// if `total_length` is less than the max sequence length in
241 /// `sequence`.
242 ///
243 /// Returns:
244 /// Tuple of Tensor containing the padded sequence, and a Tensor
245 /// containing the list of lengths of each sequence in the batch.
246 inline std::tuple<Tensor, Tensor> pad_packed_sequence(
247 PackedSequence sequence,
248 bool batch_first = false,
249 double padding_value = 0.0,
250 std::optional<int64_t> total_length = torch::nullopt) {
251 int64_t max_seq_length = sequence.batch_sizes().size(0);
252 if (total_length.has_value()) {
253 int64_t total_length_val = total_length.value();
254 TORCH_CHECK(
255 total_length_val >= max_seq_length,
256 "Expected total_length to be at least the length "
257 "of the longest sequence in input, but got "
258 "total_length=",
259 total_length_val,
260 " and max sequence length being ",
261 max_seq_length);
262 max_seq_length = total_length_val;
263 }
264 auto [padded_output, lengths] = torch::_pad_packed_sequence(
265 sequence.data(),
266 sequence.batch_sizes(),
267 batch_first,
268 padding_value,
269 max_seq_length);
270 const Tensor& unsorted_indices = sequence.unsorted_indices();
271 if (unsorted_indices.defined()) {
272 int64_t batch_dim = batch_first ? 0 : 1;
273 return std::make_tuple(
274 padded_output.index_select(batch_dim, unsorted_indices),
275 lengths.index({unsorted_indices.cpu()}));
276 }
277 return std::make_tuple(padded_output, lengths);
278 }
279
280 /// Pad a list of variable length Tensors with ``padding_value``
281 ///
282 /// ``pad_sequence`` stacks a list of Tensors along a new dimension,
283 /// and pads them to equal length. For example, if the input is list of
284 /// sequences with size ``L x *`` and if batch_first is false, and ``T x B x *``
285 /// otherwise.
286 ///
287 /// `B` is batch size. It is equal to the number of elements in ``sequences``.
288 /// `T` is length of the longest sequence.
289 /// `L` is length of the sequence.
290 /// `*` is any number of trailing dimensions, including none.
291 ///
292 /// Note:
293 /// This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
294 /// where `T` is the length of the longest sequence. This function assumes
295 /// trailing dimensions and type of all the Tensors in sequences are same.
296 ///
297 /// Arguments:
298 /// sequences (torch::ArrayRef<Tensor>): list of variable length sequences.
299 /// batch_first (bool, optional): output will be in ``B x T x *`` if true,
300 /// or in
301 /// ``T x B x *`` otherwise
302 /// padding_value (double, optional): value for padded elements. Default: 0.
303 /// padding_side (str, optional): the side to pad the sequences on. Default:
304 /// "right".
305 ///
306 /// Returns:
307 /// Tensor of size ``T x B x *`` if `batch_first` is ``false``.
308 /// Tensor of size ``B x T x *`` otherwise
309 inline Tensor pad_sequence(
310 ArrayRef<Tensor> sequences,
311 bool batch_first = false,
312 double padding_value = 0,
313 c10::string_view padding_side = "right") {
314 return at::pad_sequence(sequences, batch_first, padding_value, padding_side);
315 }
316
317 /// Packs a list of variable length Tensors
318 ///
319 /// ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
320 /// the length of a sequence and `*` is any number of trailing dimensions,
321 /// including zero.
322 ///
323 /// For unsorted sequences, use `enforce_sorted = false`. If ``enforce_sorted``
324 /// is ``true``, the sequences should be sorted in the order of decreasing
325 /// length.
326 ///
327 ///
328 /// Arguments:
329 /// sequences (torch::ArrayRef<Tensor>): A list of sequences of decreasing
330 /// length. enforce_sorted (bool, optional): if ``true``, checks that the
331 /// input
332 /// contains sequences sorted by length in a decreasing order. If
333 /// ``false``, this condition is not checked. Default: ``true``.
334 ///
335 /// Returns:
336 /// a `PackedSequence` object
337 inline PackedSequence pack_sequence(
338 ArrayRef<Tensor> sequences,
339 bool enforce_sorted = true) {
340 Tensor lengths = torch::empty({(int64_t)sequences.size()}, kInt64);
341 for (const auto i : c10::irange(sequences.size())) {
342 lengths[i] = sequences[i].size(0);
343 }
344 return pack_padded_sequence(
345 at::pad_sequence(sequences),
346 std::move(lengths),
347 /*batch_first=*/false,
348 /*enforce_sorted=*/enforce_sorted);
349 }
350
351 } // namespace rnn
352 } // namespace utils
353 } // namespace nn
354 } // namespace torch
355