xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/NestedTensorImpl.h>
2 #include <ATen/native/nested/NestedTensorUtils.h>
3 #include <optional>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/_nested_tensor_size_native.h>
9 #include <ATen/ops/_nested_tensor_storage_offsets_native.h>
10 #include <ATen/ops/_nested_tensor_strides_native.h>
11 #include <ATen/ops/chunk_native.h>
12 #include <ATen/ops/split_with_sizes_native.h>
13 #endif
14 
15 namespace at::native {
16 
17 /**
18  * Thin wrapper around get_nested_sizes that is registered as a native function
19  *
20  * @return The nested tensors' size tensor.
21  */
_nested_tensor_size(const at::Tensor & self)22 at::Tensor _nested_tensor_size(const at::Tensor& self) {
23   return get_nested_sizes(self);
24 }
25 
_nested_tensor_strides(const at::Tensor & self)26 at::Tensor _nested_tensor_strides(const at::Tensor& self){
27   return  get_nested_tensor_impl(self) -> get_nested_strides();
28 }
_nested_tensor_storage_offsets(const at::Tensor & self)29 at::Tensor _nested_tensor_storage_offsets(const at::Tensor& self){
30   return get_nested_tensor_impl(self) -> get_storage_offsets();
31 }
32 
33 // Helper functions for getting information about a nested tensor's shape.
NestedTensor_get_max_size_from_size_tensor(const Tensor & sizes)34 std::vector<int64_t> NestedTensor_get_max_size_from_size_tensor(
35     const Tensor& sizes) {
36   if (sizes.dim() == 0) {
37     return {};
38   }
39   const auto sizes_ptr = sizes.data_ptr<int64_t>();
40   const auto sizes_size_0 = sizes.sizes()[0];
41   const auto sizes_size_1 = sizes.sizes()[1];
42   TORCH_INTERNAL_ASSERT(sizes_size_1 > 0);
43   std::vector<int64_t> results(sizes_size_1, 0);
44   for (const auto ii : c10::irange(sizes_size_0)) {
45     for (const auto jj : c10::irange(sizes_size_1)) {
46       auto val = sizes_ptr[ii * sizes_size_1 + jj];
47       if (results[jj] < val) {
48         results[jj] = val;
49       }
50     }
51   }
52   return results;
53 }
54 
NestedTensor_get_max_size(const NestedTensorImpl & nt)55 std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt) {
56   return NestedTensor_get_max_size_from_size_tensor(
57       nt.get_nested_sizes());
58 }
59 
get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl & nt)60 int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt) {
61   std::optional<int64_t> last_dim = nt.opt_size(-1);
62   TORCH_CHECK(
63       last_dim != std::nullopt,
64       "Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals: ",
65       nt.get_nested_sizes().select(1, -1));
66   return *last_dim;
67 }
68 
chunk_nested_tensor(const Tensor & self,int64_t chunks,int64_t dim)69 std::vector<Tensor> chunk_nested_tensor(const Tensor& self, int64_t chunks, int64_t dim) {
70   int64_t ndim = self.dim();
71   if (ndim == 0) {
72     TORCH_CHECK_INDEX(false, "chunk() cannot be applied to a 0-dim tensor.");
73   }
74   dim = maybe_wrap_dim(dim, ndim);
75   TORCH_CHECK(self.dim() - 1 == dim,
76            "Chunk for nested tensors is currently only supported for the last dimension.");
77   TORCH_CHECK(chunks > 0,"chunk expects `chunks` to be greater than 0, got: ", chunks);
78   TORCH_CHECK(self.is_contiguous(), "chunk expects `self` to be contiguous.");
79   auto self_impl = get_nested_tensor_impl(self);
80   const int64_t last_dim_size = get_consistent_last_dim_of_nested_tensor(*self_impl);
81     TORCH_CHECK(last_dim_size % chunks == 0,
82            "Chunk for nested tensors is only supported for nested tensors with trailing dimension divisible by chunks, got: ",
83            last_dim_size, " % ", chunks, " != 0");
84   int64_t n_tensors = self.size(0);
85   int64_t split_size = last_dim_size / chunks;
86   std::vector<Tensor> splits(chunks);
87   const auto& sizes = self_impl->get_nested_sizes();
88   const auto& strides = self_impl->get_nested_strides();
89   const auto offsets = self_impl->get_storage_offsets();
90   int64_t *offsets_ptr = offsets.data_ptr<int64_t>();
91   // Account for the implicit batch dim
92   --dim;
93   int64_t tensor_dim = sizes.size(1);
94   for (const auto split_idx : c10::irange(chunks)) {
95       auto new_sizes = sizes.clone();
96       auto new_strides = strides.clone();
97       // This copys offsets so we are safe to move
98       auto new_offsets = offsets.clone();
99       int64_t *size_ptr = new_sizes.data_ptr<int64_t>();
100       int64_t *new_offsets_ptr = new_offsets.data_ptr<int64_t>();
101       // Get start val for each split
102       int64_t start_val = split_idx * split_size;
103       for (int64_t i : c10::irange(n_tensors)) {
104         const int64_t index = i * tensor_dim + dim;
105         new_offsets_ptr[i] = offsets_ptr[i] + start_val;
106         size_ptr[index] = split_size;
107     }
108     splits[split_idx] = create_nested_view_tensor(self, new_sizes, new_strides, new_offsets);
109   }
110   return splits;
111 }
112 
split_with_sizes_nested(const Tensor & self,c10::IntArrayRef split_sizes,int64_t dim)113 std::vector<Tensor> split_with_sizes_nested(
114     const Tensor& self,
115     c10::IntArrayRef split_sizes,
116     int64_t dim) {
117   int64_t ndim = self.dim();
118   if (ndim == 0) {
119     TORCH_CHECK_INDEX(false, "split_with_sizes() cannot be applied to a 0-dim tensor.");
120   }
121   dim = maybe_wrap_dim(dim, ndim);
122   TORCH_CHECK(self.dim() - 1 == dim,
123            "split_with_sizes for nested tensors is currently only supported for the last dimension.");
124   auto num_splits = split_sizes.size();
125   TORCH_CHECK(num_splits > 0,
126            "split_with_sizes expects number of splits to be greater than 0, got: ", num_splits);
127   TORCH_CHECK(self.is_contiguous(), "split_with_sizes expects `self` to be contiguous.");
128 
129   // Ensure entire dim is split.
130   int64_t total_size = 0;
131   for (const auto split_size : split_sizes) {
132       total_size += split_size;
133   }
134   auto self_impl = get_nested_tensor_impl(self);
135   auto self_size = get_consistent_last_dim_of_nested_tensor(*self_impl);
136   TORCH_CHECK(total_size == self_size,
137           "split_with_sizes expects split_sizes to sum exactly to ", self_size,
138           " (input tensor's size at dimension ", dim, "), but got split_sizes=", split_sizes);
139 
140   int64_t n_tensors = self.size(0);
141   std::vector<Tensor> splits(num_splits);
142   const auto& sizes = self_impl->get_nested_sizes();
143   const auto& strides = self_impl->get_nested_strides();
144   const auto offsets = self_impl->get_storage_offsets();
145   int64_t *offsets_ptr = offsets.data_ptr<int64_t>();
146   // Account for the implicit batch dim
147   --dim;
148   int64_t tensor_dim = sizes.size(1);
149   int64_t start_val = 0;
150   for (const auto split_idx : c10::irange(num_splits)) {
151     auto split_size = split_sizes[split_idx];
152     auto new_sizes = sizes.clone();
153     auto new_strides = strides.clone();
154     auto new_offsets = offsets.clone();
155     int64_t *size_ptr = new_sizes.data_ptr<int64_t>();
156     int64_t *new_offsets_ptr = new_offsets.data_ptr<int64_t>();
157     // Get start val for each split
158     for (int64_t i : c10::irange(n_tensors)) {
159       const int64_t index = i * tensor_dim + dim;
160       new_offsets_ptr[i] = offsets_ptr[i] + start_val;
161       size_ptr[index] = split_size;
162     }
163     start_val += split_size;
164     splits[split_idx] = create_nested_view_tensor(self, new_sizes, new_strides, new_offsets);
165   }
166   return splits;
167 }
168 
169 } // namespace at::native
170