1#include <ATen/Tensor.h> 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalTensorImpl.h> 4#import <ATen/native/metal/MetalTensorImplStorage.h> 5#import <ATen/native/metal/MetalTensorUtils.h> 6#import <ATen/native/metal/MetalContext.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9 10#include <ATen/ATen.h> 11#include <ATen/native/ReduceOpsUtils.h> 12#include <torch/library.h> 13 14namespace at::native::metal { 15 16API_AVAILABLE(ios(11.3), macos(10.13)) 17static inline MPSNNReduceUnary* kernelForReducedDim(int dim) { 18 id<MTLDevice> device = [MetalContext sharedInstance].device; 19 if (dim == 3) { 20 return [[MPSNNReduceRowMean alloc] initWithDevice:device]; 21 } else if (dim == 2) { 22 return [[MPSNNReduceColumnMean alloc] initWithDevice:device]; 23 } else if (dim == 1) { 24 return [[MPSNNReduceFeatureChannelsMean alloc] initWithDevice:device]; 25 } 26 return nil; 27} 28 29static Tensor wrapper_mean_dim( 30 const Tensor& input, 31 OptionalIntArrayRef opt_dims, 32 bool keepdim, 33 std::optional<ScalarType> dtype) { 34 if (@available(iOS 11.3, *)) { 35 MPSImage* X = imageFromTensor(input); 36 auto imageSize = input.sizes().vec(); 37 TORCH_CHECK(imageSize.size() == 4); 38 // TODO: [T87340633] Support reducing the batch dimension 39 TORCH_CHECK(imageSize[0] == 1); 40 auto mask = make_dim_mask(opt_dims, input.dim()); 41 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 42 MPSImage* Y = nil; 43 if (opt_dims.has_value()) { 44 auto dims = opt_dims.value(); 45 for (int dim : dims) { 46 imageSize[dim] = 1; 47 MPSNNReduceUnary* kernel = kernelForReducedDim(dim); 48 if (kernel) { 49 Y = createTemporaryImage(commandBuffer, imageSize); 50 [kernel encodeToCommandBuffer:commandBuffer.buffer 51 sourceImage:X 52 destinationImage:Y]; 53 X = Y; 54 } 55 } 56 } 57 MetalTensorImplStorage mt{imageSize}; 58 mt.texture()->setCommandBuffer(commandBuffer); 59 mt.texture()->setImage(Y); 60 auto shape = DimVector(input.sizes()); 61 for (int dim = shape.size() - 1; dim >= 0; dim--) { 62 if (mask[dim]) { 63 if (keepdim) { 64 shape[dim] = 1; 65 } else { 66 shape.erase(shape.begin() + dim); 67 } 68 } 69 } 70 auto output = makeTensor(std::move(mt), input.options()).view(shape); 71 return output; 72 } else { 73 // TODO: [T87350528] Fallback to shader kernels for 10.0 users 74 TORCH_CHECK( 75 false, "MPSNNReduceUnary is only available on iOS 11.3 and above"); 76 } 77} 78 79TORCH_LIBRARY_IMPL(aten, Metal, m) { 80 m.impl(TORCH_SELECTIVE_NAME("aten::mean.dim"), TORCH_FN(wrapper_mean_dim)); 81}; 82 83} // namespace at::native::metal 84