1 #include <torch/nn/modules/pixelshuffle.h> 2 3 namespace F = torch::nn::functional; 4 5 namespace torch { 6 namespace nn { 7 PixelShuffleImpl(const PixelShuffleOptions & options_)8PixelShuffleImpl::PixelShuffleImpl(const PixelShuffleOptions& options_) 9 : options(options_) {} 10 pretty_print(std::ostream & stream) const11void PixelShuffleImpl::pretty_print(std::ostream& stream) const { 12 stream << "torch::nn::PixelShuffle(upscale_factor=" 13 << options.upscale_factor() << ")"; 14 } 15 reset()16void PixelShuffleImpl::reset() {} 17 forward(const Tensor & input)18Tensor PixelShuffleImpl::forward(const Tensor& input) { 19 return F::detail::pixel_shuffle(input, options.upscale_factor()); 20 } 21 PixelUnshuffleImpl(const PixelUnshuffleOptions & options_)22PixelUnshuffleImpl::PixelUnshuffleImpl(const PixelUnshuffleOptions& options_) 23 : options(options_) {} 24 pretty_print(std::ostream & stream) const25void PixelUnshuffleImpl::pretty_print(std::ostream& stream) const { 26 stream << "torch::nn::PixelUnshuffle(downscale_factor=" 27 << options.downscale_factor() << ")"; 28 } 29 reset()30void PixelUnshuffleImpl::reset() {} 31 forward(const Tensor & input)32Tensor PixelUnshuffleImpl::forward(const Tensor& input) { 33 return F::detail::pixel_unshuffle(input, options.downscale_factor()); 34 } 35 36 } // namespace nn 37 } // namespace torch 38