1 #include <numeric>
2 #include <algorithm>
3 #include <c10/util/Exception.h>
4
5 #include <ATen/ATen.h>
6 #include <ATen/NestedTensorImpl.h>
7 #include <ATen/native/NonSymbolicBC.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_nested_from_padded_native.h>
13 #include <ATen/ops/narrow_native.h>
14 #endif
15
16 #include <ATen/native/NonSymbolicBC.h>
17 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
18 #include <ATen/native/nested/NestedTensorTransformerUtils.h>
19 #include <ATen/native/nested/NestedTensorMath.h>
20 #include <ATen/native/nested/NestedTensorUtils.h>
21 #include <ATen/native/transformers/cuda/sdp_utils.h>
22
23 #include <ATen/cuda/CUDAContext.h>
24
25 namespace at::native {
26 namespace {
padded_tensor_numel(const Tensor & sizes)27 int64_t padded_tensor_numel(const Tensor& sizes) {
28 const auto sizes_num_rows = sizes.sizes()[0];
29 const auto sizes_row_length = sizes.sizes()[1];
30 const auto* sizes_data = sizes.const_data_ptr<int64_t>();
31 int64_t numel = 0;
32 for (const auto row_num : c10::irange(sizes_num_rows)) {
33 const auto* row_ptr = sizes_data + row_num * sizes_row_length;
34 int64_t prod = 1;
35 for (const auto idx : c10::irange(sizes_row_length)) {
36 prod *= row_ptr[idx];
37 }
38 numel += prod;
39 }
40 return numel;
41 }
42 } // namespace
nested_from_padded_cuda(const Tensor & padded,const Tensor & sizes,bool do_transform_0213)43 Tensor nested_from_padded_cuda(
44 const Tensor& padded,
45 const Tensor& sizes,
46 bool do_transform_0213) {
47 if (padded.dim() > 1 && padded.dim() < 5) {
48 // Instead of erroring call the generic version
49 if(!(padded.dim() == 4 && do_transform_0213) && !(padded.dim() == 3 && !do_transform_0213)){
50 return at::native::nested_from_padded_generic(padded, sizes, do_transform_0213);
51 }
52 if (padded.dtype() != kFloat && padded.dtype() != kHalf) {
53 TORCH_WARN_ONCE(
54 "nested_from_padded CUDA kernels only support fp32/fp16; falling "
55 "back to slower generic kernel");
56 return at::native::nested_from_padded_generic(padded, sizes, do_transform_0213);
57 }
58 Tensor target_offsets =
59 NestedTensor_batch_offsets_from_size_tensor(sizes, 0);
60 Tensor padded_sizes_tensor = at::tensor(padded.sizes());
61 Tensor output = at::empty({padded_tensor_numel(sizes)}, padded.options());
62 Tensor target_size_sizes = sizes.reshape(-1);
63
64 Tensor metadata =
65 at::cat({target_size_sizes, padded_sizes_tensor, target_offsets});
66 metadata = metadata.to(at::Device(kCUDA), kInt, true, true);
67
68 auto output_size_ptr = metadata.data_ptr<int>();
69 auto input_size_ptr = output_size_ptr + target_size_sizes.numel();
70 auto offsets_ptr = input_size_ptr + padded_sizes_tensor.numel();
71
72 Tensor padded_contiguous = padded.contiguous();
73 if (padded.dtype() == kFloat) {
74 if (do_transform_0213) {
75 remove_padding_transform0213_kernelLauncher(
76 padded_contiguous.data_ptr<float>(),
77 output.data_ptr<float>(),
78 offsets_ptr,
79 input_size_ptr,
80 output_size_ptr,
81 padded_contiguous.dim() - 2,
82 padded_contiguous.sizes()[0]);
83 } else {
84 remove_padding_kernelLauncher(
85 padded_contiguous.data_ptr<float>(),
86 output.data_ptr<float>(),
87 offsets_ptr,
88 input_size_ptr,
89 output_size_ptr,
90 padded_contiguous.dim() - 1,
91 padded_contiguous.sizes()[0]);
92 }
93 } else if (padded.dtype() == kHalf) {
94 if (do_transform_0213) {
95 remove_padding_transform0213_kernelLauncher(
96 padded_contiguous.data_ptr<c10::Half>(),
97 output.data_ptr<c10::Half>(),
98 offsets_ptr,
99 input_size_ptr,
100 output_size_ptr,
101 padded_contiguous.dim() - 2,
102 padded_contiguous.sizes()[0]);
103 } else {
104 remove_padding_kernelLauncher(
105 padded_contiguous.data_ptr<c10::Half>(),
106 output.data_ptr<c10::Half>(),
107 offsets_ptr,
108 input_size_ptr,
109 output_size_ptr,
110 padded_contiguous.dim() - 1,
111 padded_contiguous.sizes()[0]);
112 }
113 } else {
114 AT_ERROR("Only support fp32/fp16 for padded input");
115 }
116 return at::detail::make_tensor<NestedTensorImpl>(std::move(output), sizes);
117 } else {
118 return at::native::nested_from_padded_generic(padded, sizes);
119 }
120 }
121
batch_offsets_from_efficient_size(const Tensor & ef_sizes)122 Tensor batch_offsets_from_efficient_size(const Tensor& ef_sizes) {
123 int64_t* nt_sizes_ptr = ef_sizes.data_ptr<int64_t>();
124 int64_t ef_sizes_size_0 = ef_sizes.sizes()[0];
125 Tensor offsets = at::empty({1 + ef_sizes_size_0}, at::kLong);
126 int64_t* offsets_ptr = offsets.mutable_data_ptr<int64_t>();
127 offsets_ptr[0] = 0;
128 int64_t ef_sizes_size_1 = ef_sizes.sizes()[1];
129 for (const auto i : c10::irange(ef_sizes_size_0)) {
130 int64_t prod = 1;
131 for (const auto j : c10::irange(ef_sizes_size_1)) {
132 prod = prod * nt_sizes_ptr[i * ef_sizes_size_1 + j];
133 }
134 offsets_ptr[i + 1] = offsets_ptr[i] + prod;
135 }
136 return offsets;
137 }
138
NestedTensor_to_padded_tensor_cuda(const Tensor & t,double padding,OptionalIntArrayRef output_size)139 Tensor NestedTensor_to_padded_tensor_cuda(
140 const Tensor& t,
141 double padding,
142 OptionalIntArrayRef output_size) {
143 TORCH_CHECK(t.numel() > 0, "to_padded_tensor: at least one constituent tensor should have non-zero numel")
144 int64_t t_dim = t.dim();
145 if (t_dim >= 2 && t_dim <= 4 &&
146 (t.dtype() == at::kFloat || t.dtype() == at::kDouble ||
147 t.dtype() == at::kHalf)) {
148 auto* nt_input = get_nested_tensor_impl(t);
149 TORCH_CHECK(
150 nested_tensor_impl_is_contiguous(nt_input),
151 "for now to_padded_tensor only supports contiguous nested tensor");
152 const auto& nt_buffer = nt_input->get_buffer();
153
154 if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0) &&
155 !(output_size.has_value())) {
156 Tensor nt_sizes = nt_input->get_nested_sizes();
157 Tensor sizes_dim1 = at::native::narrow_symint(nt_sizes, 1, 0, 1);
158 Tensor sizes_dim2 = at::native::narrow_symint(nt_sizes, 1, 1, 1);
159 Tensor result = at::detail::make_tensor<NestedTensorImpl>(
160 nt_input->get_buffer(), sizes_dim1 * sizes_dim2[0]);
161 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.dim() == 2);
162 result =
163 NestedTensor_to_padded_tensor_cuda(result, padding, output_size);
164 return result.reshape({result.sizes()[0], -1, *nt_input->opt_size(2)});
165 }
166
167 Tensor nt_sizes = nt_input->get_nested_sizes();
168 Tensor offsets = batch_offsets_from_efficient_size(nt_sizes);
169 auto new_size = NestedTensor_get_max_size(*nt_input);
170 new_size.insert(new_size.begin(), nt_sizes.sizes()[0]);
171
172 // Pad output tensor to output_size if provided
173 if (output_size.has_value()) {
174 auto output_size_ = output_size.value();
175 TORCH_CHECK(
176 output_size_.size() == new_size.size(),
177 "Length of output_size does not match NestedTensor dims. Broadcasting is not supported.");
178 for (uint64_t i = 0; i < new_size.size(); i++) {
179 TORCH_CHECK(
180 output_size_[i] >= new_size[i],
181 "Value in output_size is less than NestedTensor padded size. Truncation is not supported.");
182 new_size[i] = output_size_[i];
183 }
184 }
185
186 Tensor output = at::empty(IntArrayRef(new_size), nt_buffer.options());
187
188 int64_t input_dim = nt_sizes.sizes()[1];
189 int64_t batch_size = nt_sizes.sizes()[0];
190 int64_t output_batch_size = new_size[0];
191 // TODO: Remove need for cat here
192 at::Tensor metadata = at::cat({offsets, nt_sizes.reshape(-1)});
193 metadata = metadata.to(at::Device(kCUDA), at::kInt);
194
195 std::vector<Tensor> split =
196 at::split_with_sizes(metadata, {offsets.numel(), nt_sizes.numel()}, 0);
197
198 offsets = split[0];
199 nt_sizes = split[1];
200
201 AT_DISPATCH_FLOATING_TYPES_AND_HALF(
202 nt_buffer.scalar_type(), "NestedTensor_to_padded_tensor_cuda", [&]() {
203 add_padding_kernelLauncher(
204 nt_buffer.data_ptr<scalar_t>(),
205 output.data_ptr<scalar_t>(),
206 (scalar_t)(padding),
207 offsets.data_ptr<int>(),
208 nt_sizes.data_ptr<int>(),
209 input_dim,
210 new_size,
211 batch_size,
212 output_batch_size);
213 });
214 return output;
215 }
216 return NestedTensor_to_padded_tensor_generic(t, padding, output_size);
217 }
218
219 std::tuple<
220 Tensor,
221 Tensor,
222 Tensor,
223 Tensor,
224 c10::SymInt,
225 c10::SymInt,
226 Tensor,
227 Tensor,
228 Tensor>
_scaled_dot_product_flash_attention_nestedtensor_cuda(const Tensor & query,const Tensor & key,const Tensor & value,double dropout_p,bool is_causal,bool return_debug_mask,std::optional<double> scale)229 _scaled_dot_product_flash_attention_nestedtensor_cuda(
230 const Tensor& query,
231 const Tensor& key,
232 const Tensor& value,
233 double dropout_p,
234 bool is_causal,
235 bool return_debug_mask,
236 std::optional<double> scale) {
237 auto [
238 query_buffer_reshaped,
239 key_buffer_reshaped,
240 value_buffer_reshaped,
241 cumulative_sequence_length_q,
242 cumulative_sequence_length_kv,
243 max_seqlen_batch_q,
244 max_seqlen_batch_kv,
245 output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
246
247 auto
248 [attention,
249 logsumexp,
250 philox_seed,
251 philox_offset,
252 debug_attn_mask] =
253 at::_flash_attention_forward(
254 query_buffer_reshaped,
255 key_buffer_reshaped,
256 value_buffer_reshaped,
257 cumulative_sequence_length_q,
258 cumulative_sequence_length_kv,
259 max_seqlen_batch_q,
260 max_seqlen_batch_kv,
261 dropout_p,
262 is_causal,
263 return_debug_mask,
264 scale,
265 std::nullopt,
266 std::nullopt);
267 // Reshape output to convert nnz to batch_size and seq_len
268 attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
269 return std::make_tuple(
270 attention,
271 logsumexp,
272 cumulative_sequence_length_q,
273 cumulative_sequence_length_kv,
274 max_seqlen_batch_q,
275 max_seqlen_batch_kv,
276 philox_seed,
277 philox_offset,
278 debug_attn_mask);
279 }
280
281 std::tuple<Tensor, Tensor, Tensor, Tensor>
_scaled_dot_product_efficient_attention_nestedtensor_cuda(const Tensor & query,const Tensor & key,const Tensor & value,const std::optional<at::Tensor> & attn_bias,bool compute_log_sumexp,double dropout_p,bool is_causal,std::optional<double> scale)282 _scaled_dot_product_efficient_attention_nestedtensor_cuda(
283 const Tensor& query,
284 const Tensor& key,
285 const Tensor& value,
286 const std::optional<at::Tensor>& attn_bias,
287 bool compute_log_sumexp,
288 double dropout_p,
289 bool is_causal,
290 std::optional<double> scale) {
291 auto [
292 query_buffer_reshaped,
293 key_buffer_reshaped,
294 value_buffer_reshaped,
295 cumulative_sequence_length_q,
296 cumulative_sequence_length_kv,
297 max_seqlen_batch_q,
298 max_seqlen_batch_k,
299 output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
300
301 sdp::CustomMaskType custom_mask_type = is_causal
302 ? sdp::CustomMaskType::CausalFromTopLeft
303 : sdp::CustomMaskType::NoCustomMask;
304
305 // See Note [Seed and Offset] for description of seed and offset
306 // Although max_seqlen_q, and max_seqlen_batch_kv is returned we drop these values.
307 auto [attention, log_sumexp, seed, offset, max_seqlen_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
308 query_buffer_reshaped.unsqueeze(0),
309 key_buffer_reshaped.unsqueeze(0),
310 value_buffer_reshaped.unsqueeze(0),
311 std::nullopt,
312 cumulative_sequence_length_q,
313 cumulative_sequence_length_kv,
314 max_seqlen_batch_q,
315 max_seqlen_batch_k,
316 dropout_p,
317 static_cast<int64_t>(custom_mask_type),
318 compute_log_sumexp,
319 scale);
320
321 // Reshape output to convert nnz to batch_size and seq_len
322 attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
323 return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
324 }
325
_scaled_dot_product_flash_attention_backward_nested(const at::Tensor & grad_out_,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & out,const at::Tensor & logsumexp,const Tensor & cumulative_sequence_length_q,const Tensor & cumulative_sequence_length_k,const int64_t max_seqlen_batch_q,const int64_t max_seqlen_batch_k,double dropout_p,bool is_causal,const at::Tensor & philox_seed,const at::Tensor & philox_offset,std::optional<double> scale)326 std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
327 const at::Tensor& grad_out_,
328 const at::Tensor& query,
329 const at::Tensor& key,
330 const at::Tensor& value,
331 const at::Tensor& out,
332 const at::Tensor& logsumexp,
333 const Tensor& cumulative_sequence_length_q,
334 const Tensor& cumulative_sequence_length_k,
335 const int64_t max_seqlen_batch_q,
336 const int64_t max_seqlen_batch_k,
337 double dropout_p,
338 bool is_causal,
339 const at::Tensor& philox_seed,
340 const at::Tensor& philox_offset,
341 std::optional<double> scale){
342 if (!grad_out_.defined()) {
343 return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
344 }
345 auto [
346 grad_out_buffer_reshaped,
347 query_buffer_reshaped,
348 key_buffer_reshaped,
349 value_buffer_reshaped,
350 output_buffer_reshaped] =
351 preprocessing::sdpa_nested_preprocessing_backward(
352 grad_out_,
353 query,
354 key,
355 value,
356 out,
357 cumulative_sequence_length_q,
358 cumulative_sequence_length_k,
359 max_seqlen_batch_q,
360 max_seqlen_batch_k);
361
362 auto [grad_q, grad_k, grad_v] = at::_flash_attention_backward(
363 grad_out_buffer_reshaped,
364 query_buffer_reshaped,
365 key_buffer_reshaped,
366 value_buffer_reshaped,
367 output_buffer_reshaped,
368 logsumexp,
369 cumulative_sequence_length_q,
370 cumulative_sequence_length_k,
371 max_seqlen_batch_q,
372 max_seqlen_batch_k,
373 dropout_p,
374 is_causal,
375 philox_seed,
376 philox_offset,
377 scale);
378
379 grad_q = wrap_buffer(grad_q.view(-1), query.transpose(1,2)._nested_tensor_size()).transpose(1,2);
380 grad_k = wrap_buffer(grad_k.view(-1), key.transpose(1,2)._nested_tensor_size()).transpose(1,2);
381 grad_v = wrap_buffer(grad_v.view(-1), value.transpose(1,2)._nested_tensor_size()).transpose(1,2);
382
383 return std::make_tuple(grad_q, grad_k, grad_v);
384 }
385
386 } // namespace at::native
387