1 #ifndef MetalConvParams_h 2 #define MetalConvParams_h 3 4 #include <c10/util/ArrayRef.h> 5 6 namespace at::native::metal { 7 8 struct Conv2DParams final { Conv2DParamsfinal9 Conv2DParams() {} 10 Conv2DParams( 11 c10::IntArrayRef inputSizes, 12 c10::IntArrayRef weightSizes, 13 c10::IntArrayRef padding, 14 c10::IntArrayRef stride, 15 c10::IntArrayRef dilation, 16 int64_t groups); 17 output_sizesfinal18 std::vector<int64_t> output_sizes() const { 19 return {N, OC, OH, OW}; 20 } 21 isDepthwisefinal22 bool isDepthwise() const { 23 // Currently, only channel multiplier of 1 is supported 24 // i.e. inputFeatureChannels == outputFeatureChannels 25 return G > 1 && IC == 1 && OC == G && OC == C; 26 } 27 28 int64_t N; // batch size 29 int64_t C; // channels 30 int64_t H; // input height 31 int64_t W; // input width 32 int64_t OC; // output channels 33 int64_t IC; // input channels 34 int64_t KH; // kernel height 35 int64_t KW; // kernel width 36 int64_t SY; // stride y (height) 37 int64_t SX; // stride x (width) 38 int64_t PY; // padding y (height) 39 int64_t PX; // padding x (width) 40 int64_t DY; // dilation y (height) 41 int64_t DX; // dilation x (width) 42 int64_t G; // groups 43 int64_t OW; // output width 44 int64_t OH; // output height 45 }; 46 47 } // namespace at::native::metal 48 49 #endif /* MetalConvParams_h */ 50