xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalPadding.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalTensorImpl.h>
3#import <ATen/native/metal/MetalTensorImplStorage.h>
4#import <ATen/native/metal/MetalTensorUtils.h>
5#import <ATen/native/metal/MetalContext.h>
6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9
10#include <torch/library.h>
11
12namespace at::native::metal {
13
14API_AVAILABLE(ios(11.0), macos(10.13))
15static Tensor reflection_pad2d(const Tensor& input, IntArrayRef padding) {
16  TORCH_CHECK(input.is_metal());
17
18  const int pad_dim = padding.size();
19  const IntArrayRef input_size = input.sizes();
20  const int input_dim = input_size.size();
21
22  TORCH_CHECK(pad_dim == 1 || pad_dim == 4, "Padding sizes must be a 1-tuple or 4-tuple!");
23  TORCH_CHECK(input_dim >= 2, "Input tensor must have dim >= 2!");
24
25  NSUInteger pad_left = padding[0];
26  NSUInteger pad_right = padding[0];
27  NSUInteger pad_top = padding[0];
28  NSUInteger pad_bottom = padding[0];
29  if (pad_dim == 4) {
30    pad_right = padding[1];
31    pad_top = padding[2];
32    pad_bottom = padding[3];
33  }
34
35  std::vector<int64_t> output_size(input_dim);
36  for (int d = 0; d < input_dim; ++d) {
37    if (d == input_dim - 1) {
38      output_size[d] = input_size[d] + pad_right + pad_left;
39    }
40    else if (d == input_dim - 2) {
41      output_size[d] = input_size[d] + pad_top + pad_bottom;
42    }
43    else {
44      output_size[d] = input_size[d];
45    }
46  }
47
48  MPSImage* X = imageFromTensor(input);
49  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
50  MetalTensorImplStorage mt{output_size};
51  mt.texture()->allocateTemporaryStorage(output_size, commandBuffer);
52  MPSImage* Y = mt.texture()->image();
53
54  id<MTLComputeCommandEncoder> encoder =
55      [commandBuffer.buffer computeCommandEncoder];
56  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
57      specializedPipelineState:"reflection_pad2d"
58                     Constants:@[
59                       @(Y.height),
60                       @(Y.width),
61                       @(Y.featureChannels),
62                       @(Y.numberOfImages),
63                       @(X.height),
64                       @(X.width),
65                       @(X.featureChannels),
66                       @(X.numberOfImages),
67                       @(pad_left),
68                       @(pad_right),
69                       @(pad_top),
70                       @(pad_bottom)
71                     ]];
72
73  [encoder setComputePipelineState:state];
74  [encoder setTexture:[X texture] atIndex:0];
75  [encoder setTexture:[Y texture] atIndex:1];
76
77  const auto& launchParams =
78      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, Y);
79  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
80          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
81  [encoder endEncoding];
82  auto output = makeTensor(std::move(mt), input.options());
83  return output;
84}
85
86TORCH_LIBRARY_IMPL(aten, Metal, m) {
87  m.impl(TORCH_SELECTIVE_NAME("aten::reflection_pad2d"), TORCH_FN(reflection_pad2d));
88}
89
90} // namespace at::native::metal
91