xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/pixelshuffle.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/modules/pixelshuffle.h>
2 
3 namespace F = torch::nn::functional;
4 
5 namespace torch {
6 namespace nn {
7 
PixelShuffleImpl(const PixelShuffleOptions & options_)8 PixelShuffleImpl::PixelShuffleImpl(const PixelShuffleOptions& options_)
9     : options(options_) {}
10 
pretty_print(std::ostream & stream) const11 void PixelShuffleImpl::pretty_print(std::ostream& stream) const {
12   stream << "torch::nn::PixelShuffle(upscale_factor="
13          << options.upscale_factor() << ")";
14 }
15 
reset()16 void PixelShuffleImpl::reset() {}
17 
forward(const Tensor & input)18 Tensor PixelShuffleImpl::forward(const Tensor& input) {
19   return F::detail::pixel_shuffle(input, options.upscale_factor());
20 }
21 
PixelUnshuffleImpl(const PixelUnshuffleOptions & options_)22 PixelUnshuffleImpl::PixelUnshuffleImpl(const PixelUnshuffleOptions& options_)
23     : options(options_) {}
24 
pretty_print(std::ostream & stream) const25 void PixelUnshuffleImpl::pretty_print(std::ostream& stream) const {
26   stream << "torch::nn::PixelUnshuffle(downscale_factor="
27          << options.downscale_factor() << ")";
28 }
29 
reset()30 void PixelUnshuffleImpl::reset() {}
31 
forward(const Tensor & input)32 Tensor PixelUnshuffleImpl::forward(const Tensor& input) {
33   return F::detail::pixel_unshuffle(input, options.downscale_factor());
34 }
35 
36 } // namespace nn
37 } // namespace torch
38