1 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/NestedTensorImpl.h>
6 #include <ATen/native/nested/NestedTensorUtils.h>
7
8 #include <c10/util/string_view.h>
9 #include <c10/util/Exception.h>
10 #include <optional>
11
12 namespace at::native {
13 namespace {
14
check_nested_tensor_matrix_constraints(const Tensor & nested_tensor,const Tensor & dense_matrix,c10::string_view caller)15 inline void check_nested_tensor_matrix_constraints(
16 const Tensor& nested_tensor,
17 const Tensor& dense_matrix,
18 c10::string_view caller) {
19 auto* nt_input = get_nested_tensor_impl(nested_tensor);
20 TORCH_INTERNAL_ASSERT(nt_input != nullptr);
21 TORCH_CHECK(
22 !dense_matrix.is_nested(),
23 caller,
24 " does not support nested weight when input is a nested tensor.")
25 // TODO: support noncontiguous case
26 // error out for now
27 TORCH_CHECK(
28 nested_tensor_impl_is_contiguous(nt_input),
29 "for now linear only supports contiguous nested tensor");
30 TORCH_CHECK(
31 nested_tensor.dim() == 3 && dense_matrix.dim() == 2,
32 caller,
33 " requires nested_tensor.dim == 3 and dense_matrix.dim == 2."
34 " Nested tensor dim: ",
35 nested_tensor.dim(),
36 ". Dense tensor dim: ",
37 dense_matrix.dim());
38 const auto last_dim = get_consistent_last_dim_of_nested_tensor(*nt_input);
39 // We check check the second dimension for linear because it transposes before matrix multiply
40 int64_t dim_constraint = (caller == "Linear") ? 1 : 0;
41 auto dense_size = dense_matrix.size(dim_constraint);
42 TORCH_CHECK(
43 last_dim == dense_size,
44 "Shape mismatch for NestedTensor ",
45 caller,
46 ": Expected input's (a nested tensor) 'last_dim' to equal 'weight.size(",
47 dim_constraint,
48 "),",
49 " but got: last_dim = ",
50 last_dim,
51 ", and weight.size(",
52 dim_constraint,
53 ") = ",
54 dense_size);
55 }
56 } // namespace
57
nested_linear(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt)58 Tensor nested_linear(
59 const Tensor& input,
60 const Tensor& weight,
61 const std::optional<Tensor>& bias_opt) {
62 check_nested_tensor_matrix_constraints(input, weight, c10::string_view{"Linear"});
63 auto* nt_input = get_nested_tensor_impl(input);
64 const Tensor& input_buffer = nt_input->get_buffer();
65 Tensor result_buffer =
66 at::linear(input_buffer.reshape({-1, weight.size(1)}), weight, bias_opt);
67 result_buffer = result_buffer.reshape({-1});
68 int64_t weight_size_1 = weight.size(0);
69 Tensor new_sizes = nt_input->get_nested_sizes().clone();
70 // Now the last entry in every row of new_sizes should be weight_size_1.
71 new_sizes.index_put_({at::indexing::Slice(), -1}, weight_size_1);
72 return wrap_buffer(result_buffer, new_sizes);
73 }
74
NestedTensor_matmul(const Tensor & self,const Tensor & other)75 Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other) {
76 check_nested_tensor_matrix_constraints(self, other, c10::string_view{"Matmul"});
77 auto* nt_self = get_nested_tensor_impl_or_null(self);
78 const Tensor& self_buffer = nt_self->get_buffer();
79 Tensor result_buffer =
80 at::mm(self_buffer.reshape({-1, other.sizes()[0]}), other);
81 result_buffer = result_buffer.reshape({-1});
82 int64_t other_size_1 = other.sizes()[1];
83 Tensor new_sizes = nt_self->get_nested_sizes().clone();
84 // Now the last entry in every row of new_sizes should be other_size_1.
85 new_sizes.index_put_({at::indexing::Slice(), -1}, other_size_1);
86 return wrap_buffer(result_buffer, new_sizes);
87 }
88
NestedTensor_times_Tensor_plus_Tensor_addmm(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const c10::Scalar & beta,const c10::Scalar & alpha,std::optional<bool> use_gelu)89 Tensor NestedTensor_times_Tensor_plus_Tensor_addmm(
90 const Tensor& self,
91 const Tensor& mat1,
92 const Tensor& mat2,
93 const c10::Scalar& beta,
94 const c10::Scalar& alpha,
95 std::optional<bool> use_gelu) {
96 // Interesting case: alpha * NT * T + beta * T
97 const auto* nt_mat1 = get_nested_tensor_impl_or_null(mat1);
98 TORCH_INTERNAL_ASSERT(nt_mat1 != nullptr);
99 TORCH_INTERNAL_ASSERT(!mat2.is_nested());
100 TORCH_INTERNAL_ASSERT(!self.is_nested());
101 TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(nt_mat1));
102 TORCH_INTERNAL_ASSERT(mat1.dim() == 3 && mat2.dim() == 2);
103 TORCH_INTERNAL_ASSERT(
104 get_consistent_last_dim_of_nested_tensor(*nt_mat1) == mat2.sizes()[0]);
105 const Tensor& mat1_buffer = nt_mat1->get_buffer();
106 Tensor result_buffer = !use_gelu.has_value()
107 ? at::addmm(
108 self, mat1_buffer.reshape({-1, mat2.sizes()[0]}), mat2, beta, alpha)
109 : at::_addmm_activation(
110 self,
111 mat1_buffer.reshape({-1, mat2.sizes()[0]}),
112 mat2,
113 beta,
114 alpha,
115 *use_gelu);
116 result_buffer = result_buffer.reshape({-1});
117 int64_t other_size_1 = mat2.sizes()[1];
118 Tensor new_sizes = nt_mat1->get_nested_sizes().clone();
119 new_sizes.index_put_({at::indexing::Slice(), -1}, other_size_1);
120 return at::detail::make_tensor<NestedTensorImpl>(
121 std::move(result_buffer), std::move(new_sizes));
122 }
123
NestedTensor_add_NestedTensor_in_place(const Tensor & self,const Tensor & other)124 Tensor NestedTensor_add_NestedTensor_in_place(
125 const Tensor& self,
126 const Tensor& other) {
127 TORCH_INTERNAL_ASSERT(self.is_nested() && other.is_nested());
128 const auto& nt_self = *get_nested_tensor_impl(self);
129 const auto& nt_other = *get_nested_tensor_impl(other);
130
131 const auto& self_sizes = nt_self.get_nested_sizes();
132 const auto& other_sizes = nt_other.get_nested_sizes();
133
134 TORCH_CHECK(at::equal(self_sizes, other_sizes));
135 TORCH_INTERNAL_ASSERT(
136 nested_tensor_impl_is_contiguous(&nt_self) &&
137 nested_tensor_impl_is_contiguous(&nt_other));
138 nt_self.get_buffer().view({-1}).add_(nt_other.get_buffer().view({-1}));
139 return self;
140 }
141
NestedTensor_softmax_dropout(const Tensor & self,const Tensor & query)142 Tensor NestedTensor_softmax_dropout(const Tensor& self, const Tensor& query) {
143 const auto* query_nt = get_nested_tensor_impl_or_null(query);
144 TORCH_INTERNAL_ASSERT(query_nt != nullptr);
145 TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(query_nt));
146
147 const Tensor& sizes = query_nt->get_nested_sizes();
148 const auto num_tensors = sizes.sizes()[0];
149
150 auto output = at::empty_like(self,{}, at::MemoryFormat::Contiguous);
151 TORCH_INTERNAL_ASSERT(output.is_contiguous());
152
153 const auto max_seq_len = self.sizes()[2];
154
155 for (int64_t i = 0; i < num_tensors; i++) {
156 auto seq_len = sizes.index({i, 0}).item<int64_t>();
157 auto subseq = self.index(
158 {i,
159 indexing::Slice(),
160 indexing::Slice(0, seq_len),
161 indexing::Slice(0, seq_len)});
162 auto subscores = at::softmax(subseq, subseq.dim() - 1);
163 output.index_put_(
164 {i,
165 indexing::Slice(),
166 indexing::Slice(0, seq_len),
167 indexing::Slice(0, seq_len)},
168 subscores);
169 output.index_put_(
170 {i,
171 indexing::Slice(),
172 indexing::Slice(0, seq_len),
173 indexing::Slice(seq_len, max_seq_len)},
174 0);
175 output.index_put_(
176 {i,
177 indexing::Slice(),
178 indexing::Slice(seq_len, max_seq_len),
179 indexing::Slice(0, max_seq_len)},
180 0);
181 }
182 return output;
183 }
184
NestedTensor_softmax_dropout_cuda(const Tensor & self,const Tensor & query)185 Tensor NestedTensor_softmax_dropout_cuda(const Tensor& self, const Tensor& query) {
186 std::optional<Tensor> attn_mask;
187
188 attn_mask = NestedTensor_to_mask(query, 2, self.size(2));
189 attn_mask = attn_mask->to(query.device(), /*non-blocking=*/true);
190 return _masked_softmax(self, *attn_mask, self.dim() - 1, /*mask type */ 1 ); // NestedTensor_to_mask produces a BxT mask
191 }
192
NestedTensor_batch_offsets_from_size_tensor(const Tensor & sizes,int64_t extra_elements)193 Tensor NestedTensor_batch_offsets_from_size_tensor(
194 const Tensor& sizes,
195 int64_t extra_elements) {
196 int64_t* const sizes_ptr = sizes.data_ptr<int64_t>();
197 Tensor offsets = at::empty({1 + sizes.size(0) + extra_elements}, at::kInt);
198 int32_t* const offsets_ptr = offsets.mutable_data_ptr<int32_t>();
199 offsets_ptr[0] = 0;
200 const auto sizes_size_1 = sizes.size(1);
201 const auto sizes_size_0 = sizes.size(0);
202 for (const auto i : c10::irange(sizes_size_0)) {
203 int64_t prod = 1;
204 for (const auto j : c10::irange(sizes_size_1)) {
205 prod *= sizes_ptr[i * sizes_size_1 + j];
206 }
207 offsets_ptr[i + 1] = offsets_ptr[i] + static_cast<int32_t>(prod);
208 }
209 return offsets;
210 }
211
212
NestedTensor_to_mask(const Tensor & nt,std::optional<int64_t> mask_dim,std::optional<int64_t> mask_dim_length)213 Tensor NestedTensor_to_mask(const Tensor& nt, std::optional<int64_t> mask_dim, std::optional<int64_t> mask_dim_length) {
214 auto* nt_impl = get_nested_tensor_impl(nt);
215 TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_impl), "to_mask only works on contiguous NestedTensors.");
216 TORCH_CHECK(
217 !mask_dim || *mask_dim < nt.dim(),
218 "Requested mask dimension ",
219 *mask_dim,
220 " is bigger than dimension ",
221 nt.dim(),
222 " of given NestedTensor.");
223
224 // TODO: port optimization for 1x1 tensors from
225 // pytorch/nestedtensor's version.
226
227 TORCH_CHECK(
228 mask_dim && *mask_dim == 2 && nt.dim() == 3,
229 "Only the special case of mask_dim == 2 on a 3-D NestedTensor is supported right now.")
230 const auto& sizes = nt_impl->get_nested_sizes();
231 // Shape: # of tensors in our NestedTensor by max size along first dim
232 // TODO: calculate this without allocating a std::vector.
233 const auto result_size_1 = mask_dim_length ? *mask_dim_length : NestedTensor_get_max_size(*nt_impl)[0];
234 auto result = at::ones({sizes.sizes()[0], result_size_1}, at::kBool);
235 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
236 auto* result_data = result.data_ptr<bool>();
237 auto* sizes_ptr = sizes.data_ptr<int64_t>();
238 const auto sizes_size_1 = sizes.sizes()[1];
239 for (const auto ii : c10::irange(sizes.sizes()[0])) {
240 auto length = sizes_ptr[ii * sizes_size_1];
241 for (const auto jj : c10::irange(length)) {
242 result_data[ii * result_size_1 + jj] = false;
243 }
244 }
245 return result;
246 }
247
_jagged_to_padded_dense_forward_cpu(const Tensor & values,TensorList offsets_list,c10::IntArrayRef max_lengths,const double padding_value)248 Tensor _jagged_to_padded_dense_forward_cpu(
249 const Tensor& values,
250 TensorList offsets_list,
251 c10::IntArrayRef max_lengths,
252 const double padding_value) {
253 // TODO: Make this kernel more efficient using TensorIterator or something.
254 TORCH_INTERNAL_ASSERT(
255 offsets_list.size() == 1 && max_lengths.size() == 1,
256 "_jagged_to_padded_dense_forward(): only a single jagged dim is supported for now");
257
258 // allocate appropriately-sized padded tensor
259 const auto& offsets = offsets_list[0];
260 TORCH_CHECK(
261 offsets.dim() == 1,
262 "_jagged_to_padded_dense_forward(): expected 1D offsets, but got offsets.dim() == ",
263 offsets.dim());
264
265 auto batch_size = offsets.size(0) - 1;
266 auto max_length = max_lengths[0];
267 auto values_shape = values.sizes().vec();
268 std::vector<int64_t> padded_shape;
269 padded_shape.reserve(values.dim() + 1);
270 padded_shape.push_back(batch_size);
271 padded_shape.push_back(max_length);
272 padded_shape.insert(padded_shape.end(), values_shape.begin() + 1, values_shape.end());
273 Tensor padded = values.new_full(padded_shape, padding_value);
274
275 // copy data to padded tensor
276 for (auto i : c10::irange(batch_size)) {
277 auto start_offset = offsets[i].item<int64_t>();
278 auto end_offset = offsets[i + 1].item<int64_t>();
279 auto length = end_offset - start_offset;
280 // NB: truncate to max length to match CUDA kernel behavior.
281 length = std::min(length, max_length);
282 auto source = values.slice(0, start_offset, start_offset + length);
283 auto dst = padded.select(0, i).slice(0, 0, length);
284 dst.copy_(source);
285 }
286
287 return padded;
288 }
289
_padded_dense_to_jagged_forward_cpu(const Tensor & padded,TensorList offsets_list,std::optional<int64_t> total_L)290 Tensor _padded_dense_to_jagged_forward_cpu(
291 const Tensor& padded,
292 TensorList offsets_list,
293 std::optional<int64_t> total_L) {
294 // TODO: Make this kernel more efficient using TensorIterator or something.
295 TORCH_INTERNAL_ASSERT(
296 offsets_list.size() == 1,
297 "_padded_dense_to_jagged_forward(): only a single jagged dim is supported for now");
298
299 // allocate appropriately-sized values tensor
300 const auto& offsets = offsets_list[0];
301 TORCH_CHECK(
302 offsets.dim() == 1,
303 "_padded_dense_to_jagged_forward(): expected 1D offsets, but got offsets.dim() == ",
304 offsets.dim());
305
306 auto final_offset = offsets[-1].item<int64_t>();
307 int64_t total_L_val = total_L.has_value() ? (*total_L) : final_offset;
308 if (total_L.has_value()) {
309 // error if the offsets try to index past the end of the packed dimension
310 TORCH_CHECK(
311 final_offset == total_L_val,
312 "_padded_dense_to_jagged_forward(): final offset should match total_L value");
313 }
314
315 TORCH_CHECK(
316 padded.dim() >= 2,
317 "_padded_dense_to_jagged_forward(): expected padded dim >= 2, but padded.dim() == ",
318 padded.dim());
319
320 std::vector<int64_t> values_shape;
321 values_shape.reserve(padded.dim() - 1);
322 values_shape.push_back(total_L_val);
323 auto padded_shape = padded.sizes();
324 values_shape.insert(values_shape.end(), padded_shape.begin() + 2, padded_shape.end());
325 Tensor values = padded.new_empty(values_shape);
326
327 // copy data to values tensor
328 auto batch_size = offsets.size(0) - 1;
329 for (auto i : c10::irange(batch_size)) {
330 auto start_offset = offsets[i].item<int64_t>();
331 auto end_offset = offsets[i + 1].item<int64_t>();
332 auto length = end_offset - start_offset;
333
334 TORCH_CHECK(
335 length <= padded_shape[1],
336 "_padded_dense_to_jagged_forward(): found batch item of length ", length,
337 " when max length specified by padded input is ", padded_shape[1]);
338
339 auto dst = values.slice(0, start_offset, end_offset);
340 auto source = padded.select(0, i).slice(0, 0, length);
341 dst.copy_(source);
342 }
343
344 return values;
345 }
346
347 } // namespace at::native
348