1 #pragma once
2
3 #include <torch/nn/options/fold.h>
4
5 namespace torch {
6 namespace nn {
7 namespace functional {
8
9 #ifndef DOXYGEN_SHOULD_SKIP_THIS
10 namespace detail {
fold(const Tensor & input,ExpandingArray<2> output_size,ExpandingArray<2> kernel_size,ExpandingArray<2> dilation,ExpandingArray<2> padding,ExpandingArray<2> stride)11 inline Tensor fold(
12 const Tensor& input,
13 ExpandingArray<2> output_size,
14 ExpandingArray<2> kernel_size,
15 ExpandingArray<2> dilation,
16 ExpandingArray<2> padding,
17 ExpandingArray<2> stride) {
18 if (input.dim() == 3 || input.dim() == 2) {
19 return torch::col2im(
20 input, output_size, kernel_size, dilation, padding, stride);
21 } else {
22 TORCH_CHECK(
23 false,
24 "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported "
25 "(got ",
26 input.dim(),
27 "D)");
28 }
29 }
30 } // namespace detail
31 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
32
33 /// See
34 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.fold
35 /// about the exact behavior of this functional.
36 ///
37 /// See the documentation for `torch::nn::functional::FoldFuncOptions` class to
38 /// learn what optional arguments are supported for this functional.
39 ///
40 /// Example:
41 /// ```
42 /// namespace F = torch::nn::functional;
43 /// F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2}));
44 /// ```
fold(const Tensor & input,const FoldFuncOptions & options)45 inline Tensor fold(const Tensor& input, const FoldFuncOptions& options) {
46 return detail::fold(
47 input,
48 options.output_size(),
49 options.kernel_size(),
50 options.dilation(),
51 options.padding(),
52 options.stride());
53 }
54
55 // ============================================================================
56
57 #ifndef DOXYGEN_SHOULD_SKIP_THIS
58 namespace detail {
unfold(const Tensor & input,ExpandingArray<2> kernel_size,ExpandingArray<2> dilation,ExpandingArray<2> padding,ExpandingArray<2> stride)59 inline Tensor unfold(
60 const Tensor& input,
61 ExpandingArray<2> kernel_size,
62 ExpandingArray<2> dilation,
63 ExpandingArray<2> padding,
64 ExpandingArray<2> stride) {
65 if (input.dim() == 4) {
66 return torch::im2col(input, kernel_size, dilation, padding, stride);
67 } else {
68 TORCH_CHECK(
69 false,
70 "Input Error: Only 4D input Tensors are supported "
71 "(got ",
72 input.dim(),
73 "D)");
74 }
75 }
76 } // namespace detail
77 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
78
79 /// See
80 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.unfold
81 /// about the exact behavior of this functional.
82 ///
83 /// See the documentation for `torch::nn::functional::UnfoldFuncOptions` class
84 /// to learn what optional arguments are supported for this functional.
85 ///
86 /// Example:
87 /// ```
88 /// namespace F = torch::nn::functional;
89 /// F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2));
90 /// ```
unfold(const Tensor & input,const UnfoldFuncOptions & options)91 inline Tensor unfold(const Tensor& input, const UnfoldFuncOptions& options) {
92 return detail::unfold(
93 input,
94 options.kernel_size(),
95 options.dilation(),
96 options.padding(),
97 options.stride());
98 }
99
100 } // namespace functional
101 } // namespace nn
102 } // namespace torch
103