1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/functional/pixelshuffle.h> 5 #include <torch/nn/options/pixelshuffle.h> 6 7 #include <torch/csrc/Export.h> 8 9 namespace torch { 10 namespace nn { 11 12 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle 13 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 14 15 /// Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` 16 /// to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an 17 /// upscale factor. See 18 /// https://pytorch.org/docs/main/nn.html#torch.nn.PixelShuffle to learn about 19 /// the exact behavior of this module. 20 /// 21 /// See the documentation for `torch::nn::PixelShuffleOptions` class to learn 22 /// what constructor arguments are supported for this module. 23 /// 24 /// Example: 25 /// ``` 26 /// PixelShuffle model(PixelShuffleOptions(5)); 27 /// ``` 28 struct TORCH_API PixelShuffleImpl 29 : public torch::nn::Cloneable<PixelShuffleImpl> { 30 explicit PixelShuffleImpl(const PixelShuffleOptions& options_); 31 32 /// Pretty prints the `PixelShuffle` module into the given `stream`. 33 void pretty_print(std::ostream& stream) const override; 34 35 Tensor forward(const Tensor& input); 36 37 void reset() override; 38 39 /// The options with which this `Module` was constructed. 40 PixelShuffleOptions options; 41 }; 42 43 /// A `ModuleHolder` subclass for `PixelShuffleImpl`. 44 /// See the documentation for `PixelShuffleImpl` class to learn what methods it 45 /// provides, and examples of how to use `PixelShuffle` with 46 /// `torch::nn::PixelShuffleOptions`. See the documentation for `ModuleHolder` 47 /// to learn about PyTorch's module storage semantics. 48 TORCH_MODULE(PixelShuffle); 49 50 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelUnshuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 51 52 /// Reverses the PixelShuffle operation by rearranging elements in a tensor of 53 /// shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape :math:`(*, 54 /// C \times r^2, H, W)`, where r is a downscale factor. See 55 /// https://pytorch.org/docs/main/nn.html#torch.nn.PixelUnshuffle to learn 56 /// about the exact behavior of this module. 57 /// 58 /// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn 59 /// what constructor arguments are supported for this module. 60 /// 61 /// Example: 62 /// ``` 63 /// PixelUnshuffle model(PixelUnshuffleOptions(5)); 64 /// ``` 65 struct TORCH_API PixelUnshuffleImpl 66 : public torch::nn::Cloneable<PixelUnshuffleImpl> { 67 explicit PixelUnshuffleImpl(const PixelUnshuffleOptions& options_); 68 69 /// Pretty prints the `PixelUnshuffle` module into the given `stream`. 70 void pretty_print(std::ostream& stream) const override; 71 72 Tensor forward(const Tensor& input); 73 74 void reset() override; 75 76 /// The options with which this `Module` was constructed. 77 PixelUnshuffleOptions options; 78 }; 79 80 /// A `ModuleHolder` subclass for `PixelUnshuffleImpl`. 81 /// See the documentation for `PixelUnshuffleImpl` class to learn what methods 82 /// it provides, and examples of how to use `PixelUnshuffle` with 83 /// `torch::nn::PixelUnshuffleOptions`. See the documentation for `ModuleHolder` 84 /// to learn about PyTorch's module storage semantics. 85 TORCH_MODULE(PixelUnshuffle); 86 87 } // namespace nn 88 } // namespace torch 89