xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/pixelshuffle.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/types.h>
6 
7 namespace torch {
8 namespace nn {
9 
10 /// Options for the `PixelShuffle` module.
11 ///
12 /// Example:
13 /// ```
14 /// PixelShuffle model(PixelShuffleOptions(5));
15 /// ```
16 struct TORCH_API PixelShuffleOptions {
PixelShuffleOptionsPixelShuffleOptions17   PixelShuffleOptions(int64_t upscale_factor)
18       : upscale_factor_(upscale_factor) {}
19 
20   /// Factor to increase spatial resolution by
21   TORCH_ARG(int64_t, upscale_factor);
22 };
23 
24 /// Options for the `PixelUnshuffle` module.
25 ///
26 /// Example:
27 /// ```
28 /// PixelUnshuffle model(PixelUnshuffleOptions(5));
29 /// ```
30 struct TORCH_API PixelUnshuffleOptions {
PixelUnshuffleOptionsPixelUnshuffleOptions31   /* implicit */ PixelUnshuffleOptions(int64_t downscale_factor)
32       : downscale_factor_(downscale_factor) {}
33 
34   /// Factor to decrease spatial resolution by
35   TORCH_ARG(int64_t, downscale_factor);
36 };
37 
38 namespace functional {
39 /// Options for `torch::nn::functional::pixel_shuffle`.
40 ///
41 /// See the documentation for `torch::nn::PixelShuffleOptions` class to learn
42 /// what arguments are supported.
43 ///
44 /// Example:
45 /// ```
46 /// namespace F = torch::nn::functional;
47 /// F::pixel_shuffle(x, F::PixelShuffleFuncOptions(2));
48 /// ```
49 using PixelShuffleFuncOptions = PixelShuffleOptions;
50 
51 /// Options for `torch::nn::functional::pixel_unshuffle`.
52 ///
53 /// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn
54 /// what arguments are supported.
55 ///
56 /// Example:
57 /// ```
58 /// namespace F = torch::nn::functional;
59 /// F::pixel_unshuffle(x, F::PixelUnshuffleFuncOptions(2));
60 /// ```
61 using PixelUnshuffleFuncOptions = PixelUnshuffleOptions;
62 } // namespace functional
63 
64 } // namespace nn
65 } // namespace torch
66