xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/utils/op_registry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/utils/op_registry.h>
2 
3 // Location for Commonly Used Shape registries
4 
5 namespace torch {
6 namespace jit {
7 
8 // Requirements:
9 //   dims           : preserved from the first argument
10 //   scalar type    : preserved from the first argument (doesn't have to
11 //                    match other arguments)
12 //   device         : always matching and preserved
13 //   tensor inputs  : *
14 //   tensor outputs : 1
15 // NB: those ops (with slight adjustments) are good candidates for restarts.
16 //     Knowing the type and device of weights or biases is usually enough to
17 //     infer the output type.
nn_ops_first_input_preserving()18 std::shared_ptr<OperatorSet> nn_ops_first_input_preserving() {
19   std::shared_ptr<OperatorSet> ops = std::make_shared<OperatorSet>(OperatorSet{
20       "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
21       "aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
22       "aten::conv2d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
23       "aten::conv3d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
24       "aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad) -> Tensor",
25       "aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
26       "aten::conv_transpose2d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
27       "aten::conv_transpose3d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
28       "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
29       "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor", // deprecated _convolution
30       "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
31       "aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor",
32       "aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor",
33       "aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor",
34       "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
35       "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor",
36       "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor",
37       "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
38       "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
39       "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
40       "aten::max_unpool2d(Tensor self, Tensor indices, int[] output_size) -> Tensor",
41       "aten::max_unpool3d(Tensor self, Tensor indices, int[] output_size, int[] stride, int[] padding) -> Tensor",
42       "aten::reflection_pad1d(Tensor self, int[] padding) -> Tensor",
43       "aten::reflection_pad2d(Tensor self, int[] padding) -> Tensor",
44       "aten::reflection_pad3d(Tensor self, int[] padding) -> Tensor",
45       "aten::replication_pad1d(Tensor self, int[] padding) -> Tensor",
46       "aten::replication_pad2d(Tensor self, int[] padding) -> Tensor",
47       "aten::replication_pad3d(Tensor self, int[] padding) -> Tensor",
48       "aten::upsample_bilinear2d(Tensor self, int[] output_size, bool align_corners, float? scales_h, float? scales_w) -> Tensor",
49       "aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners, float? scales) -> Tensor",
50       "aten::upsample_nearest1d(Tensor self, int[] output_size, float? scales) -> Tensor",
51       "aten::upsample_nearest2d(Tensor self, int[] output_size, float? scales_h, float? scales_w) -> Tensor",
52       "aten::upsample_nearest3d(Tensor self, int[] output_size, float? scales_d, float? scales_h, float? scales_w) -> Tensor",
53       "aten::upsample_trilinear3d(Tensor self, int[] output_size, bool align_corners, float? scales_d, float? scales_h, float? scales_w) -> Tensor",
54       "aten::prelu(Tensor self, Tensor weight) -> Tensor",
55 
56       // Added because Hardswish is really hard to convert to metatensors
57       "aten::hardswish(Tensor self) -> Tensor",
58       "aten::hardswish_(Tensor self) -> Tensor",
59   });
60   return ops;
61 };
62 
63 // Requirements:
64 //   dims           : Changed from first argument
65 //   scalar type    : preserved from the first argument
66 //   device         : always matching and preserved
67 //   tensor inputs  : 1
68 //   tensor outputs : 1
ops_one_tensor_in_shape_transform()69 std::shared_ptr<OperatorSet> ops_one_tensor_in_shape_transform() {
70   std::shared_ptr<OperatorSet> ops = std::make_shared<OperatorSet>(OperatorSet{
71       "aten::flatten(Tensor self, int start_dim, int end_dim) -> Tensor",
72   });
73   return ops;
74 };
75 } // namespace jit
76 } // namespace torch
77