1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/TensorUtils.h> 5 #include <ATen/Utils.h> 6 #include <ATen/Parallel.h> 7 #include <ATen/native/cpu/utils.h> 8 #include <c10/util/irange.h> 9 10 #include <algorithm> 11 12 namespace at::native { 13 14 template <typename T> 15 static void im2col( 16 const T* data_im, 17 const int64_t channels, 18 const int64_t height, 19 const int64_t width, 20 const int64_t output_height, 21 const int64_t output_width, 22 const int64_t kernel_h, 23 const int64_t kernel_w, 24 const int64_t pad_h, 25 const int64_t pad_w, 26 const int64_t stride_h, 27 const int64_t stride_w, 28 const int64_t dilation_h, 29 const int64_t dilation_w, 30 T* data_col, 31 bool is_channels_last = false) { 32 const int64_t height_col = output_height; 33 const int64_t width_col = output_width; 34 const int64_t channels_col = channels * kernel_h * kernel_w; 35 36 if (is_channels_last) { 37 at::parallel_for(0, height_col * width_col, 0, [&](int64_t begin, int64_t end) { 38 int64_t h_col{0}, w_col{0}; 39 data_index_init(begin, h_col, height_col, w_col, width_col); 40 41 for (const auto i_col : c10::irange(begin, end)) { 42 for (const auto h_offset : c10::irange(kernel_h)) { 43 int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; 44 for (const auto w_offset : c10::irange(kernel_w)) { 45 int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; 46 47 const T* slice_im = data_im + (h_im * width + w_im) * channels; 48 T* slice_col = data_col + (i_col * kernel_h * kernel_w + h_offset * kernel_w + w_offset) * channels; 49 50 if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { 51 std::copy_n(slice_im, channels, slice_col); 52 } else { 53 std::fill_n(slice_col, channels, T(0)); 54 } 55 } 56 } 57 58 // move the next index 59 data_index_step(h_col, height_col, w_col, width_col); 60 } 61 }); 62 } else { 63 at::parallel_for(0, channels_col, 0, [&](int64_t begin, int64_t end) { 64 int64_t c_im{0}, h_offset{0}, w_offset{0}; 65 data_index_init(begin, c_im, channels, h_offset, kernel_h, w_offset, kernel_w); 66 67 for (const auto c_col : c10::irange(begin, end)) { 68 for (const auto h_col : c10::irange(height_col)) { 69 int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; 70 for (const auto w_col : c10::irange(width_col)) { 71 int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; 72 data_col[(c_col * height_col + h_col) * width_col + w_col] = 73 (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) 74 ? data_im[(c_im * height + h_im) * width + w_im] 75 : static_cast<T>(0); 76 } 77 } 78 79 // move to the next index 80 data_index_step(c_im, channels, h_offset, kernel_h, w_offset, kernel_w); 81 } 82 }); 83 } 84 } 85 86 template <typename T> 87 static void col2im( 88 const T* data_col, 89 const int64_t channels, 90 const int64_t height, 91 const int64_t width, 92 const int64_t output_height, 93 const int64_t output_width, 94 const int64_t kernel_h, 95 const int64_t kernel_w, 96 const int64_t pad_h, 97 const int64_t pad_w, 98 const int64_t stride_h, 99 const int64_t stride_w, 100 const int64_t dilation_h, 101 const int64_t dilation_w, 102 T* data_im, 103 bool is_channels_last = false) { 104 std::fill_n(data_im, height * width * channels, T(0)); 105 106 const int64_t height_col = output_height; 107 const int64_t width_col = output_width; 108 const int64_t channels_col = channels * kernel_h * kernel_w; 109 110 if (is_channels_last) { 111 for (const auto h_col : c10::irange(height_col)) { 112 for (const auto w_col : c10::irange(width_col)) { 113 for (const auto h_offset : c10::irange(kernel_h)) { 114 int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; 115 for (const auto w_offset : c10::irange(kernel_w)) { 116 int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; 117 118 T* slice_im = data_im + (h_im * width + w_im) * channels; 119 const T* slice_col = data_col + ((h_col * width_col + w_col) * kernel_h * kernel_w 120 + h_offset * kernel_w + w_offset) * channels; 121 122 if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) { 123 std::transform(slice_col, slice_col + channels, slice_im, slice_im, std::plus<T>()); 124 } 125 } 126 } 127 } 128 } 129 } else { 130 for (const auto c_col : c10::irange(channels_col)) { 131 int64_t w_offset = c_col % kernel_w; 132 int64_t h_offset = (c_col / kernel_w) % kernel_h; 133 int64_t c_im = c_col / kernel_h / kernel_w; 134 135 for (const auto h_col : c10::irange(height_col)) { 136 int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; 137 for (const auto w_col : c10::irange(width_col)) { 138 int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; 139 140 if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) 141 data_im[(c_im * height + h_im) * width + w_im] += 142 data_col[(c_col * height_col + h_col) * width_col + w_col]; 143 } 144 } 145 } 146 } 147 } 148 149 } // namespace at::native 150