1 #include <ATen/ATen.h>
2 #include <ATen/NestedTensorImpl.h>
3 #include <ATen/native/nested/NestedTensorTransformerUtils.h>
4 #include <tuple>
5
6 namespace at::native::preprocessing {
7
8 namespace {
9
10 /**
11 * This builds up the cumulative sequence length for a batch of sequences.
12 * This is not very dry, but in the backward pass we already have cumulative_seq_len
13 * on device. And all we need on CPU to launch the kernel is NNz. We could refactor the
14 * the below function but it adds more complexity than I think is needed.
15 */
get_nnz(const Tensor & nestedtensor)16 int64_t get_nnz(const Tensor& nestedtensor) {
17 auto* nt_impl = get_nested_tensor_impl(nestedtensor);
18 const auto& sizes = nt_impl->get_nested_sizes();
19 auto size_tensor_stride = sizes.stride(0);
20 const int64_t batch_size = nestedtensor.size(0);
21 auto* sizes_ptr = sizes.data_ptr<int64_t>();
22 int64_t cumulative_sequence_length = 0;
23 for (const auto i : c10::irange(batch_size)) {
24 // Calculate the cumulative sum of the sequence lengths
25 int64_t current_seq_len = sizes_ptr[(i * size_tensor_stride)];
26 cumulative_sequence_length += current_seq_len;
27 }
28 return cumulative_sequence_length;
29 }
30
31 /**
32 * This function is used to calculate two pieces of metadata that are needed
33 * for use with flash-attention and efficient_attention kernels. They are the
34 * cumulative sequence_length over a batch of sequences and the maximum
35 * sequence length.
36 *
37 * @return A tuple of cumulative sequence lengths and the maximum sequence
38 * length, and the last element in the cumulative_sequence_lengths
39 */
cumulative_and_max_seq_len_nnz(const Tensor & qkv)40 std::tuple<Tensor, int64_t, int64_t> cumulative_and_max_seq_len_nnz(const Tensor& qkv) {
41 TORCH_CHECK(
42 qkv.is_nested(),
43 "QKV must be nested for flash cumulative_seq_len calculation.")
44 auto* nt_impl = get_nested_tensor_impl(qkv);
45 const auto& sizes = nt_impl->get_nested_sizes();
46 auto size_tensor_stride = sizes.stride(0);
47
48 const int64_t batch_size = qkv.size(0);
49 auto cumulative_seqlen = at::zeros(
50 {batch_size + 1}, TensorOptions().device(at::kCPU).dtype(at::kInt));
51
52 auto* sizes_ptr = sizes.data_ptr<int64_t>();
53 auto* cumulative_seqlen_ptr = cumulative_seqlen.data_ptr<int32_t>();
54
55 int64_t sum = 0;
56 int64_t max_seqlen = -1;
57 cumulative_seqlen_ptr[0] = static_cast<int32_t>(sum);
58 for (const auto i : c10::irange(batch_size)) {
59 // Calculate the cumulative sum of the sequence lengths
60 auto current_seq_len = sizes_ptr[(i * size_tensor_stride)];
61 sum += current_seq_len;
62 cumulative_seqlen_ptr[i + 1] = static_cast<int32_t>(sum);
63
64 // Find the max element while we traverse
65 max_seqlen = std::max(max_seqlen, current_seq_len);
66 }
67 // Send to GPU, this is pretty light weight calc for normal batch size
68 // but maybe this needs to be on gpu
69 cumulative_seqlen = cumulative_seqlen.to(TensorOptions().device(at::kCUDA));
70 return std::tuple<Tensor, int64_t, int64_t>{
71 cumulative_seqlen, max_seqlen, sum};
72 }
73
74 /**
75 * This function checks if a nested tensor is valid for
76 * use with the flash-attention and efficient_attention kernels without
77 * needing to call contiguous on the nested tensor input.
78 * It checks that the storage offsets' adjacent_differences are a constant
79 * multiple of the previous tensor in the nested tensor and that the strides
80 * are monotonically decreasing. This check is done after calling transpose on
81 * the nested tensor. Resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
82 *
83 * @return A boolean indicating of contiguous needs to be called for input
84 */
is_safe_to_get_storage_as_tensor(const NestedTensorImpl * tensor)85 bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {
86 const int64_t* tensor_offsets_ptr =
87 tensor->get_storage_offsets().data_ptr<int64_t>();
88 const Tensor& tensor_sizes = tensor->get_nested_sizes();
89 const Tensor& tensor_strides = tensor->get_nested_strides();
90
91 const int64_t n_tensors = tensor_strides.size(0);
92 constexpr int64_t n_dims = 3;
93 // This is safe since head_dim is assured to be consistent
94 const int64_t num_heads = tensor -> opt_size(2).value();
95 const int64_t tensor_stride_0 = tensor_strides.stride(0);
96
97 if (n_tensors <= 1) {
98 return true;
99 }
100
101 int64_t* previous_tensor_stride = tensor_strides.data_ptr<int64_t>();
102
103 // Check initially that the first tensor's strides
104 // are in strictly descending order
105 // NOTE: If num_heads is equal to 1 then we skip stride[0]
106 // Why you may ask? This is because we if n_heads == 1 then
107 // then as long as the last stride == 1 it does not matter
108 // what the strides are for the other dimensions.
109 //
110 if (num_heads == 1) {
111 if (previous_tensor_stride[0] <= previous_tensor_stride[2]) {
112 // This would mean that the last stride is greater than the seq_len
113 // stride
114 return false;
115 }
116 } else {
117 for (int i{1}; i < n_dims; i++) {
118 if (previous_tensor_stride[i - 1] <= previous_tensor_stride[i]) {
119 return false;
120 }
121 }
122 // Check that each tensor i in the nested tensor has the same strides
123 for (int i{1}; i < n_tensors; i++) {
124 for (const int64_t j : c10::irange(n_dims)) {
125 if (previous_tensor_stride[j] !=
126 previous_tensor_stride[i * tensor_stride_0 + j]) {
127 return false;
128 }
129 }
130 }
131 }
132
133 // Check the offsets are a constant multiple from the previous numels
134 const int64_t* tensor_size_ptr = tensor_sizes.const_data_ptr<int64_t>();
135 const int64_t* tensor_stride_ptr = tensor_strides.const_data_ptr<int64_t>();
136
137 int64_t numel_0 = (tensor_size_ptr[0] * tensor_stride_ptr[0]);
138 TORCH_INTERNAL_ASSERT(numel_0 > 0, "numels must be positive!");
139
140 int64_t offset_constant =
141 (tensor_offsets_ptr[1] - tensor_offsets_ptr[0]) / numel_0;
142 for (int64_t i = 2; i < n_tensors; i++) {
143 // TODO: When 0 seq_len nested tensors are allowed we need to guard
144 // against this
145 int64_t previous_numel = tensor_size_ptr[(i - 1) * tensor_stride_0] *
146 tensor_stride_ptr[(i - 1) * tensor_stride_0];
147 TORCH_INTERNAL_ASSERT(previous_numel > 0, "numels must be positive!");
148 int64_t current_offset_constant =
149 (tensor_offsets_ptr[i] - tensor_offsets_ptr[i - 1]) / previous_numel;
150 if (current_offset_constant != offset_constant) {
151 return false;
152 }
153 }
154 // Congrats you made it!
155 return true;
156 }
157
158 /**
159 * Process an individual NestedTensor to reshape and view as a DenseTensor
160 * Generally the approach for q, k, v is to
161 * (1) get the storage of the contiguous nested tensor
162 * (2) view as shape {output_batch_size, {*}_t.size(1), output_num_heads,
163 * head_dim_{*}}, and stride {0, nnz_{*}_stride, head_{*}_stride,
164 * head_dim_stride} where head_{*}_stride is 0 if
165 * {*}_num_heads_needs_broadcast (3) collapse the first two dims by reshaping
166 * to {Nnz_{*}, output_num_heads, head_dim_{*}} if {*}_t.size(1) (i.e. the
167 * seq_len is 1), the reshape should be a view and should not incur a copy
168 * dense tensor without getting the storage
169 */
view_as_dense(const at::Tensor & input_nestedtensor,const int64_t Nnz,const int64_t num_heads,const int64_t head_dim,const bool batch_needs_broadcast=false,const bool num_heads_needs_broadcast=false)170 at::Tensor view_as_dense(
171 const at::Tensor& input_nestedtensor,
172 const int64_t Nnz,
173 const int64_t num_heads,
174 const int64_t head_dim,
175 const bool batch_needs_broadcast = false,
176 const bool num_heads_needs_broadcast = false) {
177 const auto* tensor_impl = get_nested_tensor_impl(input_nestedtensor);
178 Tensor storage_as_tensor = tensor_impl->get_unsafe_storage_as_tensor();
179
180 constexpr int64_t head_dim_stride = 1;
181 const int64_t* nt_strides =
182 tensor_impl->get_nested_strides().data_ptr<int64_t>();
183 const int64_t* nt_offsets_ptr =
184 tensor_impl->get_storage_offsets().data_ptr<int64_t>();
185
186 const int64_t nnz_stride = nt_strides[0];
187 const int64_t head_stride = num_heads_needs_broadcast ? 0 : nt_strides[1];
188
189 if (batch_needs_broadcast) {
190 Tensor input_buffer_reshaped = storage_as_tensor.as_strided(
191 {Nnz, input_nestedtensor.size(1), num_heads, head_dim},
192 {0, nnz_stride, head_stride, head_dim_stride},
193 nt_offsets_ptr[0]);
194 return input_buffer_reshaped.reshape({-1, num_heads, head_dim});
195 }
196 return storage_as_tensor.as_strided(
197 {Nnz, num_heads, head_dim},
198 {nnz_stride, head_stride, head_dim_stride},
199 nt_offsets_ptr[0]);
200 }
201
202 /**
203 * This function is a helper that takes nested query, key, and value
204 * that require broadcasting on the batch or num_head dimensions
205 * and will preprocess it in order to run with either
206 * the flash-attention or efficient-attention kernels.
207 * @return A tuple containing all the necessary data for running the fused
208 * kernels
209 */
sdpa_nested_preprocessing_with_broadcast(const Tensor & query,const Tensor & key,const Tensor & value)210 auto sdpa_nested_preprocessing_with_broadcast(
211 const Tensor& query, const Tensor& key, const Tensor& value) {
212 // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
213 // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
214 // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
215 const int64_t q_batch_size = query.size(0);
216 const int64_t k_batch_size = key.size(0);
217 const int64_t v_batch_size = value.size(0);
218
219 const int64_t output_batch_size =
220 std::max({q_batch_size, k_batch_size, v_batch_size});
221
222 const int64_t q_num_heads = query.size(1);
223 const int64_t k_num_heads = key.size(1);
224 const int64_t v_num_heads = value.size(1);
225
226 const int64_t output_num_heads =
227 std::max({q_num_heads, k_num_heads, v_num_heads});
228
229 const int64_t head_dim_qk = query.size(3);
230 const int64_t head_dim_v = value.size(3);
231
232 Tensor q_t = query.transpose(1, 2);
233 Tensor k_t = key.transpose(1, 2);
234 Tensor v_t = value.transpose(1, 2);
235
236 // Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
237 // output_batch_size/num_heads then they are 1
238 bool q_batch_size_needs_broadcast = q_batch_size != output_batch_size;
239 bool k_batch_size_needs_broadcast = k_batch_size != output_batch_size;
240 bool v_batch_size_needs_broadcast = v_batch_size != output_batch_size;
241
242 // If {*}_batch_size_needs_broadcast, then
243 // (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
244 // this is because needs_broadcast indicates that the batch_size is 1
245 // and hence there is only 1 value for seq_len
246 // (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
247 // ..., outut_batch_size * {*}_t.size(1)] (3) Nnz_{*} is given by
248 // output_batch_size * {*}_t.size(1);
249
250 int64_t max_seqlen_batch_q = 0, Nnz_q = 0;
251 Tensor cumulative_sequence_length_q;
252 if (q_batch_size_needs_broadcast || !q_t.is_nested()) {
253 max_seqlen_batch_q = q_t.size(1);
254 cumulative_sequence_length_q = at::arange(
255 0,
256 (output_batch_size + 1) * max_seqlen_batch_q,
257 max_seqlen_batch_q,
258 TensorOptions().device(at::kCUDA).dtype(at::kInt));
259 Nnz_q = output_batch_size * max_seqlen_batch_q;
260 } else {
261 auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len_nnz(q_t);
262 cumulative_sequence_length_q =
263 std::get<0>(cumulative_and_max_q_and_nnz_q);
264 max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q);
265 Nnz_q = std::get<2>(cumulative_and_max_q_and_nnz_q);
266 }
267
268 int64_t max_seqlen_batch_kv = 0, Nnz_kv = 0;
269 Tensor cumulative_sequence_length_kv;
270 if (k_batch_size_needs_broadcast && v_batch_size_needs_broadcast) {
271 TORCH_CHECK(k_t.size(1) == v_t.size(1));
272 max_seqlen_batch_kv = k_t.size(1);
273 cumulative_sequence_length_kv = at::arange(
274 0,
275 (output_batch_size + 1) * max_seqlen_batch_kv,
276 max_seqlen_batch_kv,
277 TensorOptions().device(at::kCUDA).dtype(at::kInt));
278 Nnz_kv = output_batch_size * max_seqlen_batch_kv;
279 } else {
280 auto cumulative_and_max_kv_and_nnz_kv = k_batch_size_needs_broadcast
281 ? cumulative_and_max_seq_len_nnz(v_t)
282 : cumulative_and_max_seq_len_nnz(k_t);
283 cumulative_sequence_length_kv =
284 std::get<0>(cumulative_and_max_kv_and_nnz_kv);
285 max_seqlen_batch_kv = std::get<1>(cumulative_and_max_kv_and_nnz_kv);
286 Nnz_kv = std::get<2>(cumulative_and_max_kv_and_nnz_kv);
287 }
288
289 bool q_num_heads_needs_broadcast = q_num_heads != output_num_heads;
290 bool k_num_heads_needs_broadcast = k_num_heads != output_num_heads;
291 bool v_num_heads_needs_broadcast = v_num_heads != output_num_heads;
292
293 Tensor query_buffer_reshaped;
294 Tensor key_buffer_reshaped;
295 Tensor value_buffer_reshaped;
296
297 if (!q_t.is_nested()) {
298 query_buffer_reshaped = q_t.expand(
299 {output_batch_size, q_t.size(1), output_num_heads, head_dim_qk});
300 query_buffer_reshaped =
301 query_buffer_reshaped.reshape({Nnz_q, output_num_heads, head_dim_qk});
302 } else {
303 const auto* query_impl = get_nested_tensor_impl(q_t);
304 if (!q_t.is_contiguous() &&
305 !is_safe_to_get_storage_as_tensor(query_impl)) {
306 q_t = q_t.contiguous();
307 }
308 // If we are broadcasting then Nnz_q will be the output_batch_size since
309 // seq_len is 1
310 const int64_t effective_batch_size_q =
311 q_batch_size_needs_broadcast ? output_batch_size : Nnz_q;
312 query_buffer_reshaped = view_as_dense(
313 q_t,
314 effective_batch_size_q,
315 output_num_heads,
316 head_dim_qk,
317 q_batch_size_needs_broadcast,
318 q_num_heads_needs_broadcast);
319 }
320
321 const auto* key_impl = get_nested_tensor_impl(k_t);
322 const auto* value_impl = get_nested_tensor_impl(v_t);
323
324 // If the physical layout of the NestedTensor's storage
325 // is not: batch, {seq_len}, num_heads, head_dim then we need
326 // to call contiguous
327
328 if (!k_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(key_impl)) {
329 k_t = k_t.contiguous();
330 }
331 if (!v_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(value_impl)) {
332 v_t = v_t.contiguous();
333 }
334 const int64_t effective_batch_size_k =
335 k_batch_size_needs_broadcast ? output_batch_size : Nnz_kv;
336 key_buffer_reshaped = view_as_dense(
337 k_t,
338 effective_batch_size_k,
339 output_num_heads,
340 head_dim_qk,
341 k_batch_size_needs_broadcast,
342 k_num_heads_needs_broadcast);
343
344 const int64_t effective_batch_size_v =
345 v_batch_size_needs_broadcast ? output_batch_size : Nnz_kv;
346 value_buffer_reshaped = view_as_dense(
347 v_t,
348 effective_batch_size_v,
349 output_num_heads,
350 head_dim_v,
351 v_batch_size_needs_broadcast,
352 v_num_heads_needs_broadcast);
353
354 Tensor output_shape;
355 if (!q_batch_size_needs_broadcast) {
356 output_shape = get_nested_sizes(q_t).clone();
357 if (head_dim_v != head_dim_qk) {
358 output_shape.select(1, -1).fill_(head_dim_v);
359 }
360 if (q_num_heads_needs_broadcast) {
361 output_shape.select(1, 1).fill_(output_num_heads);
362 }
363 } else {
364 output_shape = at::empty(
365 {output_batch_size, 3}, TensorOptions().dtype(kLong).device(kCPU));
366 output_shape.select(1, 0).fill_(q_t.size(1));
367 output_shape.select(1, 1).fill_(output_num_heads);
368 output_shape.select(1, 2).fill_(head_dim_v);
369 }
370
371 return std::make_tuple(
372 query_buffer_reshaped,
373 key_buffer_reshaped,
374 value_buffer_reshaped,
375 cumulative_sequence_length_q,
376 cumulative_sequence_length_kv,
377 max_seqlen_batch_q,
378 max_seqlen_batch_kv,
379 output_shape);
380 }
381
382 } // namespace
383
384 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor>
sdpa_nested_preprocessing(const Tensor & query,const Tensor & key,const Tensor & value)385 sdpa_nested_preprocessing(
386 const Tensor& query,
387 const Tensor& key,
388 const Tensor& value) {
389 // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
390 // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
391 // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
392 const int64_t q_batch_size = query.size(0);
393 const int64_t k_batch_size = key.size(0);
394 const int64_t v_batch_size = value.size(0);
395
396 const int64_t q_num_heads = query.size(1);
397 const int64_t k_num_heads = key.size(1);
398 const int64_t v_num_heads = value.size(1);
399
400 if (!(q_batch_size == k_batch_size && q_batch_size == v_batch_size) ||
401 !(q_num_heads == k_num_heads && k_num_heads == v_num_heads)) {
402 return sdpa_nested_preprocessing_with_broadcast(query, key, value);
403 }
404
405 const int64_t num_heads = query.size(1);
406 const int64_t head_dim_qk = query.size(3);
407 const int64_t head_dim_v = value.size(3);
408
409 Tensor q_t = query.transpose(1, 2);
410 Tensor k_t = key.transpose(1, 2);
411 Tensor v_t = value.transpose(1, 2);
412
413 auto [cumulative_sequence_length_q, max_seqlen_batch_q, Nnz_q] = cumulative_and_max_seq_len_nnz(q_t);
414 auto [cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv]= cumulative_and_max_seq_len_nnz(k_t);
415
416 Tensor query_buffer_reshaped;
417 Tensor key_buffer_reshaped;
418 Tensor value_buffer_reshaped;
419
420 const auto* query_impl = get_nested_tensor_impl(q_t);
421 const auto* key_impl = get_nested_tensor_impl(k_t);
422 const auto* value_impl = get_nested_tensor_impl(v_t);
423
424 // If the physical layout of the NestedTensor's storage
425 // is not: batch, {seq_len}, num_heads, head_dim then we need
426 // to call contiguous
427 if (!q_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(query_impl)) {
428 q_t = q_t.contiguous();
429 }
430 if (!k_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(key_impl)) {
431 k_t = k_t.contiguous();
432 }
433 if (!v_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(value_impl)) {
434 v_t = v_t.contiguous();
435 }
436
437 query_buffer_reshaped = view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk);
438 key_buffer_reshaped = view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk);
439 value_buffer_reshaped = view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v);
440
441 auto output_shape = get_nested_sizes(q_t).clone();
442 if (head_dim_v != head_dim_qk) {
443 output_shape.select(1, -1).fill_(head_dim_v);
444 }
445
446 return std::make_tuple(
447 query_buffer_reshaped,
448 key_buffer_reshaped,
449 value_buffer_reshaped,
450 cumulative_sequence_length_q,
451 cumulative_sequence_length_kv,
452 max_seqlen_batch_q,
453 max_seqlen_batch_kv,
454 output_shape);
455 }
456
457 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
sdpa_nested_preprocessing_backward(const at::Tensor & grad_out_,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & out,const Tensor & cumulative_sequence_length_q,const Tensor & cumulative_sequence_length_kv,const int64_t max_seqlen_batch_q,const int64_t max_seqlen_batch_kv)458 sdpa_nested_preprocessing_backward(
459 const at::Tensor& grad_out_,
460 const at::Tensor& query,
461 const at::Tensor& key,
462 const at::Tensor& value,
463 const at::Tensor& out,
464 const Tensor& cumulative_sequence_length_q,
465 const Tensor& cumulative_sequence_length_kv,
466 const int64_t max_seqlen_batch_q,
467 const int64_t max_seqlen_batch_kv) {
468 const int64_t q_batch_size = query.size(0);
469 const int64_t k_batch_size = key.size(0);
470
471 const int64_t v_batch_size = value.size(0);
472
473 const int64_t q_num_heads = query.size(1);
474 const int64_t k_num_heads = key.size(1);
475 const int64_t v_num_heads = value.size(1);
476
477 if (!(q_batch_size == k_batch_size && q_batch_size == v_batch_size) ||
478 !(q_num_heads == k_num_heads && k_num_heads == v_num_heads)) {
479 TORCH_CHECK(false, "Broadcasted NestedTensor inputs is currently not supported for backwards.");
480 }
481
482 const int64_t num_heads = query.size(1);
483 const int64_t head_dim_qk = query.size(3);
484 const int64_t head_dim_v = value.size(3);
485
486 Tensor q_t = query.transpose(1, 2);
487 Tensor k_t = key.transpose(1, 2);
488 Tensor v_t = value.transpose(1, 2);
489 Tensor grad_out_t = grad_out_.transpose(1, 2);
490 Tensor out_t = out.transpose(1, 2);
491
492 const int64_t Nnz_q = get_nnz(q_t);
493 const int64_t Nnz_kv = get_nnz(k_t);
494
495 Tensor query_buffer_reshaped;
496 Tensor key_buffer_reshaped;
497 Tensor value_buffer_reshaped;
498 Tensor grad_out_buffer_reshaped;
499 Tensor output_buffer_reshaped;
500
501 const auto* query_impl = get_nested_tensor_impl(q_t);
502 const auto* key_impl = get_nested_tensor_impl(k_t);
503 const auto* value_impl = get_nested_tensor_impl(v_t);
504 const auto* grad_out_impl = get_nested_tensor_impl(grad_out_t);
505 const auto* out_impl = get_nested_tensor_impl(out_t);
506
507 // If the physical layout of the NestedTensor's storage
508 // is not: batch, {seq_len}, num_heads, head_dim then we need
509 // to call contiguous
510 if (!q_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(query_impl)) {
511 q_t = q_t.contiguous();
512 }
513 if (!k_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(key_impl)) {
514 k_t = k_t.contiguous();
515 }
516 if (!v_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(value_impl)) {
517 v_t = v_t.contiguous();
518 }
519 if (!grad_out_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(grad_out_impl)) {
520 grad_out_t = grad_out_t.contiguous();
521 }
522 if (!out_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(out_impl)) {
523 out_t = out_t.contiguous();
524 }
525
526 query_buffer_reshaped = view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk);
527 key_buffer_reshaped = view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk);
528 value_buffer_reshaped = view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v);
529
530 grad_out_buffer_reshaped =
531 view_as_dense(grad_out_t, Nnz_q, num_heads, head_dim_v);
532 output_buffer_reshaped = view_as_dense(out_t, Nnz_q, num_heads, head_dim_v);
533
534 auto output_shape = get_nested_sizes(q_t).clone();
535 if (head_dim_v != head_dim_qk) {
536 output_shape.select(1, -1).fill_(head_dim_v);
537 }
538
539 return std::make_tuple(
540 grad_out_buffer_reshaped,
541 query_buffer_reshaped,
542 key_buffer_reshaped,
543 value_buffer_reshaped,
544 output_buffer_reshaped);
545 }
546
547 } // namespace at::native::preprocessing
548