#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #endif namespace at::native { namespace { static void im2col_out_cpu_template( Tensor& output, const Tensor& input_, IntArrayRef kernel_size, IntArrayRef dilation, IntArrayRef padding, IntArrayRef stride) { TORCH_CHECK( kernel_size.size() == 2, "It is expected kernel_size equals to 2, but got size ", kernel_size.size()); TORCH_CHECK( dilation.size() == 2, "It is expected dilation equals to 2, but got size ", dilation.size()); TORCH_CHECK( padding.size() == 2, "It is expected padding equals to 2, but got size ", padding.size()); TORCH_CHECK( stride.size() == 2, "It is expected stride equals to 2, but got size ", stride.size()); int64_t kernel_height = kernel_size[0]; int64_t kernel_width = kernel_size[1]; int64_t dilation_height = dilation[0]; int64_t dilation_width = dilation[1]; int64_t pad_height = padding[0]; int64_t pad_width = padding[1]; int64_t stride_height = stride[0]; int64_t stride_width = stride[1]; im2col_shape_check( input_, Tensor(), kernel_height, kernel_width, dilation_height, dilation_width, pad_height, pad_width, stride_height, stride_width); Tensor input = input_.contiguous(); bool batched_input = true; if (input.dim() == 3) { batched_input = false; input = input.view({1, input.size(0), input.size(1), input.size(2)}); } int64_t batch_size = input.size(0); int64_t n_input_plane = input.size(1); int64_t input_height = input.size(2); int64_t input_width = input.size(3); int64_t output_height = (input_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / stride_height + 1; int64_t output_width = (input_width + 2 * pad_width - (dilation_width * (kernel_width - 1) + 1)) / stride_width + 1; int64_t n_output_plane = n_input_plane * kernel_width * kernel_height; int64_t output_length = output_height * output_width; output.resize_({batch_size, n_output_plane, output_length}); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, input.scalar_type(), "im2col_out_cpu", [&] { Tensor input_n; Tensor output_n; for (const auto elt : c10::irange(batch_size)) { input_n = input.select(0, elt); output_n = output.select(0, elt); im2col( input_n.const_data_ptr(), n_input_plane, input_height, input_width, output_height, output_width, kernel_height, kernel_width, pad_height, pad_width, stride_height, stride_width, dilation_height, dilation_width, output_n.mutable_data_ptr()); } if (!batched_input) { output.resize_({n_output_plane, output_length}); } }); } } // namespace Tensor& im2col_out_cpu(const Tensor& input, IntArrayRef kernel_size, IntArrayRef dilation, IntArrayRef padding, IntArrayRef stride, Tensor& output) { im2col_out_cpu_template( output, input, kernel_size, dilation, padding, stride); return output; } Tensor im2col_cpu( const Tensor& input, IntArrayRef kernel_size, IntArrayRef dilation, IntArrayRef padding, IntArrayRef stride) { Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); im2col_out_cpu_template( output, input, kernel_size, dilation, padding, stride); return output; } } // namespace at::native