xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/fold.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/modules/fold.h>
2 
3 #include <torch/expanding_array.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6 
7 namespace F = torch::nn::functional;
8 
9 namespace torch {
10 namespace nn {
11 
FoldImpl(const FoldOptions & options_)12 FoldImpl::FoldImpl(const FoldOptions& options_) : options(options_) {}
13 
reset()14 void FoldImpl::reset() {}
15 
pretty_print(std::ostream & stream) const16 void FoldImpl::pretty_print(std::ostream& stream) const {
17   stream << "torch::nn::Fold(output_size=" << options.output_size()
18          << ", kernel_size=" << options.kernel_size()
19          << ", dilation=" << options.dilation()
20          << ", padding=" << options.padding() << ", stride=" << options.stride()
21          << ")";
22 }
23 
forward(const Tensor & input)24 Tensor FoldImpl::forward(const Tensor& input) {
25   return F::detail::fold(
26       input,
27       options.output_size(),
28       options.kernel_size(),
29       options.dilation(),
30       options.padding(),
31       options.stride());
32 }
33 
34 // ============================================================================
35 
UnfoldImpl(const UnfoldOptions & options_)36 UnfoldImpl::UnfoldImpl(const UnfoldOptions& options_) : options(options_) {}
37 
reset()38 void UnfoldImpl::reset() {}
39 
pretty_print(std::ostream & stream) const40 void UnfoldImpl::pretty_print(std::ostream& stream) const {
41   stream << "torch::nn::Unfold(kernel_size=" << options.kernel_size()
42          << ", dilation=" << options.dilation()
43          << ", padding=" << options.padding() << ", stride=" << options.stride()
44          << ")";
45 }
46 
forward(const Tensor & input)47 Tensor UnfoldImpl::forward(const Tensor& input) {
48   return F::detail::unfold(
49       input,
50       options.kernel_size(),
51       options.dilation(),
52       options.padding(),
53       options.stride());
54 }
55 
56 } // namespace nn
57 } // namespace torch
58