xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/shape_inference.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/core/SymInt.h>
6 #include <c10/core/SymIntArrayRef.h>
7 #include <c10/core/SymNodeImpl.h>
8 #include <c10/macros/Export.h>
9 #include <torch/csrc/lazy/backend/backend_data.h>
10 #include <torch/csrc/lazy/core/ir.h>
11 #include <torch/csrc/lazy/core/shape.h>
12 #include <torch/csrc/lazy/core/tensor.h>
13 #include <optional>
14 #include <vector>
15 
16 namespace torch {
17 namespace lazy {
18 // Turn clang-format off, as we rely on the whole signature being on one line
19 // for codegen.
20 // clang-format off
21 TORCH_API std::vector<torch::lazy::Shape> compute_shape__adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size);
22 TORCH_API std::vector<torch::lazy::Shape> compute_shape__adaptive_avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self);
23 TORCH_API std::vector<torch::lazy::Shape> compute_shape__adaptive_avg_pool3d(const at::Tensor & self, at::IntArrayRef output_size);
24 TORCH_API std::vector<torch::lazy::Shape> compute_shape__adaptive_avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self);
25 TORCH_API std::vector<torch::lazy::Shape> compute_shape_abs(const at::Tensor & self);
26 TORCH_API std::vector<torch::lazy::Shape> compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out);
27 TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, ::std::optional<at::Generator> generator);
28 TORCH_API std::vector<torch::lazy::Shape> compute_shape_bernoulli(const at::Tensor & self, double p, ::std::optional<at::Generator> generator);
29 TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight, int64_t reduction);
30 TORCH_API std::vector<torch::lazy::Shape> compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight, int64_t reduction);
31 TORCH_API std::vector<torch::lazy::Shape> compute_shape_cat(at::TensorList tensors, int64_t dim);
32 TORCH_API std::vector<torch::lazy::Shape> compute_shape_cholesky(const at::Tensor & self, bool upper);
33 TORCH_API std::vector<torch::lazy::Shape> compute_shape_clamp_min(const at::Tensor & self, const at::Scalar & min);
34 TORCH_API std::vector<torch::lazy::Shape> compute_shape_clone(const at::Tensor & self, ::std::optional<at::MemoryFormat> memory_format);
35 TORCH_API std::vector<torch::lazy::Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
36 TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
37 TORCH_API std::vector<torch::lazy::Shape> compute_shape_convolution_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array<bool,3> output_mask);
38 TORCH_API std::vector<torch::lazy::Shape> compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse);
39 TORCH_API std::vector<torch::lazy::Shape> compute_shape_embedding_dense_backward(const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq);
40 TORCH_API std::vector<torch::lazy::Shape> compute_shape_expand(const at::Tensor & self, at::IntArrayRef size, bool implicit);
41 TORCH_API std::vector<torch::lazy::Shape> compute_shape_expand(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit);
42 TORCH_API std::vector<torch::lazy::Shape> compute_shape_flip(const at::Tensor & self, at::IntArrayRef dims);
43 TORCH_API std::vector<torch::lazy::Shape> compute_shape_glu_backward(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim);
44 TORCH_API std::vector<torch::lazy::Shape> compute_shape_glu_jvp(const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim);
45 TORCH_API std::vector<torch::lazy::Shape> compute_shape_grid_sampler_2d(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners);
46 TORCH_API std::vector<torch::lazy::Shape> compute_shape_grid_sampler_2d_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array<bool,2> output_mask);
47 TORCH_API std::vector<torch::lazy::Shape> compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index);
48 TORCH_API std::vector<torch::lazy::Shape> compute_shape_inverse(const at::Tensor & self);
49 TORCH_API std::vector<torch::lazy::Shape> compute_shape_isnan(const at::Tensor & self);
50 TORCH_API std::vector<torch::lazy::Shape> compute_shape_log_sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer);
51 TORCH_API std::vector<torch::lazy::Shape> compute_shape_log_sigmoid_forward(const at::Tensor & self);
52 TORCH_API std::vector<torch::lazy::Shape> compute_shape_logdet(const at::Tensor & self);
53 TORCH_API std::vector<torch::lazy::Shape> compute_shape_logical_and(const at::Tensor & self, const at::Tensor & other);
54 TORCH_API std::vector<torch::lazy::Shape> compute_shape_logical_not(const at::Tensor & self);
55 TORCH_API std::vector<torch::lazy::Shape> compute_shape_logical_or(const at::Tensor & self, const at::Tensor & other);
56 TORCH_API std::vector<torch::lazy::Shape> compute_shape_logical_xor(const at::Tensor & self, const at::Tensor & other);
57 TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value);
58 TORCH_API std::vector<torch::lazy::Shape> compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value);
59 TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & self);
60 TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, ::std::optional<at::ScalarType> dtype);
61 TORCH_API std::vector<torch::lazy::Shape> compute_shape_min(const at::Tensor & self);
62 TORCH_API std::vector<torch::lazy::Shape> compute_shape_mv(const at::Tensor & self, const at::Tensor & vec);
63 TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, const ::std::optional<at::Tensor> & running_mean, const ::std::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
64 TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & running_mean, const ::std::optional<at::Tensor> & running_var, const ::std::optional<at::Tensor> & save_mean, const ::std::optional<at::Tensor> & save_invstd, bool train, double eps, ::std::array<bool,3> output_mask);
65 TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at::Tensor & input, double p, ::std::optional<bool> train);
66 TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale);
67 TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, double eps);
68 TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional<at::Tensor> & weight, const ::std::optional<at::Tensor> & bias, ::std::array<bool,3> output_mask);
69 TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory);
70 TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
71 TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const ::std::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
72 TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
73 TORCH_API std::vector<torch::lazy::Shape> compute_shape_normal_functional(const at::Tensor & self, double mean, double std, ::std::optional<at::Generator> generator);
74 TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, ::std::optional<at::Generator> generator);
75 TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t to, ::std::optional<at::Generator> generator);
76 TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t from, ::std::optional<int64_t> to, ::std::optional<at::Generator> generator);
77 TORCH_API std::vector<torch::lazy::Shape> compute_shape_relu(const at::Tensor & self);
78 TORCH_API std::vector<torch::lazy::Shape> compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats);
79 TORCH_API std::vector<torch::lazy::Shape> compute_shape_slogdet(const at::Tensor & self);
80 TORCH_API std::vector<torch::lazy::Shape> compute_shape_smooth_l1_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta);
81 TORCH_API std::vector<torch::lazy::Shape> compute_shape_sort(const at::Tensor & self, int64_t dim, bool descending);
82 TORCH_API std::vector<torch::lazy::Shape> compute_shape_stack(at::TensorList tensors, int64_t dim);
83 TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & self, bool unbiased);
84 TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim);
85 TORCH_API std::vector<torch::lazy::Shape> compute_shape_std(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional<at::Scalar> & correction, bool keepdim);
86 TORCH_API std::vector<torch::lazy::Shape> compute_shape_sum(const at::Tensor & self, ::std::optional<at::ScalarType> dtype);
87 TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tensor & self, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, bool non_blocking, ::std::optional<at::MemoryFormat> memory_format);
88 TORCH_API std::vector<torch::lazy::Shape> compute_shape_take(const at::Tensor & self, const at::Tensor & index);
89 TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
90 TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero(const at::Tensor & self);
91 TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy_symint(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);
92 TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish(const at::Tensor & self);
93 TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self);
94 TORCH_API std::vector<torch::lazy::Shape> compute_shape_selu(const at::Tensor & self);
95 TORCH_API std::vector<torch::lazy::Shape> compute_shape_uniform(const at::Tensor & self, double from, double to, ::std::optional<at::Generator> generator);
96 
97 // Non-Native ops
98 TORCH_API std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type);
99 TORCH_API std::vector<Shape> compute_shape_expand(const Output& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand);
100 TORCH_API std::vector<Shape> compute_shape_view(const Output& input0, const std::vector<int64_t>& output_sizes);
101 TORCH_API std::vector<Shape> compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const ::std::optional<at::ScalarType>& stype);
102 
103 // View Ops
104 // (Now that functionalization pass is used, we should kill these in a later PR)
105 TORCH_API std::vector<Shape> compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
106 TORCH_API std::vector<Shape> compute_shape_as_strided(const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
107 TORCH_API std::vector<Shape> compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
108 TORCH_API std::vector<Shape> compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
109 TORCH_API std::vector<Shape> compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector<int64_t>& base_indices);
110 TORCH_API std::vector<Shape> compute_shape_narrow(const Output& input, const std::vector<int64_t>& base_indices, const std::vector<int64_t>& sizes);
111 TORCH_API std::vector<Shape> compute_shape_permute(const Output& input, const std::vector<int64_t>& dims);
112 TORCH_API std::vector<Shape> compute_shape_resize(const Output& input, const std::vector<int64_t>& size);
113 TORCH_API std::vector<Shape> compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
114 TORCH_API std::vector<Shape> compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
115 TORCH_API std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim);
116 TORCH_API std::vector<Shape> compute_shape_unsqueeze(const Output& input, const int& dim);
117 
118 TORCH_API std::vector<torch::lazy::Shape> compute_shape_select_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index);
119 TORCH_API std::vector<torch::lazy::Shape> compute_shape_diagonal_scatter(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2);
120 TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter_symint(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional<c10::SymInt> start, ::std::optional<c10::SymInt> end, c10::SymInt step);
121 TORCH_API std::vector<torch::lazy::Shape> compute_shape_as_strided_scatter_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset);
122 // clang-format on
123 } // namespace lazy
124 } // namespace torch
125