xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vol2col.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstring>
4 
5 namespace at::native {
6 
7 template <typename T>
vol2col(const T * data_vol,const int64_t channels,const int64_t depth,const int64_t height,const int64_t width,const int64_t depth_col,const int64_t height_col,const int64_t width_col,const int64_t kT,const int64_t kernel_height,const int64_t kernel_width,const int64_t pT,const int64_t pH,const int64_t pW,const int64_t dT,const int64_t dH,const int64_t dW,const int64_t dilationT,const int64_t dilationH,const int64_t dilationW,T * data_col)8 void vol2col(
9     const T* data_vol,
10     const int64_t channels,
11     const int64_t depth,
12     const int64_t height,
13     const int64_t width,
14     const int64_t depth_col,
15     const int64_t height_col,
16     const int64_t width_col,
17     const int64_t kT,
18     const int64_t kernel_height,
19     const int64_t kernel_width,
20     const int64_t pT,
21     const int64_t pH,
22     const int64_t pW,
23     const int64_t dT,
24     const int64_t dH,
25     const int64_t dW,
26     const int64_t dilationT,
27     const int64_t dilationH,
28     const int64_t dilationW,
29     T* data_col) {
30   int64_t c, t, h, w;
31   int64_t channels_col = channels * kT * kernel_height * kernel_width;
32   for (c = 0; c < channels_col; ++c) {
33     int64_t w_offset = c % kernel_width;
34     int64_t h_offset = (c / kernel_width) % kernel_height;
35     int64_t t_offset = (c / kernel_width / kernel_height) % kT;
36     int64_t c_vol = c / kT / kernel_height / kernel_width;
37     for (t = 0; t < depth_col; ++t) {
38       int64_t t_pad = t * dT - pT + t_offset * dilationT;
39       for (h = 0; h < height_col; ++h) {
40         int64_t h_pad = h * dH - pH + h_offset * dilationH;
41         for (w = 0; w < width_col; ++w) {
42           int64_t w_pad = w * dW - pW + w_offset * dilationW;
43           if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height &&
44               w_pad >= 0 && w_pad < width)
45             data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
46                 data_vol
47                     [((c_vol * depth + t_pad) * height + h_pad) * width +
48                      w_pad];
49           else
50             data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
51                 0;
52         }
53       }
54     }
55   }
56 }
57 
58 template <typename T>
col2vol(const T * data_col,const int64_t channels,const int64_t depth,const int64_t height,const int64_t width,const int64_t out_depth,const int64_t out_height,const int64_t out_width,const int64_t kT,const int64_t kernel_height,const int64_t kernel_width,const int64_t pT,const int64_t pH,const int64_t pW,const int64_t dT,const int64_t dH,const int64_t dW,const int64_t dilationT,const int64_t dilationH,const int64_t dilationW,T * data_vol)59 void col2vol(
60     const T* data_col,
61     const int64_t channels,
62     const int64_t depth,
63     const int64_t height,
64     const int64_t width,
65     const int64_t out_depth,
66     const int64_t out_height,
67     const int64_t out_width,
68     const int64_t kT,
69     const int64_t kernel_height,
70     const int64_t kernel_width,
71     const int64_t pT,
72     const int64_t pH,
73     const int64_t pW,
74     const int64_t dT,
75     const int64_t dH,
76     const int64_t dW,
77     const int64_t dilationT,
78     const int64_t dilationH,
79     const int64_t dilationW,
80     T* data_vol) {
81   memset(data_vol, 0, sizeof(T) * depth * height * width * channels);
82   int64_t depth_col = out_depth;
83   int64_t height_col = out_height;
84   int64_t width_col = out_width;
85   int64_t channels_col = channels * kT * kernel_height * kernel_width;
86   for (int64_t c = 0; c < channels_col; ++c) {
87     int64_t w_offset = c % kernel_width;
88     int64_t h_offset = (c / kernel_width) % kernel_height;
89     int64_t t_offset = (c / kernel_width / kernel_height) % kT;
90     int64_t c_vol = c / kT / kernel_height / kernel_width;
91     for (int64_t t = 0; t < depth_col; ++t) {
92       int64_t t_pad = t * dT - pT + t_offset * dilationT;
93       for (int64_t h = 0; h < height_col; ++h) {
94         int64_t h_pad = h * dH - pH + h_offset * dilationH;
95         for (int64_t w = 0; w < width_col; ++w) {
96           int64_t w_pad = w * dW - pW + w_offset * dilationW;
97           if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height &&
98               w_pad >= 0 && w_pad < width)
99             data_vol
100                 [((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] +=
101                 data_col
102                     [((c * depth_col + t) * height_col + h) * width_col + w];
103         }
104       }
105     }
106   }
107 }
108 
109 } // namespace at::native
110