xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalTensorUtils.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalTensorUtils.h>
2
3namespace at {
4namespace native {
5namespace metal {
6
7uint32_t batchSize(const Tensor& tensor) {
8  const IntArrayRef sizes = tensor.sizes();
9  const uint32_t dims = tensor.dim();
10  if (dims < 4) {
11    return 1;
12  }
13  return sizes[dims - 4];
14}
15
16uint32_t channelsSize(const Tensor& tensor) {
17  const IntArrayRef sizes = tensor.sizes();
18  const uint32_t dims = tensor.dim();
19  if (dims < 3) {
20    return 1;
21  }
22  return sizes[dims - 3];
23}
24
25uint32_t heightSize(const Tensor& tensor) {
26  const IntArrayRef sizes = tensor.sizes();
27  const uint32_t dims = tensor.dim();
28  if (dims < 2) {
29    return 1;
30  }
31  return sizes[dims - 2];
32}
33
34uint32_t widthSize(const Tensor& tensor) {
35  const IntArrayRef sizes = tensor.sizes();
36  const uint32_t dims = tensor.dim();
37  if (dims < 1) {
38    return 1;
39  }
40  return sizes[dims - 1];
41}
42
43}
44}
45}
46