xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/upsampling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/modules/upsampling.h>
2 
3 #include <string>
4 
5 namespace F = torch::nn::functional;
6 
7 namespace torch {
8 namespace nn {
9 
UpsampleImpl(const UpsampleOptions & options_)10 UpsampleImpl::UpsampleImpl(
11     const UpsampleOptions& options_) // NOLINT(modernize-pass-by-value)
12     : options(options_) {}
13 
reset()14 void UpsampleImpl::reset() {}
15 
pretty_print(std::ostream & stream) const16 void UpsampleImpl::pretty_print(std::ostream& stream) const {
17   stream << "torch::nn::Upsample(";
18   if (options.scale_factor() != std::nullopt) {
19     stream << "scale_factor=" << at::ArrayRef<double>(*options.scale_factor());
20   } else {
21     stream << "size=" << at::ArrayRef<int64_t>(*options.size());
22   }
23   stream << ", mode=" << enumtype::get_enum_name(options.mode()) << ")";
24 }
25 
forward(const Tensor & input)26 Tensor UpsampleImpl::forward(const Tensor& input) {
27   F::InterpolateFuncOptions::mode_t mode;
28   if (std::holds_alternative<enumtype::kNearest>(options.mode())) {
29     mode = torch::kNearest;
30   } else if (std::holds_alternative<enumtype::kLinear>(options.mode())) {
31     mode = torch::kLinear;
32   } else if (std::holds_alternative<enumtype::kBilinear>(options.mode())) {
33     mode = torch::kBilinear;
34   } else if (std::holds_alternative<enumtype::kBicubic>(options.mode())) {
35     mode = torch::kBicubic;
36   } else if (std::holds_alternative<enumtype::kTrilinear>(options.mode())) {
37     mode = torch::kTrilinear;
38   }
39 
40   return F::detail::interpolate(
41       input,
42       options.size(),
43       options.scale_factor(),
44       mode,
45       options.align_corners(),
46       std::nullopt,
47       false);
48 }
49 
50 } // namespace nn
51 } // namespace torch
52