xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/NestedTensorImpl.h>
6 #include <ATen/native/nested/NestedTensorUtils.h>
7 
8 #include <c10/util/string_view.h>
9 #include <c10/util/Exception.h>
10 #include <optional>
11 
12 namespace at::native {
13 namespace {
14 
check_nested_tensor_matrix_constraints(const Tensor & nested_tensor,const Tensor & dense_matrix,c10::string_view caller)15 inline void check_nested_tensor_matrix_constraints(
16     const Tensor& nested_tensor,
17     const Tensor& dense_matrix,
18     c10::string_view caller) {
19   auto* nt_input = get_nested_tensor_impl(nested_tensor);
20   TORCH_INTERNAL_ASSERT(nt_input != nullptr);
21   TORCH_CHECK(
22       !dense_matrix.is_nested(),
23       caller,
24       " does not support nested weight when input is a nested tensor.")
25   // TODO: support noncontiguous case
26   // error out for now
27   TORCH_CHECK(
28       nested_tensor_impl_is_contiguous(nt_input),
29       "for now linear only supports contiguous nested tensor");
30   TORCH_CHECK(
31       nested_tensor.dim() == 3 && dense_matrix.dim() == 2,
32       caller,
33       " requires nested_tensor.dim == 3 and dense_matrix.dim == 2."
34       " Nested tensor dim: ",
35       nested_tensor.dim(),
36       ". Dense tensor dim: ",
37       dense_matrix.dim());
38   const auto last_dim = get_consistent_last_dim_of_nested_tensor(*nt_input);
39   // We check check the second dimension for linear because it transposes before matrix multiply
40   int64_t dim_constraint = (caller == "Linear") ? 1 : 0;
41   auto dense_size = dense_matrix.size(dim_constraint);
42   TORCH_CHECK(
43       last_dim == dense_size,
44       "Shape mismatch for NestedTensor ",
45       caller,
46       ": Expected input's (a nested tensor) 'last_dim' to equal 'weight.size(",
47       dim_constraint,
48       "),",
49       " but got: last_dim = ",
50       last_dim,
51       ", and weight.size(",
52       dim_constraint,
53       ") = ",
54       dense_size);
55 }
56 } // namespace
57 
nested_linear(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt)58 Tensor nested_linear(
59     const Tensor& input,
60     const Tensor& weight,
61     const std::optional<Tensor>& bias_opt) {
62   check_nested_tensor_matrix_constraints(input, weight, c10::string_view{"Linear"});
63   auto* nt_input = get_nested_tensor_impl(input);
64   const Tensor& input_buffer = nt_input->get_buffer();
65   Tensor result_buffer =
66       at::linear(input_buffer.reshape({-1, weight.size(1)}), weight, bias_opt);
67   result_buffer = result_buffer.reshape({-1});
68   int64_t weight_size_1 = weight.size(0);
69   Tensor new_sizes = nt_input->get_nested_sizes().clone();
70   // Now the last entry in every row of new_sizes should be weight_size_1.
71   new_sizes.index_put_({at::indexing::Slice(), -1}, weight_size_1);
72   return wrap_buffer(result_buffer, new_sizes);
73 }
74 
NestedTensor_matmul(const Tensor & self,const Tensor & other)75 Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) {
76   check_nested_tensor_matrix_constraints(self, other, c10::string_view{"Matmul"});
77   auto* nt_self = get_nested_tensor_impl_or_null(self);
78   const Tensor& self_buffer = nt_self->get_buffer();
79   Tensor result_buffer =
80       at::mm(self_buffer.reshape({-1, other.sizes()[0]}), other);
81   result_buffer = result_buffer.reshape({-1});
82   int64_t other_size_1 = other.sizes()[1];
83   Tensor new_sizes = nt_self->get_nested_sizes().clone();
84   // Now the last entry in every row of new_sizes should be other_size_1.
85   new_sizes.index_put_({at::indexing::Slice(), -1}, other_size_1);
86   return wrap_buffer(result_buffer, new_sizes);
87 }
88 
NestedTensor_times_Tensor_plus_Tensor_addmm(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const c10::Scalar & beta,const c10::Scalar & alpha,std::optional<bool> use_gelu)89 Tensor NestedTensor_times_Tensor_plus_Tensor_addmm(
90     const Tensor& self,
91     const Tensor& mat1,
92     const Tensor& mat2,
93     const c10::Scalar& beta,
94     const c10::Scalar& alpha,
95     std::optional<bool> use_gelu) {
96   // Interesting case: alpha * NT * T + beta * T
97   const auto* nt_mat1 = get_nested_tensor_impl_or_null(mat1);
98   TORCH_INTERNAL_ASSERT(nt_mat1 != nullptr);
99   TORCH_INTERNAL_ASSERT(!mat2.is_nested());
100   TORCH_INTERNAL_ASSERT(!self.is_nested());
101   TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(nt_mat1));
102   TORCH_INTERNAL_ASSERT(mat1.dim() == 3 && mat2.dim() == 2);
103   TORCH_INTERNAL_ASSERT(
104       get_consistent_last_dim_of_nested_tensor(*nt_mat1) == mat2.sizes()[0]);
105   const Tensor& mat1_buffer = nt_mat1->get_buffer();
106   Tensor result_buffer = !use_gelu.has_value()
107       ? at::addmm(
108             self, mat1_buffer.reshape({-1, mat2.sizes()[0]}), mat2, beta, alpha)
109       : at::_addmm_activation(
110             self,
111             mat1_buffer.reshape({-1, mat2.sizes()[0]}),
112             mat2,
113             beta,
114             alpha,
115             *use_gelu);
116   result_buffer = result_buffer.reshape({-1});
117   int64_t other_size_1 = mat2.sizes()[1];
118   Tensor new_sizes = nt_mat1->get_nested_sizes().clone();
119   new_sizes.index_put_({at::indexing::Slice(), -1}, other_size_1);
120   return at::detail::make_tensor<NestedTensorImpl>(
121       std::move(result_buffer), std::move(new_sizes));
122 }
123 
NestedTensor_add_NestedTensor_in_place(const Tensor & self,const Tensor & other)124 Tensor NestedTensor_add_NestedTensor_in_place(
125     const Tensor& self,
126     const Tensor& other) {
127   TORCH_INTERNAL_ASSERT(self.is_nested() && other.is_nested());
128   const auto& nt_self = *get_nested_tensor_impl(self);
129   const auto& nt_other = *get_nested_tensor_impl(other);
130 
131   const auto& self_sizes = nt_self.get_nested_sizes();
132   const auto& other_sizes = nt_other.get_nested_sizes();
133 
134   TORCH_CHECK(at::equal(self_sizes, other_sizes));
135   TORCH_INTERNAL_ASSERT(
136       nested_tensor_impl_is_contiguous(&nt_self) &&
137       nested_tensor_impl_is_contiguous(&nt_other));
138   nt_self.get_buffer().view({-1}).add_(nt_other.get_buffer().view({-1}));
139   return self;
140 }
141 
NestedTensor_softmax_dropout(const Tensor & self,const Tensor & query)142 Tensor NestedTensor_softmax_dropout(const Tensor& self, const Tensor& query) {
143   const auto* query_nt = get_nested_tensor_impl_or_null(query);
144   TORCH_INTERNAL_ASSERT(query_nt != nullptr);
145   TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(query_nt));
146 
147   const Tensor& sizes = query_nt->get_nested_sizes();
148   const auto num_tensors = sizes.sizes()[0];
149 
150   auto output = at::empty_like(self,{}, at::MemoryFormat::Contiguous);
151   TORCH_INTERNAL_ASSERT(output.is_contiguous());
152 
153   const auto max_seq_len = self.sizes()[2];
154 
155   for (int64_t i = 0; i < num_tensors; i++) {
156     auto seq_len = sizes.index({i, 0}).item<int64_t>();
157     auto subseq = self.index(
158         {i,
159          indexing::Slice(),
160          indexing::Slice(0, seq_len),
161          indexing::Slice(0, seq_len)});
162     auto subscores = at::softmax(subseq, subseq.dim() - 1);
163     output.index_put_(
164         {i,
165          indexing::Slice(),
166          indexing::Slice(0, seq_len),
167          indexing::Slice(0, seq_len)},
168         subscores);
169     output.index_put_(
170         {i,
171          indexing::Slice(),
172          indexing::Slice(0, seq_len),
173          indexing::Slice(seq_len, max_seq_len)},
174         0);
175     output.index_put_(
176         {i,
177          indexing::Slice(),
178          indexing::Slice(seq_len, max_seq_len),
179          indexing::Slice(0, max_seq_len)},
180         0);
181   }
182   return output;
183 }
184 
NestedTensor_softmax_dropout_cuda(const Tensor & self,const Tensor & query)185 Tensor NestedTensor_softmax_dropout_cuda(const Tensor& self, const Tensor& query) {
186   std::optional<Tensor> attn_mask;
187 
188   attn_mask = NestedTensor_to_mask(query, 2, self.size(2));
189   attn_mask = attn_mask->to(query.device(), /*non-blocking=*/true);
190   return _masked_softmax(self, *attn_mask, self.dim() - 1, /*mask type */ 1 );  // NestedTensor_to_mask produces a BxT mask
191 }
192 
NestedTensor_batch_offsets_from_size_tensor(const Tensor & sizes,int64_t extra_elements)193 Tensor NestedTensor_batch_offsets_from_size_tensor(
194     const Tensor& sizes,
195     int64_t extra_elements) {
196   int64_t* const sizes_ptr = sizes.data_ptr<int64_t>();
197   Tensor offsets = at::empty({1 + sizes.size(0) + extra_elements}, at::kInt);
198   int32_t* const offsets_ptr = offsets.mutable_data_ptr<int32_t>();
199   offsets_ptr[0] = 0;
200   const auto sizes_size_1 = sizes.size(1);
201   const auto sizes_size_0 = sizes.size(0);
202   for (const auto i : c10::irange(sizes_size_0)) {
203     int64_t prod = 1;
204     for (const auto j : c10::irange(sizes_size_1)) {
205       prod *= sizes_ptr[i * sizes_size_1 + j];
206     }
207     offsets_ptr[i + 1] = offsets_ptr[i] + static_cast<int32_t>(prod);
208   }
209   return offsets;
210 }
211 
212 
NestedTensor_to_mask(const Tensor & nt,std::optional<int64_t> mask_dim,std::optional<int64_t> mask_dim_length)213 Tensor NestedTensor_to_mask(const Tensor& nt, std::optional<int64_t> mask_dim, std::optional<int64_t> mask_dim_length) {
214   auto* nt_impl = get_nested_tensor_impl(nt);
215   TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_impl), "to_mask only works on contiguous NestedTensors.");
216   TORCH_CHECK(
217       !mask_dim || *mask_dim < nt.dim(),
218       "Requested mask dimension ",
219       *mask_dim,
220       " is bigger than dimension ",
221       nt.dim(),
222       " of given NestedTensor.");
223 
224   // TODO: port optimization for 1x1 tensors from
225   // pytorch/nestedtensor's version.
226 
227   TORCH_CHECK(
228       mask_dim && *mask_dim == 2 && nt.dim() == 3,
229       "Only the special case of mask_dim == 2 on a 3-D NestedTensor is supported right now.")
230   const auto& sizes = nt_impl->get_nested_sizes();
231   // Shape: # of tensors in our NestedTensor by max size along first dim
232   // TODO: calculate this without allocating a std::vector.
233   const auto result_size_1 = mask_dim_length ? *mask_dim_length : NestedTensor_get_max_size(*nt_impl)[0];
234   auto result = at::ones({sizes.sizes()[0], result_size_1}, at::kBool);
235   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
236   auto* result_data = result.data_ptr<bool>();
237   auto* sizes_ptr = sizes.data_ptr<int64_t>();
238   const auto sizes_size_1 = sizes.sizes()[1];
239   for (const auto ii : c10::irange(sizes.sizes()[0])) {
240     auto length = sizes_ptr[ii * sizes_size_1];
241     for (const auto jj : c10::irange(length)) {
242       result_data[ii * result_size_1 + jj] = false;
243     }
244   }
245   return result;
246 }
247 
_jagged_to_padded_dense_forward_cpu(const Tensor & values,TensorList offsets_list,c10::IntArrayRef max_lengths,const double padding_value)248 Tensor _jagged_to_padded_dense_forward_cpu(
249     const Tensor& values,
250     TensorList offsets_list,
251     c10::IntArrayRef max_lengths,
252     const double padding_value) {
253   // TODO: Make this kernel more efficient using TensorIterator or something.
254   TORCH_INTERNAL_ASSERT(
255       offsets_list.size() == 1 && max_lengths.size() == 1,
256       "_jagged_to_padded_dense_forward(): only a single jagged dim is supported for now");
257 
258   // allocate appropriately-sized padded tensor
259   const auto& offsets = offsets_list[0];
260   TORCH_CHECK(
261       offsets.dim() == 1,
262       "_jagged_to_padded_dense_forward(): expected 1D offsets, but got offsets.dim() == ",
263       offsets.dim());
264 
265   auto batch_size = offsets.size(0) - 1;
266   auto max_length = max_lengths[0];
267   auto values_shape = values.sizes().vec();
268   std::vector<int64_t> padded_shape;
269   padded_shape.reserve(values.dim() + 1);
270   padded_shape.push_back(batch_size);
271   padded_shape.push_back(max_length);
272   padded_shape.insert(padded_shape.end(), values_shape.begin() + 1, values_shape.end());
273   Tensor padded = values.new_full(padded_shape, padding_value);
274 
275   // copy data to padded tensor
276   for (auto i : c10::irange(batch_size)) {
277     auto start_offset = offsets[i].item<int64_t>();
278     auto end_offset = offsets[i + 1].item<int64_t>();
279     auto length = end_offset - start_offset;
280     // NB: truncate to max length to match CUDA kernel behavior.
281     length = std::min(length, max_length);
282     auto source = values.slice(0, start_offset, start_offset + length);
283     auto dst = padded.select(0, i).slice(0, 0, length);
284     dst.copy_(source);
285   }
286 
287   return padded;
288 }
289 
_padded_dense_to_jagged_forward_cpu(const Tensor & padded,TensorList offsets_list,std::optional<int64_t> total_L)290 Tensor _padded_dense_to_jagged_forward_cpu(
291     const Tensor& padded,
292     TensorList offsets_list,
293     std::optional<int64_t> total_L) {
294   // TODO: Make this kernel more efficient using TensorIterator or something.
295   TORCH_INTERNAL_ASSERT(
296       offsets_list.size() == 1,
297       "_padded_dense_to_jagged_forward(): only a single jagged dim is supported for now");
298 
299   // allocate appropriately-sized values tensor
300   const auto& offsets = offsets_list[0];
301   TORCH_CHECK(
302       offsets.dim() == 1,
303       "_padded_dense_to_jagged_forward(): expected 1D offsets, but got offsets.dim() == ",
304       offsets.dim());
305 
306   auto final_offset = offsets[-1].item<int64_t>();
307   int64_t total_L_val = total_L.has_value() ? (*total_L) : final_offset;
308   if (total_L.has_value()) {
309     // error if the offsets try to index past the end of the packed dimension
310     TORCH_CHECK(
311         final_offset == total_L_val,
312         "_padded_dense_to_jagged_forward(): final offset should match total_L value");
313   }
314 
315   TORCH_CHECK(
316       padded.dim() >= 2,
317       "_padded_dense_to_jagged_forward(): expected padded dim >= 2, but padded.dim() == ",
318       padded.dim());
319 
320   std::vector<int64_t> values_shape;
321   values_shape.reserve(padded.dim() - 1);
322   values_shape.push_back(total_L_val);
323   auto padded_shape = padded.sizes();
324   values_shape.insert(values_shape.end(), padded_shape.begin() + 2, padded_shape.end());
325   Tensor values = padded.new_empty(values_shape);
326 
327   // copy data to values tensor
328   auto batch_size = offsets.size(0) - 1;
329   for (auto i : c10::irange(batch_size)) {
330     auto start_offset = offsets[i].item<int64_t>();
331     auto end_offset = offsets[i + 1].item<int64_t>();
332     auto length = end_offset - start_offset;
333 
334     TORCH_CHECK(
335         length <= padded_shape[1],
336         "_padded_dense_to_jagged_forward(): found batch item of length ", length,
337         " when max length specified by padded input is ", padded_shape[1]);
338 
339     auto dst = values.slice(0, start_offset, end_offset);
340     auto source = padded.select(0, i).slice(0, 0, length);
341     dst.copy_(source);
342   }
343 
344   return values;
345 }
346 
347 } // namespace at::native
348