1 /**
2 * This is a handwritten file that accompanies codegenerated header
3 * LazyShapeDtype.h
4 *
5 * The purpose of these shape/dtype inference methods are to fill gaps
6 * where we do not yet have structured kernels in pytorch core. Ops
7 * for which there _are_ structured kernels can use meta::op() to infer
8 * shape/dtype, and codegen makes use of this. Ops for which there are not
9 * yet structured kernels can still be used with lazy_tensor codegen, but
10 * require manual intervention to implement compute_shape_{op} and
11 * compute_dtype_{op}.
12 *
13 * READ THIS!
14 *
15 * 1. Beware: Tech Debt!
16 * ---------------------
17 * These functions are tech debt. We want to delete them all and use structured
18 * kernels instead, but it's a lot faster to write these so we're decoupling the
19 * two efforts to move fast for adding support for codegenned Lazy Tensor ops.
20 *
21 * Codegenned Lazy Tensor ops with handwritten shape formulae are still better
22 * than fully handwritten Lazy Tensor ops (which also have handwritten shape
23 * formulae).
24 *
25 * 2. Structured Kernels For The Win
26 * ---------------------------------
27 * Long term, more and more ops should be supported as 'structured kernels'.
28 * Consider doing your part and porting an op. As ops get ported over, the
29 * codegen will automatically notice and stop generating declarations for these
30 * shape formulae, so we'll need to manually clean up the unused functions in
31 * this file, or somehow automate that.
32 *
33 * https://dev-discuss.pytorch.org/t/slides-from-structured-kernel-presentation/179
34 *
35 * 3. How to figure out the shape/dtype
36 * ------------------------------------
37 * Unfortunately there isn't a one-stop-shop for learning the output shape
38 * formulae for all operators. This is partly because some operators are not
39 * part of our 'public' API, including backward operators which users don't
40 * directly invoke.
41 *
42 * Check our opinfo registry:
43 * https://github.com/pytorch/pytorch/blob/13b859983183ea9938deb5030ac9a0747841f0a8/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
44 *
45 * Read the manual (for ops that are 1:1 with python frontend):
46 * https://pytorch.org/docs/stable/generated/torch.trace.html
47 *
48 */
49
50 #include <torch/csrc/lazy/core/shape_inference.h>
51
52 #include <ATen/AccumulateType.h>
53 #include <ATen/CompositeExplicitAutogradFunctions.h>
54 #include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
55 #include <ATen/Dispatch.h>
56 #include <ATen/ExpandUtils.h>
57 #include <ATen/Functions.h>
58 #include <ATen/InferSize.h>
59 #include <ATen/NativeFunctions.h>
60 #include <ATen/WrapDimUtils.h>
61 #include <ATen/native/ConvUtils.h>
62 #include <ATen/native/ReduceOpsUtils.h>
63 #include <ATen/native/TensorConversions.h>
64 #include <c10/core/ScalarType.h>
65 #include <torch/csrc/lazy/core/dynamic_ir.h>
66 #include <torch/csrc/lazy/core/ops/utils.h>
67 #include <torch/csrc/lazy/core/shape.h>
68 #include <torch/csrc/lazy/core/util.h>
69 #include <ostream>
70 #include <vector>
71
72 namespace torch {
73 namespace lazy {
74
75 // Copied from ATen/native/utils/ParamUtils.h, which aparently I can't include
76 // from here?
expand_param_if_needed(at::IntArrayRef list_param,const char * param_name,int64_t expected_dim)77 static std::vector<int64_t> expand_param_if_needed(
78 at::IntArrayRef list_param,
79 const char* param_name,
80 int64_t expected_dim) {
81 if (list_param.size() == 1) {
82 return std::vector<int64_t>(expected_dim, list_param[0]);
83 } else if ((int64_t)list_param.size() != expected_dim) {
84 std::ostringstream ss;
85 ss << "expected " << param_name << " to be a single integer value or a "
86 << "list of " << expected_dim << " values to match the convolution "
87 << "dimensions, but got " << param_name << "=" << list_param;
88 AT_ERROR(ss.str());
89 } else {
90 return list_param.vec();
91 }
92 }
93
94 // It seems more common to not use parameters than to use them, so disable
95 // unused-parameter warning
96 #pragma GCC diagnostic push
97 #pragma GCC diagnostic ignored "-Wunused-parameter"
98
compute_shape_arange_out(const at::Scalar & start,const at::Scalar & end,const at::Scalar & step,at::Tensor & out)99 TORCH_API std::vector<Shape> compute_shape_arange_out(
100 const at::Scalar& start,
101 const at::Scalar& end,
102 const at::Scalar& step,
103 at::Tensor& out) {
104 double size_d = 0;
105 // shape inference code copied from RangeFactories.cpp arange_out function
106 // Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct
107 // c++ scalar_t type depending on out tensor
108
109 AT_DISPATCH_ALL_TYPES_AND(
110 c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out", [&]() {
111 // Note: acc_type further defines an accumulataion type depending on the
112 // scalar_t and whether its on cuda vs cpu.
113 using accscalar_t = at::acc_type<scalar_t, false>;
114 auto xstart = start.to<accscalar_t>();
115 auto xend = end.to<accscalar_t>();
116 auto xstep = step.to<accscalar_t>();
117
118 // we use double precision for (start - end) / step
119 // to compute size_d for consistency across devices.
120 // The problem with using accscalar_t is that accscalar_t might be
121 // float32 on gpu for a float32 scalar_t, but double on cpu for the
122 // same, and the effective output size starts differing on CPU vs GPU
123 // because of precision issues, which we dont want. the corner-case we
124 // do want to take into account is int64_t, which has higher precision
125 // than double NOLINTNEXTLINE(bugprone-branch-clone)
126 if constexpr (std::is_same_v<scalar_t, int64_t>) {
127 size_d = std::ceil(
128 static_cast<double>(
129 end.to<accscalar_t>() - start.to<accscalar_t>()) /
130 step.to<accscalar_t>());
131 } else {
132 size_d = std::ceil(
133 static_cast<double>(end.to<double>() - start.to<double>()) /
134 step.to<double>());
135 }
136
137 TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
138 TORCH_CHECK(
139 std::isfinite(static_cast<double>(xstart)) &&
140 std::isfinite(static_cast<double>(xend)),
141 "unsupported range: ",
142 xstart,
143 " -> ",
144 xend);
145 TORCH_CHECK(
146 ((xstep > 0) && (xend >= xstart)) ||
147 ((xstep < 0) && (xend <= xstart)),
148 "upper bound and larger bound inconsistent with step sign");
149
150 TORCH_CHECK(
151 size_d >= 0 &&
152 size_d <=
153 static_cast<double>(std::numeric_limits<int64_t>::max()),
154 "invalid size, possible overflow?");
155 });
156
157 int64_t size = static_cast<int64_t>(size_d);
158
159 // From torch.arange docs:
160 // dtype (torch.dtype, optional) – the desired data type of returned tensor.
161 // Default: if None, uses a global default (see
162 // torch.set_default_dtype()). If dtype is not given, infer the data
163 // type from the other input arguments. If any of start, end, or stop are
164 // floating-point, the dtype is inferred to be the default dtype, see
165 // get_default_dtype(). Otherwise, the dtype is inferred to be torch.int64.
166
167 return {Shape(out.scalar_type(), {size})};
168 }
169
compute_shape_abs(const at::Tensor & self)170 std::vector<Shape> compute_shape_abs(const at::Tensor& self) {
171 if (self.is_complex()) {
172 const auto float_type = c10::toRealValueType(self.scalar_type());
173 return {Shape(float_type, self.sizes().vec())};
174 }
175 return {Shape(self.scalar_type(), self.sizes().vec())};
176 }
177
compute_shape_bernoulli(const at::Tensor & self,::std::optional<at::Generator> generator)178 std::vector<Shape> compute_shape_bernoulli(
179 const at::Tensor& self,
180 ::std::optional<at::Generator> generator) {
181 return {Shape(self.scalar_type(), self.sizes().vec())};
182 }
183
compute_shape_bernoulli(const at::Tensor & self,double p,::std::optional<at::Generator> generator)184 std::vector<Shape> compute_shape_bernoulli(
185 const at::Tensor& self,
186 double p,
187 ::std::optional<at::Generator> generator) {
188 return compute_shape_bernoulli(self, generator);
189 }
190
compute_shape_binary_cross_entropy(const at::Tensor & self,const at::Tensor & target,const::std::optional<at::Tensor> & weight,int64_t reduction)191 std::vector<Shape> compute_shape_binary_cross_entropy(
192 const at::Tensor& self,
193 const at::Tensor& target,
194 const ::std::optional<at::Tensor>& weight,
195 int64_t reduction) {
196 if (reduction == at::Reduction::None) {
197 return {Shape(self.scalar_type(), self.sizes().vec())};
198 }
199 return {Shape(self.scalar_type(), {})};
200 }
201
compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output,const at::Tensor & self,const at::Tensor & target,const::std::optional<at::Tensor> & weight,int64_t reduction)202 std::vector<Shape> compute_shape_binary_cross_entropy_backward(
203 const at::Tensor& grad_output,
204 const at::Tensor& self,
205 const at::Tensor& target,
206 const ::std::optional<at::Tensor>& weight,
207 int64_t reduction) {
208 return {Shape(self.scalar_type(), self.sizes().vec())};
209 }
210
compute_shape_constant_pad_nd(const at::Tensor & self,at::IntArrayRef pad,const at::Scalar & value)211 std::vector<Shape> compute_shape_constant_pad_nd(
212 const at::Tensor& self,
213 at::IntArrayRef pad,
214 const at::Scalar& value) {
215 // Based on aten/src/ATen/native/ConstantPadNd.cpp::constant_pad_nd
216 TORCH_CHECK(
217 pad.size() % 2 == 0,
218 "Length of pad must be even but instead it equals ",
219 pad.size());
220
221 auto input_sizes = self.sizes();
222 auto l_inp = self.dim();
223
224 auto l_pad = pad.size() / 2;
225 auto l_diff = l_inp - l_pad;
226 TORCH_CHECK(
227 l_inp >= (int64_t)l_pad,
228 "Length of pad should be no more than twice the number of "
229 "dimensions of the input. Pad length is ",
230 pad.size(),
231 "while the input has ",
232 l_inp,
233 "dimensions.");
234
235 std::vector<int64_t> new_shape;
236 for (size_t i = 0; i < (size_t)l_diff; i++) {
237 new_shape.emplace_back(input_sizes[i]);
238 }
239
240 for (const auto i : c10::irange((size_t)l_pad)) {
241 auto pad_idx = pad.size() - ((i + 1) * 2);
242 auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1];
243 TORCH_CHECK(
244 new_dim > 0,
245 "The input size ",
246 input_sizes[l_diff + i],
247 ", plus negative padding ",
248 pad[pad_idx],
249 " and ",
250 pad[pad_idx + 1],
251 " resulted in a negative output size, "
252 "which is invalid. Check dimension ",
253 l_diff + i,
254 " of your input.");
255 new_shape.emplace_back(new_dim);
256 }
257 return {Shape(self.scalar_type(), new_shape)};
258 }
259
compute_shape_convolution_backward(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & weight,at::OptionalIntArrayRef bias_sizes,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool transposed,at::IntArrayRef output_padding,int64_t groups,::std::array<bool,3> output_mask)260 std::vector<Shape> compute_shape_convolution_backward(
261 const at::Tensor& grad_output,
262 const at::Tensor& input,
263 const at::Tensor& weight,
264 at::OptionalIntArrayRef bias_sizes,
265 at::IntArrayRef stride,
266 at::IntArrayRef padding,
267 at::IntArrayRef dilation,
268 bool transposed,
269 at::IntArrayRef output_padding,
270 int64_t groups,
271 ::std::array<bool, 3> output_mask) {
272 if (bias_sizes.has_value()) {
273 return {
274 Shape(input.scalar_type(), input.sizes().vec()),
275 Shape(weight.scalar_type(), weight.sizes().vec()),
276 Shape(grad_output.scalar_type(), bias_sizes.value().vec())};
277 } else {
278 // TODO(whc) not sure whether to return 2 shapes here, or a 3rd one that is
279 // empty
280 return {
281 Shape(input.scalar_type(), input.sizes().vec()),
282 Shape(weight.scalar_type(), weight.sizes().vec())};
283 }
284 }
285
compute_shape_convolution(const at::Tensor & input,const at::Tensor & weight,const::std::optional<at::Tensor> & bias,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool transposed,at::IntArrayRef output_padding,int64_t groups)286 std::vector<Shape> compute_shape_convolution(
287 const at::Tensor& input,
288 const at::Tensor& weight,
289 const ::std::optional<at::Tensor>& bias,
290 at::IntArrayRef stride,
291 at::IntArrayRef padding,
292 at::IntArrayRef dilation,
293 bool transposed,
294 at::IntArrayRef output_padding,
295 int64_t groups) {
296 int64_t dim = weight.ndimension() - 2;
297 TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
298
299 // at::convolution performs parameter expansion before running kernels on
300 // expanded parameters we must do the same. Shape formulae access differnent
301 // dimensions of e.g. output_padding, but output_padding may be passed in as a
302 // scalar. Sadly, accessing output_padding[1] in this case gives incorrect
303 // results rather than indexing error
304 auto expanded_stride = expand_param_if_needed(stride, "stride", dim);
305 auto expanded_padding = expand_param_if_needed(padding, "padding", dim);
306 auto expanded_dilation = expand_param_if_needed(dilation, "dilation", dim);
307 if (!transposed) {
308 return {Shape(
309 input.scalar_type(),
310 at::native::conv_output_size(
311 input.sizes(),
312 weight.sizes(),
313 expanded_padding,
314 expanded_stride,
315 expanded_dilation))};
316 } else {
317 auto expanded_output_padding =
318 expand_param_if_needed(output_padding, "output_padding", dim);
319 auto out_shape = at::native::conv_input_size(
320 input.sizes(),
321 weight.sizes(),
322 expanded_padding,
323 expanded_output_padding,
324 expanded_stride,
325 expanded_dilation,
326 groups);
327 return {Shape(input.scalar_type(), out_shape)};
328 }
329 }
330
compute_shape_masked_fill(const at::Tensor & self,const at::Tensor & mask,const at::Scalar & value)331 std::vector<Shape> compute_shape_masked_fill(
332 const at::Tensor& self,
333 const at::Tensor& mask,
334 const at::Scalar& value) {
335 return {Shape(self.scalar_type(), self.sizes().vec())};
336 }
337
compute_shape_masked_fill(const at::Tensor & self,const at::Tensor & mask,const at::Tensor & value)338 std::vector<Shape> compute_shape_masked_fill(
339 const at::Tensor& self,
340 const at::Tensor& mask,
341 const at::Tensor& value) {
342 return {Shape(self.scalar_type(), self.sizes().vec())};
343 }
344
compute_shape_max(const at::Tensor & self)345 std::vector<Shape> compute_shape_max(const at::Tensor& self) {
346 TORCH_CHECK(
347 self.numel() > 0,
348 "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.");
349 return {Shape(self.scalar_type(), {})};
350 }
351
compute_shape_min(const at::Tensor & self)352 std::vector<Shape> compute_shape_min(const at::Tensor& self) {
353 TORCH_CHECK(
354 self.numel() > 0,
355 "min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.");
356 return {Shape(self.scalar_type(), {})};
357 }
358
compute_shape_nonzero(const at::Tensor & t,bool as_tuple)359 static std::vector<Shape> compute_shape_nonzero(
360 const at::Tensor& t,
361 bool as_tuple) {
362 if (as_tuple) {
363 auto res = std::vector<Shape>();
364 for (auto dim_size : t.sizes()) {
365 res.emplace_back(Shape(at::kLong, {dim_size}));
366 }
367 return res;
368 }
369 int64_t max_elements = 1;
370 for (auto dim_size : t.sizes()) {
371 max_elements *= dim_size;
372 }
373 return {Shape(at::kLong, {max_elements, (int64_t)t.sizes().size()})};
374 }
375
compute_shape_nonzero(const at::Tensor & self)376 std::vector<Shape> compute_shape_nonzero(const at::Tensor& self) {
377 return compute_shape_nonzero(self, false);
378 }
379
compute_shape_embedding(const at::Tensor & weight,const at::Tensor & indices,int64_t padding_idx,bool scale_grad_by_freq,bool sparse)380 std::vector<Shape> compute_shape_embedding(
381 const at::Tensor& weight,
382 const at::Tensor& indices,
383 int64_t padding_idx,
384 bool scale_grad_by_freq,
385 bool sparse) {
386 // Based on aten/src/ATen/native/Embedding.cpp::embedding.
387 std::vector<int64_t> out_sizes = indices.sizes().vec();
388 out_sizes.emplace_back(weight.size(1));
389 return {Shape(weight.scalar_type(), out_sizes)};
390 }
391
compute_shape_std(const at::Tensor & self,bool unbiased)392 std::vector<Shape> compute_shape_std(const at::Tensor& self, bool unbiased) {
393 return compute_shape_std(self, ::std::nullopt, ::std::nullopt, false);
394 }
compute_shape_std(const at::Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)395 std::vector<Shape> compute_shape_std(
396 const at::Tensor& self,
397 at::OptionalIntArrayRef dim,
398 bool unbiased,
399 bool keepdim) {
400 return compute_shape_std(self, dim, ::std::nullopt, keepdim);
401 }
compute_shape_std(const at::Tensor & self,at::OptionalIntArrayRef dim,const::std::optional<at::Scalar> & correction,bool keepdim)402 std::vector<Shape> compute_shape_std(
403 const at::Tensor& self,
404 at::OptionalIntArrayRef dim,
405 const ::std::optional<at::Scalar>& correction,
406 bool keepdim) {
407 if (dim.has_value()) {
408 auto shape = at::native::shape_from_dim_mask(
409 self, at::native::make_dim_mask(dim.value(), self.dim()), keepdim);
410 return {Shape(
411 self.scalar_type(), std::vector<int64_t>(shape.begin(), shape.end()))};
412 }
413 return {Shape(self.scalar_type(), {})};
414 }
415
compute_shape_embedding_dense_backward(const at::Tensor & grad_output,const at::Tensor & indices,int64_t num_weights,int64_t padding_idx,bool scale_grad_by_freq)416 std::vector<Shape> compute_shape_embedding_dense_backward(
417 const at::Tensor& grad_output,
418 const at::Tensor& indices,
419 int64_t num_weights,
420 int64_t padding_idx,
421 bool scale_grad_by_freq) {
422 // Based on aten/src/ATen/native/Embedding.cpp::embedding_dense_backward_cpu.
423 return {
424 Shape(grad_output.scalar_type(), {num_weights, grad_output.size(-1)})};
425 }
426
compute_shape_expand(const at::Tensor & self,at::IntArrayRef size,bool implicit)427 std::vector<Shape> compute_shape_expand(
428 const at::Tensor& self,
429 at::IntArrayRef size,
430 bool implicit) {
431 TORCH_CHECK_GE(static_cast<int64_t>(size.size()), self.dim());
432 size_t num_new_dimensions = size.size() - self.dim();
433 std::vector<int64_t> padded_self(num_new_dimensions, 0);
434 padded_self.insert(
435 padded_self.end(), self.sizes().begin(), self.sizes().end());
436 std::vector<int64_t> target_size(size.size());
437 for (const auto idx : c10::irange(size.size())) {
438 target_size[idx] = size[idx] == -1 ? padded_self[idx] : size[idx];
439 }
440 return {Shape(self.scalar_type(), target_size)};
441 }
442
compute_shape_expand(const at::Tensor & self,c10::SymIntArrayRef size,bool implicit)443 std::vector<Shape> compute_shape_expand(
444 const at::Tensor& self,
445 c10::SymIntArrayRef size,
446 bool implicit) {
447 TORCH_CHECK_GE(static_cast<int64_t>(size.size()), self.dim());
448 std::vector<c10::SymInt> _sizes = ToVector<c10::SymInt>(size);
449 size_t num_new_dimensions = _sizes.size() - self.dim();
450 std::vector<int64_t> padded_self(num_new_dimensions, 0);
451 padded_self.insert(
452 padded_self.end(), self.sizes().begin(), self.sizes().end());
453 std::vector<int64_t> target_size(_sizes.size());
454 for (const auto idx : c10::irange(_sizes.size())) {
455 if (auto ma = _sizes[idx].maybe_as_int()) {
456 target_size[idx] = *ma;
457 if (*ma == -1) {
458 // -1 can't be specified for non-existing dimensions
459 TORCH_CHECK(idx >= num_new_dimensions);
460 target_size[idx] = padded_self[idx];
461 } else {
462 target_size[idx] = *ma;
463 }
464 } else {
465 auto* lazySymNode = dynamic_cast<torch::lazy::SymNodeImpl*>(
466 _sizes[idx].toSymNodeImplUnowned());
467 TORCH_INTERNAL_ASSERT(lazySymNode);
468 auto size_node = lazySymNode->node_;
469 auto static_value =
470 std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
471 ->getStaticValue();
472 target_size[idx] = static_value;
473 }
474 }
475 return {Shape(self.scalar_type(), target_size)};
476 }
477
compute_shape_index_select(const at::Tensor & self,int64_t dim,const at::Tensor & index)478 std::vector<Shape> compute_shape_index_select(
479 const at::Tensor& self,
480 int64_t dim,
481 const at::Tensor& index) {
482 // Based on definition of
483 // https://pytorch.org/docs/stable/generated/torch.index_select.html. Promote
484 // Rank 0 index tensor to a 1 * 1 tensor.
485 dim = at::maybe_wrap_dim(dim, self);
486 auto index_dim = index.dim() > 0 ? index.dim() : 1;
487 auto index_size = index.dim() > 0 ? index.size(0) : 1;
488 TORCH_CHECK(index_dim == 1);
489
490 auto self_sizes = self.sizes();
491 std::vector<int64_t> output_sizes(self_sizes.begin(), self_sizes.end());
492 TORCH_CHECK(!output_sizes.empty(), "Empty output_sizes is not supported.");
493 output_sizes[dim] = index_size;
494
495 return {Shape(self.scalar_type(), output_sizes)};
496 }
497
compute_shape_inverse(const at::Tensor & self)498 std::vector<Shape> compute_shape_inverse(const at::Tensor& self) {
499 return {Shape(self.scalar_type(), self.sizes().vec())};
500 }
501
compute_shape_isnan(const at::Tensor & self)502 std::vector<Shape> compute_shape_isnan(const at::Tensor& self) {
503 return {Shape(c10::ScalarType::Bool, self.sizes().vec())};
504 }
505
compute_shape_cat(at::TensorList tensors,int64_t dim)506 std::vector<Shape> compute_shape_cat(at::TensorList tensors, int64_t dim) {
507 // TODO(whc) support cat in codegen and move this to compute_*_cat functions
508 std::vector<int64_t> out_shape(
509 tensors[0].sizes().begin(), tensors[0].sizes().end());
510
511 dim = at::maybe_wrap_dim(dim, tensors);
512 size_t extended_dim_shape = 0;
513 for (auto& tensor : tensors) {
514 extended_dim_shape += tensor.sizes()[dim];
515 }
516 TORCH_CHECK(!out_shape.empty(), "Scalar tensors are not supported in cat.");
517 TORCH_CHECK(
518 extended_dim_shape <=
519 static_cast<size_t>(std::numeric_limits<int64_t>::max()),
520 "Size overflow");
521 out_shape[dim] = extended_dim_shape;
522 return {Shape(tensors[0].scalar_type(), out_shape)};
523 }
524
compute_shape_cholesky(const at::Tensor & self,bool upper)525 TORCH_API std::vector<torch::lazy::Shape> compute_shape_cholesky(
526 const at::Tensor& self,
527 bool upper) {
528 return {Shape(self.scalar_type(), self.sizes().vec())};
529 }
530
compute_shape_native_batch_norm(const at::Tensor & input,const::std::optional<at::Tensor> & weight,const::std::optional<at::Tensor> & bias,const::std::optional<at::Tensor> & running_mean,const::std::optional<at::Tensor> & running_var,bool training,double momentum,double eps)531 std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
532 const at::Tensor& input,
533 const ::std::optional<at::Tensor>& weight,
534 const ::std::optional<at::Tensor>& bias,
535 const ::std::optional<at::Tensor>& running_mean,
536 const ::std::optional<at::Tensor>& running_var,
537 bool training,
538 double momentum,
539 double eps) {
540 std::vector<torch::lazy::Shape> shapes;
541 shapes.reserve(3);
542 shapes.emplace_back(input.scalar_type(), input.sizes().vec());
543
544 // A separate mean and var needs to be kept for each channel.
545 TORCH_CHECK(
546 input.sizes().size() >= 2,
547 "Input tensor must have at least batch and channel dimensions!");
548 int64_t num_features = input.size(1);
549
550 if (running_mean.has_value()) {
551 shapes.emplace_back(
552 running_mean.value().scalar_type(), running_mean.value().sizes().vec());
553 } else {
554 shapes.emplace_back(
555 at::get_default_dtype_as_scalartype(),
556 std::vector<int64_t>{num_features});
557 }
558
559 if (running_var.has_value()) {
560 shapes.emplace_back(
561 running_var.value().scalar_type(), running_var.value().sizes().vec());
562 } else {
563 shapes.emplace_back(
564 at::get_default_dtype_as_scalartype(),
565 std::vector<int64_t>{num_features});
566 }
567 return shapes;
568 }
569
compute_shape_native_batch_norm_backward(const at::Tensor & grad_out,const at::Tensor & input,const::std::optional<at::Tensor> & weight,const::std::optional<at::Tensor> & running_mean,const::std::optional<at::Tensor> & running_var,const::std::optional<at::Tensor> & save_mean,const::std::optional<at::Tensor> & save_invstd,bool train,double eps,::std::array<bool,3> output_mask)570 std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
571 const at::Tensor& grad_out,
572 const at::Tensor& input,
573 const ::std::optional<at::Tensor>& weight,
574 const ::std::optional<at::Tensor>& running_mean,
575 const ::std::optional<at::Tensor>& running_var,
576 const ::std::optional<at::Tensor>& save_mean,
577 const ::std::optional<at::Tensor>& save_invstd,
578 bool train,
579 double eps,
580 ::std::array<bool, 3> output_mask) {
581 std::vector<torch::lazy::Shape> shapes;
582 shapes.reserve(3);
583 shapes.emplace_back(input.scalar_type(), input.sizes().vec());
584
585 // A separate mean and var needs to be kept for each channel.
586 TORCH_CHECK(
587 input.sizes().size() >= 2,
588 "Input tensor must have at least batch and channel dimensions!");
589 int64_t num_features = input.size(1);
590
591 // `weight` and `bias` are vectors of length C (number of channels)`
592 shapes.emplace_back(
593 at::get_default_dtype_as_scalartype(),
594 std::vector<int64_t>{num_features});
595 shapes.emplace_back(
596 at::get_default_dtype_as_scalartype(),
597 std::vector<int64_t>{num_features});
598
599 return shapes;
600 }
601
compute_shape_native_layer_norm(const at::Tensor & input,at::IntArrayRef normalized_shape,const::std::optional<at::Tensor> & weight,const::std::optional<at::Tensor> & bias,double eps)602 std::vector<Shape> compute_shape_native_layer_norm(
603 const at::Tensor& input,
604 at::IntArrayRef normalized_shape,
605 const ::std::optional<at::Tensor>& weight,
606 const ::std::optional<at::Tensor>& bias,
607 double eps) {
608 // Copied from aten/src/ATen/native/layer_norm.cpp::layer_norm_cpu_out.
609 auto input_shape = input.sizes().vec();
610 const size_t axis = input.dim() - normalized_shape.size();
611
612 std::vector<int64_t> stat_shape;
613 for (const auto idx : c10::irange(axis)) {
614 TORCH_CHECK(idx < input_shape.size(), "Shape mismatch");
615 stat_shape.emplace_back(input_shape[idx]);
616 }
617 for (const auto idx : c10::irange(axis, input.dim())) {
618 (void)idx; // Suppress unused variable warning
619 stat_shape.emplace_back(1);
620 }
621
622 return {
623 Shape(input.scalar_type(), input_shape),
624 Shape(input.scalar_type(), stat_shape),
625 Shape(input.scalar_type(), stat_shape)};
626 }
627
compute_shape_native_layer_norm_backward(const at::Tensor & grad_out,const at::Tensor & input,at::IntArrayRef normalized_shape,const at::Tensor & mean,const at::Tensor & rstd,const::std::optional<at::Tensor> & weight,const::std::optional<at::Tensor> & bias,::std::array<bool,3> output_mask)628 std::vector<Shape> compute_shape_native_layer_norm_backward(
629 const at::Tensor& grad_out,
630 const at::Tensor& input,
631 at::IntArrayRef normalized_shape,
632 const at::Tensor& mean,
633 const at::Tensor& rstd,
634 const ::std::optional<at::Tensor>& weight,
635 const ::std::optional<at::Tensor>& bias,
636 ::std::array<bool, 3> output_mask) {
637 std::vector<Shape> shapes;
638 shapes.emplace_back(
639 input.scalar_type(),
640 output_mask[0] ? input.sizes().vec() : std::vector<int64_t>{});
641 shapes.emplace_back(
642 weight && weight->defined() ? weight->scalar_type() : input.scalar_type(),
643 output_mask[1] && weight ? weight->sizes().vec()
644 : std::vector<int64_t>{});
645 shapes.emplace_back(
646 bias && bias->defined() ? bias->scalar_type() : input.scalar_type(),
647 output_mask[2] && bias ? bias->sizes().vec() : std::vector<int64_t>{});
648 return shapes;
649 }
650
compute_shape_mean(const at::Tensor & self,::std::optional<at::ScalarType> dtype)651 std::vector<Shape> compute_shape_mean(
652 const at::Tensor& self,
653 ::std::optional<at::ScalarType> dtype) {
654 if (dtype.has_value()) {
655 return {Shape(dtype.value(), {})};
656 }
657 return {Shape(self.scalar_type(), {})};
658 }
659
compute_shape_new_empty_strided(const at::Tensor & self,at::IntArrayRef size,at::IntArrayRef stride,::std::optional<at::ScalarType> dtype,::std::optional<at::Layout> layout,::std::optional<at::Device> device,::std::optional<bool> pin_memory)660 std::vector<Shape> compute_shape_new_empty_strided(
661 const at::Tensor& self,
662 at::IntArrayRef size,
663 at::IntArrayRef stride,
664 ::std::optional<at::ScalarType> dtype,
665 ::std::optional<at::Layout> layout,
666 ::std::optional<at::Device> device,
667 ::std::optional<bool> pin_memory) {
668 return {Shape(dtype.has_value() ? *dtype : self.scalar_type(), size.vec())};
669 }
670
compute_shape_mv(const at::Tensor & self,const at::Tensor & vec)671 std::vector<Shape> compute_shape_mv(
672 const at::Tensor& self,
673 const at::Tensor& vec) {
674 return {Shape(self.scalar_type(), {self.size(0)})};
675 }
676
compute_shape_native_dropout(const at::Tensor & input,double p,::std::optional<bool> train)677 std::vector<Shape> compute_shape_native_dropout(
678 const at::Tensor& input,
679 double p,
680 ::std::optional<bool> train) {
681 return {
682 Shape(input.scalar_type(), input.sizes().vec()),
683 Shape(c10::ScalarType::Bool, input.sizes().vec())};
684 }
685
compute_shape_native_dropout_backward(const at::Tensor & grad_output,const at::Tensor & mask,double scale)686 std::vector<Shape> compute_shape_native_dropout_backward(
687 const at::Tensor& grad_output,
688 const at::Tensor& mask,
689 double scale) {
690 return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
691 }
692
compute_shape_random(const at::Tensor & self,::std::optional<at::Generator> generator)693 std::vector<Shape> compute_shape_random(
694 const at::Tensor& self,
695 ::std::optional<at::Generator> generator) {
696 return {Shape(self.scalar_type(), self.sizes().vec())};
697 }
698
compute_shape_random(const at::Tensor & self,int64_t to,::std::optional<at::Generator> generator)699 std::vector<Shape> compute_shape_random(
700 const at::Tensor& self,
701 int64_t to,
702 ::std::optional<at::Generator> generator) {
703 return compute_shape_random(self, generator);
704 }
705
compute_shape_random(const at::Tensor & self,int64_t from,::std::optional<int64_t> to,::std::optional<at::Generator> generator)706 std::vector<Shape> compute_shape_random(
707 const at::Tensor& self,
708 int64_t from,
709 ::std::optional<int64_t> to,
710 ::std::optional<at::Generator> generator) {
711 return compute_shape_random(self, generator);
712 }
713
compute_shape_relu(const at::Tensor & self)714 std::vector<Shape> compute_shape_relu(const at::Tensor& self) {
715 return {Shape(self.scalar_type(), self.sizes().vec())};
716 }
717
compute_shape_sum(const at::Tensor & self,::std::optional<at::ScalarType> dtype)718 std::vector<Shape> compute_shape_sum(
719 const at::Tensor& self,
720 ::std::optional<at::ScalarType> dtype) {
721 if (dtype.has_value()) {
722 return {Shape(dtype.value(), {})};
723 }
724 // It's undocumented, but torch::sum promotes all integral types to int64_t by
725 // default
726 if (isIntegralType(self.scalar_type(), /*includeBool*/ true)) {
727 return {Shape(c10::ScalarType::Long, {})};
728 }
729 return {Shape(self.scalar_type(), {})};
730 ;
731 }
732
compute_shape_zero(const at::Tensor & self)733 std::vector<Shape> compute_shape_zero(const at::Tensor& self) {
734 return {Shape(self.scalar_type(), self.sizes().vec())};
735 }
736
compute_shape_take(const at::Tensor & self,const at::Tensor & index)737 TORCH_API std::vector<torch::lazy::Shape> compute_shape_take(
738 const at::Tensor& self,
739 const at::Tensor& index) {
740 return {Shape(self.scalar_type(), index.sizes().vec())};
741 }
742
compute_shape_trace(const at::Tensor & self)743 std::vector<Shape> compute_shape_trace(const at::Tensor& self) {
744 return {Shape(self.scalar_type(), {})};
745 }
746
compute_shape_sort(const at::Tensor & self,int64_t dim,bool descending)747 std::vector<Shape> compute_shape_sort(
748 const at::Tensor& self,
749 int64_t dim,
750 bool descending) {
751 return {
752 Shape(self.scalar_type(), self.sizes().vec()),
753 Shape(c10::ScalarType::Long, self.sizes().vec())};
754 }
755
compute_shape_slogdet(const at::Tensor & self)756 std::vector<Shape> compute_shape_slogdet(const at::Tensor& self) {
757 // assumes self.shape is {*, n, n} and returns shape *
758 TORCH_INTERNAL_ASSERT(self.dim() >= 2);
759 std::vector<int64_t> out_sizes(self.sizes().begin(), self.sizes().end() - 2);
760 // Doesn't check input dtype, but output dtype either matches it,
761 // or the actual slogdet operation will throw if it's an unsupported type.
762 // Sign and det outputs hold the same shape, dtype.
763 return {
764 Shape(self.scalar_type(), out_sizes),
765 Shape(self.scalar_type(), out_sizes)};
766 }
767
compute_shape_logical_and(const at::Tensor & self,const at::Tensor & other)768 std::vector<torch::lazy::Shape> compute_shape_logical_and(
769 const at::Tensor& self,
770 const at::Tensor& other) {
771 TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes()));
772 return {Shape(
773 c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))};
774 }
775
compute_shape_logical_not(const at::Tensor & self)776 std::vector<torch::lazy::Shape> compute_shape_logical_not(
777 const at::Tensor& self) {
778 return {Shape(c10::ScalarType::Bool, self.sizes().vec())};
779 }
780
compute_shape_logical_or(const at::Tensor & self,const at::Tensor & other)781 std::vector<torch::lazy::Shape> compute_shape_logical_or(
782 const at::Tensor& self,
783 const at::Tensor& other) {
784 TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes()));
785 return {Shape(
786 c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))};
787 }
788
compute_shape_logical_xor(const at::Tensor & self,const at::Tensor & other)789 std::vector<torch::lazy::Shape> compute_shape_logical_xor(
790 const at::Tensor& self,
791 const at::Tensor& other) {
792 TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes()));
793 return {Shape(
794 c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))};
795 }
796
compute_shape_smooth_l1_loss_backward(const at::Tensor & grad_output,const at::Tensor & self,const at::Tensor & target,int64_t reduction,double beta)797 std::vector<Shape> compute_shape_smooth_l1_loss_backward(
798 const at::Tensor& grad_output,
799 const at::Tensor& self,
800 const at::Tensor& target,
801 int64_t reduction,
802 double beta) {
803 // The `grad_output` tensor is really the input to this kernel, and while its
804 // shape may vary following the logic of the forward output, the output of
805 // this kernel should have fixed shapes matching the inputs to the forward
806 // kernel.
807 return {Shape(self.scalar_type(), self.sizes().vec())};
808 }
809
compute_shape_logdet(const at::Tensor & self)810 std::vector<Shape> compute_shape_logdet(const at::Tensor& self) {
811 // assumes self.shape is {*, n, n} and returns shape *
812 TORCH_INTERNAL_ASSERT(self.dim() >= 2);
813 std::vector<int64_t> out_sizes(self.sizes().begin(), self.sizes().end() - 2);
814 // Doesn't check input dtype, but output dtype either matches it,
815 // or the actual logdet operation will throw if it's an unsupported type
816 return {Shape(self.scalar_type(), out_sizes)};
817 }
818
compute_shape_log_sigmoid_forward(const at::Tensor & self)819 std::vector<Shape> compute_shape_log_sigmoid_forward(const at::Tensor& self) {
820 // Based on definition of
821 // aten/src/ATen/native/Activation.cpp::log_sigmoid_forward_out_cpu.
822 return {
823 Shape(self.scalar_type(), self.sizes().vec()),
824 Shape(self.scalar_type(), self.sizes().vec())};
825 }
826
compute_shape_log_sigmoid_backward(const at::Tensor & grad_output,const at::Tensor & self,const at::Tensor & buffer)827 std::vector<Shape> compute_shape_log_sigmoid_backward(
828 const at::Tensor& grad_output,
829 const at::Tensor& self,
830 const at::Tensor& buffer) {
831 // Based on definition of
832 // aten/src/ATen/native/Activation.cpp::log_sigmoid_backward_cpu*.
833 return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
834 }
835
compute_shape_nll_loss2d_forward(const at::Tensor & self,const at::Tensor & target,const::std::optional<at::Tensor> & weight,int64_t reduction,int64_t ignore_index)836 std::vector<Shape> compute_shape_nll_loss2d_forward(
837 const at::Tensor& self,
838 const at::Tensor& target,
839 const ::std::optional<at::Tensor>& weight,
840 int64_t reduction,
841 int64_t ignore_index) {
842 // Based on definition of
843 // aten/src/ATen/native/LossNLL2d.cpp:nll_loss2d_forward_cpu
844 auto sizes =
845 (reduction == at::Reduction::Reduction::None ? target.sizes().vec()
846 : std::vector<int64_t>{});
847 return {Shape(self.scalar_type(), sizes), Shape(self.scalar_type(), {})};
848 }
849
compute_shape_nll_loss2d_backward(const at::Tensor & grad_output,const at::Tensor & self,const at::Tensor & target,const::std::optional<at::Tensor> & weight,int64_t reduction,int64_t ignore_index,const at::Tensor & total_weight)850 std::vector<Shape> compute_shape_nll_loss2d_backward(
851 const at::Tensor& grad_output,
852 const at::Tensor& self,
853 const at::Tensor& target,
854 const ::std::optional<at::Tensor>& weight,
855 int64_t reduction,
856 int64_t ignore_index,
857 const at::Tensor& total_weight) {
858 return {Shape(self.scalar_type(), self.sizes().vec())};
859 }
860
compute_shape_grid_sampler_2d(const at::Tensor & input,const at::Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)861 std::vector<Shape> compute_shape_grid_sampler_2d(
862 const at::Tensor& input,
863 const at::Tensor& grid,
864 int64_t interpolation_mode,
865 int64_t padding_mode,
866 bool align_corners) {
867 // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp
868 int64_t N = input.size(0);
869 int64_t C = input.size(1);
870 int64_t H = grid.size(1);
871 int64_t W = grid.size(2);
872 return {Shape(input.scalar_type(), {N, C, H, W})};
873 }
874
compute_shape_grid_sampler_2d_backward(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,::std::array<bool,2> output_mask)875 std::vector<Shape> compute_shape_grid_sampler_2d_backward(
876 const at::Tensor& grad_output,
877 const at::Tensor& input,
878 const at::Tensor& grid,
879 int64_t interpolation_mode,
880 int64_t padding_mode,
881 bool align_corners,
882 ::std::array<bool, 2> output_mask) {
883 // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp
884 auto grad_input_shape = Shape(input.scalar_type(), input.sizes().vec());
885 auto grad_grid_shape = Shape(grid.scalar_type(), grid.sizes().vec());
886 return {grad_input_shape, grad_grid_shape};
887 }
888
compute_shape_flip(const at::Tensor & self,at::IntArrayRef dims)889 std::vector<Shape> compute_shape_flip(
890 const at::Tensor& self,
891 at::IntArrayRef dims) {
892 return {Shape(self.scalar_type(), self.sizes().vec())};
893 }
894
compute_shape__adaptive_avg_pool2d(const at::Tensor & self,at::IntArrayRef output_size)895 std::vector<Shape> compute_shape__adaptive_avg_pool2d(
896 const at::Tensor& self,
897 at::IntArrayRef output_size) {
898 // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
899 // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp`
900 TORCH_CHECK(
901 output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2");
902 TORCH_CHECK(
903 (output_size[0] >= 0 && output_size[1] >= 0),
904 "adaptive_avg_pool2d: elements of output_size must be greater than or equal to 0 ",
905 "but received {",
906 output_size[0],
907 ", ",
908 output_size[1],
909 "}");
910 int64_t ndim = self.ndimension();
911 for (const auto i : c10::irange(1, ndim)) {
912 TORCH_CHECK(
913 self.size(i) > 0,
914 "adaptive_avg_pool2d(): Expected self to have non-zero size for non-batch dimensions, "
915 "but Tensor has sizes ",
916 self.sizes(),
917 " with dimension ",
918 i,
919 " being "
920 "empty");
921 }
922 TORCH_CHECK(
923 (ndim == 3 || ndim == 4),
924 "adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got ",
925 self.sizes());
926
927 int64_t channels = self.size(-3);
928 int64_t output_height = output_size[0];
929 int64_t output_width = output_size[1];
930
931 if (ndim == 3) {
932 return {Shape(self.scalar_type(), {channels, output_height, output_width})};
933 } else {
934 int64_t nbatch = self.size(0);
935 return {Shape(
936 self.scalar_type(), {nbatch, channels, output_height, output_width})};
937 }
938 }
939
compute_shape__adaptive_avg_pool2d_backward(const at::Tensor & grad_output,const at::Tensor & self)940 std::vector<Shape> compute_shape__adaptive_avg_pool2d_backward(
941 const at::Tensor& grad_output,
942 const at::Tensor& self) {
943 // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
944 int64_t ndim = grad_output.ndimension();
945
946 for (const auto i : c10::irange(1, ndim)) {
947 TORCH_CHECK(
948 grad_output.size(i) > 0,
949 "adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, "
950 "but grad_output has sizes ",
951 grad_output.sizes(),
952 " with dimension ",
953 i,
954 " being "
955 "empty");
956 }
957
958 TORCH_CHECK(
959 (ndim == 3 || ndim == 4),
960 "adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got ",
961 self.sizes());
962 TORCH_CHECK(
963 self.dtype() == grad_output.dtype(),
964 "expected dtype ",
965 self.dtype(),
966 " for `grad_output` but got dtype ",
967 grad_output.dtype());
968
969 return {Shape(self.scalar_type(), self.sizes().vec())};
970 }
971
compute_shape__adaptive_avg_pool3d(const at::Tensor & self,at::IntArrayRef output_size)972 std::vector<Shape> compute_shape__adaptive_avg_pool3d(
973 const at::Tensor& self,
974 at::IntArrayRef output_size) {
975 // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
976 // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp`
977 TORCH_CHECK(
978 output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3");
979 TORCH_CHECK(
980 (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0),
981 "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ",
982 "but received {",
983 output_size[0],
984 ", ",
985 output_size[1],
986 ", ",
987 output_size[2],
988 "}");
989 int64_t ndim = self.ndimension();
990 for (const auto i : c10::irange(1, ndim)) {
991 TORCH_CHECK(
992 self.size(i) > 0,
993 "adaptive_avg_pool3d(): Expected self to have non-zero size for non-batch dimensions, "
994 "but Tensor has sizes ",
995 self.sizes(),
996 " with dimension ",
997 i,
998 " being "
999 "empty");
1000 }
1001 TORCH_CHECK(
1002 (ndim == 4 || ndim == 5),
1003 "adaptive_avg_pool3d(): Expected 4D or 5D tensor, but got ",
1004 self.sizes());
1005
1006 int64_t channels = self.size(-4);
1007 int64_t output_depth = output_size[0];
1008 int64_t output_height = output_size[1];
1009 int64_t output_width = output_size[2];
1010
1011 if (ndim == 4) {
1012 return {Shape(
1013 self.scalar_type(),
1014 {channels, output_depth, output_height, output_width})};
1015 } else {
1016 int64_t nbatch = self.size(0);
1017 return {Shape(
1018 self.scalar_type(),
1019 {nbatch, channels, output_depth, output_height, output_width})};
1020 }
1021 }
1022
compute_shape__adaptive_avg_pool3d_backward(const at::Tensor & grad_output,const at::Tensor & self)1023 std::vector<Shape> compute_shape__adaptive_avg_pool3d_backward(
1024 const at::Tensor& grad_output,
1025 const at::Tensor& self) {
1026 // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
1027 int64_t ndim = grad_output.ndimension();
1028
1029 for (const auto i : c10::irange(1, ndim)) {
1030 TORCH_CHECK(
1031 grad_output.size(i) > 0,
1032 "adaptive_avg_pool3d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, "
1033 "but grad_output has sizes ",
1034 grad_output.sizes(),
1035 " with dimension ",
1036 i,
1037 " being "
1038 "empty");
1039 }
1040
1041 TORCH_CHECK(
1042 (ndim == 4 || ndim == 5),
1043 "adaptive_avg_pool3d_backward(): Expected 4D or 5D tensor, but got ",
1044 self.sizes());
1045 TORCH_CHECK(
1046 self.dtype() == grad_output.dtype(),
1047 "expected dtype ",
1048 self.dtype(),
1049 " for `grad_output` but got dtype ",
1050 grad_output.dtype());
1051
1052 return {Shape(self.scalar_type(), self.sizes().vec())};
1053 }
1054
compute_shape_glu_backward(const at::Tensor & grad_output,const at::Tensor & self,int64_t dim)1055 std::vector<Shape> compute_shape_glu_backward(
1056 const at::Tensor& grad_output,
1057 const at::Tensor& self,
1058 int64_t dim) {
1059 return {Shape(self.scalar_type(), self.sizes().vec())};
1060 }
1061
compute_shape_glu_jvp(const at::Tensor & glu,const at::Tensor & x,const at::Tensor & dx,int64_t dim)1062 std::vector<Shape> compute_shape_glu_jvp(
1063 const at::Tensor& glu,
1064 const at::Tensor& x,
1065 const at::Tensor& dx,
1066 int64_t dim) {
1067 return {Shape(glu.scalar_type(), glu.sizes().vec())};
1068 }
1069
compute_shape_clamp_min(const at::Tensor & self,const at::Scalar & min)1070 std::vector<Shape> compute_shape_clamp_min(
1071 const at::Tensor& self,
1072 const at::Scalar& min) {
1073 return {Shape(self.scalar_type(), self.sizes().vec())};
1074 }
1075
compute_shape__to_copy(const at::Tensor & self,::std::optional<at::ScalarType> dtype,::std::optional<at::Layout> layout,::std::optional<at::Device> device,::std::optional<bool> pin_memory,bool non_blocking,::std::optional<at::MemoryFormat> memory_format)1076 std::vector<Shape> compute_shape__to_copy(
1077 const at::Tensor& self,
1078 ::std::optional<at::ScalarType> dtype,
1079 ::std::optional<at::Layout> layout,
1080 ::std::optional<at::Device> device,
1081 ::std::optional<bool> pin_memory,
1082 bool non_blocking,
1083 ::std::optional<at::MemoryFormat> memory_format) {
1084 if (dtype) {
1085 return {Shape(*dtype, self.sizes().vec())};
1086 }
1087 return {Shape(self.scalar_type(), self.sizes().vec())};
1088 }
1089
compute_shape_clone(const at::Tensor & self,::std::optional<at::MemoryFormat> memory_format)1090 TORCH_API std::vector<Shape> compute_shape_clone(
1091 const at::Tensor& self,
1092 ::std::optional<at::MemoryFormat> memory_format) {
1093 return {Shape(self.scalar_type(), self.sizes().vec())};
1094 }
1095
compute_shape_stack(at::TensorList tensors,int64_t dim)1096 std::vector<Shape> compute_shape_stack(at::TensorList tensors, int64_t dim) {
1097 TORCH_CHECK(!tensors.empty(), "stack expects a non-empty TensorList");
1098 auto wrapped_dim = at::maybe_wrap_dim(dim, tensors[0].ndimension() + 1);
1099
1100 // Copied from 'check_stack_inputs' in TensorShape.cpp
1101 at::IntArrayRef entry_shape = tensors[0].sizes();
1102 for (const auto i : c10::irange(1, tensors.size())) {
1103 TORCH_CHECK(
1104 tensors[i].sizes() == entry_shape,
1105 "stack expects each tensor to be equal size, but got ",
1106 entry_shape,
1107 " at entry 0 and ",
1108 tensors[i].sizes(),
1109 " at entry ",
1110 i);
1111 }
1112
1113 auto result_sizes = tensors[0].sizes().vec();
1114 result_sizes.insert(result_sizes.begin() + wrapped_dim, tensors.size());
1115 return {Shape(tensors[0].scalar_type(), result_sizes)};
1116 }
1117
compute_shape_repeat(const at::Tensor & self,at::IntArrayRef repeats)1118 std::vector<Shape> compute_shape_repeat(
1119 const at::Tensor& self,
1120 at::IntArrayRef repeats) {
1121 TORCH_CHECK_GE(static_cast<int64_t>(repeats.size()), self.dim());
1122 size_t num_new_dimensions = repeats.size() - self.dim();
1123 std::vector<int64_t> padded_size(num_new_dimensions, 1);
1124 padded_size.insert(
1125 padded_size.end(), self.sizes().begin(), self.sizes().end());
1126 std::vector<int64_t> target_size(repeats.size());
1127 for (const auto idx : c10::irange(repeats.size())) {
1128 target_size[idx] = padded_size[idx] * repeats[idx];
1129 }
1130 return {Shape(self.scalar_type(), target_size)};
1131 }
1132
compute_shape_narrow_copy_symint(const at::Tensor & self,int64_t dim,int64_t start,c10::SymInt length)1133 std::vector<Shape> compute_shape_narrow_copy_symint(
1134 const at::Tensor& self,
1135 int64_t dim,
1136 int64_t start,
1137 c10::SymInt length) {
1138 return {Shape(self.scalar_type(), self.sizes().vec())};
1139 }
1140
compute_shape_hardswish(const at::Tensor & self)1141 std::vector<Shape> compute_shape_hardswish(const at::Tensor& self) {
1142 return {Shape(self.scalar_type(), self.sizes().vec())};
1143 }
1144
compute_shape_hardswish_backward(const at::Tensor & grad_output,const at::Tensor & self)1145 std::vector<Shape> compute_shape_hardswish_backward(
1146 const at::Tensor& grad_output,
1147 const at::Tensor& self) {
1148 return {Shape(self.scalar_type(), self.sizes().vec())};
1149 }
1150
compute_shape_selu(const at::Tensor & self)1151 std::vector<Shape> compute_shape_selu(const at::Tensor& self) {
1152 return {Shape(self.scalar_type(), self.sizes().vec())};
1153 }
1154
1155 // Non-Native Ops
compute_shape_scalar(const at::Scalar & value,const at::ScalarType & type)1156 std::vector<Shape> compute_shape_scalar(
1157 const at::Scalar& value,
1158 const at::ScalarType& type) {
1159 return {Shape(type, {})};
1160 }
compute_shape_expand(const Output & input,const std::vector<int64_t> & size,const bool & is_scalar_expand)1161 std::vector<Shape> compute_shape_expand(
1162 const Output& input,
1163 const std::vector<int64_t>& size,
1164 const bool& is_scalar_expand) {
1165 return {Shape(input.shape().scalar_type(), size)};
1166 }
compute_shape_view(const Output & input,const std::vector<int64_t> & output_sizes)1167 std::vector<Shape> compute_shape_view(
1168 const Output& input,
1169 const std::vector<int64_t>& output_sizes) {
1170 const Shape& input_shape = input.shape();
1171 const auto complete_output_sizes =
1172 at::infer_size(output_sizes, input_shape.numel());
1173 return {Shape(input_shape.scalar_type(), complete_output_sizes)};
1174 }
compute_shape_cast(const Output & input,const at::ScalarType & dtype,const::std::optional<at::ScalarType> & stype)1175 std::vector<Shape> compute_shape_cast(
1176 const Output& input,
1177 const at::ScalarType& dtype,
1178 const ::std::optional<at::ScalarType>& stype) {
1179 Shape shape = input.shape();
1180 shape.set_scalar_type(dtype);
1181 return {shape};
1182 }
1183
1184 // View Ops
compute_shape_as_strided_view_update(const Output & target,const Output & input,const std::vector<int64_t> & size,const std::vector<int64_t> & stride,const int64_t & storage_offset)1185 std::vector<Shape> compute_shape_as_strided_view_update(
1186 const Output& target,
1187 const Output& input,
1188 const std::vector<int64_t>& size,
1189 const std::vector<int64_t>& stride,
1190 const int64_t& storage_offset) {
1191 return {Shape(target.shape().scalar_type(), size)};
1192 }
compute_shape_as_strided(const Output & input,const std::vector<int64_t> & size,const std::vector<int64_t> & stride,const int64_t & storage_offset)1193 std::vector<Shape> compute_shape_as_strided(
1194 const Output& input,
1195 const std::vector<int64_t>& size,
1196 const std::vector<int64_t>& stride,
1197 const int64_t& storage_offset) {
1198 return {Shape(input.shape().scalar_type(), size)};
1199 }
compute_shape_diagonal_view_update(const Output & target,const Output & input,const int64_t & offset,const int64_t & dim1,const int64_t & dim2)1200 std::vector<Shape> compute_shape_diagonal_view_update(
1201 const Output& target,
1202 const Output& input,
1203 const int64_t& offset,
1204 const int64_t& dim1,
1205 const int64_t& dim2) {
1206 return {target.shape()};
1207 }
compute_shape_diagonal(const Output & input,const int64_t & offset,const int64_t & dim1,const int64_t & dim2)1208 std::vector<Shape> compute_shape_diagonal(
1209 const Output& input,
1210 const int64_t& offset,
1211 const int64_t& dim1,
1212 const int64_t& dim2) {
1213 return {MakeDiagonalShape(input.shape(), offset, dim1, dim2)};
1214 }
compute_shape_narrow_view_update(const Output & input,const Output & source,const std::vector<int64_t> & base_indices)1215 std::vector<Shape> compute_shape_narrow_view_update(
1216 const Output& input,
1217 const Output& source,
1218 const std::vector<int64_t>& base_indices) {
1219 return {input.shape()};
1220 }
compute_shape_narrow(const Output & input,const std::vector<int64_t> & base_indices,const std::vector<int64_t> & sizes)1221 std::vector<Shape> compute_shape_narrow(
1222 const Output& input,
1223 const std::vector<int64_t>& base_indices,
1224 const std::vector<int64_t>& sizes) {
1225 return {Shape(input.shape().scalar_type(), sizes)};
1226 }
compute_shape_permute(const Output & input,const std::vector<int64_t> & dims)1227 std::vector<Shape> compute_shape_permute(
1228 const Output& input,
1229 const std::vector<int64_t>& dims) {
1230 return {MakePermuteShape(input.shape(), dims)};
1231 }
compute_shape_resize(const Output & input,const std::vector<int64_t> & size)1232 std::vector<Shape> compute_shape_resize(
1233 const Output& input,
1234 const std::vector<int64_t>& size) {
1235 return {Shape(input.shape().scalar_type(), size)};
1236 }
compute_shape_select_view_update(const Output & target,const Output & source,const int64_t & dim,const int64_t & start,const int64_t & end,const int64_t & stride)1237 std::vector<Shape> compute_shape_select_view_update(
1238 const Output& target,
1239 const Output& source,
1240 const int64_t& dim,
1241 const int64_t& start,
1242 const int64_t& end,
1243 const int64_t& stride) {
1244 return {target.shape()};
1245 }
compute_shape_select(const Output & input,const int64_t & dim,const int64_t & start,const int64_t & end,const int64_t & stride)1246 std::vector<Shape> compute_shape_select(
1247 const Output& input,
1248 const int64_t& dim,
1249 const int64_t& start,
1250 const int64_t& end,
1251 const int64_t& stride) {
1252 return {MakeSelectShape(input.shape(), dim, start, end, stride)};
1253 }
compute_shape_squeeze(const Output & input,const int & dim)1254 std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim) {
1255 const auto& input_shape = input.shape();
1256 return {torch::lazy::Shape(
1257 input_shape.scalar_type(),
1258 BuildSqueezedDimensions(input_shape.sizes(), dim))};
1259 }
compute_shape_unsqueeze(const Output & input,const int & dim)1260 std::vector<Shape> compute_shape_unsqueeze(
1261 const Output& input,
1262 const int& dim) {
1263 const auto& input_shape = input.shape();
1264 return {torch::lazy::Shape(
1265 input_shape.scalar_type(),
1266 BuildUnsqueezedDimensions(input_shape.sizes(), dim))};
1267 }
1268
compute_shape_select_scatter(const at::Tensor & self,const at::Tensor & src,int64_t dim,int64_t index)1269 std::vector<Shape> compute_shape_select_scatter(
1270 const at::Tensor& self,
1271 const at::Tensor& src,
1272 int64_t dim,
1273 int64_t index) {
1274 auto self_meta = at::native::empty_strided_meta_symint(
1275 self.sym_sizes(),
1276 self.sym_strides(),
1277 /*dtype=*/::std::make_optional(self.scalar_type()),
1278 /*layout=*/::std::make_optional(self.layout()),
1279 /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1280 /*pin_memory=*/::std::nullopt);
1281 auto src_meta = at::native::empty_strided_meta_symint(
1282 src.sym_sizes(),
1283 src.sym_strides(),
1284 /*dtype=*/::std::make_optional(src.scalar_type()),
1285 /*layout=*/::std::make_optional(src.layout()),
1286 /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1287 /*pin_memory=*/::std::nullopt);
1288 auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter(
1289 self_meta, src_meta, dim, index);
1290 return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1291 }
1292
compute_shape_diagonal_scatter(const at::Tensor & self,const at::Tensor & src,int64_t offset,int64_t dim1,int64_t dim2)1293 std::vector<Shape> compute_shape_diagonal_scatter(
1294 const at::Tensor& self,
1295 const at::Tensor& src,
1296 int64_t offset,
1297 int64_t dim1,
1298 int64_t dim2) {
1299 auto self_meta = at::native::empty_strided_meta_symint(
1300 self.sym_sizes(),
1301 self.sym_strides(),
1302 /*dtype=*/::std::make_optional(self.scalar_type()),
1303 /*layout=*/::std::make_optional(self.layout()),
1304 /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1305 /*pin_memory=*/::std::nullopt);
1306 auto src_meta = at::native::empty_strided_meta_symint(
1307 src.sym_sizes(),
1308 src.sym_strides(),
1309 /*dtype=*/::std::make_optional(src.scalar_type()),
1310 /*layout=*/::std::make_optional(src.layout()),
1311 /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1312 /*pin_memory=*/::std::nullopt);
1313 auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter(
1314 self_meta, src_meta, offset, dim1, dim2);
1315 return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1316 }
1317
compute_shape_slice_scatter_symint(const at::Tensor & self,const at::Tensor & src,int64_t dim,::std::optional<c10::SymInt> start,::std::optional<c10::SymInt> end,c10::SymInt step)1318 std::vector<Shape> compute_shape_slice_scatter_symint(
1319 const at::Tensor& self,
1320 const at::Tensor& src,
1321 int64_t dim,
1322 ::std::optional<c10::SymInt> start,
1323 ::std::optional<c10::SymInt> end,
1324 c10::SymInt step) {
1325 auto self_meta = at::native::empty_strided_meta_symint(
1326 self.sym_sizes(),
1327 self.sym_strides(),
1328 /*dtype=*/::std::make_optional(self.scalar_type()),
1329 /*layout=*/::std::make_optional(self.layout()),
1330 /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1331 /*pin_memory=*/::std::nullopt);
1332 auto src_meta = at::native::empty_strided_meta_symint(
1333 src.sym_sizes(),
1334 src.sym_strides(),
1335 /*dtype=*/::std::make_optional(src.scalar_type()),
1336 /*layout=*/::std::make_optional(src.layout()),
1337 /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1338 /*pin_memory=*/::std::nullopt);
1339 auto out_meta =
1340 at::compositeexplicitautogradnonfunctional::slice_scatter_symint(
1341 self_meta, src_meta, dim, start, end, step);
1342 return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1343 }
1344
compute_shape_as_strided_scatter_symint(const at::Tensor & self,const at::Tensor & src,at::SymIntArrayRef size,at::SymIntArrayRef stride,::std::optional<c10::SymInt> storage_offset)1345 std::vector<Shape> compute_shape_as_strided_scatter_symint(
1346 const at::Tensor& self,
1347 const at::Tensor& src,
1348 at::SymIntArrayRef size,
1349 at::SymIntArrayRef stride,
1350 ::std::optional<c10::SymInt> storage_offset) {
1351 auto self_meta = at::native::empty_strided_meta_symint(
1352 self.sym_sizes(),
1353 self.sym_strides(),
1354 /*dtype=*/::std::make_optional(self.scalar_type()),
1355 /*layout=*/::std::make_optional(self.layout()),
1356 /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1357 /*pin_memory=*/::std::nullopt);
1358 auto src_meta = at::native::empty_strided_meta_symint(
1359 src.sym_sizes(),
1360 src.sym_strides(),
1361 /*dtype=*/::std::make_optional(src.scalar_type()),
1362 /*layout=*/::std::make_optional(src.layout()),
1363 /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1364 /*pin_memory=*/::std::nullopt);
1365 auto out_meta =
1366 at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint(
1367 self_meta, src_meta, size, stride, storage_offset);
1368 return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1369 }
1370
compute_shape_normal_functional(const at::Tensor & self,double mean,double std,::std::optional<at::Generator> generator)1371 std::vector<Shape> compute_shape_normal_functional(
1372 const at::Tensor& self,
1373 double mean,
1374 double std,
1375 ::std::optional<at::Generator> generator) {
1376 return {Shape(self.scalar_type(), self.sizes().vec())};
1377 }
1378
compute_shape_uniform(const at::Tensor & self,double from,double to,::std::optional<at::Generator> generator)1379 std::vector<Shape> compute_shape_uniform(
1380 const at::Tensor& self,
1381 double from,
1382 double to,
1383 ::std::optional<at::Generator> generator) {
1384 return {Shape(self.scalar_type(), self.sizes().vec())};
1385 }
1386
1387 // Restore unused-parameters warnings
1388 #pragma GCC diagnostic pop
1389
1390 } // namespace lazy
1391 } // namespace torch
1392