1 #include <torch/nn/modules/fold.h>
2
3 #include <torch/expanding_array.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6
7 namespace F = torch::nn::functional;
8
9 namespace torch {
10 namespace nn {
11
FoldImpl(const FoldOptions & options_)12 FoldImpl::FoldImpl(const FoldOptions& options_) : options(options_) {}
13
reset()14 void FoldImpl::reset() {}
15
pretty_print(std::ostream & stream) const16 void FoldImpl::pretty_print(std::ostream& stream) const {
17 stream << "torch::nn::Fold(output_size=" << options.output_size()
18 << ", kernel_size=" << options.kernel_size()
19 << ", dilation=" << options.dilation()
20 << ", padding=" << options.padding() << ", stride=" << options.stride()
21 << ")";
22 }
23
forward(const Tensor & input)24 Tensor FoldImpl::forward(const Tensor& input) {
25 return F::detail::fold(
26 input,
27 options.output_size(),
28 options.kernel_size(),
29 options.dilation(),
30 options.padding(),
31 options.stride());
32 }
33
34 // ============================================================================
35
UnfoldImpl(const UnfoldOptions & options_)36 UnfoldImpl::UnfoldImpl(const UnfoldOptions& options_) : options(options_) {}
37
reset()38 void UnfoldImpl::reset() {}
39
pretty_print(std::ostream & stream) const40 void UnfoldImpl::pretty_print(std::ostream& stream) const {
41 stream << "torch::nn::Unfold(kernel_size=" << options.kernel_size()
42 << ", dilation=" << options.dilation()
43 << ", padding=" << options.padding() << ", stride=" << options.stride()
44 << ")";
45 }
46
forward(const Tensor & input)47 Tensor UnfoldImpl::forward(const Tensor& input) {
48 return F::detail::unfold(
49 input,
50 options.kernel_size(),
51 options.dilation(),
52 options.padding(),
53 options.stride());
54 }
55
56 } // namespace nn
57 } // namespace torch
58