xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/functional/pixelshuffle.h>
5 #include <torch/nn/options/pixelshuffle.h>
6 
7 #include <torch/csrc/Export.h>
8 
9 namespace torch {
10 namespace nn {
11 
12 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle
13 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14 
15 /// Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
16 /// to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an
17 /// upscale factor. See
18 /// https://pytorch.org/docs/main/nn.html#torch.nn.PixelShuffle to learn about
19 /// the exact behavior of this module.
20 ///
21 /// See the documentation for `torch::nn::PixelShuffleOptions` class to learn
22 /// what constructor arguments are supported for this module.
23 ///
24 /// Example:
25 /// ```
26 /// PixelShuffle model(PixelShuffleOptions(5));
27 /// ```
28 struct TORCH_API PixelShuffleImpl
29     : public torch::nn::Cloneable<PixelShuffleImpl> {
30   explicit PixelShuffleImpl(const PixelShuffleOptions& options_);
31 
32   /// Pretty prints the `PixelShuffle` module into the given `stream`.
33   void pretty_print(std::ostream& stream) const override;
34 
35   Tensor forward(const Tensor& input);
36 
37   void reset() override;
38 
39   /// The options with which this `Module` was constructed.
40   PixelShuffleOptions options;
41 };
42 
43 /// A `ModuleHolder` subclass for `PixelShuffleImpl`.
44 /// See the documentation for `PixelShuffleImpl` class to learn what methods it
45 /// provides, and examples of how to use `PixelShuffle` with
46 /// `torch::nn::PixelShuffleOptions`. See the documentation for `ModuleHolder`
47 /// to learn about PyTorch's module storage semantics.
48 TORCH_MODULE(PixelShuffle);
49 
50 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelUnshuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51 
52 /// Reverses the PixelShuffle operation by rearranging elements in a tensor of
53 /// shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape :math:`(*,
54 /// C \times r^2, H, W)`, where r is a downscale factor. See
55 /// https://pytorch.org/docs/main/nn.html#torch.nn.PixelUnshuffle to learn
56 /// about the exact behavior of this module.
57 ///
58 /// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn
59 /// what constructor arguments are supported for this module.
60 ///
61 /// Example:
62 /// ```
63 /// PixelUnshuffle model(PixelUnshuffleOptions(5));
64 /// ```
65 struct TORCH_API PixelUnshuffleImpl
66     : public torch::nn::Cloneable<PixelUnshuffleImpl> {
67   explicit PixelUnshuffleImpl(const PixelUnshuffleOptions& options_);
68 
69   /// Pretty prints the `PixelUnshuffle` module into the given `stream`.
70   void pretty_print(std::ostream& stream) const override;
71 
72   Tensor forward(const Tensor& input);
73 
74   void reset() override;
75 
76   /// The options with which this `Module` was constructed.
77   PixelUnshuffleOptions options;
78 };
79 
80 /// A `ModuleHolder` subclass for `PixelUnshuffleImpl`.
81 /// See the documentation for `PixelUnshuffleImpl` class to learn what methods
82 /// it provides, and examples of how to use `PixelUnshuffle` with
83 /// `torch::nn::PixelUnshuffleOptions`. See the documentation for `ModuleHolder`
84 /// to learn about PyTorch's module storage semantics.
85 TORCH_MODULE(PixelUnshuffle);
86 
87 } // namespace nn
88 } // namespace torch
89