xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalConvParams.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef MetalConvParams_h
2 #define MetalConvParams_h
3 
4 #include <c10/util/ArrayRef.h>
5 
6 namespace at::native::metal {
7 
8 struct Conv2DParams final {
Conv2DParamsfinal9   Conv2DParams() {}
10   Conv2DParams(
11       c10::IntArrayRef inputSizes,
12       c10::IntArrayRef weightSizes,
13       c10::IntArrayRef padding,
14       c10::IntArrayRef stride,
15       c10::IntArrayRef dilation,
16       int64_t groups);
17 
output_sizesfinal18   std::vector<int64_t> output_sizes() const {
19     return {N, OC, OH, OW};
20   }
21 
isDepthwisefinal22   bool isDepthwise() const {
23     // Currently, only channel multiplier of 1 is supported
24     // i.e. inputFeatureChannels == outputFeatureChannels
25     return G > 1 && IC == 1 && OC == G && OC == C;
26   }
27 
28   int64_t N; // batch size
29   int64_t C; // channels
30   int64_t H; // input height
31   int64_t W; // input width
32   int64_t OC; // output channels
33   int64_t IC; // input channels
34   int64_t KH; // kernel height
35   int64_t KW; // kernel width
36   int64_t SY; // stride y (height)
37   int64_t SX; // stride x (width)
38   int64_t PY; // padding y (height)
39   int64_t PX; // padding x (width)
40   int64_t DY; // dilation y (height)
41   int64_t DX; // dilation x (width)
42   int64_t G; // groups
43   int64_t OW; // output width
44   int64_t OH; // output height
45 };
46 
47 } // namespace at::native::metal
48 
49 #endif /* MetalConvParams_h */
50