1 #include <ATen/native/nested/NestedTensorMath.h>
2
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #include <ATen/NestedTensorImpl.h>
8 #include <ATen/ScalarOps.h>
9 #include <ATen/TensorIndexing.h>
10 #include <ATen/TensorOperators.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/WrapDimUtilsMulti.h>
13 #include <ATen/core/Tensor.h>
14 #include <ATen/native/layer_norm.h>
15 #include <ATen/native/nested/NestedTensorUtils.h>
16
17 #include <tuple>
18 #include <utility>
19
20
21 namespace at::native {
22 namespace {
23
num_bytes(IntArrayRef sizes)24 int64_t num_bytes(IntArrayRef sizes) {
25 // 0-dim Tensors have torch.Size of .size() 0, but carry 1 memory.
26 // Empty 1-dim Tensors (torch.tensor([])) have torch.Size of .size() 1,
27 // but carry 0 memory.
28 int64_t result = 1;
29 int64_t stride = 1;
30 for (int64_t ii = static_cast<int64_t>(sizes.size()) - 1; ii >= 0; --ii) {
31 result += (sizes[ii] - 1) * stride;
32 // TODO: accept strides as input when we support them instead of
33 // assuming contiguous.
34 stride *= sizes[ii];
35 }
36 return result;
37 }
38
pad_tensor_to_shape(const Tensor & t,IntArrayRef goal_shape,double value=0)39 Tensor pad_tensor_to_shape(
40 const Tensor& t,
41 IntArrayRef goal_shape,
42 double value = 0) {
43 std::vector<int64_t> padd;
44 auto tup = t.sizes();
45 TORCH_CHECK(
46 t.dim() == (int64_t)(goal_shape.size()),
47 "dimension ",
48 t.dim(),
49 " doesn't match length ",
50 goal_shape.size(),
51 " of goal shape.");
52 for (int64_t i = static_cast<int64_t>(tup.size()) - 1; i >= 0; i--) {
53 padd.push_back(0);
54 padd.push_back(goal_shape[i] - tup[i]);
55 }
56 Tensor new_tensor = at::constant_pad_nd(t, IntArrayRef(padd), value);
57 new_tensor = new_tensor.reshape(goal_shape);
58 return new_tensor;
59 }
60 } // namespace
61
62
NestedTensor_nested_tensor_from_mask(const Tensor & t,const Tensor & mask,bool mask_check)63 Tensor NestedTensor_nested_tensor_from_mask(const Tensor& t, const Tensor& mask, bool mask_check) {
64 TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Expected mask to be of ScalarType Bool, but got ", mask.scalar_type(), " instead.");
65 TORCH_CHECK(mask.dim() == 2, "Padding mask should be 2D");
66 TORCH_CHECK(t.dim() == 3, "Input should be a 3D tensor, N * L * D");
67 auto N = t.size(0), L = t.size(1), D = t.size(2);
68 auto NN = mask.size(0), LL = mask.size(1);
69 TORCH_CHECK(N == NN && L == LL, "Mask size should match input size");
70
71 // N * L
72 Tensor sizes = mask;
73 Tensor tmp_pad = at::zeros({N, 1}, mask.options());
74 // Make sure padding is only added at the end of mask
75 Tensor nums = at::cat({sizes, tmp_pad}, 1).to(kInt).argmin(1);
76
77 // N, ([size1, size2, ... sizeN])
78 sizes = sizes.cumsum(1).select(1, L - 1);
79 nums = nums.to(sizes.options());
80
81 if (mask_check)
82 TORCH_CHECK(sizes.equal(nums), "Mask must be left-aligned without gaps");
83
84 sizes = sizes.reshape({N, 1});
85 // N, ([d1=D, d2=D, ... dN=D])
86 Tensor d = at::full_like(sizes, D);
87
88 // N * 2, ([[size1, D], [size2, D], ..., [sizeN, D]])
89 sizes = at::cat({sizes, d}, 1).to(kCPU);
90
91 return at::_nested_from_padded(t, sizes, false);
92 }
93
NestedTensor_nested_tensor_from_mask_left_aligned(const Tensor & t,const Tensor & mask)94 bool NestedTensor_nested_tensor_from_mask_left_aligned(const Tensor& t, const Tensor& mask) {
95 TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Expected mask to be of ScalarType Bool, but got ", mask.scalar_type(), " instead.");
96 TORCH_CHECK(mask.dim() == 2, "Padding mask should be 2D");
97 TORCH_CHECK(t.dim() == 3, "Input should be a 3D tensor, N * L * D");
98 auto N = t.size(0), L = t.size(1);
99 auto NN = mask.size(0), LL = mask.size(1);
100 TORCH_CHECK(N == NN && L == LL, "Mask size should match input size");
101
102 // N * L
103 Tensor sizes = mask;
104 Tensor tmp_pad = at::zeros({N, 1}, mask.options());
105 // Make sure padding is only added at the end of mask
106 Tensor nums = at::cat({sizes, tmp_pad}, 1).to(kInt).argmin(1);
107
108 // N, ([size1, size2, ... sizeN])
109 sizes = sizes.cumsum(1).select(1, L - 1);
110 nums = nums.to(sizes.options());
111
112 return sizes.equal(nums);
113 }
114
_nested_tensor_from_tensor_list(TensorList list,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)115 Tensor _nested_tensor_from_tensor_list(
116 TensorList list,
117 std::optional<ScalarType> dtype,
118 std::optional<Layout> layout,
119 std::optional<Device> device,
120 std::optional<bool> pin_memory) {
121 for (const auto i : c10::irange(list.size())) {
122 if (i > 0) {
123 int64_t dim_i = list[i].dim();
124 int64_t dim_prev = list[i - 1].dim();
125 TORCH_CHECK(
126 dim_i == dim_prev,
127 "All Tensors given to nested_tensor must have the same dimension. ",
128 "Found dimension ",
129 dim_i,
130 " for Tensor at index ",
131 i,
132 " and dimension ",
133 dim_prev,
134 " for Tensor at index ",
135 i - 1,
136 ".");
137 }
138 }
139 return impl::wrap_tensor_node(
140 impl::TensorNode(list),
141 dtype,
142 layout,
143 device,
144 pin_memory);
145 }
146
nested_layer_norm(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps)147 std::tuple<Tensor, Tensor, Tensor> nested_layer_norm(
148 const Tensor& input,
149 IntArrayRef normalized_shape,
150 const std::optional<Tensor>& weight_opt,
151 const std::optional<Tensor>& bias_opt,
152 double eps) {
153 TORCH_CHECK(weight_opt && bias_opt, "NestedTensor layer_norm requires weight and bias");
154 const auto& weight = *weight_opt;
155 const auto& bias = *bias_opt;
156 TORCH_CHECK(!weight.is_nested(), "NestedTensor weight not supported for layer_norm");
157 TORCH_CHECK(!bias.is_nested(), "NestedTensor bias not supported for layer_norm");
158 auto* nt_input = get_nested_tensor_impl(input);
159 TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_input));
160 const auto& input_buffer = nt_input->get_buffer();
161 auto M_N = _check_nested_layer_norm_inputs(*nt_input, normalized_shape, weight, bias);
162 auto M = M_N.first;
163 auto N = M_N.second;
164 const auto weight_contig = weight.expect_contiguous();
165 const auto bias_contig = bias.expect_contiguous();
166 auto output_buffer = at::native::empty_like(
167 input_buffer,
168 std::nullopt /* dtype */,
169 std::nullopt /* layout */,
170 std::nullopt /* device */,
171 std::nullopt /* pin_memory */,
172 at::MemoryFormat::Contiguous);
173 auto options = input_buffer.options();
174 if (input_buffer.is_cuda()) {
175 auto acc_type = at::toAccumulateType(input_buffer.scalar_type(), true);
176 options = options.dtype(acc_type);
177 }
178 Tensor mean = at::empty({M}, options);
179 Tensor rstd = at::empty({M}, options);
180 LayerNormKernel(
181 input_buffer.is_cuda() ? kCUDA : kCPU,
182 input_buffer,
183 *weight_contig,
184 *bias_contig,
185 M,
186 N,
187 eps,
188 &output_buffer,
189 &mean,
190 &rstd);
191 return std::make_tuple(
192 wrap_buffer(output_buffer, nt_input->get_nested_sizes()),
193 mean,
194 rstd
195 );
196 }
197
NestedTensor_from_padded_and_nested_example(const Tensor & padded,const Tensor & nt_example)198 Tensor NestedTensor_from_padded_and_nested_example(
199 const Tensor& padded,
200 const Tensor& nt_example) {
201 return _nested_from_padded(padded, get_nested_tensor_impl(nt_example)->get_nested_sizes());
202 }
203
nested_from_padded_generic(const Tensor & padded,const Tensor & sizes,const bool do_transform_0213)204 Tensor nested_from_padded_generic(
205 const Tensor& padded,
206 const Tensor& sizes,
207 const bool do_transform_0213) {
208 // Check and do transform 0213
209 auto padded_transformed = padded;
210 if (do_transform_0213) {
211 padded_transformed = padded.permute({0, 2, 1, 3})
212 .contiguous()
213 .view(
214 {padded.size(0),
215 padded.size(2),
216 padded.size(1) * padded.size(3)});
217 }
218 auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes);
219 // There may be extra padding on padded beyond the max size in the nested tensor.
220 // Make the mask size match.
221 const size_t dim = padded_transformed.dim();
222 TORCH_CHECK(dim - 1 == target_size.size(), "dim: ", dim, "target_size: ", target_size.size());
223 for (size_t ii = 0; ii < dim - 1; ++ii) {
224 const auto padded_size_i = padded_transformed.sizes()[ii + 1];
225 if (target_size[ii] < padded_size_i) {
226 target_size[ii] = padded_size_i;
227 }
228 }
229 IntArrayRef target_size_arr(target_size);
230 std::vector<at::Tensor> masks;
231 std::vector<at::Tensor> all_sizes = sizes.unbind();
232 for (const auto& size : all_sizes) {
233 IntArrayRef sizes_i(
234 size.data_ptr<int64_t>(), size.data_ptr<int64_t>() + size.numel());
235 at::Tensor mask_i = padded_transformed.new_full(
236 sizes_i, true, kBool, std::nullopt, std::nullopt, std::nullopt);
237 masks.push_back(pad_tensor_to_shape(mask_i, target_size_arr));
238 }
239 at::Tensor final_mask = at::stack(masks);
240 at::Tensor new_buffer = padded_transformed.masked_select(final_mask).to(padded.device());
241 return at::detail::make_tensor<NestedTensorImpl>(
242 std::move(new_buffer), sizes);
243 }
244
NestedTensor_to_padded_tensor_generic(const Tensor & t,double padding,OptionalIntArrayRef output_size)245 Tensor NestedTensor_to_padded_tensor_generic(
246 const Tensor& t,
247 double padding,
248 OptionalIntArrayRef output_size) {
249 // TODO: support noncontiguous case
250 // error out for now
251 TORCH_CHECK(
252 nested_tensor_impl_is_contiguous(get_nested_tensor_impl(t)),
253 "for now to_padded_tensor only supports contiguous nested tensor");
254 // TODO: skipped optimization for case of all 1x1 tensors
255 auto& nt = *get_nested_tensor_impl(t);
256 auto max_size = NestedTensor_get_max_size(nt);
257 auto sizes = nt.get_nested_sizes();
258
259 if (sizes.numel() == 0 || sizes.dim() == 0) {
260 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(nt.get_buffer().numel() == 0);
261 return nt.get_buffer().clone();
262 }
263 TORCH_CHECK(
264 t.numel() > 0,
265 "to_padded_tensor: at least one constituent tensor should have non-zero numel"
266 )
267
268 // TODO: doesn't handle empty/scalar entries because we don't need
269 // it for transformers; see to_padded_tensor in
270 // pytorch/nestedtensor's masking.cpp.
271
272 const auto sizes_num_rows = sizes.sizes()[0];
273 const auto sizes_num_columns = sizes.sizes()[1];
274 const auto sizes_data_start = sizes.data_ptr<int64_t>();
275 const auto sizes_data_end = sizes_data_start + sizes.numel();
276 std::vector<int64_t> split_sizes;
277 split_sizes.reserve(sizes_num_rows);
278 for (auto sizes_data = sizes_data_start; sizes_data != sizes_data_end;
279 sizes_data += sizes_num_columns) {
280 split_sizes.push_back(
281 num_bytes(IntArrayRef(sizes_data, sizes_num_columns)));
282 }
283 std::vector<int64_t> nonzero_split_sizes;
284 for (const auto split_size : split_sizes) {
285 if (split_size > 0) {
286 nonzero_split_sizes.push_back(split_size);
287 }
288 }
289 const auto buffer = nt.get_buffer();
290 std::vector<Tensor> buffers_;
291 if (!nonzero_split_sizes.empty()) {
292 buffers_ = at::split_with_sizes(buffer, nonzero_split_sizes, 0);
293 }
294
295 std::vector<Tensor> buffers;
296 buffers.reserve(split_sizes.size());
297 int64_t next_buffer = 0;
298 auto sizes_ptr = sizes_data_start;
299 for (const auto split_size : split_sizes) {
300 Tensor to_pad;
301 IntArrayRef tensor_sizes(sizes_ptr, sizes_num_columns);
302 if (split_size > 0) {
303 to_pad = buffers_[next_buffer++].reshape(tensor_sizes);
304 } else {
305 to_pad = at::empty(tensor_sizes, buffer.options());
306 }
307 buffers.push_back(pad_tensor_to_shape(to_pad, max_size, padding));
308 sizes_ptr += sizes_num_columns;
309 }
310 auto ret_val = at::stack(buffers);
311
312 // Pad output tensor to output_size if provided
313 if (output_size.has_value()) {
314 auto output_size_ = output_size.value();
315 TORCH_CHECK(
316 (int64_t)output_size_.size() == ret_val.dim(),
317 "Length of output_size does not match NestedTensor dims. Broadcasting is not supported.");
318 for (int64_t i = 0; i < (int64_t)ret_val.dim(); i++) {
319 TORCH_CHECK(
320 output_size_[i] >= ret_val.size(i),
321 "Value in output_size is less than NestedTensor padded size. Truncation is not supported.");
322 }
323 return pad_tensor_to_shape(ret_val, output_size_, padding);
324 }
325 return ret_val;
326 }
327
NestedTensor_embedding(const Tensor & weight,const Tensor & indices,int64_t padding_idx,bool scale_grad_by_freq,bool sparse)328 Tensor NestedTensor_embedding(
329 const Tensor& weight,
330 const Tensor& indices,
331 int64_t padding_idx,
332 bool scale_grad_by_freq,
333 bool sparse) {
334 const auto* nt_indices = get_nested_tensor_impl(indices);
335 TORCH_CHECK(
336 !weight.is_nested(), "NestedTensor weight not supported for embedding");
337 TORCH_CHECK(indices.dim() < 3);
338 TORCH_CHECK(indices.dim() > 0, "NestedTensor embedding doesn't support empty indices.")
339 TORCH_CHECK(weight.dim() == 2);
340 TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_indices));
341 TORCH_CHECK(weight.is_contiguous());
342
343 const auto& indices_buffer = nt_indices->get_buffer();
344 auto result_buffer = at::embedding(
345 weight, indices_buffer, padding_idx, scale_grad_by_freq, sparse);
346 const auto& sizes = nt_indices->get_nested_sizes();
347 auto new_sizes = at::empty({sizes.size(0)}, sizes.options());
348 new_sizes.fill_(weight.sizes()[1]);
349 new_sizes = new_sizes.reshape({new_sizes.size(0), 1});
350 new_sizes = at::cat({sizes, new_sizes}, 1);
351 return at::detail::make_tensor<NestedTensorImpl>(
352 result_buffer.reshape({-1}), std::move(new_sizes));
353 }
354
355 // Very rudimentary sum_dim for prototyping with torch_scatter.segment_reduce.
NestedTensor_sum_dim_CPU(const Tensor & self,OptionalIntArrayRef opt_dims,bool keepdim,std::optional<ScalarType> dtype)356 Tensor NestedTensor_sum_dim_CPU(
357 const Tensor& self,
358 OptionalIntArrayRef opt_dims,
359 bool keepdim,
360 std::optional<ScalarType> dtype) {
361 // Only allow reductions across the last dim
362 auto dims = opt_dims.value_or(IntArrayRef{});
363 TORCH_CHECK(
364 dims.size() == 1,
365 "NestedTensor only allows reduction of a single dimension for now."
366 );
367 auto dim = maybe_wrap_dim(dims[0], self.dim());
368 TORCH_CHECK(
369 dim == self.dim() - 1,
370 "NestedTensor can only be reduced across the last dimension for now ",
371 "got dimension ",
372 dim,
373 " instead.");
374 // Always keep reduced dim for now
375 // This is to avoid the case where the nested tensors are 1D and keepdim=False
376 // making the nested tensors -> elements (e.g. sum(nt([1, 2 ,3], [4, 5]), -1) -> nt(6, 9))
377 TORCH_CHECK(keepdim, "NestedTensor always requires keepdim=True for now.");
378 // acc_dtype is not supported for now
379 TORCH_CHECK(!dtype, "NestedTensor does not support dtype argument for now.");
380
381 auto nt_input = get_nested_tensor_impl(self);
382 TORCH_CHECK(
383 nested_tensor_impl_is_contiguous(nt_input),
384 "NestedTensor does not support reductions when the input is noncontiguous for now.");
385 int64_t ntensors = nt_input->size(0);
386 if (ntensors == 0) {
387 return self;
388 }
389 const Tensor& buffer = nt_input->get_buffer();
390
391 auto sizemat = nt_input->get_nested_sizes();
392 // create output size tensor for keepdim=True
393 auto output_sizemat = sizemat.clone();
394 output_sizemat.select(1, -1).fill_(1);
395
396 auto num_segments = at::prod(output_sizemat, -1);
397 auto segment_lengths = sizemat.select(1, -1);
398 const int64_t new_numel = at::sum(num_segments).item<int64_t>();
399 auto output_buffer = buffer.new_empty(IntArrayRef(new_numel));
400
401 // This logic assumes for now that
402 // (1) all the nested tensors are contiguous
403 // (2) the nested tensors are stored contiguously in the buffer
404 AT_DISPATCH_ALL_TYPES_AND2(
405 ScalarType::Half, ScalarType::BFloat16, buffer.scalar_type(), "nested_sum_dim_cpu", [&]() {
406 auto* output_data = output_buffer.data_ptr<scalar_t>();
407 const auto* input_data = buffer.const_data_ptr<scalar_t>();
408 int64_t out_idx = 0, in_idx = 0;
409 for (const auto i : c10::irange(ntensors)) {
410 int64_t segments = num_segments[i].item<int64_t>();
411 int64_t segment_length = segment_lengths[i].item<int64_t>();
412 for (auto j = 0; j < segments; j++) {
413 scalar_t res = 0;
414 for (auto k = 0; k < segment_length; k++) {
415 res += input_data[in_idx];
416 in_idx += 1;
417 }
418 output_data[out_idx] = res;
419 out_idx += 1;
420 }
421 }
422 });
423
424 return wrap_buffer(output_buffer, output_sizemat);
425 }
426
select_nested(const Tensor & self,int64_t dim,int64_t index)427 Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) {
428 auto self_ptr = get_nested_tensor_impl(self);
429 std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
430 strides = NestedTensor_get_strides(self_ptr);
431 int64_t *offsets_ptr = self_ptr->get_storage_offsets().data_ptr<int64_t>();
432 const at::Tensor& buffer = self_ptr->get_unsafe_storage_as_tensor();
433 int64_t positive_dim = at::maybe_wrap_dim(dim, self_ptr->dim());
434 int64_t ntensors = self_ptr->size(0);
435 TORCH_CHECK_INDEX(ntensors > 0, "You can only select when the NT is not empty.");
436 int64_t ndims = static_cast<long>(sizes[0].size());
437 if (positive_dim == 0) {
438 TORCH_CHECK_INDEX(
439 index >= -ntensors && index < ntensors,
440 "index ",
441 index,
442 " is out of bounds for dimension 0 with size ",
443 ntensors);
444 int64_t positive_index = index < 0 ? index + ntensors : index;
445 return buffer.as_strided(
446 sizes[positive_index],
447 strides[positive_index],
448 offsets_ptr[positive_index]);
449 } else {
450 auto new_sizes = at::empty({ntensors, ndims-1}, TensorOptions().dtype(kLong));
451 auto new_strides = at::empty({ntensors, ndims-1}, TensorOptions().dtype(kLong));
452 auto new_offsets = at::empty({ntensors}, TensorOptions().dtype(kLong));
453 for (int64_t i : c10::irange(ntensors)) {
454 int64_t *size_ptr = new_sizes[i].data_ptr<int64_t>();
455 int64_t *stride_ptr = new_strides[i].data_ptr<int64_t>();
456
457 int64_t dim_idx = 0;
458 for (int64_t j : c10::irange(ndims)) {
459 if (j != dim - 1) {
460 size_ptr[dim_idx] = sizes[i][j];
461 stride_ptr[dim_idx] = strides[i][j];
462 ++dim_idx;
463 } else {
464 TORCH_CHECK_INDEX(
465 index >= 0 && index < sizes[i][j],
466 "index ",
467 index,
468 " is out of bounds for dimension ",
469 j,
470 " of the ",
471 i,
472 "th constituent tensor with size ",
473 sizes[i][j]);
474 new_offsets[i] = offsets_ptr[i] + index * strides[i][j];
475 }
476 }
477 }
478 return create_nested_view_tensor(self, new_sizes, new_strides, new_offsets);
479 }
480
481 }
482
native_dropout_nested(const Tensor & input,double p,std::optional<bool> train)483 std::tuple<Tensor,Tensor> native_dropout_nested(const Tensor& input, double p, std::optional<bool> train) {
484 auto input_ptr = get_nested_tensor_impl(input);
485 const Tensor& input_buffer = input_ptr-> get_unsafe_storage_as_tensor(),
486 & sizemat = input_ptr->get_nested_sizes(),
487 & stridemat = input_ptr->get_nested_strides();
488 const auto offsets = input_ptr->get_storage_offsets();
489 Tensor output_buffer, mask_buffer;
490 if (input_buffer.numel() == 0) {
491 output_buffer = input_buffer.clone();
492 mask_buffer = input_buffer.clone();
493 }
494 else {
495 std::tie(output_buffer, mask_buffer) = at::native_dropout(input_buffer, p, train);
496 }
497 // regular tensor dropout reuses input size and stride
498 // i.e. if input is not contiguous, then output is also discontiguous
499 Tensor output = wrap_buffer(output_buffer, sizemat.clone(), stridemat.clone(), offsets.clone()),
500 mask = wrap_buffer(mask_buffer, sizemat.clone(), stridemat.clone(), offsets.clone());
501 return std::make_tuple(output, mask);
502 }
503
softmax_nested(const Tensor & input,const int64_t dim,const bool half_to_float)504 Tensor softmax_nested(
505 const Tensor& input,
506 const int64_t dim,
507 const bool half_to_float) {
508 auto input_ptr = get_nested_tensor_impl(input);
509 int64_t ntensors = input_ptr->size(0);
510 if (ntensors == 0) {
511 return input.clone();
512 }
513 int64_t positive_dim = at::maybe_wrap_dim(dim, input_ptr->dim());
514 TORCH_CHECK(
515 positive_dim >= 1,
516 "Cannot apply softmax across nested dimension 0");
517 // create a contiguous output
518 // TODO We would ideally use a empty_like here, but that is not supported
519 // for nested tensors yet. Since we are only using the buffer for the options
520 // and size it is okay to use unsafe_storage_as_tensor here.
521 const Tensor& buffer = input_ptr->get_unsafe_storage_as_tensor(),
522 & sizemat = input_ptr->get_nested_sizes();
523 Tensor output_buffer = buffer.new_empty(buffer.sizes());
524 Tensor output = wrap_buffer(output_buffer, sizemat.clone());
525 // call tensor softmax
526 // TODO: for cpu, maybe use `parallel_for` if benchmarks show necessity
527 // to do that, have to merge `aten/src/ATen/native/cpu/SoftMaxKernel.cpp/softmax_kernel`
528 // 1. it has `parallel_for` and we cannot multi-thread in multi-thread
529 // 2. cannot dispatch in multi-thread (in this case at::_softmax_out)
530 std::vector<Tensor> input_unbind = input.unbind(),
531 output_unbind = output.unbind();
532 for (int64_t i = 0; i < ntensors; i++) {
533 at::_softmax_out(
534 output_unbind[i],
535 input_unbind[i],
536 positive_dim - 1,
537 half_to_float);
538 }
539 return output;
540 }
541
NestedTensor_all(const Tensor & input,const int64_t dim,const bool keepdim)542 Tensor NestedTensor_all(
543 const Tensor& input,
544 const int64_t dim,
545 const bool keepdim) {
546 auto input_ptr = get_nested_tensor_impl(input);
547 int64_t ntensors = input_ptr->size(0);
548 if (ntensors == 0) {
549 return input.clone();
550 }
551 int64_t positive_dim = at::maybe_wrap_dim(dim, input_ptr->dim());
552 TORCH_CHECK(
553 positive_dim >= 1,
554 "Cannot apply all across nested dimension 0");
555 const Tensor& buffer = input_ptr->get_unsafe_storage_as_tensor(),
556 & sizemat = input_ptr->get_nested_sizes();
557
558
559 Tensor output_buffer = buffer.new_empty(buffer.sizes());
560
561 Tensor output_size = sizemat.clone();
562 if (keepdim) {
563 output_size.select(1, positive_dim - 1).fill_(1);
564 } else {
565 output_size = output_size.slice(1, 0, positive_dim - 1);
566 }
567
568 Tensor output = wrap_buffer(output_buffer, output_size.contiguous());
569
570 std::vector<Tensor> input_unbind = input.unbind(),
571 output_unbind = output.unbind();
572 for (int64_t i = 0; i < ntensors; i++) {
573 at::all_out(
574 output_unbind[i],
575 input_unbind[i],
576 positive_dim - 1,
577 keepdim);
578 }
579 return output;
580 }
581
transpose_nested(const Tensor & self,int64_t dim0,int64_t dim1)582 Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) {
583 auto self_ptr = get_nested_tensor_impl(self);
584 // check input dimensions
585 int64_t ndims = self_ptr->dim();
586 int64_t positive_dim0 = at::maybe_wrap_dim(dim0, ndims),
587 positive_dim1 = at::maybe_wrap_dim(dim1, ndims);
588 if (positive_dim0 == positive_dim1) {
589 return self;
590 }
591 TORCH_CHECK(positive_dim0 > 0 && positive_dim1 > 0, "Nested tensor dimension 0 cannot be transposed");
592 // -- to exclude the implicit batch dimension
593 ndims--;
594 positive_dim0--;
595 positive_dim1--;
596 // transpose = switch `dim0` and `dim1` columns of `sizemat` and `stridemat`
597 const Tensor& sizemat = self_ptr->get_nested_sizes(),
598 & stridemat = self_ptr->get_nested_strides();
599 Tensor column_indices = sizemat.new_empty(ndims);
600 int64_t* column_indices_ptr = column_indices.data_ptr<int64_t>();
601 std::iota(column_indices_ptr, column_indices_ptr + ndims, 0);
602 column_indices_ptr[positive_dim0] = positive_dim1;
603 column_indices_ptr[positive_dim1] = positive_dim0;
604 // create transposed `sizemat` and `stridemat`
605 Tensor sizemat_transposed = at::index_select(sizemat, 1, column_indices),
606 stridemat_transposed = at::index_select(stridemat, 1, column_indices);
607 return create_nested_view_tensor(
608 self, sizemat_transposed, stridemat_transposed, self_ptr->get_storage_offsets().clone());
609 }
610
squeeze_nested(const Tensor & self)611 Tensor squeeze_nested(const Tensor& self) {
612 TORCH_CHECK(false,
613 "squeeze(): For nested tensors, squeeze without the dim argument is not supported ",
614 "at the moment, however you can use squeeze(Tensor self, int dim) instead ",
615 "if you need this feature, please open an issue on github describing your use case.");
616 return self;
617 }
618
squeeze_dim_nested(const Tensor & self,IntArrayRef dims)619 Tensor squeeze_dim_nested(const Tensor& self, IntArrayRef dims) {
620 auto self_ptr = get_nested_tensor_impl(self);
621 int64_t ndim = self_ptr->dim();
622 auto mask = at::dim_list_to_bitset(dims, ndim);
623 TORCH_CHECK(!mask.test(0),
624 "squeeze(): For nested tensors, squeezing dimension 0 is not supported at the moment ",
625 "if you need this feature, please open an issue on github describing your use case.");
626 const Tensor& sizemat = self_ptr->get_nested_sizes();
627 const Tensor& stridemat = self_ptr->get_nested_strides();
628 // if tensor.size(dim) != 1 torch.squeeze will return the result, we do the same here
629 for (const auto d : c10::irange(ndim)) {
630 if (mask.test(d)) {
631 std::optional<int64_t> size_dim = self_ptr->opt_size(d);
632 if (!(size_dim.has_value() && *size_dim == 1)) {
633 mask.reset(d);
634 }
635 }
636 }
637
638 if (!mask.any()) {
639 // detach to avoid triggering throw_error_if_base_and_tensor_are_same
640 return self.detach();
641 }
642 // if ndim == 2 and we pass the above if statement we should have a
643 // nested tensor of singleton tensors
644 TORCH_CHECK(ndim > static_cast<int64_t>(1 + dims.size()),
645 "squeeze(): For nested tensors, squeezing a nested tensor of singleton tensors is not ",
646 "supported at the moment, if you need this feature, please open an issue on github",
647 "describing your use case.");
648 const auto new_ndim = ndim - mask.count();
649 auto column_indices = sizemat.new_empty(static_cast<int64_t>(new_ndim) - 1);
650 int64_t* column_indices_ptr = column_indices.data_ptr<int64_t>();
651 for (const auto d : c10::irange(1, ndim)) {
652 if (!mask.test(d)) {
653 *column_indices_ptr++ = d - 1;
654 }
655 }
656 auto sizemat_squeezed = at::index_select(sizemat, 1, column_indices);
657 auto stridemat_squeezed = at::index_select(stridemat, 1, column_indices);
658 return create_nested_view_tensor(
659 self, sizemat_squeezed, stridemat_squeezed, self_ptr->get_storage_offsets().clone());
660 }
661
squeeze_dim_nested(const Tensor & self,int64_t dim)662 Tensor squeeze_dim_nested(const Tensor& self, int64_t dim) {
663 return squeeze_dim_nested(self, IntArrayRef{dim});
664 }
665
unsqueeze_nested(const Tensor & self,int64_t dim)666 Tensor unsqueeze_nested(const Tensor& self, int64_t dim) {
667 auto self_ptr = get_nested_tensor_impl(self);
668 int64_t ndim = self_ptr->dim();
669 int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim + 1);
670 TORCH_CHECK(wrapped_dim > 0,
671 "unsqueeze(): For nested tensors, unsqueezing dimension 0 is not supported at the moment ",
672 "if you need this feature, please open an issue on github describing your use case.");
673 const Tensor& sizemat = self_ptr->get_nested_sizes();
674 const Tensor& stridemat = self_ptr->get_nested_strides();
675 auto mat_dim = wrapped_dim - 1;
676 Tensor new_size = sizemat.new_ones({sizemat.size(0), 1});
677 Tensor sizemat_unsqueezed = at::cat({sizemat.slice(1, 0, mat_dim),
678 new_size,
679 sizemat.slice(1, mat_dim, ndim)}, 1);
680 Tensor new_stride;
681 if (wrapped_dim == ndim) {
682 new_stride = stridemat.new_ones({stridemat.size(0), 1});
683 } else {
684 new_stride = (stridemat.select(1, mat_dim) * sizemat.select(1, mat_dim)).unsqueeze(-1);
685 }
686 Tensor stridemat_unsqueezed = at::cat({stridemat.slice(1, 0, mat_dim),
687 new_stride,
688 stridemat.slice(1, mat_dim, ndim)}, 1);
689 return create_nested_view_tensor(
690 self, sizemat_unsqueezed, stridemat_unsqueezed, self_ptr->get_storage_offsets().clone());
691 }
692
693 // utilities supporting `view_nested` and `reshape_nested`
694 namespace {
695 // Args:
696 // sizes: the sizes of original nested tensor
697 // strides: the strides of original nested tensor
698 // proposed_shape: user proposed new shape
699 // op: the options for new size and stride matrices
700 // Returns:
701 // whether viewable
702 // size matrix after reshape
703 // stride matrix after reshape (not fully populated if not viewable)
NestedTensor_compute_size_stride(const std::vector<IntArrayRef> & sizes,const std::vector<IntArrayRef> & strides,const IntArrayRef & proposed_shape,const c10::TensorOptions & op)704 inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
705 const std::vector<IntArrayRef>& sizes,
706 const std::vector<IntArrayRef>& strides,
707 const IntArrayRef& proposed_shape,
708 const c10::TensorOptions& op) {
709 int64_t ntensors = static_cast<int64_t>(sizes.size());
710 int64_t ndims_underlying = static_cast<int64_t>(sizes[0].size());
711 int64_t ndims_underlying_reshaped = static_cast<int64_t>(proposed_shape.size() - 1);
712 bool viewable = true;
713 Tensor sizemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op),
714 stridemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op);
715 int64_t* sizemat_reshaped_ptr = sizemat_reshaped.mutable_data_ptr<int64_t>(),
716 * stridemat_reshaped_ptr = stridemat_reshaped.mutable_data_ptr<int64_t>();
717 for (int64_t itensor = 0; itensor < ntensors; itensor++) {
718 const IntArrayRef& size = sizes[itensor],
719 & stride = strides[itensor];
720 // compute reshaped size
721 std::vector<int64_t> size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end());
722 // only allow one pre-existing dimension to have proposed shape == -1
723 int64_t infer_index_old = -1;
724 // some negative sizes remain to be inferred
725 if (ndims_underlying < ndims_underlying_reshaped) {
726 int64_t numel = 1, numel_reshaped = 1;
727 // replace negative sizes for old dimensions with old sizes
728 for (int64_t idim = 0; idim < ndims_underlying; idim++) {
729 int64_t& size_reshaped = size_reshaped_vector[idim];
730 TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
731 if (size_reshaped == -1) {
732 TORCH_CHECK(infer_index_old == -1, "only one dimension can be inferred");
733 size_reshaped = size[idim];
734 infer_index_old = idim;
735 }
736 numel *= size[idim];
737 numel_reshaped *= size_reshaped;
738 }
739 // infer negative size for new dimension
740 int64_t infer_index = -1;
741 for (int64_t idim = ndims_underlying; idim < ndims_underlying_reshaped; idim++) {
742 const int64_t& size_reshaped = size_reshaped_vector[idim];
743 if (size_reshaped >= 0) {
744 numel_reshaped *= size_reshaped;
745 }
746 else if (size_reshaped == -1) {
747 if (infer_index > -1) {
748 throw std::runtime_error("only one dimension can be inferred");
749 }
750 else {
751 infer_index = idim;
752 }
753 }
754 else {
755 AT_ERROR("invalid shape dimension ", size_reshaped);
756 }
757 }
758 // See Note [Special size rule for nested tensor]
759 TORCH_CHECK(infer_index == -1, "nested tensor does not infer shape");
760 TORCH_CHECK(
761 numel == numel_reshaped,
762 "shape '", proposed_shape, "' ",
763 "is invalid for input of size ", numel);
764 }
765 // all negative sizes can be replaced
766 else {
767 int64_t numel = 1, numel_reshaped = 1;
768 for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
769 int64_t& size_reshaped = size_reshaped_vector[idim];
770 TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
771 if (size_reshaped == -1) {
772 size_reshaped = size[idim];
773 }
774 numel *= size[idim];
775 numel_reshaped *= size_reshaped;
776 }
777 for (int64_t idim = ndims_underlying_reshaped; idim < ndims_underlying; idim++) {
778 numel *= size[idim];
779 }
780 TORCH_CHECK(
781 numel == numel_reshaped,
782 "shape '", proposed_shape, "' ",
783 "is invalid for input of size ", numel);
784 }
785 IntArrayRef size_reshaped(size_reshaped_vector);
786 // compute reshaped stride
787 auto opt_stride_reshaped = at::detail::computeStride(size, stride, size_reshaped);
788 // reshape as view is possible
789 if (opt_stride_reshaped.has_value()) {
790 const IntArrayRef& stride_reshaped = *opt_stride_reshaped;
791 // fill reshaped size and stride into sizemat and stridemat
792 for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
793 sizemat_reshaped_ptr[idim] = size_reshaped[idim];
794 stridemat_reshaped_ptr[idim] = stride_reshaped[idim];
795 }
796 sizemat_reshaped_ptr += ndims_underlying_reshaped;
797 stridemat_reshaped_ptr += ndims_underlying_reshaped;
798 }
799 // reshape as view is impossible
800 else {
801 viewable = false;
802 // fill reshaped size into sizemat
803 for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
804 sizemat_reshaped_ptr[idim] = size_reshaped[idim];
805 }
806 sizemat_reshaped_ptr += ndims_underlying_reshaped;
807 }
808 }
809 return std::make_tuple(viewable, sizemat_reshaped, stridemat_reshaped);
810 }
811 } // namespace
812
813 // Note [Special size rule for nested tensor]
814 // Instead of inferring size, -1 means "inherit the old size", so:
815 // * negative size is legal for a ragged dimension
816 // * however, we only allow one -1
817 // In principle we could still infer a dimension,
818 // we are designing a better semantics to include both inheritance and inference
view_nested(const Tensor & self,IntArrayRef proposed_shape)819 Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
820 TORCH_CHECK(
821 !proposed_shape.empty(),
822 "shape '[]' is invalid for a nested tensor");
823 auto self_ptr = get_nested_tensor_impl(self);
824 // basic information before reshaping
825 int64_t ntensors = self_ptr->size(0);
826 TORCH_CHECK(
827 ntensors > 0,
828 "empty nested tensor cannot be reshaped");
829 // basic information after reshaping
830 int64_t ntensors_reshaped = proposed_shape[0];
831 TORCH_CHECK(
832 ntensors == ntensors_reshaped,
833 "view: For now nested view cannot change or infer the implicit batch dimension");
834 std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
835 strides = NestedTensor_get_strides(self_ptr);
836 // reshaping underlying tensor dimensions does not change offset
837 // determine reshaped size and stride
838 const Tensor& sizemat = self_ptr->get_nested_sizes();
839 auto [viewable, sizemat_reshaped, stridemat_reshaped] = NestedTensor_compute_size_stride(
840 sizes, strides, proposed_shape, sizemat.options());
841 TORCH_CHECK(
842 viewable,
843 "view size is not compatible with input tensor's size and stride "
844 "(at least one dimension spans across two contiguous subspaces). "
845 "Use .reshape(...) instead.");
846 return create_nested_view_tensor(self, sizemat_reshaped, stridemat_reshaped, self_ptr->get_storage_offsets().clone());
847 }
848 /**
849 * Create a buffer tensor that is a view of self
850 *
851 * This serves as the boundary between nested and non nested tensor
852 * view conversions
853 *
854 * @return Returns a new non nested tensor that
855 * aliases the same storage as self
856 */
values_nested(const Tensor & self)857 Tensor values_nested(const Tensor& self) {
858 TORCH_INTERNAL_ASSERT(self.is_nested(), "Can only create a buffer from Nested Tensor");
859 auto* nt_self = get_nested_tensor_impl(self);
860 return nt_self->get_unsafe_storage_as_tensor();
861 }
862
863 /**
864 * Create a nested tensor that is a view of a buffer
865 *
866 * This serves as the boundary between non nested tensor and nested
867 * view conversions
868 *
869 * @return Returns a nested tensor that
870 * aliases the same storage as buffer
871 */
_nested_view_from_buffer(const Tensor & buffer,const Tensor & nested_sizes,const Tensor & nested_strides,const Tensor & storage_offsets)872 Tensor _nested_view_from_buffer(
873 const Tensor& buffer,
874 const Tensor& nested_sizes,
875 const Tensor& nested_strides,
876 const Tensor& storage_offsets) {
877 TORCH_INTERNAL_ASSERT(
878 !buffer.is_nested(),
879 "Can only a create Nested Tensor from a normal tensor buffer");
880 TORCH_INTERNAL_ASSERT(buffer.dim() == 1, "The input buffer must be flat");
881 TORCH_INTERNAL_ASSERT(nested_sizes.dim() == 2, "Expected the nested size tensor to be two dimensional.");
882 uint64_t num_elements_nested_size = at::prod(nested_sizes, 1).sum().item<int64_t>();
883 uint64_t buffer_storage_size = buffer.storage().nbytes()/buffer.dtype().itemsize();
884 TORCH_INTERNAL_ASSERT(
885 buffer_storage_size == num_elements_nested_size,
886 "The number of elements in the buffer must equal the nested tensor size but buffer size: ",
887 buffer_storage_size,
888 " and nested tensor size: ",
889 num_elements_nested_size,
890 ".");
891
892 TORCH_INTERNAL_ASSERT(nested_strides.dim() == 2, "Expected the nested stride tensor to be two dimensional.");
893 TORCH_INTERNAL_ASSERT(nested_sizes.size(0) == nested_strides.size(0), "Expected the first dimension of nested size and nested stride tensor to be equal.");
894 TORCH_INTERNAL_ASSERT(nested_strides.size(0) == storage_offsets.size(0), "Expected the first dimension of nested stride tensor to equal the length of offsets.");
895 return at::detail::make_tensor<NestedTensorImpl>(
896 c10::TensorImpl::VIEW,
897 buffer,
898 nested_sizes,
899 nested_strides,
900 storage_offsets);
901 }
902
_nested_compute_contiguous_strides_offsets(const Tensor & nested_size)903 std::tuple<Tensor, Tensor> _nested_compute_contiguous_strides_offsets(const Tensor& nested_size) {
904 return std::make_tuple(
905 construct_nested_strides(nested_size),
906 construct_offsets(nested_size));
907 }
908
909 // See Note [Special size rule for nested tensor]
reshape_nested(const Tensor & self,IntArrayRef proposed_shape)910 Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
911 TORCH_CHECK(
912 !proposed_shape.empty(),
913 "shape '[]' is invalid for a nested tensor");
914 auto self_ptr = get_nested_tensor_impl(self);
915 // basic information before reshaping
916 int64_t ntensors = self_ptr->size(0);
917 TORCH_CHECK(
918 ntensors > 0,
919 "empty nested tensor cannot be reshaped");
920 // basic information after reshaping
921 int64_t ntensors_reshaped = proposed_shape[0];
922 TORCH_CHECK(
923 ntensors == ntensors_reshaped,
924 "reshape: For now nested reshape cannot change or infer the implicit batch dimension");
925 std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
926 strides = NestedTensor_get_strides(self_ptr);
927 // reshaping underlying tensor dimensions does not change offset
928 // determine reshaped size and stride
929 const Tensor& sizemat = self_ptr->get_nested_sizes();
930 auto [viewable, sizemat_reshaped, stridemat_reshaped] = NestedTensor_compute_size_stride(
931 sizes, strides, proposed_shape, sizemat.options());
932 if (viewable) {
933 return self.view(proposed_shape);
934 }
935 else {
936 return self.clone(at::MemoryFormat::Contiguous).view(proposed_shape);
937 }
938 }
939
reshape_nested_symint(const Tensor & self,SymIntArrayRef proposed_shape)940 Tensor reshape_nested_symint(const Tensor& self, SymIntArrayRef proposed_shape) {
941 // Jagged layout NT decomp
942 if (self.layout() == at::kJagged) {
943 // TODO: Expand decomp to handle other viewable cases
944 bool viewable = self.is_contiguous();
945 return (
946 viewable ? self.view_symint(proposed_shape) :
947 self.clone(at::MemoryFormat::Contiguous).view_symint(proposed_shape)
948 );
949 }
950
951 return reshape_nested(self, C10_AS_INTARRAYREF_SLOW(proposed_shape));
952 }
953
reshape_as_nested(const Tensor & self,const Tensor & other)954 Tensor reshape_as_nested(const Tensor& self, const Tensor& other) {
955 // Jagged layout NT decomp
956 if (self.layout() == at::kJagged) {
957 return self.reshape_symint(other.sym_sizes());
958 }
959
960 auto other_ptr = get_nested_tensor_impl(other);
961 // TODO: this is to reproduce other_ptr->opt_sizes_
962 // if an accessor is provided in the future, can replace this
963 std::vector<int64_t> sizes;
964 for (int64_t i = 0; i < other_ptr->dim(); i++) {
965 std::optional<int64_t> opt_size = other_ptr->opt_size(i);
966 if (opt_size.has_value()) {
967 sizes.push_back(*opt_size);
968 }
969 else {
970 sizes.push_back(-1);
971 }
972 }
973 // reshape with other.opt_sizes_
974 return self.reshape(sizes);
975 }
976
normal_nested_(Tensor & self,double mean,double std,std::optional<Generator> gen)977 Tensor& normal_nested_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
978 const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
979 self_buf.normal_(mean, std, std::move(gen));
980 return self;
981 }
982
983 // returns true if the sizes are compatible to be concatenated along the specified dim
984 // sizes should match outside of the dim of concatenation
can_cat_nested_sizes(const Tensor & nested_sizes1,const Tensor & nested_sizes2,int64_t cat_dim)985 static bool can_cat_nested_sizes(const Tensor& nested_sizes1, const Tensor& nested_sizes2, int64_t cat_dim) {
986 if (nested_sizes1.sizes() != nested_sizes2.sizes()) {
987 return false;
988 }
989
990 auto nested_sizes1_ptr = nested_sizes1.data_ptr<int64_t>();
991 auto nested_sizes2_ptr = nested_sizes2.data_ptr<int64_t>();
992 const auto num_components = nested_sizes1.size(0);
993 const auto num_dims = nested_sizes1.size(1);
994 for (auto c : c10::irange(num_components)) {
995 for (auto d : c10::irange(num_dims)) {
996 // subtract 1 to account for batch dim
997 auto component_cat_dim = cat_dim - 1;
998 if (d == component_cat_dim) {
999 continue;
1000 }
1001 if (nested_sizes1_ptr[c * num_dims + d] != nested_sizes2_ptr[c * num_dims + d]) {
1002 return false;
1003 }
1004 }
1005 }
1006
1007 return true;
1008 }
1009
1010 // cat a list of NTs that are representable as jagged
cat_nested_as_jagged(const MaterializedITensorListRef & tensors,int64_t dim)1011 static Tensor cat_nested_as_jagged(
1012 const MaterializedITensorListRef& tensors,
1013 int64_t dim) {
1014 const auto first_item = tensors[0].get();
1015 const auto first_item_dim = first_item.dim();
1016 const auto first_item_batch_size = first_item.size(0);
1017 std::vector<Tensor> jagged_views;
1018 for (auto i : c10::irange(tensors.size())) {
1019 auto t = tensors[i].get();
1020 TORCH_CHECK(t.is_nested(),
1021 "cat(): expected each tensor in given list to be nested");
1022 TORCH_CHECK(t.is_contiguous(),
1023 "cat(): only contiguous nested tensors are supported");
1024 if (i > 0) {
1025 TORCH_CHECK(
1026 can_cat_nested_sizes(
1027 get_nested_tensor_impl(first_item)->get_nested_sizes(),
1028 get_nested_tensor_impl(t)->get_nested_sizes(),
1029 dim),
1030 "cat(): expected all nested tensors to have matching ragged structures outside of the concatenated dim");
1031 }
1032 // only support inputs in the form (B, *, D_0, D_1, ...)
1033 // i.e. require at most a single ragged dim next to the batch dim
1034 auto *nt_impl = get_nested_tensor_impl(t);
1035 std::vector<int64_t> jagged_size;
1036 jagged_size.push_back(-1);
1037 for (auto d : c10::irange(first_item_dim - 2)) {
1038 TORCH_CHECK(nt_impl->opt_size(d + 2).has_value(),
1039 "cat(): only nested tensors with a single ragged dim next to the batch dim are supported");
1040 jagged_size.push_back(nt_impl->size(d + 2));
1041 }
1042 auto jagged = nt_impl->get_buffer().view(jagged_size);
1043 jagged_views.push_back(jagged);
1044 }
1045
1046 // view each of the NTs as jagged for the cat() call
1047 auto new_buffer = at::cat(jagged_views, dim - 1);
1048
1049 // wrap result into nested tensor
1050 const auto component_dim = first_item_dim - 1;
1051 auto new_dim_size = new_buffer.size(dim - 1);
1052 auto new_sizes = get_nested_tensor_impl(tensors[0].get())->get_nested_sizes().clone();
1053 auto new_sizes_ptr = new_sizes.data_ptr<int64_t>();
1054 for (const auto i : c10::irange(first_item_batch_size)) {
1055 new_sizes_ptr[i * component_dim + (dim - 1)] = new_dim_size;
1056 }
1057 return at::detail::make_tensor<NestedTensorImpl>(
1058 new_buffer.view(-1), new_sizes);
1059 }
1060
cat_nested_impl(const MaterializedITensorListRef & tensors,int64_t dim)1061 static Tensor cat_nested_impl(
1062 const MaterializedITensorListRef& tensors,
1063 int64_t dim) {
1064 dim = maybe_wrap_dim(dim, tensors[0].get());
1065 if (dim == 0) {
1066 // handle simple case of dim=0: concat NT components
1067 std::vector<at::Tensor> buffers;
1068 std::vector<at::Tensor> sizes;
1069 for (const auto i : c10::irange(tensors.size())) {
1070 const Tensor& t = tensors[i];
1071 TORCH_CHECK(
1072 t.is_nested(), "Expected each tensor in given list to be nested.");
1073 TORCH_CHECK(
1074 t.is_contiguous(),
1075 "Expected each tensor in given list to be contiguous.");
1076 auto t_ptr = get_nested_tensor_impl(t);
1077 buffers.push_back(t_ptr->get_buffer().view({-1}));
1078 sizes.push_back(t_ptr->get_nested_sizes());
1079 }
1080 return at::detail::make_tensor<NestedTensorImpl>(
1081 at::cat(buffers).view({-1}), at::cat(sizes, 0));
1082 }
1083
1084 // NB: support for other dims is restricted to nested tensors representable as jagged
1085 return cat_nested_as_jagged(tensors, dim);
1086 }
1087
cat_nested(const ITensorListRef & tensors,int64_t dim)1088 Tensor cat_nested(const ITensorListRef& tensors, int64_t dim) {
1089 auto materialized = tensors.materialize();
1090 return cat_nested_impl(materialized, at::legacy_cat_wrap_dim(dim, materialized));
1091 }
1092
1093 } // namespace at::native
1094