xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalConvParams.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalConvParams.h>
2
3#include <cmath>
4
5namespace at {
6namespace native {
7namespace metal {
8
9Conv2DParams::Conv2DParams(
10    c10::IntArrayRef inputSizes,
11    c10::IntArrayRef weightSizes,
12    c10::IntArrayRef padding,
13    c10::IntArrayRef stride,
14    c10::IntArrayRef dilation,
15    int64_t groups)
16    : N(inputSizes[0]),
17      C(inputSizes[1]),
18      H(inputSizes[2]),
19      W(inputSizes[3]),
20      OC(weightSizes[0]),
21      IC(weightSizes[1]),
22      KH(weightSizes[2]),
23      KW(weightSizes[3]),
24      SY(stride[0]),
25      SX(stride[1]),
26      PY(padding[0]),
27      PX(padding[1]),
28      DY(dilation[0]),
29      DX(dilation[1]),
30      G(groups) {
31  OH = std::floor((H + 2 * PY - DY * (KH - 1) - 1) / SY + 1);
32  OW = std::floor((W + 2 * PX - DX * (KW - 1) - 1) / SX + 1);
33};
34
35}
36}
37}
38