1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalTensorImpl.h> 3#import <ATen/native/metal/MetalTensorImplStorage.h> 4#import <ATen/native/metal/MetalTensorUtils.h> 5#import <ATen/native/metal/MetalContext.h> 6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9 10#include <torch/library.h> 11 12namespace at::native::metal { 13 14API_AVAILABLE(ios(11.0), macos(10.13)) 15static Tensor reflection_pad2d(const Tensor& input, IntArrayRef padding) { 16 TORCH_CHECK(input.is_metal()); 17 18 const int pad_dim = padding.size(); 19 const IntArrayRef input_size = input.sizes(); 20 const int input_dim = input_size.size(); 21 22 TORCH_CHECK(pad_dim == 1 || pad_dim == 4, "Padding sizes must be a 1-tuple or 4-tuple!"); 23 TORCH_CHECK(input_dim >= 2, "Input tensor must have dim >= 2!"); 24 25 NSUInteger pad_left = padding[0]; 26 NSUInteger pad_right = padding[0]; 27 NSUInteger pad_top = padding[0]; 28 NSUInteger pad_bottom = padding[0]; 29 if (pad_dim == 4) { 30 pad_right = padding[1]; 31 pad_top = padding[2]; 32 pad_bottom = padding[3]; 33 } 34 35 std::vector<int64_t> output_size(input_dim); 36 for (int d = 0; d < input_dim; ++d) { 37 if (d == input_dim - 1) { 38 output_size[d] = input_size[d] + pad_right + pad_left; 39 } 40 else if (d == input_dim - 2) { 41 output_size[d] = input_size[d] + pad_top + pad_bottom; 42 } 43 else { 44 output_size[d] = input_size[d]; 45 } 46 } 47 48 MPSImage* X = imageFromTensor(input); 49 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 50 MetalTensorImplStorage mt{output_size}; 51 mt.texture()->allocateTemporaryStorage(output_size, commandBuffer); 52 MPSImage* Y = mt.texture()->image(); 53 54 id<MTLComputeCommandEncoder> encoder = 55 [commandBuffer.buffer computeCommandEncoder]; 56 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 57 specializedPipelineState:"reflection_pad2d" 58 Constants:@[ 59 @(Y.height), 60 @(Y.width), 61 @(Y.featureChannels), 62 @(Y.numberOfImages), 63 @(X.height), 64 @(X.width), 65 @(X.featureChannels), 66 @(X.numberOfImages), 67 @(pad_left), 68 @(pad_right), 69 @(pad_top), 70 @(pad_bottom) 71 ]]; 72 73 [encoder setComputePipelineState:state]; 74 [encoder setTexture:[X texture] atIndex:0]; 75 [encoder setTexture:[Y texture] atIndex:1]; 76 77 const auto& launchParams = 78 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, Y); 79 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 80 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 81 [encoder endEncoding]; 82 auto output = makeTensor(std::move(mt), input.options()); 83 return output; 84} 85 86TORCH_LIBRARY_IMPL(aten, Metal, m) { 87 m.impl(TORCH_SELECTIVE_NAME("aten::reflection_pad2d"), TORCH_FN(reflection_pad2d)); 88} 89 90} // namespace at::native::metal 91