xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/fold.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/options/fold.h>
4 
5 namespace torch {
6 namespace nn {
7 namespace functional {
8 
9 #ifndef DOXYGEN_SHOULD_SKIP_THIS
10 namespace detail {
fold(const Tensor & input,ExpandingArray<2> output_size,ExpandingArray<2> kernel_size,ExpandingArray<2> dilation,ExpandingArray<2> padding,ExpandingArray<2> stride)11 inline Tensor fold(
12     const Tensor& input,
13     ExpandingArray<2> output_size,
14     ExpandingArray<2> kernel_size,
15     ExpandingArray<2> dilation,
16     ExpandingArray<2> padding,
17     ExpandingArray<2> stride) {
18   if (input.dim() == 3 || input.dim() == 2) {
19     return torch::col2im(
20         input, output_size, kernel_size, dilation, padding, stride);
21   } else {
22     TORCH_CHECK(
23         false,
24         "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported "
25         "(got ",
26         input.dim(),
27         "D)");
28   }
29 }
30 } // namespace detail
31 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
32 
33 /// See
34 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.fold
35 /// about the exact behavior of this functional.
36 ///
37 /// See the documentation for `torch::nn::functional::FoldFuncOptions` class to
38 /// learn what optional arguments are supported for this functional.
39 ///
40 /// Example:
41 /// ```
42 /// namespace F = torch::nn::functional;
43 /// F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2}));
44 /// ```
fold(const Tensor & input,const FoldFuncOptions & options)45 inline Tensor fold(const Tensor& input, const FoldFuncOptions& options) {
46   return detail::fold(
47       input,
48       options.output_size(),
49       options.kernel_size(),
50       options.dilation(),
51       options.padding(),
52       options.stride());
53 }
54 
55 // ============================================================================
56 
57 #ifndef DOXYGEN_SHOULD_SKIP_THIS
58 namespace detail {
unfold(const Tensor & input,ExpandingArray<2> kernel_size,ExpandingArray<2> dilation,ExpandingArray<2> padding,ExpandingArray<2> stride)59 inline Tensor unfold(
60     const Tensor& input,
61     ExpandingArray<2> kernel_size,
62     ExpandingArray<2> dilation,
63     ExpandingArray<2> padding,
64     ExpandingArray<2> stride) {
65   if (input.dim() == 4) {
66     return torch::im2col(input, kernel_size, dilation, padding, stride);
67   } else {
68     TORCH_CHECK(
69         false,
70         "Input Error: Only 4D input Tensors are supported "
71         "(got ",
72         input.dim(),
73         "D)");
74   }
75 }
76 } // namespace detail
77 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
78 
79 /// See
80 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.unfold
81 /// about the exact behavior of this functional.
82 ///
83 /// See the documentation for `torch::nn::functional::UnfoldFuncOptions` class
84 /// to learn what optional arguments are supported for this functional.
85 ///
86 /// Example:
87 /// ```
88 /// namespace F = torch::nn::functional;
89 /// F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2));
90 /// ```
unfold(const Tensor & input,const UnfoldFuncOptions & options)91 inline Tensor unfold(const Tensor& input, const UnfoldFuncOptions& options) {
92   return detail::unfold(
93       input,
94       options.kernel_size(),
95       options.dilation(),
96       options.padding(),
97       options.stride());
98 }
99 
100 } // namespace functional
101 } // namespace nn
102 } // namespace torch
103