1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/enum.h> 6 #include <torch/expanding_array.h> 7 #include <torch/types.h> 8 9 namespace torch { 10 namespace nn { 11 12 /// Options for a `D`-dimensional ReflectionPad module. 13 template <size_t D> 14 struct TORCH_API ReflectionPadOptions { ReflectionPadOptionsReflectionPadOptions15 ReflectionPadOptions(ExpandingArray<D * 2> padding) : padding_(padding) {} 16 17 /// The size of the padding. 18 /// If it is `int`, uses the same padding in all boundaries. 19 /// If it is a 2-`tuple` (for ReflectionPad1d), uses (padding_left, 20 /// padding_right). If it is a 4-`tuple` (for ReflectionPad2d), uses 21 /// (padding_left, padding_right, padding_top, padding_bottom). If it is a 22 /// 6-`tuple` (for ReflectionPad3d), uses (padding_left, padding_right, 23 /// padding_top, padding_bottom, padding_front, padding_back). 24 25 TORCH_ARG(ExpandingArray<D * 2>, padding); 26 }; 27 28 /// `ReflectionPadOptions` specialized for the `ReflectionPad1d` module. 29 /// 30 /// Example: 31 /// ``` 32 /// ReflectionPad1d model(ReflectionPad1dOptions({3, 1})); 33 /// ``` 34 using ReflectionPad1dOptions = ReflectionPadOptions<1>; 35 36 /// `ReflectionPadOptions` specialized for the `ReflectionPad2d` module. 37 /// 38 /// Example: 39 /// ``` 40 /// ReflectionPad2d model(ReflectionPad2dOptions({1, 1, 2, 0})); 41 /// ``` 42 using ReflectionPad2dOptions = ReflectionPadOptions<2>; 43 44 /// `ReflectionPadOptions` specialized for the `ReflectionPad3d` module. 45 /// 46 /// Example: 47 /// ``` 48 /// ReflectionPad3d model(ReflectionPad3dOptions({1, 1, 2, 0, 1, 1})); 49 /// ``` 50 using ReflectionPad3dOptions = ReflectionPadOptions<3>; 51 52 // ============================================================================ 53 54 /// Options for a `D`-dimensional ReplicationPad module. 55 template <size_t D> 56 struct TORCH_API ReplicationPadOptions { ReplicationPadOptionsReplicationPadOptions57 ReplicationPadOptions(ExpandingArray<D * 2> padding) : padding_(padding) {} 58 59 /// The size of the padding. 60 /// - If it is `int`, uses the same padding in all boundaries. 61 /// - If it is a 2-`tuple` (for ReplicationPad1d), uses (padding_left, 62 /// padding_right). 63 /// - If it is a 4-`tuple` (for ReplicationPad2d), uses (padding_left, 64 /// padding_right, padding_top, padding_bottom). 65 /// - If it is a 6-`tuple` (for ReplicationPad3d), uses 66 /// (padding_left, padding_right, padding_top, padding_bottom, 67 /// padding_front, padding_back). 68 TORCH_ARG(ExpandingArray<D * 2>, padding); 69 }; 70 71 /// `ReplicationPadOptions` specialized for the `ReplicationPad1d` module. 72 /// 73 /// Example: 74 /// ``` 75 /// ReplicationPad1d model(ReplicationPad1dOptions({3, 1})); 76 /// ``` 77 using ReplicationPad1dOptions = ReplicationPadOptions<1>; 78 79 /// `ReplicationPadOptions` specialized for the `ReplicationPad2d` module. 80 /// 81 /// Example: 82 /// ``` 83 /// ReplicationPad2d model(ReplicationPad2dOptions({1, 1, 2, 0})); 84 /// ``` 85 using ReplicationPad2dOptions = ReplicationPadOptions<2>; 86 87 /// `ReplicationPadOptions` specialized for the `ReplicationPad3d` module. 88 /// 89 /// Example: 90 /// ``` 91 /// ReplicationPad3d model(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2})); 92 /// ``` 93 using ReplicationPad3dOptions = ReplicationPadOptions<3>; 94 95 // ============================================================================ 96 97 template <size_t D> 98 struct TORCH_API ZeroPadOptions { ZeroPadOptionsZeroPadOptions99 ZeroPadOptions(ExpandingArray<D * 2> padding) : padding_(padding) {} 100 101 /// The size of the padding. 102 /// - If it is `int`, uses the same padding in all boundaries. 103 /// - If it is a 2-`tuple` (for ZeroPad1d), uses (padding_left, 104 /// padding_right). 105 /// - If it is a 4-`tuple` (for ZeroPad2d), uses (padding_left, padding_right, 106 /// padding_top, padding_bottom). 107 /// - If it is a 6-`tuple` (for ZeroPad3d), uses 108 /// (padding_left, padding_right, padding_top, padding_bottom, 109 /// padding_front, padding_back). 110 TORCH_ARG(ExpandingArray<D * 2>, padding); 111 }; 112 113 /// `ZeroPadOptions` specialized for the `ZeroPad1d` module. 114 /// 115 /// Example: 116 /// ``` 117 /// ConstantPad1d model(ConstantPad1dOptions({3, 1}); 118 /// ``` 119 using ZeroPad1dOptions = ZeroPadOptions<1>; 120 121 /// `ZeroPadOptions` specialized for the `ZeroPad2d` module. 122 /// 123 /// Example: 124 /// ``` 125 /// ConstantPad2d model(ConstantPad2dOptions({1, 1, 2, 0}); 126 /// ``` 127 using ZeroPad2dOptions = ZeroPadOptions<2>; 128 129 /// `ZeroPadOptions` specialized for the `ZeroPad3d` module. 130 /// 131 /// Example: 132 /// ``` 133 /// ConstantPad3d model(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}); 134 /// ``` 135 using ZeroPad3dOptions = ZeroPadOptions<3>; 136 137 // ============================================================================ 138 139 /// Options for a `D`-dimensional ConstantPad module. 140 template <size_t D> 141 struct TORCH_API ConstantPadOptions { ConstantPadOptionsConstantPadOptions142 ConstantPadOptions(ExpandingArray<D * 2> padding, double value) 143 : padding_(padding), value_(value) {} 144 145 /// The size of the padding. 146 /// - If it is `int`, uses the same padding in all boundaries. 147 /// - If it is a 2-`tuple` (for ConstantPad1d), uses (padding_left, 148 /// padding_right). 149 /// - If it is a 4-`tuple` (for ConstantPad2d), uses (padding_left, 150 /// padding_right, padding_top, padding_bottom). 151 /// - If it is a 6-`tuple` (for ConstantPad3d), uses 152 /// (padding_left, padding_right, padding_top, padding_bottom, 153 /// padding_front, padding_back). 154 TORCH_ARG(ExpandingArray<D * 2>, padding); 155 156 /// Fill value for constant padding. 157 TORCH_ARG(double, value); 158 }; 159 160 /// `ConstantPadOptions` specialized for the `ConstantPad1d` module. 161 /// 162 /// Example: 163 /// ``` 164 /// ConstantPad1d model(ConstantPad1dOptions({3, 1}, 3.5)); 165 /// ``` 166 using ConstantPad1dOptions = ConstantPadOptions<1>; 167 168 /// `ConstantPadOptions` specialized for the `ConstantPad2d` module. 169 /// 170 /// Example: 171 /// ``` 172 /// ConstantPad2d model(ConstantPad2dOptions({3, 0, 2, 1}, 3.5)); 173 /// ``` 174 using ConstantPad2dOptions = ConstantPadOptions<2>; 175 176 /// `ConstantPadOptions` specialized for the `ConstantPad3d` module. 177 /// 178 /// Example: 179 /// ``` 180 /// ConstantPad3d model(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5)); 181 /// ``` 182 using ConstantPad3dOptions = ConstantPadOptions<3>; 183 184 // ============================================================================ 185 186 namespace functional { 187 188 /// Options for `torch::nn::functional::pad`. 189 /// 190 /// Example: 191 /// ``` 192 /// namespace F = torch::nn::functional; 193 /// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1, 194 /// 2}).mode(torch::kReplicate)); 195 /// ``` 196 struct TORCH_API PadFuncOptions { 197 typedef std::variant< 198 enumtype::kConstant, 199 enumtype::kReflect, 200 enumtype::kReplicate, 201 enumtype::kCircular> 202 mode_t; 203 204 PadFuncOptions(std::vector<int64_t> pad); 205 206 /// m-elements tuple, where m/2 <= input dimensions and m is even. 207 TORCH_ARG(std::vector<int64_t>, pad); 208 209 /// "constant", "reflect", "replicate" or "circular". Default: "constant" 210 TORCH_ARG(mode_t, mode) = torch::kConstant; 211 212 /// fill value for "constant" padding. Default: 0 213 TORCH_ARG(double, value) = 0; 214 }; 215 216 } // namespace functional 217 218 } // namespace nn 219 } // namespace torch 220