1#import <ATen/native/metal/MetalConvParams.h> 2 3#include <cmath> 4 5namespace at { 6namespace native { 7namespace metal { 8 9Conv2DParams::Conv2DParams( 10 c10::IntArrayRef inputSizes, 11 c10::IntArrayRef weightSizes, 12 c10::IntArrayRef padding, 13 c10::IntArrayRef stride, 14 c10::IntArrayRef dilation, 15 int64_t groups) 16 : N(inputSizes[0]), 17 C(inputSizes[1]), 18 H(inputSizes[2]), 19 W(inputSizes[3]), 20 OC(weightSizes[0]), 21 IC(weightSizes[1]), 22 KH(weightSizes[2]), 23 KW(weightSizes[3]), 24 SY(stride[0]), 25 SX(stride[1]), 26 PY(padding[0]), 27 PX(padding[1]), 28 DY(dilation[0]), 29 DX(dilation[1]), 30 G(groups) { 31 OH = std::floor((H + 2 * PY - DY * (KH - 1) - 1) / SY + 1); 32 OW = std::floor((W + 2 * PX - DX * (KW - 1) - 1) / SX + 1); 33}; 34 35} 36} 37} 38