1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalTensorImpl.h> 3#import <ATen/native/metal/MetalTensorImplStorage.h> 4#import <ATen/native/metal/MetalTensorUtils.h> 5#import <ATen/native/metal/MetalContext.h> 6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9 10#include <ATen/Tensor.h> 11#include <ATen/native/Pool.h> 12#include <torch/library.h> 13 14namespace at::native::metal { 15 16API_AVAILABLE(ios(11.0), macos(10.13)) 17static Tensor max_pool2d( 18 const Tensor& input, 19 IntArrayRef kernel_size, 20 IntArrayRef stride, 21 IntArrayRef padding, 22 IntArrayRef dilation, 23 bool ceil_mode) { 24 TORCH_CHECK(input.is_metal()); 25 TORCH_CHECK(input.dim() == 3 || input.dim() == 4); 26 TORCH_CHECK( 27 dilation[0] == dilation[1] == 1, "dilation is not supported on MPSCNN"); 28 const int64_t iN = input.sizes()[0]; 29 const int64_t iC = input.sizes()[1]; 30 const int64_t iH = input.sizes()[2]; 31 const int64_t iW = input.sizes()[3]; 32 const int64_t kH = kernel_size[0]; 33 const int64_t kW = kernel_size[1]; 34 const int64_t sH = stride[0]; 35 const int64_t sW = stride[1]; 36 const int64_t pH = padding[0]; 37 const int64_t pW = padding[1]; 38 const int64_t dH = dilation[0]; 39 const int64_t dW = dilation[1]; 40 int64_t oN = iN; 41 int64_t oC = iC; 42 int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode); 43 int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode); 44 SmallVector<int64_t, 4>outputSize{oN, oC, oH, oW}; 45 if(input.numel() == 0){ 46 return makeTensor({IntArrayRef(outputSize).vec()}, input.options()); 47 } 48 MPSImage* X = imageFromTensor(input); 49 MPSCNNPoolingMax* pool = [[MPSCNNPoolingMax alloc] 50 initWithDevice:[MetalContext sharedInstance].device 51 kernelWidth:kernel_size[0] 52 kernelHeight:kernel_size[1] 53 strideInPixelsX:stride[0] 54 strideInPixelsY:stride[1]]; 55 [pool setEdgeMode:MPSImageEdgeModeClamp]; 56 [pool 57 setOffset:{.x = mpscnn::computeMPSAlignOffset(kernel_size[0], padding[0]), 58 .y = mpscnn::computeMPSAlignOffset(kernel_size[1], padding[1]), 59 .z = 0}]; 60 MetalTensorImplStorage mt{IntArrayRef(outputSize).vec()}; 61 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 62 mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); 63 MPSImage* Y = mt.texture()->image(); 64 [pool encodeToCommandBuffer:commandBuffer.buffer 65 sourceImage:X 66 destinationImage:Y]; 67 auto output = makeTensor(std::move(mt), input.options()); 68 return output; 69} 70 71API_AVAILABLE(ios(11.0), macos(10.13)) 72static Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) { 73 // averages across the width and height, and outputs a 1x1xC image. 74 TORCH_CHECK(output_size[0] == 1 && output_size[1] == 1); 75 TORCH_CHECK(input.is_metal()); 76 SmallVector<int64_t, 4> outputSize{ 77 input.sizes()[0], input.sizes()[1], output_size[0], output_size[1]}; 78 if(input.numel() == 0){ 79 return makeTensor({IntArrayRef(outputSize).vec()}, input.options()); 80 } 81 MPSImage* X = imageFromTensor(input); 82 MPSCNNPoolingAverage* pool = [[MPSCNNPoolingAverage alloc] 83 initWithDevice:[MetalContext sharedInstance].device 84 kernelWidth:X.width 85 kernelHeight:X.height 86 strideInPixelsX:X.width 87 strideInPixelsY:X.height]; 88 [pool setEdgeMode:MPSImageEdgeModeClamp]; 89 [pool setOffset:{.x = static_cast<NSInteger>(X.width / 2), 90 .y = static_cast<NSInteger>(X.height / 2), 91 .z = 0}]; 92 93 MetalTensorImplStorage mt{IntArrayRef(outputSize).vec()}; 94 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 95 mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); 96 MPSImage* Y = mt.texture()->image(); 97 [pool encodeToCommandBuffer:commandBuffer.buffer 98 sourceImage:X 99 destinationImage:Y]; 100 auto output = makeTensor(std::move(mt), input.options()); 101 return output; 102} 103 104TORCH_LIBRARY_IMPL(aten, Metal, m) { 105 m.impl(TORCH_SELECTIVE_NAME("aten::max_pool2d"), TORCH_FN(max_pool2d)); 106 m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_avg_pool2d"), TORCH_FN(adaptive_avg_pool2d)); 107} 108 109} // namespace at::native::metal 110