1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/enum.h> 6 #include <torch/expanding_array.h> 7 #include <torch/types.h> 8 9 namespace torch { 10 namespace nn { 11 12 namespace detail { 13 14 typedef std::variant< 15 enumtype::kZeros, 16 enumtype::kReflect, 17 enumtype::kReplicate, 18 enumtype::kCircular> 19 conv_padding_mode_t; 20 21 template <size_t D> 22 using conv_padding_t = 23 std::variant<ExpandingArray<D>, enumtype::kValid, enumtype::kSame>; 24 25 /// Options for a `D`-dimensional convolution or convolution transpose module. 26 template <size_t D> 27 struct ConvNdOptions { 28 using padding_t = conv_padding_t<D>; ConvNdOptionsConvNdOptions29 ConvNdOptions( 30 int64_t in_channels, 31 int64_t out_channels, 32 ExpandingArray<D> kernel_size) 33 : in_channels_(in_channels), 34 out_channels_(out_channels), 35 kernel_size_(std::move(kernel_size)) {} 36 37 /// The number of channels the input volumes will have. 38 /// Changing this parameter after construction __has no effect__. 39 TORCH_ARG(int64_t, in_channels); 40 41 /// The number of output channels the convolution should produce. 42 /// Changing this parameter after construction __has no effect__. 43 TORCH_ARG(int64_t, out_channels); 44 45 /// The kernel size to use. 46 /// For a `D`-dim convolution, must be a single number or a list of `D` 47 /// numbers. 48 /// This parameter __can__ be changed after construction. 49 TORCH_ARG(ExpandingArray<D>, kernel_size); 50 51 /// The stride of the convolution. 52 /// For a `D`-dim convolution, must be a single number or a list of `D` 53 /// numbers. 54 /// This parameter __can__ be changed after construction. 55 TORCH_ARG(ExpandingArray<D>, stride) = 1; 56 57 /// The padding to add to the input volumes. 58 /// For a `D`-dim convolution, must be a single number or a list of `D` 59 /// numbers. 60 /// This parameter __can__ be changed after construction. 61 TORCH_ARG(padding_t, padding) = 0; 62 63 public: decltypeConvNdOptions64 decltype(auto) padding(std::initializer_list<int64_t> il) { 65 return padding(IntArrayRef{il}); 66 } 67 68 /// The kernel dilation. 69 /// For a `D`-dim convolution, must be a single number or a list of `D` 70 /// numbers. 71 /// This parameter __can__ be changed after construction. 72 TORCH_ARG(ExpandingArray<D>, dilation) = 1; 73 74 /// If true, convolutions will be transpose convolutions (a.k.a. 75 /// deconvolutions). 76 /// Changing this parameter after construction __has no effect__. 77 TORCH_ARG(bool, transposed) = false; 78 79 /// For transpose convolutions, the padding to add to output volumes. 80 /// For a `D`-dim convolution, must be a single number or a list of `D` 81 /// numbers. 82 /// This parameter __can__ be changed after construction. 83 TORCH_ARG(ExpandingArray<D>, output_padding) = 0; 84 85 /// The number of convolution groups. 86 /// This parameter __can__ be changed after construction. 87 TORCH_ARG(int64_t, groups) = 1; 88 89 /// Whether to add a bias after individual applications of the kernel. 90 /// Changing this parameter after construction __has no effect__. 91 TORCH_ARG(bool, bias) = true; 92 93 /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or 94 /// `torch::kCircular`. Default: `torch::kZeros` 95 TORCH_ARG(conv_padding_mode_t, padding_mode) = torch::kZeros; 96 }; 97 98 } // namespace detail 99 100 // ============================================================================ 101 102 /// Options for a `D`-dimensional convolution module. 103 template <size_t D> 104 struct ConvOptions { 105 using padding_mode_t = detail::conv_padding_mode_t; 106 using padding_t = detail::conv_padding_t<D>; 107 ConvOptionsConvOptions108 ConvOptions( 109 int64_t in_channels, 110 int64_t out_channels, 111 ExpandingArray<D> kernel_size) 112 : in_channels_(in_channels), 113 out_channels_(out_channels), 114 kernel_size_(std::move(kernel_size)) {} 115 116 /// The number of channels the input volumes will have. 117 /// Changing this parameter after construction __has no effect__. 118 TORCH_ARG(int64_t, in_channels); 119 120 /// The number of output channels the convolution should produce. 121 /// Changing this parameter after construction __has no effect__. 122 TORCH_ARG(int64_t, out_channels); 123 124 /// The kernel size to use. 125 /// For a `D`-dim convolution, must be a single number or a list of `D` 126 /// numbers. 127 /// This parameter __can__ be changed after construction. 128 TORCH_ARG(ExpandingArray<D>, kernel_size); 129 130 /// The stride of the convolution. 131 /// For a `D`-dim convolution, must be a single number or a list of `D` 132 /// numbers. 133 /// This parameter __can__ be changed after construction. 134 TORCH_ARG(ExpandingArray<D>, stride) = 1; 135 136 /// The padding to add to the input volumes. 137 /// For a `D`-dim convolution, must be a single number or a list of `D` 138 /// numbers. 139 /// This parameter __can__ be changed after construction. 140 TORCH_ARG(padding_t, padding) = 0; 141 142 public: decltypeConvOptions143 decltype(auto) padding(std::initializer_list<int64_t> il) { 144 return padding(IntArrayRef{il}); 145 } 146 147 /// The kernel dilation. 148 /// For a `D`-dim convolution, must be a single number or a list of `D` 149 /// numbers. 150 /// This parameter __can__ be changed after construction. 151 TORCH_ARG(ExpandingArray<D>, dilation) = 1; 152 153 /// The number of convolution groups. 154 /// This parameter __can__ be changed after construction. 155 TORCH_ARG(int64_t, groups) = 1; 156 157 /// Whether to add a bias after individual applications of the kernel. 158 /// Changing this parameter after construction __has no effect__. 159 TORCH_ARG(bool, bias) = true; 160 161 /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or 162 /// `torch::kCircular`. Default: `torch::kZeros` 163 TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros; 164 }; 165 166 /// `ConvOptions` specialized for the `Conv1d` module. 167 /// 168 /// Example: 169 /// ``` 170 /// Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false)); 171 /// ``` 172 using Conv1dOptions = ConvOptions<1>; 173 174 /// `ConvOptions` specialized for the `Conv2d` module. 175 /// 176 /// Example: 177 /// ``` 178 /// Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false)); 179 /// ``` 180 using Conv2dOptions = ConvOptions<2>; 181 182 /// `ConvOptions` specialized for the `Conv3d` module. 183 /// 184 /// Example: 185 /// ``` 186 /// Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false)); 187 /// ``` 188 using Conv3dOptions = ConvOptions<3>; 189 190 // ============================================================================ 191 192 namespace functional { 193 194 /// Options for a `D`-dimensional convolution functional. 195 template <size_t D> 196 struct ConvFuncOptions { 197 using padding_t = torch::nn::detail::conv_padding_t<D>; 198 199 /// optional bias of shape `(out_channels)`. Default: ``None`` 200 TORCH_ARG(torch::Tensor, bias) = Tensor(); 201 202 /// The stride of the convolving kernel. 203 /// For a `D`-dim convolution, must be a single number or a list of `D` 204 /// numbers. 205 TORCH_ARG(ExpandingArray<D>, stride) = 1; 206 207 /// Implicit paddings on both sides of the input. 208 /// For a `D`-dim convolution, must be a single number or a list of `D` 209 /// numbers. 210 TORCH_ARG(padding_t, padding) = 0; 211 212 public: decltypeConvFuncOptions213 decltype(auto) padding(std::initializer_list<int64_t> il) { 214 return padding(IntArrayRef{il}); 215 } 216 217 /// The spacing between kernel elements. 218 /// For a `D`-dim convolution, must be a single number or a list of `D` 219 /// numbers. 220 TORCH_ARG(ExpandingArray<D>, dilation) = 1; 221 222 /// Split input into groups, `in_channels` should be divisible by 223 /// the number of groups. 224 TORCH_ARG(int64_t, groups) = 1; 225 }; 226 227 /// `ConvFuncOptions` specialized for `torch::nn::functional::conv1d`. 228 /// 229 /// Example: 230 /// ``` 231 /// namespace F = torch::nn::functional; 232 /// F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1)); 233 /// ``` 234 using Conv1dFuncOptions = ConvFuncOptions<1>; 235 236 /// `ConvFuncOptions` specialized for `torch::nn::functional::conv2d`. 237 /// 238 /// Example: 239 /// ``` 240 /// namespace F = torch::nn::functional; 241 /// F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1)); 242 /// ``` 243 using Conv2dFuncOptions = ConvFuncOptions<2>; 244 245 /// `ConvFuncOptions` specialized for `torch::nn::functional::conv3d`. 246 /// 247 /// Example: 248 /// ``` 249 /// namespace F = torch::nn::functional; 250 /// F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1)); 251 /// ``` 252 using Conv3dFuncOptions = ConvFuncOptions<3>; 253 254 } // namespace functional 255 256 // ============================================================================ 257 258 template <size_t D> 259 struct ConvTransposeOptions { 260 using padding_mode_t = detail::conv_padding_mode_t; 261 ConvTransposeOptionsConvTransposeOptions262 ConvTransposeOptions( 263 int64_t in_channels, 264 int64_t out_channels, 265 ExpandingArray<D> kernel_size) 266 : in_channels_(in_channels), 267 out_channels_(out_channels), 268 kernel_size_(std::move(kernel_size)) {} 269 270 /// The number of channels the input volumes will have. 271 /// Changing this parameter after construction __has no effect__. 272 TORCH_ARG(int64_t, in_channels); 273 274 /// The number of output channels the convolution should produce. 275 /// Changing this parameter after construction __has no effect__. 276 TORCH_ARG(int64_t, out_channels); 277 278 /// The kernel size to use. 279 /// For a `D`-dim convolution, must be a single number or a list of `D` 280 /// numbers. 281 /// This parameter __can__ be changed after construction. 282 TORCH_ARG(ExpandingArray<D>, kernel_size); 283 284 /// The stride of the convolution. 285 /// For a `D`-dim convolution, must be a single number or a list of `D` 286 /// numbers. 287 /// This parameter __can__ be changed after construction. 288 TORCH_ARG(ExpandingArray<D>, stride) = 1; 289 290 /// The padding to add to the input volumes. 291 /// For a `D`-dim convolution, must be a single number or a list of `D` 292 /// numbers. 293 /// This parameter __can__ be changed after construction. 294 TORCH_ARG(ExpandingArray<D>, padding) = 0; 295 296 /// For transpose convolutions, the padding to add to output volumes. 297 /// For a `D`-dim convolution, must be a single number or a list of `D` 298 /// numbers. 299 /// This parameter __can__ be changed after construction. 300 TORCH_ARG(ExpandingArray<D>, output_padding) = 0; 301 302 /// The number of convolution groups. 303 /// This parameter __can__ be changed after construction. 304 TORCH_ARG(int64_t, groups) = 1; 305 306 /// Whether to add a bias after individual applications of the kernel. 307 /// Changing this parameter after construction __has no effect__. 308 TORCH_ARG(bool, bias) = true; 309 310 /// The kernel dilation. 311 /// For a `D`-dim convolution, must be a single number or a list of `D` 312 /// numbers. 313 /// This parameter __can__ be changed after construction. 314 TORCH_ARG(ExpandingArray<D>, dilation) = 1; 315 316 /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or 317 /// `torch::kCircular`. Default: `torch::kZeros` 318 TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros; 319 }; 320 321 /// `ConvTransposeOptions` specialized for the `ConvTranspose1d` module. 322 /// 323 /// Example: 324 /// ``` 325 /// ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 326 /// 3).stride(1).bias(false)); 327 /// ``` 328 using ConvTranspose1dOptions = ConvTransposeOptions<1>; 329 330 /// `ConvTransposeOptions` specialized for the `ConvTranspose2d` module. 331 /// 332 /// Example: 333 /// ``` 334 /// ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 335 /// 3).stride(1).bias(false)); 336 /// ``` 337 using ConvTranspose2dOptions = ConvTransposeOptions<2>; 338 339 /// `ConvTransposeOptions` specialized for the `ConvTranspose3d` module. 340 /// 341 /// Example: 342 /// ``` 343 /// ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 344 /// 2).stride(1).bias(false)); 345 /// ``` 346 using ConvTranspose3dOptions = ConvTransposeOptions<3>; 347 348 // ============================================================================ 349 350 namespace functional { 351 352 /// Options for a `D`-dimensional convolution functional. 353 template <size_t D> 354 struct ConvTransposeFuncOptions { 355 /// optional bias of shape `(out_channels)`. Default: ``None`` 356 TORCH_ARG(torch::Tensor, bias) = Tensor(); 357 358 /// The stride of the convolving kernel. 359 /// For a `D`-dim convolution, must be a single number or a list of `D` 360 /// numbers. 361 TORCH_ARG(ExpandingArray<D>, stride) = 1; 362 363 /// Implicit paddings on both sides of the input. 364 /// For a `D`-dim convolution, must be a single number or a list of `D` 365 /// numbers. 366 TORCH_ARG(ExpandingArray<D>, padding) = 0; 367 368 /// Additional size added to one side of each dimension in the output shape. 369 /// Default: 0 370 TORCH_ARG(ExpandingArray<D>, output_padding) = 0; 371 372 /// Split input into groups, `in_channels` should be divisible by 373 /// the number of groups. 374 TORCH_ARG(int64_t, groups) = 1; 375 376 /// The spacing between kernel elements. 377 /// For a `D`-dim convolution, must be a single number or a list of `D` 378 /// numbers. 379 TORCH_ARG(ExpandingArray<D>, dilation) = 1; 380 }; 381 382 /// `ConvTransposeFuncOptions` specialized for 383 /// `torch::nn::functional::conv_transpose1d`. 384 /// 385 /// Example: 386 /// ``` 387 /// namespace F = torch::nn::functional; 388 /// F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1)); 389 /// ``` 390 using ConvTranspose1dFuncOptions = ConvTransposeFuncOptions<1>; 391 392 /// `ConvTransposeFuncOptions` specialized for 393 /// `torch::nn::functional::conv_transpose2d`. 394 /// 395 /// Example: 396 /// ``` 397 /// namespace F = torch::nn::functional; 398 /// F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); 399 /// ``` 400 using ConvTranspose2dFuncOptions = ConvTransposeFuncOptions<2>; 401 402 /// `ConvTransposeFuncOptions` specialized for 403 /// `torch::nn::functional::conv_transpose3d`. 404 /// 405 /// Example: 406 /// ``` 407 /// namespace F = torch::nn::functional; 408 /// F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1)); 409 /// ``` 410 using ConvTranspose3dFuncOptions = ConvTransposeFuncOptions<3>; 411 412 } // namespace functional 413 414 } // namespace nn 415 } // namespace torch 416