1#import <ATen/native/metal/MetalTensorUtils.h> 2 3namespace at { 4namespace native { 5namespace metal { 6 7uint32_t batchSize(const Tensor& tensor) { 8 const IntArrayRef sizes = tensor.sizes(); 9 const uint32_t dims = tensor.dim(); 10 if (dims < 4) { 11 return 1; 12 } 13 return sizes[dims - 4]; 14} 15 16uint32_t channelsSize(const Tensor& tensor) { 17 const IntArrayRef sizes = tensor.sizes(); 18 const uint32_t dims = tensor.dim(); 19 if (dims < 3) { 20 return 1; 21 } 22 return sizes[dims - 3]; 23} 24 25uint32_t heightSize(const Tensor& tensor) { 26 const IntArrayRef sizes = tensor.sizes(); 27 const uint32_t dims = tensor.dim(); 28 if (dims < 2) { 29 return 1; 30 } 31 return sizes[dims - 2]; 32} 33 34uint32_t widthSize(const Tensor& tensor) { 35 const IntArrayRef sizes = tensor.sizes(); 36 const uint32_t dims = tensor.dim(); 37 if (dims < 1) { 38 return 1; 39 } 40 return sizes[dims - 1]; 41} 42 43} 44} 45} 46