xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalReduce.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/Tensor.h>
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalTensorImpl.h>
4#import <ATen/native/metal/MetalTensorImplStorage.h>
5#import <ATen/native/metal/MetalTensorUtils.h>
6#import <ATen/native/metal/MetalContext.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9
10#include <ATen/ATen.h>
11#include <ATen/native/ReduceOpsUtils.h>
12#include <torch/library.h>
13
14namespace at::native::metal {
15
16API_AVAILABLE(ios(11.3), macos(10.13))
17static inline MPSNNReduceUnary* kernelForReducedDim(int dim) {
18  id<MTLDevice> device = [MetalContext sharedInstance].device;
19  if (dim == 3) {
20    return [[MPSNNReduceRowMean alloc] initWithDevice:device];
21  } else if (dim == 2) {
22    return [[MPSNNReduceColumnMean alloc] initWithDevice:device];
23  } else if (dim == 1) {
24    return [[MPSNNReduceFeatureChannelsMean alloc] initWithDevice:device];
25  }
26  return nil;
27}
28
29static Tensor wrapper_mean_dim(
30    const Tensor& input,
31    OptionalIntArrayRef opt_dims,
32    bool keepdim,
33    std::optional<ScalarType> dtype) {
34  if (@available(iOS 11.3, *)) {
35    MPSImage* X = imageFromTensor(input);
36    auto imageSize = input.sizes().vec();
37    TORCH_CHECK(imageSize.size() == 4);
38    // TODO: [T87340633] Support reducing the batch dimension
39    TORCH_CHECK(imageSize[0] == 1);
40    auto mask = make_dim_mask(opt_dims, input.dim());
41    MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
42    MPSImage* Y = nil;
43    if (opt_dims.has_value()) {
44      auto dims = opt_dims.value();
45      for (int dim : dims) {
46        imageSize[dim] = 1;
47        MPSNNReduceUnary* kernel = kernelForReducedDim(dim);
48        if (kernel) {
49          Y = createTemporaryImage(commandBuffer, imageSize);
50          [kernel encodeToCommandBuffer:commandBuffer.buffer
51                            sourceImage:X
52                       destinationImage:Y];
53          X = Y;
54        }
55      }
56    }
57    MetalTensorImplStorage mt{imageSize};
58    mt.texture()->setCommandBuffer(commandBuffer);
59    mt.texture()->setImage(Y);
60    auto shape = DimVector(input.sizes());
61    for (int dim = shape.size() - 1; dim >= 0; dim--) {
62      if (mask[dim]) {
63        if (keepdim) {
64          shape[dim] = 1;
65        } else {
66          shape.erase(shape.begin() + dim);
67        }
68      }
69    }
70    auto output = makeTensor(std::move(mt), input.options()).view(shape);
71    return output;
72  } else {
73    // TODO: [T87350528] Fallback to shader kernels for 10.0 users
74    TORCH_CHECK(
75        false, "MPSNNReduceUnary is only available on iOS 11.3 and above");
76  }
77}
78
79TORCH_LIBRARY_IMPL(aten, Metal, m) {
80  m.impl(TORCH_SELECTIVE_NAME("aten::mean.dim"), TORCH_FN(wrapper_mean_dim));
81};
82
83} // namespace at::native::metal
84