xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalPooling.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalTensorImpl.h>
3#import <ATen/native/metal/MetalTensorImplStorage.h>
4#import <ATen/native/metal/MetalTensorUtils.h>
5#import <ATen/native/metal/MetalContext.h>
6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9
10#include <ATen/Tensor.h>
11#include <ATen/native/Pool.h>
12#include <torch/library.h>
13
14namespace at::native::metal {
15
16API_AVAILABLE(ios(11.0), macos(10.13))
17static Tensor max_pool2d(
18    const Tensor& input,
19    IntArrayRef kernel_size,
20    IntArrayRef stride,
21    IntArrayRef padding,
22    IntArrayRef dilation,
23    bool ceil_mode) {
24  TORCH_CHECK(input.is_metal());
25  TORCH_CHECK(input.dim() == 3 || input.dim() == 4);
26  TORCH_CHECK(
27      dilation[0] == dilation[1] == 1, "dilation is not supported on MPSCNN");
28  const int64_t iN = input.sizes()[0];
29  const int64_t iC = input.sizes()[1];
30  const int64_t iH = input.sizes()[2];
31  const int64_t iW = input.sizes()[3];
32  const int64_t kH = kernel_size[0];
33  const int64_t kW = kernel_size[1];
34  const int64_t sH = stride[0];
35  const int64_t sW = stride[1];
36  const int64_t pH = padding[0];
37  const int64_t pW = padding[1];
38  const int64_t dH = dilation[0];
39  const int64_t dW = dilation[1];
40  int64_t oN = iN;
41  int64_t oC = iC;
42  int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode);
43  int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode);
44  SmallVector<int64_t, 4>outputSize{oN, oC, oH, oW};
45  if(input.numel() == 0){
46    return makeTensor({IntArrayRef(outputSize).vec()}, input.options());
47  }
48  MPSImage* X = imageFromTensor(input);
49  MPSCNNPoolingMax* pool = [[MPSCNNPoolingMax alloc]
50       initWithDevice:[MetalContext sharedInstance].device
51          kernelWidth:kernel_size[0]
52         kernelHeight:kernel_size[1]
53      strideInPixelsX:stride[0]
54      strideInPixelsY:stride[1]];
55  [pool setEdgeMode:MPSImageEdgeModeClamp];
56  [pool
57      setOffset:{.x = mpscnn::computeMPSAlignOffset(kernel_size[0], padding[0]),
58                 .y = mpscnn::computeMPSAlignOffset(kernel_size[1], padding[1]),
59                 .z = 0}];
60  MetalTensorImplStorage mt{IntArrayRef(outputSize).vec()};
61  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
62  mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
63  MPSImage* Y = mt.texture()->image();
64  [pool encodeToCommandBuffer:commandBuffer.buffer
65                  sourceImage:X
66             destinationImage:Y];
67  auto output = makeTensor(std::move(mt), input.options());
68  return output;
69}
70
71API_AVAILABLE(ios(11.0), macos(10.13))
72static Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) {
73  // averages across the width and height, and outputs a 1x1xC image.
74  TORCH_CHECK(output_size[0] == 1 && output_size[1] == 1);
75  TORCH_CHECK(input.is_metal());
76  SmallVector<int64_t, 4> outputSize{
77      input.sizes()[0], input.sizes()[1], output_size[0], output_size[1]};
78  if(input.numel() == 0){
79      return makeTensor({IntArrayRef(outputSize).vec()}, input.options());
80  }
81  MPSImage* X = imageFromTensor(input);
82  MPSCNNPoolingAverage* pool = [[MPSCNNPoolingAverage alloc]
83       initWithDevice:[MetalContext sharedInstance].device
84          kernelWidth:X.width
85         kernelHeight:X.height
86      strideInPixelsX:X.width
87      strideInPixelsY:X.height];
88  [pool setEdgeMode:MPSImageEdgeModeClamp];
89  [pool setOffset:{.x = static_cast<NSInteger>(X.width / 2),
90                   .y = static_cast<NSInteger>(X.height / 2),
91                   .z = 0}];
92
93  MetalTensorImplStorage mt{IntArrayRef(outputSize).vec()};
94  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
95  mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
96  MPSImage* Y = mt.texture()->image();
97  [pool encodeToCommandBuffer:commandBuffer.buffer
98                  sourceImage:X
99             destinationImage:Y];
100  auto output = makeTensor(std::move(mt), input.options());
101  return output;
102}
103
104TORCH_LIBRARY_IMPL(aten, Metal, m) {
105  m.impl(TORCH_SELECTIVE_NAME("aten::max_pool2d"), TORCH_FN(max_pool2d));
106  m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_avg_pool2d"), TORCH_FN(adaptive_avg_pool2d));
107}
108
109} // namespace at::native::metal
110