1 #include <torch/nn/modules/upsampling.h> 2 3 #include <string> 4 5 namespace F = torch::nn::functional; 6 7 namespace torch { 8 namespace nn { 9 UpsampleImpl(const UpsampleOptions & options_)10UpsampleImpl::UpsampleImpl( 11 const UpsampleOptions& options_) // NOLINT(modernize-pass-by-value) 12 : options(options_) {} 13 reset()14void UpsampleImpl::reset() {} 15 pretty_print(std::ostream & stream) const16void UpsampleImpl::pretty_print(std::ostream& stream) const { 17 stream << "torch::nn::Upsample("; 18 if (options.scale_factor() != std::nullopt) { 19 stream << "scale_factor=" << at::ArrayRef<double>(*options.scale_factor()); 20 } else { 21 stream << "size=" << at::ArrayRef<int64_t>(*options.size()); 22 } 23 stream << ", mode=" << enumtype::get_enum_name(options.mode()) << ")"; 24 } 25 forward(const Tensor & input)26Tensor UpsampleImpl::forward(const Tensor& input) { 27 F::InterpolateFuncOptions::mode_t mode; 28 if (std::holds_alternative<enumtype::kNearest>(options.mode())) { 29 mode = torch::kNearest; 30 } else if (std::holds_alternative<enumtype::kLinear>(options.mode())) { 31 mode = torch::kLinear; 32 } else if (std::holds_alternative<enumtype::kBilinear>(options.mode())) { 33 mode = torch::kBilinear; 34 } else if (std::holds_alternative<enumtype::kBicubic>(options.mode())) { 35 mode = torch::kBicubic; 36 } else if (std::holds_alternative<enumtype::kTrilinear>(options.mode())) { 37 mode = torch::kTrilinear; 38 } 39 40 return F::detail::interpolate( 41 input, 42 options.size(), 43 options.scale_factor(), 44 mode, 45 options.align_corners(), 46 std::nullopt, 47 false); 48 } 49 50 } // namespace nn 51 } // namespace torch 52