1 #include <ATen/NestedTensorImpl.h>
2 #include <ATen/native/nested/NestedTensorUtils.h>
3 #include <optional>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/_nested_tensor_size_native.h>
9 #include <ATen/ops/_nested_tensor_storage_offsets_native.h>
10 #include <ATen/ops/_nested_tensor_strides_native.h>
11 #include <ATen/ops/chunk_native.h>
12 #include <ATen/ops/split_with_sizes_native.h>
13 #endif
14
15 namespace at::native {
16
17 /**
18 * Thin wrapper around get_nested_sizes that is registered as a native function
19 *
20 * @return The nested tensors' size tensor.
21 */
_nested_tensor_size(const at::Tensor & self)22 at::Tensor _nested_tensor_size(const at::Tensor& self) {
23 return get_nested_sizes(self);
24 }
25
_nested_tensor_strides(const at::Tensor & self)26 at::Tensor _nested_tensor_strides(const at::Tensor& self){
27 return get_nested_tensor_impl(self) -> get_nested_strides();
28 }
_nested_tensor_storage_offsets(const at::Tensor & self)29 at::Tensor _nested_tensor_storage_offsets(const at::Tensor& self){
30 return get_nested_tensor_impl(self) -> get_storage_offsets();
31 }
32
33 // Helper functions for getting information about a nested tensor's shape.
NestedTensor_get_max_size_from_size_tensor(const Tensor & sizes)34 std::vector<int64_t> NestedTensor_get_max_size_from_size_tensor(
35 const Tensor& sizes) {
36 if (sizes.dim() == 0) {
37 return {};
38 }
39 const auto sizes_ptr = sizes.data_ptr<int64_t>();
40 const auto sizes_size_0 = sizes.sizes()[0];
41 const auto sizes_size_1 = sizes.sizes()[1];
42 TORCH_INTERNAL_ASSERT(sizes_size_1 > 0);
43 std::vector<int64_t> results(sizes_size_1, 0);
44 for (const auto ii : c10::irange(sizes_size_0)) {
45 for (const auto jj : c10::irange(sizes_size_1)) {
46 auto val = sizes_ptr[ii * sizes_size_1 + jj];
47 if (results[jj] < val) {
48 results[jj] = val;
49 }
50 }
51 }
52 return results;
53 }
54
NestedTensor_get_max_size(const NestedTensorImpl & nt)55 std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt) {
56 return NestedTensor_get_max_size_from_size_tensor(
57 nt.get_nested_sizes());
58 }
59
get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl & nt)60 int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt) {
61 std::optional<int64_t> last_dim = nt.opt_size(-1);
62 TORCH_CHECK(
63 last_dim != std::nullopt,
64 "Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals: ",
65 nt.get_nested_sizes().select(1, -1));
66 return *last_dim;
67 }
68
chunk_nested_tensor(const Tensor & self,int64_t chunks,int64_t dim)69 std::vector<Tensor> chunk_nested_tensor(const Tensor& self, int64_t chunks, int64_t dim) {
70 int64_t ndim = self.dim();
71 if (ndim == 0) {
72 TORCH_CHECK_INDEX(false, "chunk() cannot be applied to a 0-dim tensor.");
73 }
74 dim = maybe_wrap_dim(dim, ndim);
75 TORCH_CHECK(self.dim() - 1 == dim,
76 "Chunk for nested tensors is currently only supported for the last dimension.");
77 TORCH_CHECK(chunks > 0,"chunk expects `chunks` to be greater than 0, got: ", chunks);
78 TORCH_CHECK(self.is_contiguous(), "chunk expects `self` to be contiguous.");
79 auto self_impl = get_nested_tensor_impl(self);
80 const int64_t last_dim_size = get_consistent_last_dim_of_nested_tensor(*self_impl);
81 TORCH_CHECK(last_dim_size % chunks == 0,
82 "Chunk for nested tensors is only supported for nested tensors with trailing dimension divisible by chunks, got: ",
83 last_dim_size, " % ", chunks, " != 0");
84 int64_t n_tensors = self.size(0);
85 int64_t split_size = last_dim_size / chunks;
86 std::vector<Tensor> splits(chunks);
87 const auto& sizes = self_impl->get_nested_sizes();
88 const auto& strides = self_impl->get_nested_strides();
89 const auto offsets = self_impl->get_storage_offsets();
90 int64_t *offsets_ptr = offsets.data_ptr<int64_t>();
91 // Account for the implicit batch dim
92 --dim;
93 int64_t tensor_dim = sizes.size(1);
94 for (const auto split_idx : c10::irange(chunks)) {
95 auto new_sizes = sizes.clone();
96 auto new_strides = strides.clone();
97 // This copys offsets so we are safe to move
98 auto new_offsets = offsets.clone();
99 int64_t *size_ptr = new_sizes.data_ptr<int64_t>();
100 int64_t *new_offsets_ptr = new_offsets.data_ptr<int64_t>();
101 // Get start val for each split
102 int64_t start_val = split_idx * split_size;
103 for (int64_t i : c10::irange(n_tensors)) {
104 const int64_t index = i * tensor_dim + dim;
105 new_offsets_ptr[i] = offsets_ptr[i] + start_val;
106 size_ptr[index] = split_size;
107 }
108 splits[split_idx] = create_nested_view_tensor(self, new_sizes, new_strides, new_offsets);
109 }
110 return splits;
111 }
112
split_with_sizes_nested(const Tensor & self,c10::IntArrayRef split_sizes,int64_t dim)113 std::vector<Tensor> split_with_sizes_nested(
114 const Tensor& self,
115 c10::IntArrayRef split_sizes,
116 int64_t dim) {
117 int64_t ndim = self.dim();
118 if (ndim == 0) {
119 TORCH_CHECK_INDEX(false, "split_with_sizes() cannot be applied to a 0-dim tensor.");
120 }
121 dim = maybe_wrap_dim(dim, ndim);
122 TORCH_CHECK(self.dim() - 1 == dim,
123 "split_with_sizes for nested tensors is currently only supported for the last dimension.");
124 auto num_splits = split_sizes.size();
125 TORCH_CHECK(num_splits > 0,
126 "split_with_sizes expects number of splits to be greater than 0, got: ", num_splits);
127 TORCH_CHECK(self.is_contiguous(), "split_with_sizes expects `self` to be contiguous.");
128
129 // Ensure entire dim is split.
130 int64_t total_size = 0;
131 for (const auto split_size : split_sizes) {
132 total_size += split_size;
133 }
134 auto self_impl = get_nested_tensor_impl(self);
135 auto self_size = get_consistent_last_dim_of_nested_tensor(*self_impl);
136 TORCH_CHECK(total_size == self_size,
137 "split_with_sizes expects split_sizes to sum exactly to ", self_size,
138 " (input tensor's size at dimension ", dim, "), but got split_sizes=", split_sizes);
139
140 int64_t n_tensors = self.size(0);
141 std::vector<Tensor> splits(num_splits);
142 const auto& sizes = self_impl->get_nested_sizes();
143 const auto& strides = self_impl->get_nested_strides();
144 const auto offsets = self_impl->get_storage_offsets();
145 int64_t *offsets_ptr = offsets.data_ptr<int64_t>();
146 // Account for the implicit batch dim
147 --dim;
148 int64_t tensor_dim = sizes.size(1);
149 int64_t start_val = 0;
150 for (const auto split_idx : c10::irange(num_splits)) {
151 auto split_size = split_sizes[split_idx];
152 auto new_sizes = sizes.clone();
153 auto new_strides = strides.clone();
154 auto new_offsets = offsets.clone();
155 int64_t *size_ptr = new_sizes.data_ptr<int64_t>();
156 int64_t *new_offsets_ptr = new_offsets.data_ptr<int64_t>();
157 // Get start val for each split
158 for (int64_t i : c10::irange(n_tensors)) {
159 const int64_t index = i * tensor_dim + dim;
160 new_offsets_ptr[i] = offsets_ptr[i] + start_val;
161 size_ptr[index] = split_size;
162 }
163 start_val += split_size;
164 splits[split_idx] = create_nested_view_tensor(self, new_sizes, new_strides, new_offsets);
165 }
166 return splits;
167 }
168
169 } // namespace at::native
170