xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <numeric>
2 #include <algorithm>
3 #include <c10/util/Exception.h>
4 
5 #include <ATen/ATen.h>
6 #include <ATen/NestedTensorImpl.h>
7 #include <ATen/native/NonSymbolicBC.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_nested_from_padded_native.h>
13 #include <ATen/ops/narrow_native.h>
14 #endif
15 
16 #include <ATen/native/NonSymbolicBC.h>
17 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
18 #include <ATen/native/nested/NestedTensorTransformerUtils.h>
19 #include <ATen/native/nested/NestedTensorMath.h>
20 #include <ATen/native/nested/NestedTensorUtils.h>
21 #include <ATen/native/transformers/cuda/sdp_utils.h>
22 
23 #include <ATen/cuda/CUDAContext.h>
24 
25 namespace at::native {
26 namespace {
padded_tensor_numel(const Tensor & sizes)27 int64_t padded_tensor_numel(const Tensor& sizes) {
28   const auto sizes_num_rows = sizes.sizes()[0];
29   const auto sizes_row_length = sizes.sizes()[1];
30   const auto* sizes_data = sizes.const_data_ptr<int64_t>();
31   int64_t numel = 0;
32   for (const auto row_num : c10::irange(sizes_num_rows)) {
33     const auto* row_ptr = sizes_data + row_num * sizes_row_length;
34     int64_t prod = 1;
35     for (const auto idx : c10::irange(sizes_row_length)) {
36       prod *= row_ptr[idx];
37     }
38     numel += prod;
39   }
40   return numel;
41 }
42 } // namespace
nested_from_padded_cuda(const Tensor & padded,const Tensor & sizes,bool do_transform_0213)43 Tensor nested_from_padded_cuda(
44     const Tensor& padded,
45     const Tensor& sizes,
46     bool do_transform_0213) {
47   if (padded.dim() > 1 && padded.dim() < 5) {
48     // Instead of erroring call the generic version
49     if(!(padded.dim() == 4 && do_transform_0213) && !(padded.dim() == 3 && !do_transform_0213)){
50       return at::native::nested_from_padded_generic(padded, sizes, do_transform_0213);
51     }
52     if (padded.dtype() != kFloat && padded.dtype() != kHalf) {
53       TORCH_WARN_ONCE(
54           "nested_from_padded CUDA kernels only support fp32/fp16; falling "
55           "back to slower generic kernel");
56       return at::native::nested_from_padded_generic(padded, sizes, do_transform_0213);
57     }
58     Tensor target_offsets =
59         NestedTensor_batch_offsets_from_size_tensor(sizes, 0);
60     Tensor padded_sizes_tensor = at::tensor(padded.sizes());
61     Tensor output = at::empty({padded_tensor_numel(sizes)}, padded.options());
62     Tensor target_size_sizes = sizes.reshape(-1);
63 
64     Tensor metadata =
65         at::cat({target_size_sizes, padded_sizes_tensor, target_offsets});
66     metadata = metadata.to(at::Device(kCUDA), kInt, true, true);
67 
68     auto output_size_ptr = metadata.data_ptr<int>();
69     auto input_size_ptr = output_size_ptr + target_size_sizes.numel();
70     auto offsets_ptr = input_size_ptr + padded_sizes_tensor.numel();
71 
72     Tensor padded_contiguous = padded.contiguous();
73     if (padded.dtype() == kFloat) {
74       if (do_transform_0213) {
75         remove_padding_transform0213_kernelLauncher(
76             padded_contiguous.data_ptr<float>(),
77             output.data_ptr<float>(),
78             offsets_ptr,
79             input_size_ptr,
80             output_size_ptr,
81             padded_contiguous.dim() - 2,
82             padded_contiguous.sizes()[0]);
83       } else {
84         remove_padding_kernelLauncher(
85             padded_contiguous.data_ptr<float>(),
86             output.data_ptr<float>(),
87             offsets_ptr,
88             input_size_ptr,
89             output_size_ptr,
90             padded_contiguous.dim() - 1,
91             padded_contiguous.sizes()[0]);
92       }
93     } else if (padded.dtype() == kHalf) {
94       if (do_transform_0213) {
95         remove_padding_transform0213_kernelLauncher(
96             padded_contiguous.data_ptr<c10::Half>(),
97             output.data_ptr<c10::Half>(),
98             offsets_ptr,
99             input_size_ptr,
100             output_size_ptr,
101             padded_contiguous.dim() - 2,
102             padded_contiguous.sizes()[0]);
103       } else {
104         remove_padding_kernelLauncher(
105             padded_contiguous.data_ptr<c10::Half>(),
106             output.data_ptr<c10::Half>(),
107             offsets_ptr,
108             input_size_ptr,
109             output_size_ptr,
110             padded_contiguous.dim() - 1,
111             padded_contiguous.sizes()[0]);
112       }
113     } else {
114       AT_ERROR("Only support fp32/fp16 for padded input");
115     }
116     return at::detail::make_tensor<NestedTensorImpl>(std::move(output), sizes);
117   } else {
118     return at::native::nested_from_padded_generic(padded, sizes);
119   }
120 }
121 
batch_offsets_from_efficient_size(const Tensor & ef_sizes)122 Tensor batch_offsets_from_efficient_size(const Tensor& ef_sizes) {
123   int64_t* nt_sizes_ptr = ef_sizes.data_ptr<int64_t>();
124   int64_t ef_sizes_size_0 = ef_sizes.sizes()[0];
125   Tensor offsets = at::empty({1 + ef_sizes_size_0}, at::kLong);
126   int64_t* offsets_ptr = offsets.mutable_data_ptr<int64_t>();
127   offsets_ptr[0] = 0;
128   int64_t ef_sizes_size_1 = ef_sizes.sizes()[1];
129   for (const auto i : c10::irange(ef_sizes_size_0)) {
130     int64_t prod = 1;
131     for (const auto j : c10::irange(ef_sizes_size_1)) {
132       prod = prod * nt_sizes_ptr[i * ef_sizes_size_1 + j];
133     }
134     offsets_ptr[i + 1] = offsets_ptr[i] + prod;
135   }
136   return offsets;
137 }
138 
NestedTensor_to_padded_tensor_cuda(const Tensor & t,double padding,OptionalIntArrayRef output_size)139 Tensor NestedTensor_to_padded_tensor_cuda(
140     const Tensor& t,
141     double padding,
142     OptionalIntArrayRef output_size) {
143   TORCH_CHECK(t.numel() > 0, "to_padded_tensor: at least one constituent tensor should have non-zero numel")
144   int64_t t_dim = t.dim();
145   if (t_dim >= 2 && t_dim <= 4 &&
146       (t.dtype() == at::kFloat || t.dtype() == at::kDouble ||
147        t.dtype() == at::kHalf)) {
148     auto* nt_input = get_nested_tensor_impl(t);
149     TORCH_CHECK(
150         nested_tensor_impl_is_contiguous(nt_input),
151         "for now to_padded_tensor only supports contiguous nested tensor");
152     const auto& nt_buffer = nt_input->get_buffer();
153 
154     if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0) &&
155         !(output_size.has_value())) {
156       Tensor nt_sizes = nt_input->get_nested_sizes();
157       Tensor sizes_dim1 = at::native::narrow_symint(nt_sizes, 1, 0, 1);
158       Tensor sizes_dim2 = at::native::narrow_symint(nt_sizes, 1, 1, 1);
159       Tensor result = at::detail::make_tensor<NestedTensorImpl>(
160           nt_input->get_buffer(), sizes_dim1 * sizes_dim2[0]);
161       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.dim() == 2);
162       result =
163           NestedTensor_to_padded_tensor_cuda(result, padding, output_size);
164       return result.reshape({result.sizes()[0], -1, *nt_input->opt_size(2)});
165     }
166 
167     Tensor nt_sizes = nt_input->get_nested_sizes();
168     Tensor offsets = batch_offsets_from_efficient_size(nt_sizes);
169     auto new_size = NestedTensor_get_max_size(*nt_input);
170     new_size.insert(new_size.begin(), nt_sizes.sizes()[0]);
171 
172     // Pad output tensor to output_size if provided
173     if (output_size.has_value()) {
174       auto output_size_ = output_size.value();
175       TORCH_CHECK(
176           output_size_.size() == new_size.size(),
177           "Length of output_size does not match NestedTensor dims. Broadcasting is not supported.");
178       for (uint64_t i = 0; i < new_size.size(); i++) {
179         TORCH_CHECK(
180             output_size_[i] >= new_size[i],
181             "Value in output_size is less than NestedTensor padded size. Truncation is not supported.");
182         new_size[i] = output_size_[i];
183       }
184     }
185 
186     Tensor output = at::empty(IntArrayRef(new_size), nt_buffer.options());
187 
188     int64_t input_dim = nt_sizes.sizes()[1];
189     int64_t batch_size = nt_sizes.sizes()[0];
190     int64_t output_batch_size = new_size[0];
191     // TODO: Remove need for cat here
192     at::Tensor metadata = at::cat({offsets, nt_sizes.reshape(-1)});
193     metadata = metadata.to(at::Device(kCUDA), at::kInt);
194 
195     std::vector<Tensor> split =
196         at::split_with_sizes(metadata, {offsets.numel(), nt_sizes.numel()}, 0);
197 
198     offsets = split[0];
199     nt_sizes = split[1];
200 
201     AT_DISPATCH_FLOATING_TYPES_AND_HALF(
202         nt_buffer.scalar_type(), "NestedTensor_to_padded_tensor_cuda", [&]() {
203           add_padding_kernelLauncher(
204               nt_buffer.data_ptr<scalar_t>(),
205               output.data_ptr<scalar_t>(),
206               (scalar_t)(padding),
207               offsets.data_ptr<int>(),
208               nt_sizes.data_ptr<int>(),
209               input_dim,
210               new_size,
211               batch_size,
212               output_batch_size);
213         });
214     return output;
215   }
216   return NestedTensor_to_padded_tensor_generic(t, padding, output_size);
217 }
218 
219 std::tuple<
220     Tensor,
221     Tensor,
222     Tensor,
223     Tensor,
224     c10::SymInt,
225     c10::SymInt,
226     Tensor,
227     Tensor,
228     Tensor>
_scaled_dot_product_flash_attention_nestedtensor_cuda(const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,bool return_debug_mask,std::optional<double> scale)229 _scaled_dot_product_flash_attention_nestedtensor_cuda(
230     const Tensor& query,
231     const Tensor& key,
232     const Tensor& value,
233     double dropout_p,
234     bool is_causal,
235     bool return_debug_mask,
236     std::optional<double> scale) {
237   auto [
238       query_buffer_reshaped,
239       key_buffer_reshaped,
240       value_buffer_reshaped,
241       cumulative_sequence_length_q,
242       cumulative_sequence_length_kv,
243       max_seqlen_batch_q,
244       max_seqlen_batch_kv,
245       output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
246 
247   auto
248       [attention,
249        logsumexp,
250        philox_seed,
251        philox_offset,
252        debug_attn_mask] =
253       at::_flash_attention_forward(
254           query_buffer_reshaped,
255           key_buffer_reshaped,
256           value_buffer_reshaped,
257           cumulative_sequence_length_q,
258           cumulative_sequence_length_kv,
259           max_seqlen_batch_q,
260           max_seqlen_batch_kv,
261           dropout_p,
262           is_causal,
263           return_debug_mask,
264           scale,
265           std::nullopt,
266           std::nullopt);
267   // Reshape output to convert nnz to batch_size and seq_len
268   attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
269   return std::make_tuple(
270       attention,
271       logsumexp,
272       cumulative_sequence_length_q,
273       cumulative_sequence_length_kv,
274       max_seqlen_batch_q,
275       max_seqlen_batch_kv,
276       philox_seed,
277       philox_offset,
278       debug_attn_mask);
279 }
280 
281 std::tuple<Tensor, Tensor, Tensor, Tensor>
_scaled_dot_product_efficient_attention_nestedtensor_cuda(const Tensor & query,const Tensor & key,const Tensor & value,const std::optional<at::Tensor> & attn_bias,bool compute_log_sumexp,double dropout_p,bool is_causal,std::optional<double> scale)282 _scaled_dot_product_efficient_attention_nestedtensor_cuda(
283     const Tensor& query,
284     const Tensor& key,
285     const Tensor& value,
286     const std::optional<at::Tensor>&  attn_bias,
287     bool compute_log_sumexp,
288     double dropout_p,
289     bool is_causal,
290     std::optional<double> scale) {
291   auto [
292       query_buffer_reshaped,
293       key_buffer_reshaped,
294       value_buffer_reshaped,
295       cumulative_sequence_length_q,
296       cumulative_sequence_length_kv,
297       max_seqlen_batch_q,
298       max_seqlen_batch_k,
299       output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
300 
301   sdp::CustomMaskType custom_mask_type = is_causal
302       ? sdp::CustomMaskType::CausalFromTopLeft
303       : sdp::CustomMaskType::NoCustomMask;
304 
305   // See Note [Seed and Offset] for description of seed and offset
306   // Although max_seqlen_q, and max_seqlen_batch_kv is returned we drop these values.
307   auto [attention, log_sumexp, seed, offset, max_seqlen_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
308       query_buffer_reshaped.unsqueeze(0),
309       key_buffer_reshaped.unsqueeze(0),
310       value_buffer_reshaped.unsqueeze(0),
311       std::nullopt,
312       cumulative_sequence_length_q,
313       cumulative_sequence_length_kv,
314       max_seqlen_batch_q,
315       max_seqlen_batch_k,
316       dropout_p,
317       static_cast<int64_t>(custom_mask_type),
318       compute_log_sumexp,
319       scale);
320 
321   // Reshape output to convert nnz to batch_size and seq_len
322   attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
323   return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
324 }
325 
_scaled_dot_product_flash_attention_backward_nested(const at::Tensor & grad_out_,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & out,const at::Tensor & logsumexp,const Tensor & cumulative_sequence_length_q,const Tensor & cumulative_sequence_length_k,const int64_t max_seqlen_batch_q,const int64_t max_seqlen_batch_k,double dropout_p,bool is_causal,const at::Tensor & philox_seed,const at::Tensor & philox_offset,std::optional<double> scale)326 std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
327     const at::Tensor& grad_out_,
328     const at::Tensor& query,
329     const at::Tensor& key,
330     const at::Tensor& value,
331     const at::Tensor& out,
332     const at::Tensor& logsumexp,
333     const Tensor& cumulative_sequence_length_q,
334     const Tensor& cumulative_sequence_length_k,
335     const int64_t max_seqlen_batch_q,
336     const int64_t max_seqlen_batch_k,
337     double dropout_p,
338     bool is_causal,
339     const at::Tensor& philox_seed,
340     const at::Tensor& philox_offset,
341     std::optional<double> scale){
342   if (!grad_out_.defined()) {
343     return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
344   }
345   auto [
346       grad_out_buffer_reshaped,
347       query_buffer_reshaped,
348       key_buffer_reshaped,
349       value_buffer_reshaped,
350       output_buffer_reshaped] =
351       preprocessing::sdpa_nested_preprocessing_backward(
352           grad_out_,
353           query,
354           key,
355           value,
356           out,
357           cumulative_sequence_length_q,
358           cumulative_sequence_length_k,
359           max_seqlen_batch_q,
360           max_seqlen_batch_k);
361 
362   auto [grad_q, grad_k, grad_v] = at::_flash_attention_backward(
363     grad_out_buffer_reshaped,
364     query_buffer_reshaped,
365     key_buffer_reshaped,
366     value_buffer_reshaped,
367     output_buffer_reshaped,
368     logsumexp,
369     cumulative_sequence_length_q,
370     cumulative_sequence_length_k,
371     max_seqlen_batch_q,
372     max_seqlen_batch_k,
373     dropout_p,
374     is_causal,
375     philox_seed,
376     philox_offset,
377     scale);
378 
379   grad_q = wrap_buffer(grad_q.view(-1), query.transpose(1,2)._nested_tensor_size()).transpose(1,2);
380   grad_k = wrap_buffer(grad_k.view(-1), key.transpose(1,2)._nested_tensor_size()).transpose(1,2);
381   grad_v = wrap_buffer(grad_v.view(-1), value.transpose(1,2)._nested_tensor_size()).transpose(1,2);
382 
383   return std::make_tuple(grad_q, grad_k, grad_v);
384 }
385 
386 } // namespace at::native
387