xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/PackedSequence.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 
4 #ifndef AT_PER_OPERATOR_HEADERS
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/_pack_padded_sequence_backward_native.h>
9 #include <ATen/ops/_pack_padded_sequence_native.h>
10 #include <ATen/ops/_pad_packed_sequence_native.h>
11 #include <ATen/ops/cat.h>
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/full.h>
14 #include <ATen/ops/pad_sequence_native.h>
15 #include <ATen/ops/zeros.h>
16 #include <ATen/ops/zeros_like_ops.h>
17 #endif
18 
19 #include <c10/util/irange.h>
20 
21 namespace at::native {
22 
checkLongTensor(const Tensor & tensor)23 static void checkLongTensor(const Tensor& tensor) {
24   TORCH_CHECK(tensor.dim() == 1 && tensor.device().type() == at::kCPU && tensor.scalar_type() == at::kLong,
25            "'lengths' argument should be a 1D CPU int64 tensor, but got ",
26             tensor.dim(), "D ", tensor.device().str(), " ", tensor.scalar_type(), " tensor");
27 }
28 
29 // This method returns `(data, batch_sizes)`, which are then passed into a
30 // `PackedSequence` constructor.
31 // `data` can be on arbitrary device and of arbitrary dtype, but `batch_sizes`
32 // must be a CPU int64 tensor.
33 // See NOTE [ device and dtype of a PackedSequence ]
_pack_padded_sequence(const Tensor & _input,const Tensor & _lengths,bool batch_first)34 std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Tensor& _lengths, bool batch_first) {
35   TORCH_CHECK(_input.numel() > 0, "Cannot pack empty tensors.");
36   auto input = batch_first ? _input.transpose(0, 1) : _input;
37   auto lengths_t = _lengths.contiguous();
38   checkLongTensor(lengths_t);
39 
40   int64_t batch_size = input.size(1);
41   int64_t * lengths = lengths_t.data_ptr<int64_t>();
42 
43   TORCH_CHECK(lengths_t.size(0) == batch_size,
44            "Expected `len(lengths)` to be equal to batch_size, but got ", lengths_t.size(0),
45            " (batch_size=", batch_size, ")");
46   TORCH_CHECK(lengths[batch_size - 1] > 0,
47            "Length of all samples has to be greater than 0, but found an element "
48            "in 'lengths' that is <= 0");
49   for (const auto i : c10::irange(batch_size - 1)) {
50     if (lengths[batch_size - 1 - i] > lengths[batch_size - 2 - i]) {
51       // NB: enforce_sorted is implemented at a Python level, but the sortedness
52       // check lives here. If enforce_sorted=False then this error should never
53       // get called.
54       AT_ERROR("`lengths` array must be sorted in decreasing order when "
55                "`enforce_sorted` is True. You can pass `enforce_sorted=False` "
56                "to pack_padded_sequence and/or pack_sequence to sidestep this "
57                "requirement if you do not need ONNX exportability.");
58     }
59   }
60 
61   std::vector<at::Tensor> steps;
62   steps.reserve(batch_size);
63   at::Tensor batch_sizes_t = at::empty(lengths[0], _lengths.options());
64   int64_t * batch_sizes = batch_sizes_t.mutable_data_ptr<int64_t>();
65 
66   std::vector<int64_t> step_shape; // == [-1, *input.shape[2:]]
67   {
68     auto input_sizes = input.sizes();
69     step_shape.reserve(input_sizes.size());
70     auto s_input_sizes = input_sizes.slice(2);
71     step_shape.push_back(-1);
72     step_shape.insert(step_shape.end(), s_input_sizes.begin(), s_input_sizes.end());
73   }
74 
75   // To understand what's going on in this loop imagine that the input is a padded 2D
76   // array that looks like this (x = valid entry, . = padding)
77   //
78   //  1 1 1 1 1
79   //  2 2 2 . .
80   //  2 2 2 . .
81   //  4 . . . .
82   //  4 . . . .
83   //
84   // Where the vertical dimension corresponds to time, and horizontal dim to batch.
85   // In this example, the lengths array will be equal to [5, 3, 3, 1, 1], and we will
86   // iterate over them in reverse order (from the rightmost column to the left).
87   // We want to avoid eager slicing of the input at every time step, and wait for
88   // the moments where the length increases. In this example, that will happen at the
89   // first, second and fourth steps. Then, we slice out the whole block of the input
90   // that corresponds to this length, and hasn't been sliced yet (the steps at which each
91   // element is sliced are annotated in the array above).  You can think of this as if we
92   // were scanning the sequences from the shortest one, and every time we realize there's
93   // more elements below in our column, we lower the counter (prev_l), and append the new
94   // block to the output.
95   int64_t prev_l = 0;
96   for (const auto i : c10::irange(batch_size)) {
97     int64_t l = lengths[batch_size - 1 - i];
98     if (l > prev_l) {
99       auto current_batch_size = batch_size - i;
100       steps.push_back(input.slice(0, prev_l, l).slice(1, 0, current_batch_size).contiguous().view(step_shape));
101       for (int64_t j = 0; j < (l - prev_l); ++j) {
102         (*batch_sizes++) = current_batch_size;
103       }
104       prev_l = l;
105     }
106     TORCH_CHECK(l >= prev_l);
107   }
108 
109   return std::make_tuple(at::cat(steps), batch_sizes_t);
110 }
111 
112 // `grad` could be on arbitrary device and of arbitrary dtype, but `_batch_sizes`
113 // is guaranteed to be a CPU int64 tensor.
114 // See NOTE [ device and dtype of a PackedSequence ]
_pack_padded_sequence_backward_symint(const Tensor & grad,c10::SymIntArrayRef input_size,const Tensor & _batch_sizes,bool batch_first)115 Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArrayRef input_size, const Tensor& _batch_sizes, bool batch_first) {
116   std::vector<c10::SymInt> input_size_after_t = input_size.vec();
117   if (batch_first) {
118     TORCH_CHECK(input_size.size() >= 2);
119     std::swap(input_size_after_t[0], input_size_after_t[1]);
120   }
121   auto grad_input = at::zeros_symint(input_size_after_t, grad.options());
122   auto batch_sizes_t = _batch_sizes.contiguous();
123   checkLongTensor(batch_sizes_t);
124 
125   int64_t offset = 0;
126   // NOTE: this op advertises as CompositeImplicitAutograd, but uses data_ptr().
127   // we should fix this.
128   auto max_seq_len = batch_sizes_t.size(0);
129   int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
130   for (const auto i : c10::irange(max_seq_len)) {
131     grad_input[i].slice(0, 0, batch_sizes[i]).copy_(grad.slice(0, offset, offset + batch_sizes[i]));
132     offset += batch_sizes[i];
133   }
134 
135   if (batch_first) {
136     grad_input = grad_input.transpose(0, 1);
137   }
138 
139   return grad_input;
140 }
141 
_pad_packed_sequence(const Tensor & data,const Tensor & _batch_sizes,bool batch_first,const Scalar & padding_value,int64_t total_length)142 std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
143   auto batch_sizes_t = _batch_sizes.contiguous();
144   checkLongTensor(batch_sizes_t);
145 
146   int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
147   int64_t max_batch_size = batch_sizes[0];
148   int64_t max_real_seq_length = batch_sizes_t.size(0);
149   int64_t max_seq_length = max_real_seq_length;
150   if (total_length > 0) {
151     TORCH_CHECK(total_length >= max_seq_length,
152              "Expected total_length to be at least the length of the longest "
153              "sequence in input, but got total_length=", total_length, " and "
154              "max sequence length being ", max_seq_length);
155     max_seq_length = total_length;
156   }
157 
158   std::vector<int64_t> output_size; // == [max_seq_length, max_batch_size, *var_data.size()[1:]]
159   {
160     output_size.reserve(data.dim() + 1);
161     output_size.push_back(max_seq_length);
162     output_size.push_back(max_batch_size);
163     auto s_data_size = data.sizes().slice(1);
164     output_size.insert(output_size.end(), s_data_size.begin(), s_data_size.end());
165   }
166   auto output = at::full(output_size, padding_value, data.options());
167 
168   // This will be modified at every iteration, but we reserve memory for it now.
169   std::vector<int64_t> tmp_view_size = std::move(output_size); // == [-1, -1, *var_data.size()[1:]]
170 
171   at::Tensor lengths_t = at::empty(max_batch_size, batch_sizes_t.options());
172   int64_t * lengths = lengths_t.mutable_data_ptr<int64_t>() + max_batch_size - 1;
173   int64_t data_offset = 0;
174   int64_t prev_batch_size = max_batch_size;
175   int64_t prev_i = 0;
176   for (int64_t i = 0; i <= max_real_seq_length; ++i) {
177     int64_t batch_size = i != max_real_seq_length ? batch_sizes[i] : 0;
178     if (batch_size != prev_batch_size) {
179       int64_t l = prev_batch_size * (i - prev_i);
180       // The lines below are equivalent to this:
181       // output[prev_i:i, :prev_batch_size] = tmp.view(i - prev_i, prev_batch_size, *input.shape[2:])
182       auto tmp = data.slice(0, data_offset, data_offset + l);
183       tmp_view_size[0] = i - prev_i;
184       tmp_view_size[1] = prev_batch_size;
185       output.slice(0, prev_i, i).slice(1, 0, prev_batch_size).copy_(tmp.view(tmp_view_size));
186       data_offset += l;
187       prev_i = i;
188     }
189     int64_t dec = prev_batch_size - batch_size;
190     if (dec > 0) {
191       for (C10_UNUSED const auto j : c10::irange(dec)) {
192         (*lengths--) = i;
193       }
194     }
195     prev_batch_size = batch_size;
196   }
197 
198   if (batch_first) {
199     output = output.transpose(0, 1);
200   }
201 
202   return std::make_tuple(output, lengths_t);
203 }
204 
pad_sequence(TensorList sequences,bool batch_first,double padding_value,const c10::string_view padding_side)205 Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value, const c10::string_view padding_side) {
206   const int64_t sequences_size = sequences.size();
207   TORCH_CHECK(sequences_size > 0, "received an empty list of sequences");
208   TORCH_CHECK(padding_side == "left" || padding_side == "right",
209               "Expected padding_side to be one of left or right, but got ", padding_side, ".");
210   IntArrayRef max_size = sequences[0].sizes();
211   IntArrayRef trailing_dims = max_size.slice(1);
212   int64_t max_len = std::max_element(
213     sequences.begin(),
214     sequences.end(),
215     [](const Tensor &a, const Tensor &b) {
216       return a.size(0) < b.size(0);
217     }
218   )->size(0);
219 
220   DimVector out_dims;
221   if (batch_first) {
222     out_dims = {sequences_size, max_len};
223   } else {
224     out_dims = {max_len, sequences_size};
225   }
226   out_dims.insert(out_dims.end(), trailing_dims.begin(), trailing_dims.end());
227 
228   Tensor out = at::full(out_dims, padding_value, sequences[0].options());
229   for (const auto i : c10::irange(sequences_size)) {
230     const Tensor& currseq = sequences[i];
231     const int64_t length_i = currseq.size(0);
232     const int64_t start = padding_side == "left" ? max_len - length_i : 0;
233     // use index notation to prevent duplicate references to the tensor
234     if (batch_first) {
235       out.select(0, i).narrow(0, start, length_i).copy_(currseq);
236     } else {
237       out.narrow(0, start, length_i).select(1, i).copy_(currseq);
238     }
239   }
240   return out;
241 }
242 
243 } // namespace at::native
244