1 #pragma once 2 #include <ATen/core/List.h> 3 #include <ATen/native/ConvUtils.h> 4 5 namespace at::native::quantized { 6 namespace { 7 // MakeConvOutputShape used from both CPU and CUDA libraries 8 // and exporting symbol from torch_cpu would probably take more storage 9 // than duplicating implementation which likely be inlined away 10 template <int kSpatialDim> 11 at::SmallVector<int64_t, kSpatialDim + 2> MakeConvOutputShape( 12 int N, // mini-batch 13 int M, // output channels 14 const std::array<int64_t, kSpatialDim>& input_image_shape, 15 const std::vector<int64_t>& kernel, 16 const torch::List<int64_t>& stride, 17 const torch::List<int64_t>& padding, 18 const torch::List<int64_t>& dilation); 19 20 #if defined(USE_CUDA) || defined(USE_PYTORCH_QNNPACK) 21 template <> 22 at::SmallVector<int64_t, 4> MakeConvOutputShape<2>( 23 int N, // mini-batch 24 int M, // output channels 25 const std::array<int64_t, 2>& input_image_shape, 26 const std::vector<int64_t>& kernel, 27 const at::List<int64_t>& stride, 28 const at::List<int64_t>& padding, 29 const at::List<int64_t>& dilation) { 30 const int H = input_image_shape[0]; 31 const int W = input_image_shape[1]; 32 const int64_t Y_H = 33 (H + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1; 34 const int64_t Y_W = 35 (W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1; 36 return {N, M, Y_H, Y_W}; 37 } 38 39 template <> 40 at::SmallVector<int64_t, 5> MakeConvOutputShape<3>( 41 int N, // mini-batch 42 int M, // output channels 43 const std::array<int64_t, 3>& input_image_shape, 44 const std::vector<int64_t>& kernel, 45 const at::List<int64_t>& stride, 46 const at::List<int64_t>& padding, 47 const torch::List<int64_t>& dilation) { 48 const int D = input_image_shape[0]; 49 const int H = input_image_shape[1]; 50 const int W = input_image_shape[2]; 51 const int64_t Y_D = 52 (D + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1; 53 const int64_t Y_H = 54 (H + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1; 55 const int64_t Y_W = 56 (W + 2 * padding[2] - dilation[2] * (kernel[2] - 1) - 1) / stride[2] + 1; 57 return {N, M, Y_D, Y_H, Y_W}; 58 } 59 60 #endif 61 } // anonymous namespace 62 } // namespace at::native::quantized 63