xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/shape_inference.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * This is a handwritten file that accompanies codegenerated header
3  * LazyShapeDtype.h
4  *
5  * The purpose of these shape/dtype inference methods are to fill gaps
6  * where we do not yet have structured kernels in pytorch core.  Ops
7  * for which there _are_ structured kernels can use meta::op() to infer
8  * shape/dtype, and codegen makes use of this.  Ops for which there are not
9  * yet structured kernels can still be used with lazy_tensor codegen, but
10  * require manual intervention to implement compute_shape_{op} and
11  * compute_dtype_{op}.
12  *
13  * READ THIS!
14  *
15  * 1. Beware: Tech Debt!
16  * ---------------------
17  * These functions are tech debt.  We want to delete them all and use structured
18  * kernels instead, but it's a lot faster to write these so we're decoupling the
19  * two efforts to move fast for adding support for codegenned Lazy Tensor ops.
20  *
21  * Codegenned Lazy Tensor ops with handwritten shape formulae are still better
22  * than fully handwritten Lazy Tensor ops (which also have handwritten shape
23  * formulae).
24  *
25  * 2. Structured Kernels For The Win
26  * ---------------------------------
27  * Long term, more and more ops should be supported as 'structured kernels'.
28  * Consider doing your part and porting an op.  As ops get ported over, the
29  * codegen will automatically notice and stop generating declarations for these
30  * shape formulae, so we'll need to manually clean up the unused functions in
31  * this file, or somehow automate that.
32  *
33  * https://dev-discuss.pytorch.org/t/slides-from-structured-kernel-presentation/179
34  *
35  * 3. How to figure out the shape/dtype
36  * ------------------------------------
37  * Unfortunately there isn't a one-stop-shop for learning the output shape
38  * formulae for all operators.  This is partly because some operators are not
39  * part of our 'public' API, including backward operators which users don't
40  * directly invoke.
41  *
42  * Check our opinfo registry:
43  *  https://github.com/pytorch/pytorch/blob/13b859983183ea9938deb5030ac9a0747841f0a8/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
44  *
45  * Read the manual (for ops that are 1:1 with python frontend):
46  *  https://pytorch.org/docs/stable/generated/torch.trace.html
47  *
48  */
49 
50 #include <torch/csrc/lazy/core/shape_inference.h>
51 
52 #include <ATen/AccumulateType.h>
53 #include <ATen/CompositeExplicitAutogradFunctions.h>
54 #include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
55 #include <ATen/Dispatch.h>
56 #include <ATen/ExpandUtils.h>
57 #include <ATen/Functions.h>
58 #include <ATen/InferSize.h>
59 #include <ATen/NativeFunctions.h>
60 #include <ATen/WrapDimUtils.h>
61 #include <ATen/native/ConvUtils.h>
62 #include <ATen/native/ReduceOpsUtils.h>
63 #include <ATen/native/TensorConversions.h>
64 #include <c10/core/ScalarType.h>
65 #include <torch/csrc/lazy/core/dynamic_ir.h>
66 #include <torch/csrc/lazy/core/ops/utils.h>
67 #include <torch/csrc/lazy/core/shape.h>
68 #include <torch/csrc/lazy/core/util.h>
69 #include <ostream>
70 #include <vector>
71 
72 namespace torch {
73 namespace lazy {
74 
75 // Copied from ATen/native/utils/ParamUtils.h, which aparently I can't include
76 // from here?
expand_param_if_needed(at::IntArrayRef list_param,const char * param_name,int64_t expected_dim)77 static std::vector<int64_t> expand_param_if_needed(
78     at::IntArrayRef list_param,
79     const char* param_name,
80     int64_t expected_dim) {
81   if (list_param.size() == 1) {
82     return std::vector<int64_t>(expected_dim, list_param[0]);
83   } else if ((int64_t)list_param.size() != expected_dim) {
84     std::ostringstream ss;
85     ss << "expected " << param_name << " to be a single integer value or a "
86        << "list of " << expected_dim << " values to match the convolution "
87        << "dimensions, but got " << param_name << "=" << list_param;
88     AT_ERROR(ss.str());
89   } else {
90     return list_param.vec();
91   }
92 }
93 
94 // It seems more common to not use parameters than to use them, so disable
95 // unused-parameter warning
96 #pragma GCC diagnostic push
97 #pragma GCC diagnostic ignored "-Wunused-parameter"
98 
compute_shape_arange_out(const at::Scalar & start,const at::Scalar & end,const at::Scalar & step,at::Tensor & out)99 TORCH_API std::vector<Shape> compute_shape_arange_out(
100     const at::Scalar& start,
101     const at::Scalar& end,
102     const at::Scalar& step,
103     at::Tensor& out) {
104   double size_d = 0;
105   // shape inference code copied from RangeFactories.cpp arange_out function
106   // Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct
107   // c++ scalar_t type depending on out tensor
108 
109   AT_DISPATCH_ALL_TYPES_AND(
110       c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out", [&]() {
111         // Note: acc_type further defines an accumulataion type depending on the
112         // scalar_t and whether its on cuda vs cpu.
113         using accscalar_t = at::acc_type<scalar_t, false>;
114         auto xstart = start.to<accscalar_t>();
115         auto xend = end.to<accscalar_t>();
116         auto xstep = step.to<accscalar_t>();
117 
118         // we use double precision for (start - end) / step
119         // to compute size_d for consistency across devices.
120         // The problem with using accscalar_t is that accscalar_t might be
121         // float32 on gpu for a float32 scalar_t, but double on cpu for the
122         // same, and the effective output size starts differing on CPU vs GPU
123         // because of precision issues, which we dont want. the corner-case we
124         // do want to take into account is int64_t, which has higher precision
125         // than double NOLINTNEXTLINE(bugprone-branch-clone)
126         if constexpr (std::is_same_v<scalar_t, int64_t>) {
127           size_d = std::ceil(
128               static_cast<double>(
129                   end.to<accscalar_t>() - start.to<accscalar_t>()) /
130               step.to<accscalar_t>());
131         } else {
132           size_d = std::ceil(
133               static_cast<double>(end.to<double>() - start.to<double>()) /
134               step.to<double>());
135         }
136 
137         TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
138         TORCH_CHECK(
139             std::isfinite(static_cast<double>(xstart)) &&
140                 std::isfinite(static_cast<double>(xend)),
141             "unsupported range: ",
142             xstart,
143             " -> ",
144             xend);
145         TORCH_CHECK(
146             ((xstep > 0) && (xend >= xstart)) ||
147                 ((xstep < 0) && (xend <= xstart)),
148             "upper bound and larger bound inconsistent with step sign");
149 
150         TORCH_CHECK(
151             size_d >= 0 &&
152                 size_d <=
153                     static_cast<double>(std::numeric_limits<int64_t>::max()),
154             "invalid size, possible overflow?");
155       });
156 
157   int64_t size = static_cast<int64_t>(size_d);
158 
159   // From torch.arange docs:
160   // dtype (torch.dtype, optional) – the desired data type of returned tensor.
161   // Default: if None, uses a global default (see
162   // torch.set_default_dtype()). If dtype is not given, infer the data
163   // type from the other input arguments. If any of start, end, or stop are
164   // floating-point, the dtype is inferred to be the default dtype, see
165   // get_default_dtype(). Otherwise, the dtype is inferred to be torch.int64.
166 
167   return {Shape(out.scalar_type(), {size})};
168 }
169 
compute_shape_abs(const at::Tensor & self)170 std::vector<Shape> compute_shape_abs(const at::Tensor& self) {
171   if (self.is_complex()) {
172     const auto float_type = c10::toRealValueType(self.scalar_type());
173     return {Shape(float_type, self.sizes().vec())};
174   }
175   return {Shape(self.scalar_type(), self.sizes().vec())};
176 }
177 
compute_shape_bernoulli(const at::Tensor & self,::std::optional<at::Generator> generator)178 std::vector<Shape> compute_shape_bernoulli(
179     const at::Tensor& self,
180     ::std::optional<at::Generator> generator) {
181   return {Shape(self.scalar_type(), self.sizes().vec())};
182 }
183 
compute_shape_bernoulli(const at::Tensor & self,double p,::std::optional<at::Generator> generator)184 std::vector<Shape> compute_shape_bernoulli(
185     const at::Tensor& self,
186     double p,
187     ::std::optional<at::Generator> generator) {
188   return compute_shape_bernoulli(self, generator);
189 }
190 
compute_shape_binary_cross_entropy(const at::Tensor & self,const at::Tensor & target,const::std::optional<at::Tensor> & weight,int64_t reduction)191 std::vector<Shape> compute_shape_binary_cross_entropy(
192     const at::Tensor& self,
193     const at::Tensor& target,
194     const ::std::optional<at::Tensor>& weight,
195     int64_t reduction) {
196   if (reduction == at::Reduction::None) {
197     return {Shape(self.scalar_type(), self.sizes().vec())};
198   }
199   return {Shape(self.scalar_type(), {})};
200 }
201 
compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output,const at::Tensor & self,const at::Tensor & target,const::std::optional<at::Tensor> & weight,int64_t reduction)202 std::vector<Shape> compute_shape_binary_cross_entropy_backward(
203     const at::Tensor& grad_output,
204     const at::Tensor& self,
205     const at::Tensor& target,
206     const ::std::optional<at::Tensor>& weight,
207     int64_t reduction) {
208   return {Shape(self.scalar_type(), self.sizes().vec())};
209 }
210 
compute_shape_constant_pad_nd(const at::Tensor & self,at::IntArrayRef pad,const at::Scalar & value)211 std::vector<Shape> compute_shape_constant_pad_nd(
212     const at::Tensor& self,
213     at::IntArrayRef pad,
214     const at::Scalar& value) {
215   // Based on aten/src/ATen/native/ConstantPadNd.cpp::constant_pad_nd
216   TORCH_CHECK(
217       pad.size() % 2 == 0,
218       "Length of pad must be even but instead it equals ",
219       pad.size());
220 
221   auto input_sizes = self.sizes();
222   auto l_inp = self.dim();
223 
224   auto l_pad = pad.size() / 2;
225   auto l_diff = l_inp - l_pad;
226   TORCH_CHECK(
227       l_inp >= (int64_t)l_pad,
228       "Length of pad should be no more than twice the number of "
229       "dimensions of the input. Pad length is ",
230       pad.size(),
231       "while the input has ",
232       l_inp,
233       "dimensions.");
234 
235   std::vector<int64_t> new_shape;
236   for (size_t i = 0; i < (size_t)l_diff; i++) {
237     new_shape.emplace_back(input_sizes[i]);
238   }
239 
240   for (const auto i : c10::irange((size_t)l_pad)) {
241     auto pad_idx = pad.size() - ((i + 1) * 2);
242     auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1];
243     TORCH_CHECK(
244         new_dim > 0,
245         "The input size ",
246         input_sizes[l_diff + i],
247         ", plus negative padding ",
248         pad[pad_idx],
249         " and ",
250         pad[pad_idx + 1],
251         " resulted in a negative output size, "
252         "which is invalid. Check dimension ",
253         l_diff + i,
254         " of your input.");
255     new_shape.emplace_back(new_dim);
256   }
257   return {Shape(self.scalar_type(), new_shape)};
258 }
259 
compute_shape_convolution_backward(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & weight,at::OptionalIntArrayRef bias_sizes,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool transposed,at::IntArrayRef output_padding,int64_t groups,::std::array<bool,3> output_mask)260 std::vector<Shape> compute_shape_convolution_backward(
261     const at::Tensor& grad_output,
262     const at::Tensor& input,
263     const at::Tensor& weight,
264     at::OptionalIntArrayRef bias_sizes,
265     at::IntArrayRef stride,
266     at::IntArrayRef padding,
267     at::IntArrayRef dilation,
268     bool transposed,
269     at::IntArrayRef output_padding,
270     int64_t groups,
271     ::std::array<bool, 3> output_mask) {
272   if (bias_sizes.has_value()) {
273     return {
274         Shape(input.scalar_type(), input.sizes().vec()),
275         Shape(weight.scalar_type(), weight.sizes().vec()),
276         Shape(grad_output.scalar_type(), bias_sizes.value().vec())};
277   } else {
278     // TODO(whc) not sure whether to return 2 shapes here, or a 3rd one that is
279     // empty
280     return {
281         Shape(input.scalar_type(), input.sizes().vec()),
282         Shape(weight.scalar_type(), weight.sizes().vec())};
283   }
284 }
285 
compute_shape_convolution(const at::Tensor & input,const at::Tensor & weight,const::std::optional<at::Tensor> & bias,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool transposed,at::IntArrayRef output_padding,int64_t groups)286 std::vector<Shape> compute_shape_convolution(
287     const at::Tensor& input,
288     const at::Tensor& weight,
289     const ::std::optional<at::Tensor>& bias,
290     at::IntArrayRef stride,
291     at::IntArrayRef padding,
292     at::IntArrayRef dilation,
293     bool transposed,
294     at::IntArrayRef output_padding,
295     int64_t groups) {
296   int64_t dim = weight.ndimension() - 2;
297   TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
298 
299   // at::convolution performs parameter expansion before running kernels on
300   // expanded parameters we must do the same.  Shape formulae access differnent
301   // dimensions of e.g. output_padding, but output_padding may be passed in as a
302   // scalar.  Sadly, accessing output_padding[1] in this case gives incorrect
303   // results rather than indexing error
304   auto expanded_stride = expand_param_if_needed(stride, "stride", dim);
305   auto expanded_padding = expand_param_if_needed(padding, "padding", dim);
306   auto expanded_dilation = expand_param_if_needed(dilation, "dilation", dim);
307   if (!transposed) {
308     return {Shape(
309         input.scalar_type(),
310         at::native::conv_output_size(
311             input.sizes(),
312             weight.sizes(),
313             expanded_padding,
314             expanded_stride,
315             expanded_dilation))};
316   } else {
317     auto expanded_output_padding =
318         expand_param_if_needed(output_padding, "output_padding", dim);
319     auto out_shape = at::native::conv_input_size(
320         input.sizes(),
321         weight.sizes(),
322         expanded_padding,
323         expanded_output_padding,
324         expanded_stride,
325         expanded_dilation,
326         groups);
327     return {Shape(input.scalar_type(), out_shape)};
328   }
329 }
330 
compute_shape_masked_fill(const at::Tensor & self,const at::Tensor & mask,const at::Scalar & value)331 std::vector<Shape> compute_shape_masked_fill(
332     const at::Tensor& self,
333     const at::Tensor& mask,
334     const at::Scalar& value) {
335   return {Shape(self.scalar_type(), self.sizes().vec())};
336 }
337 
compute_shape_masked_fill(const at::Tensor & self,const at::Tensor & mask,const at::Tensor & value)338 std::vector<Shape> compute_shape_masked_fill(
339     const at::Tensor& self,
340     const at::Tensor& mask,
341     const at::Tensor& value) {
342   return {Shape(self.scalar_type(), self.sizes().vec())};
343 }
344 
compute_shape_max(const at::Tensor & self)345 std::vector<Shape> compute_shape_max(const at::Tensor& self) {
346   TORCH_CHECK(
347       self.numel() > 0,
348       "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.");
349   return {Shape(self.scalar_type(), {})};
350 }
351 
compute_shape_min(const at::Tensor & self)352 std::vector<Shape> compute_shape_min(const at::Tensor& self) {
353   TORCH_CHECK(
354       self.numel() > 0,
355       "min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.");
356   return {Shape(self.scalar_type(), {})};
357 }
358 
compute_shape_nonzero(const at::Tensor & t,bool as_tuple)359 static std::vector<Shape> compute_shape_nonzero(
360     const at::Tensor& t,
361     bool as_tuple) {
362   if (as_tuple) {
363     auto res = std::vector<Shape>();
364     for (auto dim_size : t.sizes()) {
365       res.emplace_back(Shape(at::kLong, {dim_size}));
366     }
367     return res;
368   }
369   int64_t max_elements = 1;
370   for (auto dim_size : t.sizes()) {
371     max_elements *= dim_size;
372   }
373   return {Shape(at::kLong, {max_elements, (int64_t)t.sizes().size()})};
374 }
375 
compute_shape_nonzero(const at::Tensor & self)376 std::vector<Shape> compute_shape_nonzero(const at::Tensor& self) {
377   return compute_shape_nonzero(self, false);
378 }
379 
compute_shape_embedding(const at::Tensor & weight,const at::Tensor & indices,int64_t padding_idx,bool scale_grad_by_freq,bool sparse)380 std::vector<Shape> compute_shape_embedding(
381     const at::Tensor& weight,
382     const at::Tensor& indices,
383     int64_t padding_idx,
384     bool scale_grad_by_freq,
385     bool sparse) {
386   // Based on aten/src/ATen/native/Embedding.cpp::embedding.
387   std::vector<int64_t> out_sizes = indices.sizes().vec();
388   out_sizes.emplace_back(weight.size(1));
389   return {Shape(weight.scalar_type(), out_sizes)};
390 }
391 
compute_shape_std(const at::Tensor & self,bool unbiased)392 std::vector<Shape> compute_shape_std(const at::Tensor& self, bool unbiased) {
393   return compute_shape_std(self, ::std::nullopt, ::std::nullopt, false);
394 }
compute_shape_std(const at::Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)395 std::vector<Shape> compute_shape_std(
396     const at::Tensor& self,
397     at::OptionalIntArrayRef dim,
398     bool unbiased,
399     bool keepdim) {
400   return compute_shape_std(self, dim, ::std::nullopt, keepdim);
401 }
compute_shape_std(const at::Tensor & self,at::OptionalIntArrayRef dim,const::std::optional<at::Scalar> & correction,bool keepdim)402 std::vector<Shape> compute_shape_std(
403     const at::Tensor& self,
404     at::OptionalIntArrayRef dim,
405     const ::std::optional<at::Scalar>& correction,
406     bool keepdim) {
407   if (dim.has_value()) {
408     auto shape = at::native::shape_from_dim_mask(
409         self, at::native::make_dim_mask(dim.value(), self.dim()), keepdim);
410     return {Shape(
411         self.scalar_type(), std::vector<int64_t>(shape.begin(), shape.end()))};
412   }
413   return {Shape(self.scalar_type(), {})};
414 }
415 
compute_shape_embedding_dense_backward(const at::Tensor & grad_output,const at::Tensor & indices,int64_t num_weights,int64_t padding_idx,bool scale_grad_by_freq)416 std::vector<Shape> compute_shape_embedding_dense_backward(
417     const at::Tensor& grad_output,
418     const at::Tensor& indices,
419     int64_t num_weights,
420     int64_t padding_idx,
421     bool scale_grad_by_freq) {
422   // Based on aten/src/ATen/native/Embedding.cpp::embedding_dense_backward_cpu.
423   return {
424       Shape(grad_output.scalar_type(), {num_weights, grad_output.size(-1)})};
425 }
426 
compute_shape_expand(const at::Tensor & self,at::IntArrayRef size,bool implicit)427 std::vector<Shape> compute_shape_expand(
428     const at::Tensor& self,
429     at::IntArrayRef size,
430     bool implicit) {
431   TORCH_CHECK_GE(static_cast<int64_t>(size.size()), self.dim());
432   size_t num_new_dimensions = size.size() - self.dim();
433   std::vector<int64_t> padded_self(num_new_dimensions, 0);
434   padded_self.insert(
435       padded_self.end(), self.sizes().begin(), self.sizes().end());
436   std::vector<int64_t> target_size(size.size());
437   for (const auto idx : c10::irange(size.size())) {
438     target_size[idx] = size[idx] == -1 ? padded_self[idx] : size[idx];
439   }
440   return {Shape(self.scalar_type(), target_size)};
441 }
442 
compute_shape_expand(const at::Tensor & self,c10::SymIntArrayRef size,bool implicit)443 std::vector<Shape> compute_shape_expand(
444     const at::Tensor& self,
445     c10::SymIntArrayRef size,
446     bool implicit) {
447   TORCH_CHECK_GE(static_cast<int64_t>(size.size()), self.dim());
448   std::vector<c10::SymInt> _sizes = ToVector<c10::SymInt>(size);
449   size_t num_new_dimensions = _sizes.size() - self.dim();
450   std::vector<int64_t> padded_self(num_new_dimensions, 0);
451   padded_self.insert(
452       padded_self.end(), self.sizes().begin(), self.sizes().end());
453   std::vector<int64_t> target_size(_sizes.size());
454   for (const auto idx : c10::irange(_sizes.size())) {
455     if (auto ma = _sizes[idx].maybe_as_int()) {
456       target_size[idx] = *ma;
457       if (*ma == -1) {
458         // -1 can't be specified for non-existing dimensions
459         TORCH_CHECK(idx >= num_new_dimensions);
460         target_size[idx] = padded_self[idx];
461       } else {
462         target_size[idx] = *ma;
463       }
464     } else {
465       auto* lazySymNode = dynamic_cast<torch::lazy::SymNodeImpl*>(
466           _sizes[idx].toSymNodeImplUnowned());
467       TORCH_INTERNAL_ASSERT(lazySymNode);
468       auto size_node = lazySymNode->node_;
469       auto static_value =
470           std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
471               ->getStaticValue();
472       target_size[idx] = static_value;
473     }
474   }
475   return {Shape(self.scalar_type(), target_size)};
476 }
477 
compute_shape_index_select(const at::Tensor & self,int64_t dim,const at::Tensor & index)478 std::vector<Shape> compute_shape_index_select(
479     const at::Tensor& self,
480     int64_t dim,
481     const at::Tensor& index) {
482   // Based on definition of
483   // https://pytorch.org/docs/stable/generated/torch.index_select.html. Promote
484   // Rank 0 index tensor to a 1 * 1 tensor.
485   dim = at::maybe_wrap_dim(dim, self);
486   auto index_dim = index.dim() > 0 ? index.dim() : 1;
487   auto index_size = index.dim() > 0 ? index.size(0) : 1;
488   TORCH_CHECK(index_dim == 1);
489 
490   auto self_sizes = self.sizes();
491   std::vector<int64_t> output_sizes(self_sizes.begin(), self_sizes.end());
492   TORCH_CHECK(!output_sizes.empty(), "Empty output_sizes is not supported.");
493   output_sizes[dim] = index_size;
494 
495   return {Shape(self.scalar_type(), output_sizes)};
496 }
497 
compute_shape_inverse(const at::Tensor & self)498 std::vector<Shape> compute_shape_inverse(const at::Tensor& self) {
499   return {Shape(self.scalar_type(), self.sizes().vec())};
500 }
501 
compute_shape_isnan(const at::Tensor & self)502 std::vector<Shape> compute_shape_isnan(const at::Tensor& self) {
503   return {Shape(c10::ScalarType::Bool, self.sizes().vec())};
504 }
505 
compute_shape_cat(at::TensorList tensors,int64_t dim)506 std::vector<Shape> compute_shape_cat(at::TensorList tensors, int64_t dim) {
507   // TODO(whc) support cat in codegen and move this to compute_*_cat functions
508   std::vector<int64_t> out_shape(
509       tensors[0].sizes().begin(), tensors[0].sizes().end());
510 
511   dim = at::maybe_wrap_dim(dim, tensors);
512   size_t extended_dim_shape = 0;
513   for (auto& tensor : tensors) {
514     extended_dim_shape += tensor.sizes()[dim];
515   }
516   TORCH_CHECK(!out_shape.empty(), "Scalar tensors are not supported in cat.");
517   TORCH_CHECK(
518       extended_dim_shape <=
519           static_cast<size_t>(std::numeric_limits<int64_t>::max()),
520       "Size overflow");
521   out_shape[dim] = extended_dim_shape;
522   return {Shape(tensors[0].scalar_type(), out_shape)};
523 }
524 
compute_shape_cholesky(const at::Tensor & self,bool upper)525 TORCH_API std::vector<torch::lazy::Shape> compute_shape_cholesky(
526     const at::Tensor& self,
527     bool upper) {
528   return {Shape(self.scalar_type(), self.sizes().vec())};
529 }
530 
compute_shape_native_batch_norm(const at::Tensor & input,const::std::optional<at::Tensor> & weight,const::std::optional<at::Tensor> & bias,const::std::optional<at::Tensor> & running_mean,const::std::optional<at::Tensor> & running_var,bool training,double momentum,double eps)531 std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
532     const at::Tensor& input,
533     const ::std::optional<at::Tensor>& weight,
534     const ::std::optional<at::Tensor>& bias,
535     const ::std::optional<at::Tensor>& running_mean,
536     const ::std::optional<at::Tensor>& running_var,
537     bool training,
538     double momentum,
539     double eps) {
540   std::vector<torch::lazy::Shape> shapes;
541   shapes.reserve(3);
542   shapes.emplace_back(input.scalar_type(), input.sizes().vec());
543 
544   // A separate mean and var needs to be kept for each channel.
545   TORCH_CHECK(
546       input.sizes().size() >= 2,
547       "Input tensor must have at least batch and channel dimensions!");
548   int64_t num_features = input.size(1);
549 
550   if (running_mean.has_value()) {
551     shapes.emplace_back(
552         running_mean.value().scalar_type(), running_mean.value().sizes().vec());
553   } else {
554     shapes.emplace_back(
555         at::get_default_dtype_as_scalartype(),
556         std::vector<int64_t>{num_features});
557   }
558 
559   if (running_var.has_value()) {
560     shapes.emplace_back(
561         running_var.value().scalar_type(), running_var.value().sizes().vec());
562   } else {
563     shapes.emplace_back(
564         at::get_default_dtype_as_scalartype(),
565         std::vector<int64_t>{num_features});
566   }
567   return shapes;
568 }
569 
compute_shape_native_batch_norm_backward(const at::Tensor & grad_out,const at::Tensor & input,const::std::optional<at::Tensor> & weight,const::std::optional<at::Tensor> & running_mean,const::std::optional<at::Tensor> & running_var,const::std::optional<at::Tensor> & save_mean,const::std::optional<at::Tensor> & save_invstd,bool train,double eps,::std::array<bool,3> output_mask)570 std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
571     const at::Tensor& grad_out,
572     const at::Tensor& input,
573     const ::std::optional<at::Tensor>& weight,
574     const ::std::optional<at::Tensor>& running_mean,
575     const ::std::optional<at::Tensor>& running_var,
576     const ::std::optional<at::Tensor>& save_mean,
577     const ::std::optional<at::Tensor>& save_invstd,
578     bool train,
579     double eps,
580     ::std::array<bool, 3> output_mask) {
581   std::vector<torch::lazy::Shape> shapes;
582   shapes.reserve(3);
583   shapes.emplace_back(input.scalar_type(), input.sizes().vec());
584 
585   // A separate mean and var needs to be kept for each channel.
586   TORCH_CHECK(
587       input.sizes().size() >= 2,
588       "Input tensor must have at least batch and channel dimensions!");
589   int64_t num_features = input.size(1);
590 
591   // `weight` and `bias` are vectors of length C (number of channels)`
592   shapes.emplace_back(
593       at::get_default_dtype_as_scalartype(),
594       std::vector<int64_t>{num_features});
595   shapes.emplace_back(
596       at::get_default_dtype_as_scalartype(),
597       std::vector<int64_t>{num_features});
598 
599   return shapes;
600 }
601 
compute_shape_native_layer_norm(const at::Tensor & input,at::IntArrayRef normalized_shape,const::std::optional<at::Tensor> & weight,const::std::optional<at::Tensor> & bias,double eps)602 std::vector<Shape> compute_shape_native_layer_norm(
603     const at::Tensor& input,
604     at::IntArrayRef normalized_shape,
605     const ::std::optional<at::Tensor>& weight,
606     const ::std::optional<at::Tensor>& bias,
607     double eps) {
608   // Copied from aten/src/ATen/native/layer_norm.cpp::layer_norm_cpu_out.
609   auto input_shape = input.sizes().vec();
610   const size_t axis = input.dim() - normalized_shape.size();
611 
612   std::vector<int64_t> stat_shape;
613   for (const auto idx : c10::irange(axis)) {
614     TORCH_CHECK(idx < input_shape.size(), "Shape mismatch");
615     stat_shape.emplace_back(input_shape[idx]);
616   }
617   for (const auto idx : c10::irange(axis, input.dim())) {
618     (void)idx; // Suppress unused variable warning
619     stat_shape.emplace_back(1);
620   }
621 
622   return {
623       Shape(input.scalar_type(), input_shape),
624       Shape(input.scalar_type(), stat_shape),
625       Shape(input.scalar_type(), stat_shape)};
626 }
627 
compute_shape_native_layer_norm_backward(const at::Tensor & grad_out,const at::Tensor & input,at::IntArrayRef normalized_shape,const at::Tensor & mean,const at::Tensor & rstd,const::std::optional<at::Tensor> & weight,const::std::optional<at::Tensor> & bias,::std::array<bool,3> output_mask)628 std::vector<Shape> compute_shape_native_layer_norm_backward(
629     const at::Tensor& grad_out,
630     const at::Tensor& input,
631     at::IntArrayRef normalized_shape,
632     const at::Tensor& mean,
633     const at::Tensor& rstd,
634     const ::std::optional<at::Tensor>& weight,
635     const ::std::optional<at::Tensor>& bias,
636     ::std::array<bool, 3> output_mask) {
637   std::vector<Shape> shapes;
638   shapes.emplace_back(
639       input.scalar_type(),
640       output_mask[0] ? input.sizes().vec() : std::vector<int64_t>{});
641   shapes.emplace_back(
642       weight && weight->defined() ? weight->scalar_type() : input.scalar_type(),
643       output_mask[1] && weight ? weight->sizes().vec()
644                                : std::vector<int64_t>{});
645   shapes.emplace_back(
646       bias && bias->defined() ? bias->scalar_type() : input.scalar_type(),
647       output_mask[2] && bias ? bias->sizes().vec() : std::vector<int64_t>{});
648   return shapes;
649 }
650 
compute_shape_mean(const at::Tensor & self,::std::optional<at::ScalarType> dtype)651 std::vector<Shape> compute_shape_mean(
652     const at::Tensor& self,
653     ::std::optional<at::ScalarType> dtype) {
654   if (dtype.has_value()) {
655     return {Shape(dtype.value(), {})};
656   }
657   return {Shape(self.scalar_type(), {})};
658 }
659 
compute_shape_new_empty_strided(const at::Tensor & self,at::IntArrayRef size,at::IntArrayRef stride,::std::optional<at::ScalarType> dtype,::std::optional<at::Layout> layout,::std::optional<at::Device> device,::std::optional<bool> pin_memory)660 std::vector<Shape> compute_shape_new_empty_strided(
661     const at::Tensor& self,
662     at::IntArrayRef size,
663     at::IntArrayRef stride,
664     ::std::optional<at::ScalarType> dtype,
665     ::std::optional<at::Layout> layout,
666     ::std::optional<at::Device> device,
667     ::std::optional<bool> pin_memory) {
668   return {Shape(dtype.has_value() ? *dtype : self.scalar_type(), size.vec())};
669 }
670 
compute_shape_mv(const at::Tensor & self,const at::Tensor & vec)671 std::vector<Shape> compute_shape_mv(
672     const at::Tensor& self,
673     const at::Tensor& vec) {
674   return {Shape(self.scalar_type(), {self.size(0)})};
675 }
676 
compute_shape_native_dropout(const at::Tensor & input,double p,::std::optional<bool> train)677 std::vector<Shape> compute_shape_native_dropout(
678     const at::Tensor& input,
679     double p,
680     ::std::optional<bool> train) {
681   return {
682       Shape(input.scalar_type(), input.sizes().vec()),
683       Shape(c10::ScalarType::Bool, input.sizes().vec())};
684 }
685 
compute_shape_native_dropout_backward(const at::Tensor & grad_output,const at::Tensor & mask,double scale)686 std::vector<Shape> compute_shape_native_dropout_backward(
687     const at::Tensor& grad_output,
688     const at::Tensor& mask,
689     double scale) {
690   return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
691 }
692 
compute_shape_random(const at::Tensor & self,::std::optional<at::Generator> generator)693 std::vector<Shape> compute_shape_random(
694     const at::Tensor& self,
695     ::std::optional<at::Generator> generator) {
696   return {Shape(self.scalar_type(), self.sizes().vec())};
697 }
698 
compute_shape_random(const at::Tensor & self,int64_t to,::std::optional<at::Generator> generator)699 std::vector<Shape> compute_shape_random(
700     const at::Tensor& self,
701     int64_t to,
702     ::std::optional<at::Generator> generator) {
703   return compute_shape_random(self, generator);
704 }
705 
compute_shape_random(const at::Tensor & self,int64_t from,::std::optional<int64_t> to,::std::optional<at::Generator> generator)706 std::vector<Shape> compute_shape_random(
707     const at::Tensor& self,
708     int64_t from,
709     ::std::optional<int64_t> to,
710     ::std::optional<at::Generator> generator) {
711   return compute_shape_random(self, generator);
712 }
713 
compute_shape_relu(const at::Tensor & self)714 std::vector<Shape> compute_shape_relu(const at::Tensor& self) {
715   return {Shape(self.scalar_type(), self.sizes().vec())};
716 }
717 
compute_shape_sum(const at::Tensor & self,::std::optional<at::ScalarType> dtype)718 std::vector<Shape> compute_shape_sum(
719     const at::Tensor& self,
720     ::std::optional<at::ScalarType> dtype) {
721   if (dtype.has_value()) {
722     return {Shape(dtype.value(), {})};
723   }
724   // It's undocumented, but torch::sum promotes all integral types to int64_t by
725   // default
726   if (isIntegralType(self.scalar_type(), /*includeBool*/ true)) {
727     return {Shape(c10::ScalarType::Long, {})};
728   }
729   return {Shape(self.scalar_type(), {})};
730   ;
731 }
732 
compute_shape_zero(const at::Tensor & self)733 std::vector<Shape> compute_shape_zero(const at::Tensor& self) {
734   return {Shape(self.scalar_type(), self.sizes().vec())};
735 }
736 
compute_shape_take(const at::Tensor & self,const at::Tensor & index)737 TORCH_API std::vector<torch::lazy::Shape> compute_shape_take(
738     const at::Tensor& self,
739     const at::Tensor& index) {
740   return {Shape(self.scalar_type(), index.sizes().vec())};
741 }
742 
compute_shape_trace(const at::Tensor & self)743 std::vector<Shape> compute_shape_trace(const at::Tensor& self) {
744   return {Shape(self.scalar_type(), {})};
745 }
746 
compute_shape_sort(const at::Tensor & self,int64_t dim,bool descending)747 std::vector<Shape> compute_shape_sort(
748     const at::Tensor& self,
749     int64_t dim,
750     bool descending) {
751   return {
752       Shape(self.scalar_type(), self.sizes().vec()),
753       Shape(c10::ScalarType::Long, self.sizes().vec())};
754 }
755 
compute_shape_slogdet(const at::Tensor & self)756 std::vector<Shape> compute_shape_slogdet(const at::Tensor& self) {
757   // assumes self.shape is {*, n, n} and returns shape *
758   TORCH_INTERNAL_ASSERT(self.dim() >= 2);
759   std::vector<int64_t> out_sizes(self.sizes().begin(), self.sizes().end() - 2);
760   // Doesn't check input dtype, but output dtype either matches it,
761   // or the actual slogdet operation will throw if it's an unsupported type.
762   // Sign and det outputs hold the same shape, dtype.
763   return {
764       Shape(self.scalar_type(), out_sizes),
765       Shape(self.scalar_type(), out_sizes)};
766 }
767 
compute_shape_logical_and(const at::Tensor & self,const at::Tensor & other)768 std::vector<torch::lazy::Shape> compute_shape_logical_and(
769     const at::Tensor& self,
770     const at::Tensor& other) {
771   TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes()));
772   return {Shape(
773       c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))};
774 }
775 
compute_shape_logical_not(const at::Tensor & self)776 std::vector<torch::lazy::Shape> compute_shape_logical_not(
777     const at::Tensor& self) {
778   return {Shape(c10::ScalarType::Bool, self.sizes().vec())};
779 }
780 
compute_shape_logical_or(const at::Tensor & self,const at::Tensor & other)781 std::vector<torch::lazy::Shape> compute_shape_logical_or(
782     const at::Tensor& self,
783     const at::Tensor& other) {
784   TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes()));
785   return {Shape(
786       c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))};
787 }
788 
compute_shape_logical_xor(const at::Tensor & self,const at::Tensor & other)789 std::vector<torch::lazy::Shape> compute_shape_logical_xor(
790     const at::Tensor& self,
791     const at::Tensor& other) {
792   TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes()));
793   return {Shape(
794       c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))};
795 }
796 
compute_shape_smooth_l1_loss_backward(const at::Tensor & grad_output,const at::Tensor & self,const at::Tensor & target,int64_t reduction,double beta)797 std::vector<Shape> compute_shape_smooth_l1_loss_backward(
798     const at::Tensor& grad_output,
799     const at::Tensor& self,
800     const at::Tensor& target,
801     int64_t reduction,
802     double beta) {
803   // The `grad_output` tensor is really the input to this kernel, and while its
804   // shape may vary following the logic of the forward output, the output of
805   // this kernel should have fixed shapes matching the inputs to the forward
806   // kernel.
807   return {Shape(self.scalar_type(), self.sizes().vec())};
808 }
809 
compute_shape_logdet(const at::Tensor & self)810 std::vector<Shape> compute_shape_logdet(const at::Tensor& self) {
811   // assumes self.shape is {*, n, n} and returns shape *
812   TORCH_INTERNAL_ASSERT(self.dim() >= 2);
813   std::vector<int64_t> out_sizes(self.sizes().begin(), self.sizes().end() - 2);
814   // Doesn't check input dtype, but output dtype either matches it,
815   // or the actual logdet operation will throw if it's an unsupported type
816   return {Shape(self.scalar_type(), out_sizes)};
817 }
818 
compute_shape_log_sigmoid_forward(const at::Tensor & self)819 std::vector<Shape> compute_shape_log_sigmoid_forward(const at::Tensor& self) {
820   // Based on definition of
821   // aten/src/ATen/native/Activation.cpp::log_sigmoid_forward_out_cpu.
822   return {
823       Shape(self.scalar_type(), self.sizes().vec()),
824       Shape(self.scalar_type(), self.sizes().vec())};
825 }
826 
compute_shape_log_sigmoid_backward(const at::Tensor & grad_output,const at::Tensor & self,const at::Tensor & buffer)827 std::vector<Shape> compute_shape_log_sigmoid_backward(
828     const at::Tensor& grad_output,
829     const at::Tensor& self,
830     const at::Tensor& buffer) {
831   // Based on definition of
832   // aten/src/ATen/native/Activation.cpp::log_sigmoid_backward_cpu*.
833   return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
834 }
835 
compute_shape_nll_loss2d_forward(const at::Tensor & self,const at::Tensor & target,const::std::optional<at::Tensor> & weight,int64_t reduction,int64_t ignore_index)836 std::vector<Shape> compute_shape_nll_loss2d_forward(
837     const at::Tensor& self,
838     const at::Tensor& target,
839     const ::std::optional<at::Tensor>& weight,
840     int64_t reduction,
841     int64_t ignore_index) {
842   // Based on definition of
843   // aten/src/ATen/native/LossNLL2d.cpp:nll_loss2d_forward_cpu
844   auto sizes =
845       (reduction == at::Reduction::Reduction::None ? target.sizes().vec()
846                                                    : std::vector<int64_t>{});
847   return {Shape(self.scalar_type(), sizes), Shape(self.scalar_type(), {})};
848 }
849 
compute_shape_nll_loss2d_backward(const at::Tensor & grad_output,const at::Tensor & self,const at::Tensor & target,const::std::optional<at::Tensor> & weight,int64_t reduction,int64_t ignore_index,const at::Tensor & total_weight)850 std::vector<Shape> compute_shape_nll_loss2d_backward(
851     const at::Tensor& grad_output,
852     const at::Tensor& self,
853     const at::Tensor& target,
854     const ::std::optional<at::Tensor>& weight,
855     int64_t reduction,
856     int64_t ignore_index,
857     const at::Tensor& total_weight) {
858   return {Shape(self.scalar_type(), self.sizes().vec())};
859 }
860 
compute_shape_grid_sampler_2d(const at::Tensor & input,const at::Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)861 std::vector<Shape> compute_shape_grid_sampler_2d(
862     const at::Tensor& input,
863     const at::Tensor& grid,
864     int64_t interpolation_mode,
865     int64_t padding_mode,
866     bool align_corners) {
867   // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp
868   int64_t N = input.size(0);
869   int64_t C = input.size(1);
870   int64_t H = grid.size(1);
871   int64_t W = grid.size(2);
872   return {Shape(input.scalar_type(), {N, C, H, W})};
873 }
874 
compute_shape_grid_sampler_2d_backward(const at::Tensor & grad_output,const at::Tensor & input,const at::Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,::std::array<bool,2> output_mask)875 std::vector<Shape> compute_shape_grid_sampler_2d_backward(
876     const at::Tensor& grad_output,
877     const at::Tensor& input,
878     const at::Tensor& grid,
879     int64_t interpolation_mode,
880     int64_t padding_mode,
881     bool align_corners,
882     ::std::array<bool, 2> output_mask) {
883   // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp
884   auto grad_input_shape = Shape(input.scalar_type(), input.sizes().vec());
885   auto grad_grid_shape = Shape(grid.scalar_type(), grid.sizes().vec());
886   return {grad_input_shape, grad_grid_shape};
887 }
888 
compute_shape_flip(const at::Tensor & self,at::IntArrayRef dims)889 std::vector<Shape> compute_shape_flip(
890     const at::Tensor& self,
891     at::IntArrayRef dims) {
892   return {Shape(self.scalar_type(), self.sizes().vec())};
893 }
894 
compute_shape__adaptive_avg_pool2d(const at::Tensor & self,at::IntArrayRef output_size)895 std::vector<Shape> compute_shape__adaptive_avg_pool2d(
896     const at::Tensor& self,
897     at::IntArrayRef output_size) {
898   // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
899   // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp`
900   TORCH_CHECK(
901       output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2");
902   TORCH_CHECK(
903       (output_size[0] >= 0 && output_size[1] >= 0),
904       "adaptive_avg_pool2d: elements of output_size must be greater than or equal to 0 ",
905       "but received {",
906       output_size[0],
907       ", ",
908       output_size[1],
909       "}");
910   int64_t ndim = self.ndimension();
911   for (const auto i : c10::irange(1, ndim)) {
912     TORCH_CHECK(
913         self.size(i) > 0,
914         "adaptive_avg_pool2d(): Expected self to have non-zero size for non-batch dimensions, "
915         "but Tensor has sizes ",
916         self.sizes(),
917         " with dimension ",
918         i,
919         " being "
920         "empty");
921   }
922   TORCH_CHECK(
923       (ndim == 3 || ndim == 4),
924       "adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got ",
925       self.sizes());
926 
927   int64_t channels = self.size(-3);
928   int64_t output_height = output_size[0];
929   int64_t output_width = output_size[1];
930 
931   if (ndim == 3) {
932     return {Shape(self.scalar_type(), {channels, output_height, output_width})};
933   } else {
934     int64_t nbatch = self.size(0);
935     return {Shape(
936         self.scalar_type(), {nbatch, channels, output_height, output_width})};
937   }
938 }
939 
compute_shape__adaptive_avg_pool2d_backward(const at::Tensor & grad_output,const at::Tensor & self)940 std::vector<Shape> compute_shape__adaptive_avg_pool2d_backward(
941     const at::Tensor& grad_output,
942     const at::Tensor& self) {
943   // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
944   int64_t ndim = grad_output.ndimension();
945 
946   for (const auto i : c10::irange(1, ndim)) {
947     TORCH_CHECK(
948         grad_output.size(i) > 0,
949         "adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, "
950         "but grad_output has sizes ",
951         grad_output.sizes(),
952         " with dimension ",
953         i,
954         " being "
955         "empty");
956   }
957 
958   TORCH_CHECK(
959       (ndim == 3 || ndim == 4),
960       "adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got ",
961       self.sizes());
962   TORCH_CHECK(
963       self.dtype() == grad_output.dtype(),
964       "expected dtype ",
965       self.dtype(),
966       " for `grad_output` but got dtype ",
967       grad_output.dtype());
968 
969   return {Shape(self.scalar_type(), self.sizes().vec())};
970 }
971 
compute_shape__adaptive_avg_pool3d(const at::Tensor & self,at::IntArrayRef output_size)972 std::vector<Shape> compute_shape__adaptive_avg_pool3d(
973     const at::Tensor& self,
974     at::IntArrayRef output_size) {
975   // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
976   // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp`
977   TORCH_CHECK(
978       output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3");
979   TORCH_CHECK(
980       (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0),
981       "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ",
982       "but received {",
983       output_size[0],
984       ", ",
985       output_size[1],
986       ", ",
987       output_size[2],
988       "}");
989   int64_t ndim = self.ndimension();
990   for (const auto i : c10::irange(1, ndim)) {
991     TORCH_CHECK(
992         self.size(i) > 0,
993         "adaptive_avg_pool3d(): Expected self to have non-zero size for non-batch dimensions, "
994         "but Tensor has sizes ",
995         self.sizes(),
996         " with dimension ",
997         i,
998         " being "
999         "empty");
1000   }
1001   TORCH_CHECK(
1002       (ndim == 4 || ndim == 5),
1003       "adaptive_avg_pool3d(): Expected 4D or 5D tensor, but got ",
1004       self.sizes());
1005 
1006   int64_t channels = self.size(-4);
1007   int64_t output_depth = output_size[0];
1008   int64_t output_height = output_size[1];
1009   int64_t output_width = output_size[2];
1010 
1011   if (ndim == 4) {
1012     return {Shape(
1013         self.scalar_type(),
1014         {channels, output_depth, output_height, output_width})};
1015   } else {
1016     int64_t nbatch = self.size(0);
1017     return {Shape(
1018         self.scalar_type(),
1019         {nbatch, channels, output_depth, output_height, output_width})};
1020   }
1021 }
1022 
compute_shape__adaptive_avg_pool3d_backward(const at::Tensor & grad_output,const at::Tensor & self)1023 std::vector<Shape> compute_shape__adaptive_avg_pool3d_backward(
1024     const at::Tensor& grad_output,
1025     const at::Tensor& self) {
1026   // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
1027   int64_t ndim = grad_output.ndimension();
1028 
1029   for (const auto i : c10::irange(1, ndim)) {
1030     TORCH_CHECK(
1031         grad_output.size(i) > 0,
1032         "adaptive_avg_pool3d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, "
1033         "but grad_output has sizes ",
1034         grad_output.sizes(),
1035         " with dimension ",
1036         i,
1037         " being "
1038         "empty");
1039   }
1040 
1041   TORCH_CHECK(
1042       (ndim == 4 || ndim == 5),
1043       "adaptive_avg_pool3d_backward(): Expected 4D or 5D tensor, but got ",
1044       self.sizes());
1045   TORCH_CHECK(
1046       self.dtype() == grad_output.dtype(),
1047       "expected dtype ",
1048       self.dtype(),
1049       " for `grad_output` but got dtype ",
1050       grad_output.dtype());
1051 
1052   return {Shape(self.scalar_type(), self.sizes().vec())};
1053 }
1054 
compute_shape_glu_backward(const at::Tensor & grad_output,const at::Tensor & self,int64_t dim)1055 std::vector<Shape> compute_shape_glu_backward(
1056     const at::Tensor& grad_output,
1057     const at::Tensor& self,
1058     int64_t dim) {
1059   return {Shape(self.scalar_type(), self.sizes().vec())};
1060 }
1061 
compute_shape_glu_jvp(const at::Tensor & glu,const at::Tensor & x,const at::Tensor & dx,int64_t dim)1062 std::vector<Shape> compute_shape_glu_jvp(
1063     const at::Tensor& glu,
1064     const at::Tensor& x,
1065     const at::Tensor& dx,
1066     int64_t dim) {
1067   return {Shape(glu.scalar_type(), glu.sizes().vec())};
1068 }
1069 
compute_shape_clamp_min(const at::Tensor & self,const at::Scalar & min)1070 std::vector<Shape> compute_shape_clamp_min(
1071     const at::Tensor& self,
1072     const at::Scalar& min) {
1073   return {Shape(self.scalar_type(), self.sizes().vec())};
1074 }
1075 
compute_shape__to_copy(const at::Tensor & self,::std::optional<at::ScalarType> dtype,::std::optional<at::Layout> layout,::std::optional<at::Device> device,::std::optional<bool> pin_memory,bool non_blocking,::std::optional<at::MemoryFormat> memory_format)1076 std::vector<Shape> compute_shape__to_copy(
1077     const at::Tensor& self,
1078     ::std::optional<at::ScalarType> dtype,
1079     ::std::optional<at::Layout> layout,
1080     ::std::optional<at::Device> device,
1081     ::std::optional<bool> pin_memory,
1082     bool non_blocking,
1083     ::std::optional<at::MemoryFormat> memory_format) {
1084   if (dtype) {
1085     return {Shape(*dtype, self.sizes().vec())};
1086   }
1087   return {Shape(self.scalar_type(), self.sizes().vec())};
1088 }
1089 
compute_shape_clone(const at::Tensor & self,::std::optional<at::MemoryFormat> memory_format)1090 TORCH_API std::vector<Shape> compute_shape_clone(
1091     const at::Tensor& self,
1092     ::std::optional<at::MemoryFormat> memory_format) {
1093   return {Shape(self.scalar_type(), self.sizes().vec())};
1094 }
1095 
compute_shape_stack(at::TensorList tensors,int64_t dim)1096 std::vector<Shape> compute_shape_stack(at::TensorList tensors, int64_t dim) {
1097   TORCH_CHECK(!tensors.empty(), "stack expects a non-empty TensorList");
1098   auto wrapped_dim = at::maybe_wrap_dim(dim, tensors[0].ndimension() + 1);
1099 
1100   // Copied from 'check_stack_inputs' in TensorShape.cpp
1101   at::IntArrayRef entry_shape = tensors[0].sizes();
1102   for (const auto i : c10::irange(1, tensors.size())) {
1103     TORCH_CHECK(
1104         tensors[i].sizes() == entry_shape,
1105         "stack expects each tensor to be equal size, but got ",
1106         entry_shape,
1107         " at entry 0 and ",
1108         tensors[i].sizes(),
1109         " at entry ",
1110         i);
1111   }
1112 
1113   auto result_sizes = tensors[0].sizes().vec();
1114   result_sizes.insert(result_sizes.begin() + wrapped_dim, tensors.size());
1115   return {Shape(tensors[0].scalar_type(), result_sizes)};
1116 }
1117 
compute_shape_repeat(const at::Tensor & self,at::IntArrayRef repeats)1118 std::vector<Shape> compute_shape_repeat(
1119     const at::Tensor& self,
1120     at::IntArrayRef repeats) {
1121   TORCH_CHECK_GE(static_cast<int64_t>(repeats.size()), self.dim());
1122   size_t num_new_dimensions = repeats.size() - self.dim();
1123   std::vector<int64_t> padded_size(num_new_dimensions, 1);
1124   padded_size.insert(
1125       padded_size.end(), self.sizes().begin(), self.sizes().end());
1126   std::vector<int64_t> target_size(repeats.size());
1127   for (const auto idx : c10::irange(repeats.size())) {
1128     target_size[idx] = padded_size[idx] * repeats[idx];
1129   }
1130   return {Shape(self.scalar_type(), target_size)};
1131 }
1132 
compute_shape_narrow_copy_symint(const at::Tensor & self,int64_t dim,int64_t start,c10::SymInt length)1133 std::vector<Shape> compute_shape_narrow_copy_symint(
1134     const at::Tensor& self,
1135     int64_t dim,
1136     int64_t start,
1137     c10::SymInt length) {
1138   return {Shape(self.scalar_type(), self.sizes().vec())};
1139 }
1140 
compute_shape_hardswish(const at::Tensor & self)1141 std::vector<Shape> compute_shape_hardswish(const at::Tensor& self) {
1142   return {Shape(self.scalar_type(), self.sizes().vec())};
1143 }
1144 
compute_shape_hardswish_backward(const at::Tensor & grad_output,const at::Tensor & self)1145 std::vector<Shape> compute_shape_hardswish_backward(
1146     const at::Tensor& grad_output,
1147     const at::Tensor& self) {
1148   return {Shape(self.scalar_type(), self.sizes().vec())};
1149 }
1150 
compute_shape_selu(const at::Tensor & self)1151 std::vector<Shape> compute_shape_selu(const at::Tensor& self) {
1152   return {Shape(self.scalar_type(), self.sizes().vec())};
1153 }
1154 
1155 // Non-Native Ops
compute_shape_scalar(const at::Scalar & value,const at::ScalarType & type)1156 std::vector<Shape> compute_shape_scalar(
1157     const at::Scalar& value,
1158     const at::ScalarType& type) {
1159   return {Shape(type, {})};
1160 }
compute_shape_expand(const Output & input,const std::vector<int64_t> & size,const bool & is_scalar_expand)1161 std::vector<Shape> compute_shape_expand(
1162     const Output& input,
1163     const std::vector<int64_t>& size,
1164     const bool& is_scalar_expand) {
1165   return {Shape(input.shape().scalar_type(), size)};
1166 }
compute_shape_view(const Output & input,const std::vector<int64_t> & output_sizes)1167 std::vector<Shape> compute_shape_view(
1168     const Output& input,
1169     const std::vector<int64_t>& output_sizes) {
1170   const Shape& input_shape = input.shape();
1171   const auto complete_output_sizes =
1172       at::infer_size(output_sizes, input_shape.numel());
1173   return {Shape(input_shape.scalar_type(), complete_output_sizes)};
1174 }
compute_shape_cast(const Output & input,const at::ScalarType & dtype,const::std::optional<at::ScalarType> & stype)1175 std::vector<Shape> compute_shape_cast(
1176     const Output& input,
1177     const at::ScalarType& dtype,
1178     const ::std::optional<at::ScalarType>& stype) {
1179   Shape shape = input.shape();
1180   shape.set_scalar_type(dtype);
1181   return {shape};
1182 }
1183 
1184 // View Ops
compute_shape_as_strided_view_update(const Output & target,const Output & input,const std::vector<int64_t> & size,const std::vector<int64_t> & stride,const int64_t & storage_offset)1185 std::vector<Shape> compute_shape_as_strided_view_update(
1186     const Output& target,
1187     const Output& input,
1188     const std::vector<int64_t>& size,
1189     const std::vector<int64_t>& stride,
1190     const int64_t& storage_offset) {
1191   return {Shape(target.shape().scalar_type(), size)};
1192 }
compute_shape_as_strided(const Output & input,const std::vector<int64_t> & size,const std::vector<int64_t> & stride,const int64_t & storage_offset)1193 std::vector<Shape> compute_shape_as_strided(
1194     const Output& input,
1195     const std::vector<int64_t>& size,
1196     const std::vector<int64_t>& stride,
1197     const int64_t& storage_offset) {
1198   return {Shape(input.shape().scalar_type(), size)};
1199 }
compute_shape_diagonal_view_update(const Output & target,const Output & input,const int64_t & offset,const int64_t & dim1,const int64_t & dim2)1200 std::vector<Shape> compute_shape_diagonal_view_update(
1201     const Output& target,
1202     const Output& input,
1203     const int64_t& offset,
1204     const int64_t& dim1,
1205     const int64_t& dim2) {
1206   return {target.shape()};
1207 }
compute_shape_diagonal(const Output & input,const int64_t & offset,const int64_t & dim1,const int64_t & dim2)1208 std::vector<Shape> compute_shape_diagonal(
1209     const Output& input,
1210     const int64_t& offset,
1211     const int64_t& dim1,
1212     const int64_t& dim2) {
1213   return {MakeDiagonalShape(input.shape(), offset, dim1, dim2)};
1214 }
compute_shape_narrow_view_update(const Output & input,const Output & source,const std::vector<int64_t> & base_indices)1215 std::vector<Shape> compute_shape_narrow_view_update(
1216     const Output& input,
1217     const Output& source,
1218     const std::vector<int64_t>& base_indices) {
1219   return {input.shape()};
1220 }
compute_shape_narrow(const Output & input,const std::vector<int64_t> & base_indices,const std::vector<int64_t> & sizes)1221 std::vector<Shape> compute_shape_narrow(
1222     const Output& input,
1223     const std::vector<int64_t>& base_indices,
1224     const std::vector<int64_t>& sizes) {
1225   return {Shape(input.shape().scalar_type(), sizes)};
1226 }
compute_shape_permute(const Output & input,const std::vector<int64_t> & dims)1227 std::vector<Shape> compute_shape_permute(
1228     const Output& input,
1229     const std::vector<int64_t>& dims) {
1230   return {MakePermuteShape(input.shape(), dims)};
1231 }
compute_shape_resize(const Output & input,const std::vector<int64_t> & size)1232 std::vector<Shape> compute_shape_resize(
1233     const Output& input,
1234     const std::vector<int64_t>& size) {
1235   return {Shape(input.shape().scalar_type(), size)};
1236 }
compute_shape_select_view_update(const Output & target,const Output & source,const int64_t & dim,const int64_t & start,const int64_t & end,const int64_t & stride)1237 std::vector<Shape> compute_shape_select_view_update(
1238     const Output& target,
1239     const Output& source,
1240     const int64_t& dim,
1241     const int64_t& start,
1242     const int64_t& end,
1243     const int64_t& stride) {
1244   return {target.shape()};
1245 }
compute_shape_select(const Output & input,const int64_t & dim,const int64_t & start,const int64_t & end,const int64_t & stride)1246 std::vector<Shape> compute_shape_select(
1247     const Output& input,
1248     const int64_t& dim,
1249     const int64_t& start,
1250     const int64_t& end,
1251     const int64_t& stride) {
1252   return {MakeSelectShape(input.shape(), dim, start, end, stride)};
1253 }
compute_shape_squeeze(const Output & input,const int & dim)1254 std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim) {
1255   const auto& input_shape = input.shape();
1256   return {torch::lazy::Shape(
1257       input_shape.scalar_type(),
1258       BuildSqueezedDimensions(input_shape.sizes(), dim))};
1259 }
compute_shape_unsqueeze(const Output & input,const int & dim)1260 std::vector<Shape> compute_shape_unsqueeze(
1261     const Output& input,
1262     const int& dim) {
1263   const auto& input_shape = input.shape();
1264   return {torch::lazy::Shape(
1265       input_shape.scalar_type(),
1266       BuildUnsqueezedDimensions(input_shape.sizes(), dim))};
1267 }
1268 
compute_shape_select_scatter(const at::Tensor & self,const at::Tensor & src,int64_t dim,int64_t index)1269 std::vector<Shape> compute_shape_select_scatter(
1270     const at::Tensor& self,
1271     const at::Tensor& src,
1272     int64_t dim,
1273     int64_t index) {
1274   auto self_meta = at::native::empty_strided_meta_symint(
1275       self.sym_sizes(),
1276       self.sym_strides(),
1277       /*dtype=*/::std::make_optional(self.scalar_type()),
1278       /*layout=*/::std::make_optional(self.layout()),
1279       /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1280       /*pin_memory=*/::std::nullopt);
1281   auto src_meta = at::native::empty_strided_meta_symint(
1282       src.sym_sizes(),
1283       src.sym_strides(),
1284       /*dtype=*/::std::make_optional(src.scalar_type()),
1285       /*layout=*/::std::make_optional(src.layout()),
1286       /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1287       /*pin_memory=*/::std::nullopt);
1288   auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter(
1289       self_meta, src_meta, dim, index);
1290   return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1291 }
1292 
compute_shape_diagonal_scatter(const at::Tensor & self,const at::Tensor & src,int64_t offset,int64_t dim1,int64_t dim2)1293 std::vector<Shape> compute_shape_diagonal_scatter(
1294     const at::Tensor& self,
1295     const at::Tensor& src,
1296     int64_t offset,
1297     int64_t dim1,
1298     int64_t dim2) {
1299   auto self_meta = at::native::empty_strided_meta_symint(
1300       self.sym_sizes(),
1301       self.sym_strides(),
1302       /*dtype=*/::std::make_optional(self.scalar_type()),
1303       /*layout=*/::std::make_optional(self.layout()),
1304       /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1305       /*pin_memory=*/::std::nullopt);
1306   auto src_meta = at::native::empty_strided_meta_symint(
1307       src.sym_sizes(),
1308       src.sym_strides(),
1309       /*dtype=*/::std::make_optional(src.scalar_type()),
1310       /*layout=*/::std::make_optional(src.layout()),
1311       /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1312       /*pin_memory=*/::std::nullopt);
1313   auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter(
1314       self_meta, src_meta, offset, dim1, dim2);
1315   return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1316 }
1317 
compute_shape_slice_scatter_symint(const at::Tensor & self,const at::Tensor & src,int64_t dim,::std::optional<c10::SymInt> start,::std::optional<c10::SymInt> end,c10::SymInt step)1318 std::vector<Shape> compute_shape_slice_scatter_symint(
1319     const at::Tensor& self,
1320     const at::Tensor& src,
1321     int64_t dim,
1322     ::std::optional<c10::SymInt> start,
1323     ::std::optional<c10::SymInt> end,
1324     c10::SymInt step) {
1325   auto self_meta = at::native::empty_strided_meta_symint(
1326       self.sym_sizes(),
1327       self.sym_strides(),
1328       /*dtype=*/::std::make_optional(self.scalar_type()),
1329       /*layout=*/::std::make_optional(self.layout()),
1330       /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1331       /*pin_memory=*/::std::nullopt);
1332   auto src_meta = at::native::empty_strided_meta_symint(
1333       src.sym_sizes(),
1334       src.sym_strides(),
1335       /*dtype=*/::std::make_optional(src.scalar_type()),
1336       /*layout=*/::std::make_optional(src.layout()),
1337       /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1338       /*pin_memory=*/::std::nullopt);
1339   auto out_meta =
1340       at::compositeexplicitautogradnonfunctional::slice_scatter_symint(
1341           self_meta, src_meta, dim, start, end, step);
1342   return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1343 }
1344 
compute_shape_as_strided_scatter_symint(const at::Tensor & self,const at::Tensor & src,at::SymIntArrayRef size,at::SymIntArrayRef stride,::std::optional<c10::SymInt> storage_offset)1345 std::vector<Shape> compute_shape_as_strided_scatter_symint(
1346     const at::Tensor& self,
1347     const at::Tensor& src,
1348     at::SymIntArrayRef size,
1349     at::SymIntArrayRef stride,
1350     ::std::optional<c10::SymInt> storage_offset) {
1351   auto self_meta = at::native::empty_strided_meta_symint(
1352       self.sym_sizes(),
1353       self.sym_strides(),
1354       /*dtype=*/::std::make_optional(self.scalar_type()),
1355       /*layout=*/::std::make_optional(self.layout()),
1356       /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1357       /*pin_memory=*/::std::nullopt);
1358   auto src_meta = at::native::empty_strided_meta_symint(
1359       src.sym_sizes(),
1360       src.sym_strides(),
1361       /*dtype=*/::std::make_optional(src.scalar_type()),
1362       /*layout=*/::std::make_optional(src.layout()),
1363       /*device=*/::std::make_optional(c10::Device(c10::kMeta)),
1364       /*pin_memory=*/::std::nullopt);
1365   auto out_meta =
1366       at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint(
1367           self_meta, src_meta, size, stride, storage_offset);
1368   return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1369 }
1370 
compute_shape_normal_functional(const at::Tensor & self,double mean,double std,::std::optional<at::Generator> generator)1371 std::vector<Shape> compute_shape_normal_functional(
1372     const at::Tensor& self,
1373     double mean,
1374     double std,
1375     ::std::optional<at::Generator> generator) {
1376   return {Shape(self.scalar_type(), self.sizes().vec())};
1377 }
1378 
compute_shape_uniform(const at::Tensor & self,double from,double to,::std::optional<at::Generator> generator)1379 std::vector<Shape> compute_shape_uniform(
1380     const at::Tensor& self,
1381     double from,
1382     double to,
1383     ::std::optional<at::Generator> generator) {
1384   return {Shape(self.scalar_type(), self.sizes().vec())};
1385 }
1386 
1387 // Restore unused-parameters warnings
1388 #pragma GCC diagnostic pop
1389 
1390 } // namespace lazy
1391 } // namespace torch
1392