xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/NestedTensorImpl.h>
3 #include <ATen/native/nested/NestedTensorTransformerUtils.h>
4 #include <tuple>
5 
6 namespace at::native::preprocessing {
7 
8 namespace {
9 
10 /**
11  * This builds up the cumulative sequence length for a batch of sequences.
12  * This is not very dry, but in the backward pass we already have cumulative_seq_len
13  * on device. And all we need on CPU to launch the kernel is NNz. We could refactor the
14  * the below function but it adds more complexity than I think is needed.
15  */
get_nnz(const Tensor & nestedtensor)16 int64_t get_nnz(const Tensor& nestedtensor) {
17   auto* nt_impl = get_nested_tensor_impl(nestedtensor);
18   const auto& sizes = nt_impl->get_nested_sizes();
19   auto size_tensor_stride = sizes.stride(0);
20   const int64_t batch_size = nestedtensor.size(0);
21   auto* sizes_ptr = sizes.data_ptr<int64_t>();
22   int64_t cumulative_sequence_length = 0;
23   for (const auto i : c10::irange(batch_size)) {
24     // Calculate the cumulative sum of the sequence lengths
25     int64_t current_seq_len = sizes_ptr[(i * size_tensor_stride)];
26     cumulative_sequence_length += current_seq_len;
27   }
28   return cumulative_sequence_length;
29 }
30 
31   /**
32    * This function is used to calculate two pieces of metadata that are needed
33    * for use with flash-attention and efficient_attention kernels. They are the
34    * cumulative sequence_length over a batch of sequences and the maximum
35    * sequence length.
36    *
37    * @return A tuple of cumulative sequence lengths and the maximum sequence
38    * length, and the last element in the cumulative_sequence_lengths
39    */
cumulative_and_max_seq_len_nnz(const Tensor & qkv)40   std::tuple<Tensor, int64_t, int64_t> cumulative_and_max_seq_len_nnz(const Tensor& qkv) {
41     TORCH_CHECK(
42         qkv.is_nested(),
43         "QKV must be nested for flash cumulative_seq_len calculation.")
44     auto* nt_impl = get_nested_tensor_impl(qkv);
45     const auto& sizes = nt_impl->get_nested_sizes();
46     auto size_tensor_stride = sizes.stride(0);
47 
48     const int64_t batch_size = qkv.size(0);
49     auto cumulative_seqlen = at::zeros(
50         {batch_size + 1}, TensorOptions().device(at::kCPU).dtype(at::kInt));
51 
52     auto* sizes_ptr = sizes.data_ptr<int64_t>();
53     auto* cumulative_seqlen_ptr = cumulative_seqlen.data_ptr<int32_t>();
54 
55     int64_t sum = 0;
56     int64_t max_seqlen = -1;
57     cumulative_seqlen_ptr[0] = static_cast<int32_t>(sum);
58     for (const auto i : c10::irange(batch_size)) {
59       // Calculate the cumulative sum of the sequence lengths
60       auto current_seq_len = sizes_ptr[(i * size_tensor_stride)];
61       sum += current_seq_len;
62       cumulative_seqlen_ptr[i + 1] = static_cast<int32_t>(sum);
63 
64       // Find the max element while we traverse
65       max_seqlen = std::max(max_seqlen, current_seq_len);
66     }
67     // Send to GPU, this is pretty light weight calc for normal batch size
68     // but maybe this needs to be on gpu
69     cumulative_seqlen = cumulative_seqlen.to(TensorOptions().device(at::kCUDA));
70     return std::tuple<Tensor, int64_t, int64_t>{
71         cumulative_seqlen, max_seqlen, sum};
72   }
73 
74   /**
75    * This function checks if a nested tensor is valid for
76    * use with the flash-attention and efficient_attention kernels without
77    * needing to call contiguous on the nested tensor input.
78    * It checks that the storage offsets' adjacent_differences are a constant
79    * multiple of the previous tensor in the nested tensor and that the strides
80    * are monotonically decreasing. This check is done after calling transpose on
81    * the nested tensor. Resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
82    *
83    * @return A boolean indicating of contiguous needs to be called for input
84    */
is_safe_to_get_storage_as_tensor(const NestedTensorImpl * tensor)85   bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {
86     const int64_t* tensor_offsets_ptr =
87         tensor->get_storage_offsets().data_ptr<int64_t>();
88     const Tensor& tensor_sizes = tensor->get_nested_sizes();
89     const Tensor& tensor_strides = tensor->get_nested_strides();
90 
91     const int64_t n_tensors = tensor_strides.size(0);
92     constexpr int64_t n_dims = 3;
93     // This is safe since head_dim is assured to be consistent
94     const int64_t num_heads = tensor -> opt_size(2).value();
95     const int64_t tensor_stride_0 = tensor_strides.stride(0);
96 
97     if (n_tensors <= 1) {
98       return true;
99     }
100 
101     int64_t* previous_tensor_stride = tensor_strides.data_ptr<int64_t>();
102 
103     // Check initially that the first tensor's strides
104     // are in strictly descending order
105     // NOTE: If num_heads is equal to 1 then we skip stride[0]
106     // Why you may ask? This is because we if n_heads == 1 then
107     // then as long as the last stride == 1 it does not matter
108     // what the strides are for the other dimensions.
109     //
110     if (num_heads == 1) {
111       if (previous_tensor_stride[0] <= previous_tensor_stride[2]) {
112         // This would mean that the last stride is greater than the seq_len
113         // stride
114         return false;
115       }
116     } else {
117       for (int i{1}; i < n_dims; i++) {
118         if (previous_tensor_stride[i - 1] <= previous_tensor_stride[i]) {
119           return false;
120         }
121       }
122       // Check that each tensor i in the nested tensor has the same strides
123       for (int i{1}; i < n_tensors; i++) {
124         for (const int64_t j : c10::irange(n_dims)) {
125           if (previous_tensor_stride[j] !=
126               previous_tensor_stride[i * tensor_stride_0 + j]) {
127             return false;
128           }
129         }
130       }
131     }
132 
133     // Check the offsets are a constant multiple from the previous numels
134     const int64_t* tensor_size_ptr = tensor_sizes.const_data_ptr<int64_t>();
135     const int64_t* tensor_stride_ptr = tensor_strides.const_data_ptr<int64_t>();
136 
137     int64_t numel_0 = (tensor_size_ptr[0] * tensor_stride_ptr[0]);
138     TORCH_INTERNAL_ASSERT(numel_0 > 0, "numels must be positive!");
139 
140     int64_t offset_constant =
141         (tensor_offsets_ptr[1] - tensor_offsets_ptr[0]) / numel_0;
142     for (int64_t i = 2; i < n_tensors; i++) {
143       // TODO: When 0 seq_len nested tensors are allowed we need to guard
144       // against this
145       int64_t previous_numel = tensor_size_ptr[(i - 1) * tensor_stride_0] *
146           tensor_stride_ptr[(i - 1) * tensor_stride_0];
147       TORCH_INTERNAL_ASSERT(previous_numel > 0, "numels must be positive!");
148       int64_t current_offset_constant =
149           (tensor_offsets_ptr[i] - tensor_offsets_ptr[i - 1]) / previous_numel;
150       if (current_offset_constant != offset_constant) {
151         return false;
152       }
153     }
154     // Congrats you made it!
155     return true;
156   }
157 
158   /**
159    * Process an individual NestedTensor to reshape and view as a DenseTensor
160    * Generally the approach for q, k, v is to
161    * (1) get the storage of the contiguous nested tensor
162    * (2) view as shape {output_batch_size, {*}_t.size(1), output_num_heads,
163    * head_dim_{*}}, and stride {0, nnz_{*}_stride, head_{*}_stride,
164    * head_dim_stride} where head_{*}_stride is 0 if
165    * {*}_num_heads_needs_broadcast (3) collapse the first two dims by reshaping
166    * to {Nnz_{*}, output_num_heads, head_dim_{*}} if {*}_t.size(1) (i.e. the
167    * seq_len is 1), the reshape should be a view and should not incur a copy
168    *  dense tensor without getting the storage
169    */
view_as_dense(const at::Tensor & input_nestedtensor,const int64_t Nnz,const int64_t num_heads,const int64_t head_dim,const bool batch_needs_broadcast=false,const bool num_heads_needs_broadcast=false)170   at::Tensor view_as_dense(
171       const at::Tensor& input_nestedtensor,
172       const int64_t Nnz,
173       const int64_t num_heads,
174       const int64_t head_dim,
175       const bool batch_needs_broadcast = false,
176       const bool num_heads_needs_broadcast = false) {
177     const auto* tensor_impl = get_nested_tensor_impl(input_nestedtensor);
178     Tensor storage_as_tensor = tensor_impl->get_unsafe_storage_as_tensor();
179 
180     constexpr int64_t head_dim_stride = 1;
181     const int64_t* nt_strides =
182         tensor_impl->get_nested_strides().data_ptr<int64_t>();
183     const int64_t* nt_offsets_ptr =
184         tensor_impl->get_storage_offsets().data_ptr<int64_t>();
185 
186     const int64_t nnz_stride = nt_strides[0];
187     const int64_t head_stride = num_heads_needs_broadcast ? 0 : nt_strides[1];
188 
189     if (batch_needs_broadcast) {
190       Tensor input_buffer_reshaped = storage_as_tensor.as_strided(
191           {Nnz, input_nestedtensor.size(1), num_heads, head_dim},
192           {0, nnz_stride, head_stride, head_dim_stride},
193           nt_offsets_ptr[0]);
194       return input_buffer_reshaped.reshape({-1, num_heads, head_dim});
195     }
196     return storage_as_tensor.as_strided(
197         {Nnz, num_heads, head_dim},
198         {nnz_stride, head_stride, head_dim_stride},
199         nt_offsets_ptr[0]);
200   }
201 
202   /**
203    * This function is a helper that takes nested query, key, and value
204    * that require broadcasting on the batch or num_head dimensions
205    * and will preprocess it in order to run with either
206    * the flash-attention or efficient-attention kernels.
207    * @return A tuple containing all the necessary data for running the fused
208    * kernels
209    */
sdpa_nested_preprocessing_with_broadcast(const Tensor & query,const Tensor & key,const Tensor & value)210   auto sdpa_nested_preprocessing_with_broadcast(
211       const Tensor& query, const Tensor& key, const Tensor& value) {
212     // Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
213     // Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
214     // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
215     const int64_t q_batch_size = query.size(0);
216     const int64_t k_batch_size = key.size(0);
217     const int64_t v_batch_size = value.size(0);
218 
219     const int64_t output_batch_size =
220         std::max({q_batch_size, k_batch_size, v_batch_size});
221 
222     const int64_t q_num_heads = query.size(1);
223     const int64_t k_num_heads = key.size(1);
224     const int64_t v_num_heads = value.size(1);
225 
226     const int64_t output_num_heads =
227         std::max({q_num_heads, k_num_heads, v_num_heads});
228 
229     const int64_t head_dim_qk = query.size(3);
230     const int64_t head_dim_v = value.size(3);
231 
232     Tensor q_t = query.transpose(1, 2);
233     Tensor k_t = key.transpose(1, 2);
234     Tensor v_t = value.transpose(1, 2);
235 
236     // Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
237     // output_batch_size/num_heads then they are 1
238     bool q_batch_size_needs_broadcast = q_batch_size != output_batch_size;
239     bool k_batch_size_needs_broadcast = k_batch_size != output_batch_size;
240     bool v_batch_size_needs_broadcast = v_batch_size != output_batch_size;
241 
242     // If {*}_batch_size_needs_broadcast, then
243     // (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
244     //     this is because needs_broadcast indicates that the batch_size is 1
245     //     and hence there is only 1 value for seq_len
246     // (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
247     // ..., outut_batch_size * {*}_t.size(1)] (3) Nnz_{*} is given by
248     // output_batch_size * {*}_t.size(1);
249 
250     int64_t max_seqlen_batch_q = 0, Nnz_q = 0;
251     Tensor cumulative_sequence_length_q;
252     if (q_batch_size_needs_broadcast || !q_t.is_nested()) {
253       max_seqlen_batch_q = q_t.size(1);
254       cumulative_sequence_length_q = at::arange(
255           0,
256           (output_batch_size + 1) * max_seqlen_batch_q,
257           max_seqlen_batch_q,
258           TensorOptions().device(at::kCUDA).dtype(at::kInt));
259       Nnz_q = output_batch_size * max_seqlen_batch_q;
260     } else {
261       auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len_nnz(q_t);
262       cumulative_sequence_length_q =
263           std::get<0>(cumulative_and_max_q_and_nnz_q);
264       max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q);
265       Nnz_q = std::get<2>(cumulative_and_max_q_and_nnz_q);
266     }
267 
268     int64_t max_seqlen_batch_kv = 0, Nnz_kv = 0;
269     Tensor cumulative_sequence_length_kv;
270     if (k_batch_size_needs_broadcast && v_batch_size_needs_broadcast) {
271       TORCH_CHECK(k_t.size(1) == v_t.size(1));
272       max_seqlen_batch_kv = k_t.size(1);
273       cumulative_sequence_length_kv = at::arange(
274           0,
275           (output_batch_size + 1) * max_seqlen_batch_kv,
276           max_seqlen_batch_kv,
277           TensorOptions().device(at::kCUDA).dtype(at::kInt));
278       Nnz_kv = output_batch_size * max_seqlen_batch_kv;
279     } else {
280       auto cumulative_and_max_kv_and_nnz_kv = k_batch_size_needs_broadcast
281           ? cumulative_and_max_seq_len_nnz(v_t)
282           : cumulative_and_max_seq_len_nnz(k_t);
283       cumulative_sequence_length_kv =
284           std::get<0>(cumulative_and_max_kv_and_nnz_kv);
285       max_seqlen_batch_kv = std::get<1>(cumulative_and_max_kv_and_nnz_kv);
286       Nnz_kv = std::get<2>(cumulative_and_max_kv_and_nnz_kv);
287     }
288 
289     bool q_num_heads_needs_broadcast = q_num_heads != output_num_heads;
290     bool k_num_heads_needs_broadcast = k_num_heads != output_num_heads;
291     bool v_num_heads_needs_broadcast = v_num_heads != output_num_heads;
292 
293     Tensor query_buffer_reshaped;
294     Tensor key_buffer_reshaped;
295     Tensor value_buffer_reshaped;
296 
297     if (!q_t.is_nested()) {
298       query_buffer_reshaped = q_t.expand(
299           {output_batch_size, q_t.size(1), output_num_heads, head_dim_qk});
300       query_buffer_reshaped =
301           query_buffer_reshaped.reshape({Nnz_q, output_num_heads, head_dim_qk});
302     } else {
303       const auto* query_impl = get_nested_tensor_impl(q_t);
304       if (!q_t.is_contiguous() &&
305           !is_safe_to_get_storage_as_tensor(query_impl)) {
306         q_t = q_t.contiguous();
307       }
308       // If we are broadcasting then Nnz_q will be the output_batch_size since
309       // seq_len is 1
310       const int64_t effective_batch_size_q =
311           q_batch_size_needs_broadcast ? output_batch_size : Nnz_q;
312       query_buffer_reshaped = view_as_dense(
313           q_t,
314           effective_batch_size_q,
315           output_num_heads,
316           head_dim_qk,
317           q_batch_size_needs_broadcast,
318           q_num_heads_needs_broadcast);
319     }
320 
321     const auto* key_impl = get_nested_tensor_impl(k_t);
322     const auto* value_impl = get_nested_tensor_impl(v_t);
323 
324     // If the physical layout of the NestedTensor's storage
325     // is not: batch, {seq_len}, num_heads, head_dim then we need
326     // to call contiguous
327 
328     if (!k_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(key_impl)) {
329       k_t = k_t.contiguous();
330     }
331     if (!v_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(value_impl)) {
332       v_t = v_t.contiguous();
333     }
334     const int64_t effective_batch_size_k =
335         k_batch_size_needs_broadcast ? output_batch_size : Nnz_kv;
336     key_buffer_reshaped = view_as_dense(
337         k_t,
338         effective_batch_size_k,
339         output_num_heads,
340         head_dim_qk,
341         k_batch_size_needs_broadcast,
342         k_num_heads_needs_broadcast);
343 
344     const int64_t effective_batch_size_v =
345         v_batch_size_needs_broadcast ? output_batch_size : Nnz_kv;
346     value_buffer_reshaped = view_as_dense(
347         v_t,
348         effective_batch_size_v,
349         output_num_heads,
350         head_dim_v,
351         v_batch_size_needs_broadcast,
352         v_num_heads_needs_broadcast);
353 
354     Tensor output_shape;
355     if (!q_batch_size_needs_broadcast) {
356       output_shape = get_nested_sizes(q_t).clone();
357       if (head_dim_v != head_dim_qk) {
358         output_shape.select(1, -1).fill_(head_dim_v);
359       }
360       if (q_num_heads_needs_broadcast) {
361         output_shape.select(1, 1).fill_(output_num_heads);
362       }
363     } else {
364       output_shape = at::empty(
365           {output_batch_size, 3}, TensorOptions().dtype(kLong).device(kCPU));
366       output_shape.select(1, 0).fill_(q_t.size(1));
367       output_shape.select(1, 1).fill_(output_num_heads);
368       output_shape.select(1, 2).fill_(head_dim_v);
369     }
370 
371     return std::make_tuple(
372         query_buffer_reshaped,
373         key_buffer_reshaped,
374         value_buffer_reshaped,
375         cumulative_sequence_length_q,
376         cumulative_sequence_length_kv,
377         max_seqlen_batch_q,
378         max_seqlen_batch_kv,
379         output_shape);
380   }
381 
382 } // namespace
383 
384 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor>
sdpa_nested_preprocessing(const Tensor & query,const Tensor & key,const Tensor & value)385 sdpa_nested_preprocessing(
386     const Tensor& query,
387     const Tensor& key,
388     const Tensor& value) {
389   // Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
390   // Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
391   // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
392   const int64_t q_batch_size = query.size(0);
393   const int64_t k_batch_size = key.size(0);
394   const int64_t v_batch_size = value.size(0);
395 
396   const int64_t q_num_heads = query.size(1);
397   const int64_t k_num_heads = key.size(1);
398   const int64_t v_num_heads = value.size(1);
399 
400   if (!(q_batch_size == k_batch_size && q_batch_size == v_batch_size) ||
401       !(q_num_heads == k_num_heads && k_num_heads == v_num_heads)) {
402     return sdpa_nested_preprocessing_with_broadcast(query, key, value);
403   }
404 
405   const int64_t num_heads = query.size(1);
406   const int64_t head_dim_qk = query.size(3);
407   const int64_t head_dim_v = value.size(3);
408 
409   Tensor q_t = query.transpose(1, 2);
410   Tensor k_t = key.transpose(1, 2);
411   Tensor v_t = value.transpose(1, 2);
412 
413   auto [cumulative_sequence_length_q, max_seqlen_batch_q, Nnz_q] = cumulative_and_max_seq_len_nnz(q_t);
414   auto [cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv]= cumulative_and_max_seq_len_nnz(k_t);
415 
416   Tensor query_buffer_reshaped;
417   Tensor key_buffer_reshaped;
418   Tensor value_buffer_reshaped;
419 
420   const auto* query_impl = get_nested_tensor_impl(q_t);
421   const auto* key_impl = get_nested_tensor_impl(k_t);
422   const auto* value_impl = get_nested_tensor_impl(v_t);
423 
424   // If the physical layout of the NestedTensor's storage
425   // is not: batch, {seq_len}, num_heads, head_dim then we need
426   // to call contiguous
427   if (!q_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(query_impl)) {
428     q_t = q_t.contiguous();
429   }
430   if (!k_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(key_impl)) {
431     k_t = k_t.contiguous();
432   }
433   if (!v_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(value_impl)) {
434     v_t = v_t.contiguous();
435   }
436 
437   query_buffer_reshaped = view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk);
438   key_buffer_reshaped = view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk);
439   value_buffer_reshaped = view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v);
440 
441   auto output_shape = get_nested_sizes(q_t).clone();
442   if (head_dim_v != head_dim_qk) {
443     output_shape.select(1, -1).fill_(head_dim_v);
444   }
445 
446   return std::make_tuple(
447       query_buffer_reshaped,
448       key_buffer_reshaped,
449       value_buffer_reshaped,
450       cumulative_sequence_length_q,
451       cumulative_sequence_length_kv,
452       max_seqlen_batch_q,
453       max_seqlen_batch_kv,
454       output_shape);
455 }
456 
457 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
sdpa_nested_preprocessing_backward(const at::Tensor & grad_out_,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & out,const Tensor & cumulative_sequence_length_q,const Tensor & cumulative_sequence_length_kv,const int64_t max_seqlen_batch_q,const int64_t max_seqlen_batch_kv)458 sdpa_nested_preprocessing_backward(
459     const at::Tensor& grad_out_,
460     const at::Tensor& query,
461     const at::Tensor& key,
462     const at::Tensor& value,
463     const at::Tensor& out,
464     const Tensor& cumulative_sequence_length_q,
465     const Tensor& cumulative_sequence_length_kv,
466     const int64_t max_seqlen_batch_q,
467     const int64_t max_seqlen_batch_kv) {
468   const int64_t q_batch_size = query.size(0);
469   const int64_t k_batch_size = key.size(0);
470 
471   const int64_t v_batch_size = value.size(0);
472 
473   const int64_t q_num_heads = query.size(1);
474   const int64_t k_num_heads = key.size(1);
475   const int64_t v_num_heads = value.size(1);
476 
477   if (!(q_batch_size == k_batch_size && q_batch_size == v_batch_size) ||
478       !(q_num_heads == k_num_heads && k_num_heads == v_num_heads)) {
479         TORCH_CHECK(false, "Broadcasted NestedTensor inputs is currently not supported for backwards.");
480   }
481 
482   const int64_t num_heads = query.size(1);
483   const int64_t head_dim_qk = query.size(3);
484   const int64_t head_dim_v = value.size(3);
485 
486   Tensor q_t = query.transpose(1, 2);
487   Tensor k_t = key.transpose(1, 2);
488   Tensor v_t = value.transpose(1, 2);
489   Tensor grad_out_t = grad_out_.transpose(1, 2);
490   Tensor out_t = out.transpose(1, 2);
491 
492   const int64_t Nnz_q = get_nnz(q_t);
493   const int64_t Nnz_kv = get_nnz(k_t);
494 
495   Tensor query_buffer_reshaped;
496   Tensor key_buffer_reshaped;
497   Tensor value_buffer_reshaped;
498   Tensor grad_out_buffer_reshaped;
499   Tensor output_buffer_reshaped;
500 
501   const auto* query_impl = get_nested_tensor_impl(q_t);
502   const auto* key_impl = get_nested_tensor_impl(k_t);
503   const auto* value_impl = get_nested_tensor_impl(v_t);
504   const auto* grad_out_impl = get_nested_tensor_impl(grad_out_t);
505   const auto* out_impl = get_nested_tensor_impl(out_t);
506 
507   // If the physical layout of the NestedTensor's storage
508   // is not: batch, {seq_len}, num_heads, head_dim then we need
509   // to call contiguous
510   if (!q_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(query_impl)) {
511     q_t = q_t.contiguous();
512   }
513   if (!k_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(key_impl)) {
514     k_t = k_t.contiguous();
515   }
516   if (!v_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(value_impl)) {
517     v_t = v_t.contiguous();
518   }
519   if (!grad_out_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(grad_out_impl)) {
520     grad_out_t = grad_out_t.contiguous();
521   }
522   if (!out_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(out_impl)) {
523     out_t = out_t.contiguous();
524   }
525 
526   query_buffer_reshaped = view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk);
527   key_buffer_reshaped = view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk);
528   value_buffer_reshaped = view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v);
529 
530  grad_out_buffer_reshaped =
531       view_as_dense(grad_out_t, Nnz_q, num_heads, head_dim_v);
532   output_buffer_reshaped = view_as_dense(out_t, Nnz_q, num_heads, head_dim_v);
533 
534   auto output_shape = get_nested_sizes(q_t).clone();
535   if (head_dim_v != head_dim_qk) {
536     output_shape.select(1, -1).fill_(head_dim_v);
537   }
538 
539   return std::make_tuple(
540       grad_out_buffer_reshaped,
541       query_buffer_reshaped,
542       key_buffer_reshaped,
543       value_buffer_reshaped,
544       output_buffer_reshaped);
545 }
546 
547 } // namespace at::native::preprocessing
548