xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorMath.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/nested/NestedTensorMath.h>
2 
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #include <ATen/NestedTensorImpl.h>
8 #include <ATen/ScalarOps.h>
9 #include <ATen/TensorIndexing.h>
10 #include <ATen/TensorOperators.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/WrapDimUtilsMulti.h>
13 #include <ATen/core/Tensor.h>
14 #include <ATen/native/layer_norm.h>
15 #include <ATen/native/nested/NestedTensorUtils.h>
16 
17 #include <tuple>
18 #include <utility>
19 
20 
21 namespace at::native {
22 namespace {
23 
num_bytes(IntArrayRef sizes)24 int64_t num_bytes(IntArrayRef sizes) {
25   // 0-dim Tensors have torch.Size of .size() 0, but carry 1 memory.
26   // Empty 1-dim Tensors (torch.tensor([])) have torch.Size of .size() 1,
27   // but carry 0 memory.
28   int64_t result = 1;
29   int64_t stride = 1;
30   for (int64_t ii = static_cast<int64_t>(sizes.size()) - 1; ii >= 0; --ii) {
31     result += (sizes[ii] - 1) * stride;
32     // TODO: accept strides as input when we support them instead of
33     // assuming contiguous.
34     stride *= sizes[ii];
35   }
36   return result;
37 }
38 
pad_tensor_to_shape(const Tensor & t,IntArrayRef goal_shape,double value=0)39 Tensor pad_tensor_to_shape(
40     const Tensor& t,
41     IntArrayRef goal_shape,
42     double value = 0) {
43   std::vector<int64_t> padd;
44   auto tup = t.sizes();
45   TORCH_CHECK(
46       t.dim() == (int64_t)(goal_shape.size()),
47       "dimension ",
48       t.dim(),
49       " doesn't match length ",
50       goal_shape.size(),
51       " of goal shape.");
52   for (int64_t i = static_cast<int64_t>(tup.size()) - 1; i >= 0; i--) {
53     padd.push_back(0);
54     padd.push_back(goal_shape[i] - tup[i]);
55   }
56   Tensor new_tensor = at::constant_pad_nd(t, IntArrayRef(padd), value);
57   new_tensor = new_tensor.reshape(goal_shape);
58   return new_tensor;
59 }
60 } // namespace
61 
62 
NestedTensor_nested_tensor_from_mask(const Tensor & t,const Tensor & mask,bool mask_check)63 Tensor NestedTensor_nested_tensor_from_mask(const Tensor& t, const Tensor& mask, bool mask_check) {
64     TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Expected mask to be of ScalarType Bool, but got ", mask.scalar_type(), " instead.");
65     TORCH_CHECK(mask.dim() == 2, "Padding mask should be 2D");
66     TORCH_CHECK(t.dim() == 3, "Input should be a 3D tensor, N * L * D");
67     auto N = t.size(0), L = t.size(1), D = t.size(2);
68     auto NN = mask.size(0), LL = mask.size(1);
69     TORCH_CHECK(N == NN && L == LL, "Mask size should match input size");
70 
71     // N * L
72     Tensor sizes = mask;
73     Tensor tmp_pad = at::zeros({N, 1}, mask.options());
74     // Make sure padding is only added at the end of mask
75     Tensor nums = at::cat({sizes, tmp_pad}, 1).to(kInt).argmin(1);
76 
77     // N, ([size1, size2, ... sizeN])
78     sizes = sizes.cumsum(1).select(1, L - 1);
79     nums = nums.to(sizes.options());
80 
81     if (mask_check)
82       TORCH_CHECK(sizes.equal(nums), "Mask must be left-aligned without gaps");
83 
84     sizes = sizes.reshape({N, 1});
85     // N, ([d1=D, d2=D, ... dN=D])
86     Tensor d = at::full_like(sizes, D);
87 
88     // N * 2, ([[size1, D], [size2, D], ..., [sizeN, D]])
89     sizes = at::cat({sizes, d}, 1).to(kCPU);
90 
91     return at::_nested_from_padded(t, sizes, false);
92 }
93 
NestedTensor_nested_tensor_from_mask_left_aligned(const Tensor & t,const Tensor & mask)94 bool NestedTensor_nested_tensor_from_mask_left_aligned(const Tensor& t, const Tensor& mask) {
95     TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Expected mask to be of ScalarType Bool, but got ", mask.scalar_type(), " instead.");
96     TORCH_CHECK(mask.dim() == 2, "Padding mask should be 2D");
97     TORCH_CHECK(t.dim() == 3, "Input should be a 3D tensor, N * L * D");
98     auto N = t.size(0), L = t.size(1);
99     auto NN = mask.size(0), LL = mask.size(1);
100     TORCH_CHECK(N == NN && L == LL, "Mask size should match input size");
101 
102     // N * L
103     Tensor sizes = mask;
104     Tensor tmp_pad = at::zeros({N, 1}, mask.options());
105     // Make sure padding is only added at the end of mask
106     Tensor nums = at::cat({sizes, tmp_pad}, 1).to(kInt).argmin(1);
107 
108     // N, ([size1, size2, ... sizeN])
109     sizes = sizes.cumsum(1).select(1, L - 1);
110     nums = nums.to(sizes.options());
111 
112     return sizes.equal(nums);
113 }
114 
_nested_tensor_from_tensor_list(TensorList list,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)115 Tensor _nested_tensor_from_tensor_list(
116     TensorList list,
117     std::optional<ScalarType> dtype,
118     std::optional<Layout> layout,
119     std::optional<Device> device,
120     std::optional<bool> pin_memory) {
121   for (const auto i : c10::irange(list.size())) {
122     if (i > 0) {
123       int64_t dim_i = list[i].dim();
124       int64_t dim_prev = list[i - 1].dim();
125       TORCH_CHECK(
126           dim_i == dim_prev,
127           "All Tensors given to nested_tensor must have the same dimension. ",
128           "Found dimension ",
129           dim_i,
130           " for Tensor at index ",
131           i,
132           " and dimension ",
133           dim_prev,
134           " for Tensor at index ",
135           i - 1,
136           ".");
137     }
138   }
139   return impl::wrap_tensor_node(
140       impl::TensorNode(list),
141       dtype,
142       layout,
143       device,
144       pin_memory);
145 }
146 
nested_layer_norm(const Tensor & input,IntArrayRef normalized_shape,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,double eps)147 std::tuple<Tensor, Tensor, Tensor> nested_layer_norm(
148     const Tensor& input,
149     IntArrayRef normalized_shape,
150     const std::optional<Tensor>& weight_opt,
151     const std::optional<Tensor>& bias_opt,
152     double eps) {
153   TORCH_CHECK(weight_opt && bias_opt, "NestedTensor layer_norm requires weight and bias");
154   const auto& weight = *weight_opt;
155   const auto& bias = *bias_opt;
156   TORCH_CHECK(!weight.is_nested(), "NestedTensor weight not supported for layer_norm");
157   TORCH_CHECK(!bias.is_nested(), "NestedTensor bias not supported for layer_norm");
158   auto* nt_input = get_nested_tensor_impl(input);
159   TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_input));
160   const auto& input_buffer = nt_input->get_buffer();
161   auto M_N = _check_nested_layer_norm_inputs(*nt_input, normalized_shape, weight, bias);
162   auto M = M_N.first;
163   auto N = M_N.second;
164   const auto weight_contig = weight.expect_contiguous();
165   const auto bias_contig = bias.expect_contiguous();
166   auto output_buffer = at::native::empty_like(
167       input_buffer,
168       std::nullopt /* dtype */,
169       std::nullopt /* layout */,
170       std::nullopt /* device */,
171       std::nullopt /* pin_memory */,
172       at::MemoryFormat::Contiguous);
173   auto options = input_buffer.options();
174   if (input_buffer.is_cuda()) {
175     auto acc_type = at::toAccumulateType(input_buffer.scalar_type(), true);
176     options = options.dtype(acc_type);
177   }
178   Tensor mean = at::empty({M}, options);
179   Tensor rstd = at::empty({M}, options);
180   LayerNormKernel(
181       input_buffer.is_cuda() ? kCUDA : kCPU,
182       input_buffer,
183       *weight_contig,
184       *bias_contig,
185       M,
186       N,
187       eps,
188       &output_buffer,
189       &mean,
190       &rstd);
191   return std::make_tuple(
192     wrap_buffer(output_buffer, nt_input->get_nested_sizes()),
193     mean,
194     rstd
195   );
196 }
197 
NestedTensor_from_padded_and_nested_example(const Tensor & padded,const Tensor & nt_example)198 Tensor NestedTensor_from_padded_and_nested_example(
199     const Tensor& padded,
200     const Tensor& nt_example) {
201   return _nested_from_padded(padded, get_nested_tensor_impl(nt_example)->get_nested_sizes());
202 }
203 
nested_from_padded_generic(const Tensor & padded,const Tensor & sizes,const bool do_transform_0213)204 Tensor nested_from_padded_generic(
205     const Tensor& padded,
206     const Tensor& sizes,
207     const bool do_transform_0213) {
208   // Check and do transform 0213
209   auto padded_transformed = padded;
210   if (do_transform_0213) {
211     padded_transformed = padded.permute({0, 2, 1, 3})
212       .contiguous()
213       .view(
214           {padded.size(0),
215            padded.size(2),
216            padded.size(1) * padded.size(3)});
217   }
218   auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes);
219   // There may be extra padding on padded beyond the max size in the nested tensor.
220   // Make the mask size match.
221   const size_t dim = padded_transformed.dim();
222   TORCH_CHECK(dim - 1 == target_size.size(), "dim: ", dim, "target_size: ", target_size.size());
223   for (size_t ii = 0; ii < dim - 1; ++ii) {
224     const auto padded_size_i = padded_transformed.sizes()[ii + 1];
225     if (target_size[ii] < padded_size_i) {
226       target_size[ii] = padded_size_i;
227     }
228   }
229   IntArrayRef target_size_arr(target_size);
230   std::vector<at::Tensor> masks;
231   std::vector<at::Tensor> all_sizes = sizes.unbind();
232   for (const auto& size : all_sizes) {
233     IntArrayRef sizes_i(
234         size.data_ptr<int64_t>(), size.data_ptr<int64_t>() + size.numel());
235     at::Tensor mask_i = padded_transformed.new_full(
236         sizes_i, true, kBool, std::nullopt, std::nullopt, std::nullopt);
237     masks.push_back(pad_tensor_to_shape(mask_i, target_size_arr));
238   }
239   at::Tensor final_mask = at::stack(masks);
240   at::Tensor new_buffer = padded_transformed.masked_select(final_mask).to(padded.device());
241   return at::detail::make_tensor<NestedTensorImpl>(
242       std::move(new_buffer), sizes);
243 }
244 
NestedTensor_to_padded_tensor_generic(const Tensor & t,double padding,OptionalIntArrayRef output_size)245 Tensor NestedTensor_to_padded_tensor_generic(
246     const Tensor& t,
247     double padding,
248     OptionalIntArrayRef output_size) {
249   // TODO: support noncontiguous case
250   // error out for now
251   TORCH_CHECK(
252       nested_tensor_impl_is_contiguous(get_nested_tensor_impl(t)),
253       "for now to_padded_tensor only supports contiguous nested tensor");
254   // TODO: skipped optimization for case of all 1x1 tensors
255   auto& nt = *get_nested_tensor_impl(t);
256   auto max_size = NestedTensor_get_max_size(nt);
257   auto sizes = nt.get_nested_sizes();
258 
259   if (sizes.numel() == 0 || sizes.dim() == 0) {
260     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(nt.get_buffer().numel() == 0);
261     return nt.get_buffer().clone();
262   }
263   TORCH_CHECK(
264       t.numel() > 0,
265       "to_padded_tensor: at least one constituent tensor should have non-zero numel"
266   )
267 
268   // TODO: doesn't handle empty/scalar entries because we don't need
269   // it for transformers; see to_padded_tensor in
270   // pytorch/nestedtensor's masking.cpp.
271 
272   const auto sizes_num_rows = sizes.sizes()[0];
273   const auto sizes_num_columns = sizes.sizes()[1];
274   const auto sizes_data_start = sizes.data_ptr<int64_t>();
275   const auto sizes_data_end = sizes_data_start + sizes.numel();
276   std::vector<int64_t> split_sizes;
277   split_sizes.reserve(sizes_num_rows);
278   for (auto sizes_data = sizes_data_start; sizes_data != sizes_data_end;
279        sizes_data += sizes_num_columns) {
280     split_sizes.push_back(
281         num_bytes(IntArrayRef(sizes_data, sizes_num_columns)));
282   }
283   std::vector<int64_t> nonzero_split_sizes;
284   for (const auto split_size : split_sizes) {
285     if (split_size > 0) {
286       nonzero_split_sizes.push_back(split_size);
287     }
288   }
289   const auto buffer = nt.get_buffer();
290   std::vector<Tensor> buffers_;
291   if (!nonzero_split_sizes.empty()) {
292     buffers_ = at::split_with_sizes(buffer, nonzero_split_sizes, 0);
293   }
294 
295   std::vector<Tensor> buffers;
296   buffers.reserve(split_sizes.size());
297   int64_t next_buffer = 0;
298   auto sizes_ptr = sizes_data_start;
299   for (const auto split_size : split_sizes) {
300     Tensor to_pad;
301     IntArrayRef tensor_sizes(sizes_ptr, sizes_num_columns);
302     if (split_size > 0) {
303       to_pad = buffers_[next_buffer++].reshape(tensor_sizes);
304     } else {
305       to_pad = at::empty(tensor_sizes, buffer.options());
306     }
307     buffers.push_back(pad_tensor_to_shape(to_pad, max_size, padding));
308     sizes_ptr += sizes_num_columns;
309   }
310   auto ret_val = at::stack(buffers);
311 
312   // Pad output tensor to output_size if provided
313   if (output_size.has_value()) {
314     auto output_size_ = output_size.value();
315     TORCH_CHECK(
316         (int64_t)output_size_.size() == ret_val.dim(),
317         "Length of output_size does not match NestedTensor dims. Broadcasting is not supported.");
318     for (int64_t i = 0; i < (int64_t)ret_val.dim(); i++) {
319       TORCH_CHECK(
320           output_size_[i] >= ret_val.size(i),
321           "Value in output_size is less than NestedTensor padded size. Truncation is not supported.");
322     }
323     return pad_tensor_to_shape(ret_val, output_size_, padding);
324   }
325   return ret_val;
326 }
327 
NestedTensor_embedding(const Tensor & weight,const Tensor & indices,int64_t padding_idx,bool scale_grad_by_freq,bool sparse)328 Tensor NestedTensor_embedding(
329     const Tensor& weight,
330     const Tensor& indices,
331     int64_t padding_idx,
332     bool scale_grad_by_freq,
333     bool sparse) {
334   const auto* nt_indices = get_nested_tensor_impl(indices);
335   TORCH_CHECK(
336       !weight.is_nested(), "NestedTensor weight not supported for embedding");
337   TORCH_CHECK(indices.dim() < 3);
338   TORCH_CHECK(indices.dim() > 0, "NestedTensor embedding doesn't support empty indices.")
339   TORCH_CHECK(weight.dim() == 2);
340   TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_indices));
341   TORCH_CHECK(weight.is_contiguous());
342 
343   const auto& indices_buffer = nt_indices->get_buffer();
344   auto result_buffer = at::embedding(
345       weight, indices_buffer, padding_idx, scale_grad_by_freq, sparse);
346   const auto& sizes = nt_indices->get_nested_sizes();
347   auto new_sizes = at::empty({sizes.size(0)}, sizes.options());
348   new_sizes.fill_(weight.sizes()[1]);
349   new_sizes = new_sizes.reshape({new_sizes.size(0), 1});
350   new_sizes = at::cat({sizes, new_sizes}, 1);
351   return at::detail::make_tensor<NestedTensorImpl>(
352       result_buffer.reshape({-1}), std::move(new_sizes));
353 }
354 
355 // Very rudimentary sum_dim for prototyping with torch_scatter.segment_reduce.
NestedTensor_sum_dim_CPU(const Tensor & self,OptionalIntArrayRef opt_dims,bool keepdim,std::optional<ScalarType> dtype)356 Tensor NestedTensor_sum_dim_CPU(
357     const Tensor& self,
358     OptionalIntArrayRef opt_dims,
359     bool keepdim,
360     std::optional<ScalarType> dtype) {
361   // Only allow reductions across the last dim
362   auto dims = opt_dims.value_or(IntArrayRef{});
363   TORCH_CHECK(
364       dims.size() == 1,
365       "NestedTensor only allows reduction of a single dimension for now."
366   );
367   auto dim = maybe_wrap_dim(dims[0], self.dim());
368   TORCH_CHECK(
369       dim == self.dim() - 1,
370       "NestedTensor can only be reduced across the last dimension for now ",
371       "got dimension ",
372       dim,
373       " instead.");
374   // Always keep reduced dim for now
375   // This is to avoid the case where the nested tensors are 1D and keepdim=False
376   // making the nested tensors -> elements (e.g. sum(nt([1, 2 ,3], [4, 5]), -1) -> nt(6, 9))
377   TORCH_CHECK(keepdim, "NestedTensor always requires keepdim=True for now.");
378   // acc_dtype is not supported for now
379   TORCH_CHECK(!dtype, "NestedTensor does not support dtype argument for now.");
380 
381   auto nt_input = get_nested_tensor_impl(self);
382   TORCH_CHECK(
383       nested_tensor_impl_is_contiguous(nt_input),
384       "NestedTensor does not support reductions when the input is noncontiguous for now.");
385   int64_t ntensors = nt_input->size(0);
386   if (ntensors == 0) {
387     return self;
388   }
389   const Tensor& buffer = nt_input->get_buffer();
390 
391   auto sizemat = nt_input->get_nested_sizes();
392   // create output size tensor for keepdim=True
393   auto output_sizemat = sizemat.clone();
394   output_sizemat.select(1, -1).fill_(1);
395 
396   auto num_segments = at::prod(output_sizemat, -1);
397   auto segment_lengths = sizemat.select(1, -1);
398   const int64_t new_numel = at::sum(num_segments).item<int64_t>();
399   auto output_buffer = buffer.new_empty(IntArrayRef(new_numel));
400 
401   // This logic assumes for now that
402   // (1) all the nested tensors are contiguous
403   // (2) the nested tensors are stored contiguously in the buffer
404   AT_DISPATCH_ALL_TYPES_AND2(
405     ScalarType::Half, ScalarType::BFloat16, buffer.scalar_type(), "nested_sum_dim_cpu", [&]() {
406     auto* output_data = output_buffer.data_ptr<scalar_t>();
407     const auto* input_data = buffer.const_data_ptr<scalar_t>();
408     int64_t out_idx = 0, in_idx = 0;
409     for (const auto i : c10::irange(ntensors)) {
410       int64_t segments = num_segments[i].item<int64_t>();
411       int64_t segment_length = segment_lengths[i].item<int64_t>();
412       for (auto j = 0; j < segments; j++) {
413         scalar_t res = 0;
414         for (auto k = 0; k < segment_length; k++) {
415           res += input_data[in_idx];
416           in_idx += 1;
417         }
418         output_data[out_idx] = res;
419         out_idx += 1;
420       }
421     }
422   });
423 
424   return wrap_buffer(output_buffer, output_sizemat);
425 }
426 
select_nested(const Tensor & self,int64_t dim,int64_t index)427 Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) {
428   auto self_ptr = get_nested_tensor_impl(self);
429   std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
430                            strides = NestedTensor_get_strides(self_ptr);
431   int64_t *offsets_ptr = self_ptr->get_storage_offsets().data_ptr<int64_t>();
432   const at::Tensor& buffer = self_ptr->get_unsafe_storage_as_tensor();
433   int64_t positive_dim = at::maybe_wrap_dim(dim, self_ptr->dim());
434   int64_t ntensors = self_ptr->size(0);
435   TORCH_CHECK_INDEX(ntensors > 0, "You can only select when the NT is not empty.");
436   int64_t ndims = static_cast<long>(sizes[0].size());
437   if (positive_dim == 0) {
438     TORCH_CHECK_INDEX(
439         index >= -ntensors && index < ntensors,
440         "index ",
441         index,
442         " is out of bounds for dimension 0 with size ",
443         ntensors);
444     int64_t positive_index = index < 0 ? index + ntensors : index;
445     return buffer.as_strided(
446         sizes[positive_index],
447         strides[positive_index],
448         offsets_ptr[positive_index]);
449   } else {
450     auto new_sizes = at::empty({ntensors, ndims-1}, TensorOptions().dtype(kLong));
451     auto new_strides = at::empty({ntensors, ndims-1}, TensorOptions().dtype(kLong));
452     auto new_offsets = at::empty({ntensors}, TensorOptions().dtype(kLong));
453     for (int64_t i : c10::irange(ntensors)) {
454       int64_t *size_ptr = new_sizes[i].data_ptr<int64_t>();
455       int64_t *stride_ptr = new_strides[i].data_ptr<int64_t>();
456 
457       int64_t dim_idx = 0;
458       for (int64_t j : c10::irange(ndims)) {
459         if (j != dim - 1) {
460           size_ptr[dim_idx] = sizes[i][j];
461           stride_ptr[dim_idx] = strides[i][j];
462           ++dim_idx;
463         } else {
464           TORCH_CHECK_INDEX(
465               index >= 0 && index < sizes[i][j],
466               "index ",
467               index,
468               " is out of bounds for dimension ",
469               j,
470               " of the ",
471               i,
472               "th constituent tensor with size ",
473               sizes[i][j]);
474           new_offsets[i] = offsets_ptr[i] + index * strides[i][j];
475         }
476       }
477     }
478     return create_nested_view_tensor(self, new_sizes, new_strides, new_offsets);
479   }
480 
481 }
482 
native_dropout_nested(const Tensor & input,double p,std::optional<bool> train)483 std::tuple<Tensor,Tensor> native_dropout_nested(const Tensor& input, double p, std::optional<bool> train) {
484   auto input_ptr = get_nested_tensor_impl(input);
485   const Tensor& input_buffer = input_ptr-> get_unsafe_storage_as_tensor(),
486       & sizemat = input_ptr->get_nested_sizes(),
487       & stridemat = input_ptr->get_nested_strides();
488   const auto offsets = input_ptr->get_storage_offsets();
489   Tensor output_buffer, mask_buffer;
490   if (input_buffer.numel() == 0) {
491     output_buffer = input_buffer.clone();
492     mask_buffer = input_buffer.clone();
493   }
494   else {
495     std::tie(output_buffer, mask_buffer) = at::native_dropout(input_buffer, p, train);
496   }
497   // regular tensor dropout reuses input size and stride
498   // i.e. if input is not contiguous, then output is also discontiguous
499   Tensor output = wrap_buffer(output_buffer, sizemat.clone(), stridemat.clone(), offsets.clone()),
500       mask = wrap_buffer(mask_buffer, sizemat.clone(), stridemat.clone(), offsets.clone());
501   return std::make_tuple(output, mask);
502 }
503 
softmax_nested(const Tensor & input,const int64_t dim,const bool half_to_float)504 Tensor softmax_nested(
505     const Tensor& input,
506     const int64_t dim,
507     const bool half_to_float) {
508   auto input_ptr = get_nested_tensor_impl(input);
509   int64_t ntensors = input_ptr->size(0);
510   if (ntensors == 0) {
511     return input.clone();
512   }
513   int64_t positive_dim = at::maybe_wrap_dim(dim, input_ptr->dim());
514   TORCH_CHECK(
515       positive_dim >= 1,
516       "Cannot apply softmax across nested dimension 0");
517   // create a contiguous output
518   // TODO We would ideally use a empty_like here, but that is not supported
519   // for nested tensors yet. Since we are only using the buffer for the options
520   // and size it is okay to use unsafe_storage_as_tensor here.
521   const Tensor& buffer = input_ptr->get_unsafe_storage_as_tensor(),
522       & sizemat = input_ptr->get_nested_sizes();
523   Tensor output_buffer = buffer.new_empty(buffer.sizes());
524   Tensor output = wrap_buffer(output_buffer, sizemat.clone());
525   // call tensor softmax
526   // TODO: for cpu, maybe use `parallel_for` if benchmarks show necessity
527   //       to do that, have to merge `aten/src/ATen/native/cpu/SoftMaxKernel.cpp/softmax_kernel`
528   //       1. it has `parallel_for` and we cannot multi-thread in multi-thread
529   //       2. cannot dispatch in multi-thread (in this case at::_softmax_out)
530   std::vector<Tensor> input_unbind = input.unbind(),
531       output_unbind = output.unbind();
532   for (int64_t i = 0; i < ntensors; i++) {
533     at::_softmax_out(
534         output_unbind[i],
535         input_unbind[i],
536         positive_dim - 1,
537         half_to_float);
538   }
539   return output;
540 }
541 
NestedTensor_all(const Tensor & input,const int64_t dim,const bool keepdim)542 Tensor NestedTensor_all(
543     const Tensor& input,
544     const int64_t dim,
545     const bool keepdim) {
546   auto input_ptr = get_nested_tensor_impl(input);
547   int64_t ntensors = input_ptr->size(0);
548   if (ntensors == 0) {
549     return input.clone();
550   }
551   int64_t positive_dim = at::maybe_wrap_dim(dim, input_ptr->dim());
552   TORCH_CHECK(
553       positive_dim >= 1,
554       "Cannot apply all across nested dimension 0");
555   const Tensor& buffer = input_ptr->get_unsafe_storage_as_tensor(),
556       & sizemat = input_ptr->get_nested_sizes();
557 
558 
559   Tensor output_buffer = buffer.new_empty(buffer.sizes());
560 
561   Tensor output_size = sizemat.clone();
562   if (keepdim) {
563     output_size.select(1, positive_dim - 1).fill_(1);
564   } else {
565     output_size = output_size.slice(1, 0, positive_dim - 1);
566   }
567 
568   Tensor output = wrap_buffer(output_buffer, output_size.contiguous());
569 
570   std::vector<Tensor> input_unbind = input.unbind(),
571       output_unbind = output.unbind();
572   for (int64_t i = 0; i < ntensors; i++) {
573     at::all_out(
574         output_unbind[i],
575         input_unbind[i],
576         positive_dim - 1,
577         keepdim);
578   }
579   return output;
580 }
581 
transpose_nested(const Tensor & self,int64_t dim0,int64_t dim1)582 Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) {
583   auto self_ptr = get_nested_tensor_impl(self);
584   // check input dimensions
585   int64_t ndims = self_ptr->dim();
586   int64_t positive_dim0 = at::maybe_wrap_dim(dim0, ndims),
587       positive_dim1 = at::maybe_wrap_dim(dim1, ndims);
588   if (positive_dim0 == positive_dim1) {
589     return self;
590   }
591   TORCH_CHECK(positive_dim0 > 0 && positive_dim1 > 0, "Nested tensor dimension 0 cannot be transposed");
592   // -- to exclude the implicit batch dimension
593   ndims--;
594   positive_dim0--;
595   positive_dim1--;
596   // transpose = switch `dim0` and `dim1` columns of `sizemat` and `stridemat`
597   const Tensor& sizemat = self_ptr->get_nested_sizes(),
598       & stridemat = self_ptr->get_nested_strides();
599   Tensor column_indices = sizemat.new_empty(ndims);
600   int64_t* column_indices_ptr = column_indices.data_ptr<int64_t>();
601   std::iota(column_indices_ptr, column_indices_ptr + ndims, 0);
602   column_indices_ptr[positive_dim0] = positive_dim1;
603   column_indices_ptr[positive_dim1] = positive_dim0;
604   // create transposed `sizemat` and `stridemat`
605   Tensor sizemat_transposed = at::index_select(sizemat, 1, column_indices),
606       stridemat_transposed = at::index_select(stridemat, 1, column_indices);
607   return create_nested_view_tensor(
608       self, sizemat_transposed, stridemat_transposed, self_ptr->get_storage_offsets().clone());
609 }
610 
squeeze_nested(const Tensor & self)611 Tensor squeeze_nested(const Tensor& self) {
612   TORCH_CHECK(false,
613   "squeeze(): For nested tensors, squeeze without the dim argument is not supported ",
614   "at the moment, however you can use squeeze(Tensor self, int dim) instead ",
615   "if you need this feature, please open an issue on github describing your use case.");
616   return self;
617 }
618 
squeeze_dim_nested(const Tensor & self,IntArrayRef dims)619 Tensor squeeze_dim_nested(const Tensor& self, IntArrayRef dims) {
620   auto self_ptr = get_nested_tensor_impl(self);
621   int64_t ndim = self_ptr->dim();
622   auto mask = at::dim_list_to_bitset(dims, ndim);
623   TORCH_CHECK(!mask.test(0),
624   "squeeze(): For nested tensors, squeezing dimension 0 is not supported at the moment ",
625   "if you need this feature, please open an issue on github describing your use case.");
626   const Tensor& sizemat = self_ptr->get_nested_sizes();
627   const Tensor& stridemat = self_ptr->get_nested_strides();
628   // if tensor.size(dim) != 1 torch.squeeze will return the result, we do the same here
629   for (const auto d : c10::irange(ndim)) {
630     if (mask.test(d)) {
631       std::optional<int64_t> size_dim = self_ptr->opt_size(d);
632       if (!(size_dim.has_value() && *size_dim == 1)) {
633         mask.reset(d);
634       }
635     }
636   }
637 
638   if (!mask.any()) {
639     // detach to avoid triggering throw_error_if_base_and_tensor_are_same
640     return self.detach();
641   }
642   // if ndim == 2 and we pass the above if statement we should have a
643   // nested tensor of singleton tensors
644   TORCH_CHECK(ndim > static_cast<int64_t>(1 + dims.size()),
645   "squeeze(): For nested tensors, squeezing a nested tensor of singleton tensors is not ",
646   "supported at the moment, if you need this feature, please open an issue on github",
647   "describing your use case.");
648   const auto new_ndim = ndim - mask.count();
649   auto column_indices = sizemat.new_empty(static_cast<int64_t>(new_ndim) - 1);
650   int64_t* column_indices_ptr = column_indices.data_ptr<int64_t>();
651   for (const auto d : c10::irange(1, ndim)) {
652     if (!mask.test(d)) {
653       *column_indices_ptr++ = d - 1;
654     }
655   }
656   auto sizemat_squeezed = at::index_select(sizemat, 1, column_indices);
657   auto stridemat_squeezed = at::index_select(stridemat, 1, column_indices);
658   return create_nested_view_tensor(
659       self, sizemat_squeezed, stridemat_squeezed, self_ptr->get_storage_offsets().clone());
660 }
661 
squeeze_dim_nested(const Tensor & self,int64_t dim)662 Tensor squeeze_dim_nested(const Tensor& self, int64_t dim) {
663   return squeeze_dim_nested(self, IntArrayRef{dim});
664 }
665 
unsqueeze_nested(const Tensor & self,int64_t dim)666 Tensor unsqueeze_nested(const Tensor& self, int64_t dim) {
667   auto self_ptr = get_nested_tensor_impl(self);
668   int64_t ndim = self_ptr->dim();
669   int64_t wrapped_dim = at::maybe_wrap_dim(dim, ndim + 1);
670   TORCH_CHECK(wrapped_dim > 0,
671   "unsqueeze(): For nested tensors, unsqueezing dimension 0 is not supported at the moment ",
672   "if you need this feature, please open an issue on github describing your use case.");
673   const Tensor& sizemat = self_ptr->get_nested_sizes();
674   const Tensor& stridemat = self_ptr->get_nested_strides();
675   auto mat_dim = wrapped_dim - 1;
676   Tensor new_size = sizemat.new_ones({sizemat.size(0), 1});
677   Tensor sizemat_unsqueezed = at::cat({sizemat.slice(1, 0, mat_dim),
678                                        new_size,
679                                        sizemat.slice(1, mat_dim, ndim)}, 1);
680   Tensor new_stride;
681   if (wrapped_dim == ndim) {
682     new_stride = stridemat.new_ones({stridemat.size(0), 1});
683   } else {
684     new_stride = (stridemat.select(1, mat_dim) * sizemat.select(1, mat_dim)).unsqueeze(-1);
685   }
686   Tensor stridemat_unsqueezed = at::cat({stridemat.slice(1, 0, mat_dim),
687                                          new_stride,
688                                          stridemat.slice(1, mat_dim, ndim)}, 1);
689   return create_nested_view_tensor(
690       self, sizemat_unsqueezed, stridemat_unsqueezed, self_ptr->get_storage_offsets().clone());
691 }
692 
693 // utilities supporting `view_nested` and `reshape_nested`
694 namespace {
695 // Args:
696 //     sizes: the sizes of original nested tensor
697 //     strides: the strides of original nested tensor
698 //     proposed_shape: user proposed new shape
699 //     op: the options for new size and stride matrices
700 // Returns:
701 //     whether viewable
702 //     size matrix after reshape
703 //     stride matrix after reshape (not fully populated if not viewable)
NestedTensor_compute_size_stride(const std::vector<IntArrayRef> & sizes,const std::vector<IntArrayRef> & strides,const IntArrayRef & proposed_shape,const c10::TensorOptions & op)704 inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
705     const std::vector<IntArrayRef>& sizes,
706     const std::vector<IntArrayRef>& strides,
707     const IntArrayRef& proposed_shape,
708     const c10::TensorOptions& op) {
709   int64_t ntensors = static_cast<int64_t>(sizes.size());
710   int64_t ndims_underlying = static_cast<int64_t>(sizes[0].size());
711   int64_t ndims_underlying_reshaped = static_cast<int64_t>(proposed_shape.size() - 1);
712   bool viewable = true;
713   Tensor sizemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op),
714       stridemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op);
715   int64_t* sizemat_reshaped_ptr = sizemat_reshaped.mutable_data_ptr<int64_t>(),
716       * stridemat_reshaped_ptr = stridemat_reshaped.mutable_data_ptr<int64_t>();
717   for (int64_t itensor = 0; itensor < ntensors; itensor++) {
718     const IntArrayRef& size = sizes[itensor],
719         & stride = strides[itensor];
720     // compute reshaped size
721     std::vector<int64_t> size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end());
722     // only allow one pre-existing dimension to have proposed shape == -1
723     int64_t infer_index_old = -1;
724     // some negative sizes remain to be inferred
725     if (ndims_underlying < ndims_underlying_reshaped) {
726       int64_t numel = 1, numel_reshaped = 1;
727       // replace negative sizes for old dimensions with old sizes
728       for (int64_t idim = 0; idim < ndims_underlying; idim++) {
729         int64_t& size_reshaped = size_reshaped_vector[idim];
730         TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
731         if (size_reshaped == -1) {
732           TORCH_CHECK(infer_index_old == -1, "only one dimension can be inferred");
733           size_reshaped = size[idim];
734           infer_index_old = idim;
735         }
736         numel *= size[idim];
737         numel_reshaped *= size_reshaped;
738       }
739       // infer negative size for new dimension
740       int64_t infer_index = -1;
741       for (int64_t idim = ndims_underlying; idim < ndims_underlying_reshaped; idim++) {
742         const int64_t& size_reshaped = size_reshaped_vector[idim];
743         if (size_reshaped >= 0) {
744           numel_reshaped *= size_reshaped;
745         }
746         else if (size_reshaped == -1) {
747           if (infer_index > -1) {
748             throw std::runtime_error("only one dimension can be inferred");
749           }
750           else {
751             infer_index = idim;
752           }
753         }
754         else {
755           AT_ERROR("invalid shape dimension ", size_reshaped);
756         }
757       }
758       // See Note [Special size rule for nested tensor]
759       TORCH_CHECK(infer_index == -1, "nested tensor does not infer shape");
760       TORCH_CHECK(
761           numel == numel_reshaped,
762           "shape '", proposed_shape, "' ",
763           "is invalid for input of size ", numel);
764     }
765     // all negative sizes can be replaced
766     else {
767       int64_t numel = 1, numel_reshaped = 1;
768       for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
769         int64_t& size_reshaped = size_reshaped_vector[idim];
770         TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
771         if (size_reshaped == -1) {
772           size_reshaped = size[idim];
773         }
774         numel *= size[idim];
775         numel_reshaped *= size_reshaped;
776       }
777       for (int64_t idim = ndims_underlying_reshaped; idim < ndims_underlying; idim++) {
778         numel *= size[idim];
779       }
780       TORCH_CHECK(
781           numel == numel_reshaped,
782           "shape '", proposed_shape, "' ",
783           "is invalid for input of size ", numel);
784     }
785     IntArrayRef size_reshaped(size_reshaped_vector);
786     // compute reshaped stride
787     auto opt_stride_reshaped = at::detail::computeStride(size, stride, size_reshaped);
788     // reshape as view is possible
789     if (opt_stride_reshaped.has_value()) {
790       const IntArrayRef& stride_reshaped = *opt_stride_reshaped;
791       // fill reshaped size and stride into sizemat and stridemat
792       for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
793         sizemat_reshaped_ptr[idim] = size_reshaped[idim];
794         stridemat_reshaped_ptr[idim] = stride_reshaped[idim];
795       }
796       sizemat_reshaped_ptr += ndims_underlying_reshaped;
797       stridemat_reshaped_ptr += ndims_underlying_reshaped;
798     }
799     // reshape as view is impossible
800     else {
801       viewable = false;
802       // fill reshaped size into sizemat
803       for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
804         sizemat_reshaped_ptr[idim] = size_reshaped[idim];
805       }
806       sizemat_reshaped_ptr += ndims_underlying_reshaped;
807     }
808   }
809   return std::make_tuple(viewable, sizemat_reshaped, stridemat_reshaped);
810 }
811 } // namespace
812 
813 // Note [Special size rule for nested tensor]
814 // Instead of inferring size, -1 means "inherit the old size", so:
815 // * negative size is legal for a ragged dimension
816 // * however, we only allow one -1
817 // In principle we could still infer a dimension,
818 // we are designing a better semantics to include both inheritance and inference
view_nested(const Tensor & self,IntArrayRef proposed_shape)819 Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
820   TORCH_CHECK(
821       !proposed_shape.empty(),
822       "shape '[]' is invalid for a nested tensor");
823   auto self_ptr = get_nested_tensor_impl(self);
824   // basic information before reshaping
825   int64_t ntensors = self_ptr->size(0);
826   TORCH_CHECK(
827       ntensors > 0,
828       "empty nested tensor cannot be reshaped");
829   // basic information after reshaping
830   int64_t ntensors_reshaped = proposed_shape[0];
831   TORCH_CHECK(
832       ntensors == ntensors_reshaped,
833       "view: For now nested view cannot change or infer the implicit batch dimension");
834   std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
835       strides = NestedTensor_get_strides(self_ptr);
836   // reshaping underlying tensor dimensions does not change offset
837   // determine reshaped size and stride
838   const Tensor& sizemat = self_ptr->get_nested_sizes();
839   auto [viewable, sizemat_reshaped, stridemat_reshaped] = NestedTensor_compute_size_stride(
840       sizes, strides, proposed_shape, sizemat.options());
841   TORCH_CHECK(
842       viewable,
843       "view size is not compatible with input tensor's size and stride "
844       "(at least one dimension spans across two contiguous subspaces). "
845       "Use .reshape(...) instead.");
846   return create_nested_view_tensor(self, sizemat_reshaped, stridemat_reshaped, self_ptr->get_storage_offsets().clone());
847 }
848   /**
849    * Create a buffer tensor that is a view of self
850    *
851    * This serves as the boundary between nested and non nested tensor
852    * view conversions
853    *
854    * @return Returns a new non nested tensor that
855    * aliases the same storage as self
856    */
values_nested(const Tensor & self)857 Tensor values_nested(const Tensor& self) {
858   TORCH_INTERNAL_ASSERT(self.is_nested(), "Can only create a buffer from Nested Tensor");
859   auto* nt_self = get_nested_tensor_impl(self);
860   return nt_self->get_unsafe_storage_as_tensor();
861 }
862 
863 /**
864  * Create a nested tensor that is a view of a buffer
865  *
866  * This serves as the boundary between non nested tensor and nested
867  * view conversions
868  *
869  * @return Returns a nested tensor that
870  * aliases the same storage as buffer
871  */
_nested_view_from_buffer(const Tensor & buffer,const Tensor & nested_sizes,const Tensor & nested_strides,const Tensor & storage_offsets)872 Tensor _nested_view_from_buffer(
873     const Tensor& buffer,
874     const Tensor& nested_sizes,
875     const Tensor& nested_strides,
876     const Tensor& storage_offsets) {
877   TORCH_INTERNAL_ASSERT(
878       !buffer.is_nested(),
879       "Can only a create Nested Tensor from a normal tensor buffer");
880   TORCH_INTERNAL_ASSERT(buffer.dim() == 1, "The input buffer must be flat");
881   TORCH_INTERNAL_ASSERT(nested_sizes.dim() == 2, "Expected the nested size tensor to be two dimensional.");
882   uint64_t num_elements_nested_size = at::prod(nested_sizes, 1).sum().item<int64_t>();
883   uint64_t buffer_storage_size = buffer.storage().nbytes()/buffer.dtype().itemsize();
884   TORCH_INTERNAL_ASSERT(
885       buffer_storage_size == num_elements_nested_size,
886       "The number of elements in the buffer must equal the nested tensor size but buffer size: ",
887       buffer_storage_size,
888       " and nested tensor size: ",
889       num_elements_nested_size,
890       ".");
891 
892   TORCH_INTERNAL_ASSERT(nested_strides.dim() == 2, "Expected the nested stride tensor to be two dimensional.");
893   TORCH_INTERNAL_ASSERT(nested_sizes.size(0) == nested_strides.size(0), "Expected the first dimension of nested size and nested stride tensor to be equal.");
894   TORCH_INTERNAL_ASSERT(nested_strides.size(0) == storage_offsets.size(0), "Expected the first dimension of nested stride tensor to equal the length of offsets.");
895   return at::detail::make_tensor<NestedTensorImpl>(
896     c10::TensorImpl::VIEW,
897     buffer,
898     nested_sizes,
899     nested_strides,
900     storage_offsets);
901 }
902 
_nested_compute_contiguous_strides_offsets(const Tensor & nested_size)903 std::tuple<Tensor, Tensor> _nested_compute_contiguous_strides_offsets(const Tensor& nested_size) {
904   return std::make_tuple(
905       construct_nested_strides(nested_size),
906       construct_offsets(nested_size));
907 }
908 
909 // See Note [Special size rule for nested tensor]
reshape_nested(const Tensor & self,IntArrayRef proposed_shape)910 Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
911   TORCH_CHECK(
912       !proposed_shape.empty(),
913       "shape '[]' is invalid for a nested tensor");
914   auto self_ptr = get_nested_tensor_impl(self);
915   // basic information before reshaping
916   int64_t ntensors = self_ptr->size(0);
917   TORCH_CHECK(
918       ntensors > 0,
919       "empty nested tensor cannot be reshaped");
920   // basic information after reshaping
921   int64_t ntensors_reshaped = proposed_shape[0];
922   TORCH_CHECK(
923       ntensors == ntensors_reshaped,
924       "reshape: For now nested reshape cannot change or infer the implicit batch dimension");
925   std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
926       strides = NestedTensor_get_strides(self_ptr);
927   // reshaping underlying tensor dimensions does not change offset
928   // determine reshaped size and stride
929   const Tensor& sizemat = self_ptr->get_nested_sizes();
930   auto [viewable, sizemat_reshaped, stridemat_reshaped] = NestedTensor_compute_size_stride(
931       sizes, strides, proposed_shape, sizemat.options());
932   if (viewable) {
933     return self.view(proposed_shape);
934   }
935   else {
936     return self.clone(at::MemoryFormat::Contiguous).view(proposed_shape);
937   }
938 }
939 
reshape_nested_symint(const Tensor & self,SymIntArrayRef proposed_shape)940 Tensor reshape_nested_symint(const Tensor& self, SymIntArrayRef proposed_shape) {
941   // Jagged layout NT decomp
942   if (self.layout() == at::kJagged) {
943     // TODO: Expand decomp to handle other viewable cases
944     bool viewable = self.is_contiguous();
945     return (
946         viewable ? self.view_symint(proposed_shape) :
947         self.clone(at::MemoryFormat::Contiguous).view_symint(proposed_shape)
948     );
949   }
950 
951   return reshape_nested(self, C10_AS_INTARRAYREF_SLOW(proposed_shape));
952 }
953 
reshape_as_nested(const Tensor & self,const Tensor & other)954 Tensor reshape_as_nested(const Tensor& self, const Tensor& other) {
955   // Jagged layout NT decomp
956   if (self.layout() == at::kJagged) {
957     return self.reshape_symint(other.sym_sizes());
958   }
959 
960   auto other_ptr = get_nested_tensor_impl(other);
961   // TODO: this is to reproduce other_ptr->opt_sizes_
962   //       if an accessor is provided in the future, can replace this
963   std::vector<int64_t> sizes;
964   for (int64_t i = 0; i < other_ptr->dim(); i++) {
965     std::optional<int64_t> opt_size = other_ptr->opt_size(i);
966     if (opt_size.has_value()) {
967       sizes.push_back(*opt_size);
968     }
969     else {
970       sizes.push_back(-1);
971     }
972   }
973   // reshape with other.opt_sizes_
974   return self.reshape(sizes);
975 }
976 
normal_nested_(Tensor & self,double mean,double std,std::optional<Generator> gen)977 Tensor& normal_nested_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
978   const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
979   self_buf.normal_(mean, std, std::move(gen));
980   return self;
981 }
982 
983 // returns true if the sizes are compatible to be concatenated along the specified dim
984 // sizes should match outside of the dim of concatenation
can_cat_nested_sizes(const Tensor & nested_sizes1,const Tensor & nested_sizes2,int64_t cat_dim)985 static bool can_cat_nested_sizes(const Tensor& nested_sizes1, const Tensor& nested_sizes2, int64_t cat_dim) {
986   if (nested_sizes1.sizes() != nested_sizes2.sizes()) {
987     return false;
988   }
989 
990   auto nested_sizes1_ptr = nested_sizes1.data_ptr<int64_t>();
991   auto nested_sizes2_ptr = nested_sizes2.data_ptr<int64_t>();
992   const auto num_components = nested_sizes1.size(0);
993   const auto num_dims = nested_sizes1.size(1);
994   for (auto c : c10::irange(num_components)) {
995     for (auto d : c10::irange(num_dims)) {
996       // subtract 1 to account for batch dim
997       auto component_cat_dim = cat_dim - 1;
998       if (d == component_cat_dim) {
999         continue;
1000       }
1001       if (nested_sizes1_ptr[c * num_dims + d] != nested_sizes2_ptr[c * num_dims + d]) {
1002         return false;
1003       }
1004     }
1005   }
1006 
1007   return true;
1008 }
1009 
1010 // cat a list of NTs that are representable as jagged
cat_nested_as_jagged(const MaterializedITensorListRef & tensors,int64_t dim)1011 static Tensor cat_nested_as_jagged(
1012     const MaterializedITensorListRef& tensors,
1013     int64_t dim) {
1014   const auto first_item = tensors[0].get();
1015   const auto first_item_dim = first_item.dim();
1016   const auto first_item_batch_size = first_item.size(0);
1017   std::vector<Tensor> jagged_views;
1018   for (auto i : c10::irange(tensors.size())) {
1019     auto t = tensors[i].get();
1020     TORCH_CHECK(t.is_nested(),
1021         "cat(): expected each tensor in given list to be nested");
1022     TORCH_CHECK(t.is_contiguous(),
1023         "cat(): only contiguous nested tensors are supported");
1024     if (i > 0) {
1025       TORCH_CHECK(
1026           can_cat_nested_sizes(
1027               get_nested_tensor_impl(first_item)->get_nested_sizes(),
1028               get_nested_tensor_impl(t)->get_nested_sizes(),
1029               dim),
1030           "cat(): expected all nested tensors to have matching ragged structures outside of the concatenated dim");
1031     }
1032     // only support inputs in the form (B, *, D_0, D_1, ...)
1033     // i.e. require at most a single ragged dim next to the batch dim
1034     auto *nt_impl = get_nested_tensor_impl(t);
1035     std::vector<int64_t> jagged_size;
1036     jagged_size.push_back(-1);
1037     for (auto d : c10::irange(first_item_dim - 2)) {
1038       TORCH_CHECK(nt_impl->opt_size(d + 2).has_value(),
1039           "cat(): only nested tensors with a single ragged dim next to the batch dim are supported");
1040       jagged_size.push_back(nt_impl->size(d + 2));
1041     }
1042     auto jagged = nt_impl->get_buffer().view(jagged_size);
1043     jagged_views.push_back(jagged);
1044   }
1045 
1046   // view each of the NTs as jagged for the cat() call
1047   auto new_buffer = at::cat(jagged_views, dim - 1);
1048 
1049   // wrap result into nested tensor
1050   const auto component_dim = first_item_dim - 1;
1051   auto new_dim_size = new_buffer.size(dim - 1);
1052   auto new_sizes = get_nested_tensor_impl(tensors[0].get())->get_nested_sizes().clone();
1053   auto new_sizes_ptr = new_sizes.data_ptr<int64_t>();
1054   for (const auto i : c10::irange(first_item_batch_size)) {
1055     new_sizes_ptr[i * component_dim + (dim - 1)] = new_dim_size;
1056   }
1057   return at::detail::make_tensor<NestedTensorImpl>(
1058       new_buffer.view(-1), new_sizes);
1059 }
1060 
cat_nested_impl(const MaterializedITensorListRef & tensors,int64_t dim)1061 static Tensor cat_nested_impl(
1062     const MaterializedITensorListRef& tensors,
1063     int64_t dim) {
1064   dim = maybe_wrap_dim(dim, tensors[0].get());
1065   if (dim == 0) {
1066     // handle simple case of dim=0: concat NT components
1067     std::vector<at::Tensor> buffers;
1068     std::vector<at::Tensor> sizes;
1069     for (const auto i : c10::irange(tensors.size())) {
1070       const Tensor& t = tensors[i];
1071       TORCH_CHECK(
1072           t.is_nested(), "Expected each tensor in given list to be nested.");
1073       TORCH_CHECK(
1074           t.is_contiguous(),
1075           "Expected each tensor in given list to be contiguous.");
1076       auto t_ptr = get_nested_tensor_impl(t);
1077       buffers.push_back(t_ptr->get_buffer().view({-1}));
1078       sizes.push_back(t_ptr->get_nested_sizes());
1079     }
1080     return at::detail::make_tensor<NestedTensorImpl>(
1081         at::cat(buffers).view({-1}), at::cat(sizes, 0));
1082   }
1083 
1084   // NB: support for other dims is restricted to nested tensors representable as jagged
1085   return cat_nested_as_jagged(tensors, dim);
1086 }
1087 
cat_nested(const ITensorListRef & tensors,int64_t dim)1088 Tensor cat_nested(const ITensorListRef& tensors, int64_t dim) {
1089   auto materialized = tensors.materialize();
1090   return cat_nested_impl(materialized, at::legacy_cat_wrap_dim(dim, materialized));
1091 }
1092 
1093 } // namespace at::native
1094