xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalTranspose.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalTensorImpl.h>
3#import <ATen/native/metal/MetalTensorImplStorage.h>
4#import <ATen/native/metal/MetalTensorUtils.h>
5#import <ATen/native/metal/MetalContext.h>
6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9
10#include <ATen/ATen.h>
11#include <torch/library.h>
12
13namespace at::native::metal {
14
15// TODO: Move this function to MetalContext
16template<typename T>
17id<MTLBuffer> _makeMTLBuffer(const std::vector<T>& src) {
18    id<MTLBuffer> buffer = [[MetalContext sharedInstance].device
19          newBufferWithLength:src.size() * sizeof(T)
20                      options:MTLResourceCPUCacheModeWriteCombined];
21    memcpy(buffer.contents, src.data(), src.size() * sizeof(T));
22    return buffer;
23}
24
25static Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) {
26  TORCH_CHECK(input.is_metal());
27  auto ndims = input.dim();
28  // Support maximum eight channels on mobile
29  TORCH_CHECK(ndims <= 8);
30  dim0 = maybe_wrap_dim(dim0, ndims);
31  dim1 = maybe_wrap_dim(dim1, ndims);
32  if (dim0 == dim1) {
33    return input;
34  }
35  auto outputSizes = input.sizes().vec();
36  std::swap(outputSizes[dim0], outputSizes[dim1]);
37  MPSImage* X = imageFromTensor(input);
38  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
39  if (input.dim() == 2) {
40    MetalTensorImplStorage mt{outputSizes};
41    mt.texture()->allocateTemporaryStorage(outputSizes, commandBuffer);
42    MPSImage* Y = mt.texture()->image();
43    MPSImageTranspose* transpose = [[MPSImageTranspose alloc]
44        initWithDevice:[MetalContext sharedInstance].device];
45    [transpose encodeToCommandBuffer:commandBuffer.buffer
46                         sourceImage:X
47                    destinationImage:Y];
48    auto output = makeTensor(std::move(mt), input.options());
49    return output;
50  } else {
51    id<MTLBuffer> sizeBuf1 = _makeMTLBuffer<ushort>(
52        std::vector<ushort>{input.sizes().begin(), input.sizes().end()});
53    id<MTLBuffer> sizeBuf2 = _makeMTLBuffer<ushort>(
54        std::vector<ushort>{outputSizes.begin(), outputSizes.end()});
55    MetalTensorImplStorage mt{outputSizes};
56    mt.texture()->allocateTemporaryStorage(outputSizes, commandBuffer);
57    MPSImage* Y = mt.texture()->image();
58    id<MTLComputeCommandEncoder> encoder =
59        [commandBuffer.buffer computeCommandEncoder];
60    id<MTLComputePipelineState> state =
61        [[MetalContext sharedInstance] specializedPipelineState:"transpose"
62                                                       Constants:@[
63                                                         @(dim0),
64                                                         @(dim1),
65                                                         @(input.dim()),
66                                                         @(X.numberOfImages),
67                                                         @(X.featureChannels),
68                                                         @(Y.numberOfImages),
69                                                         @(Y.featureChannels),
70                                                       ]];
71
72    [encoder setComputePipelineState:state];
73    [encoder setTexture:[X texture] atIndex:0];
74    [encoder setTexture:[Y texture] atIndex:1];
75    [encoder setBuffer:sizeBuf1 offset:0 atIndex:0];
76    [encoder setBuffer:sizeBuf2 offset:0 atIndex:1];
77
78    const auto& launchParams =
79        mpscnn::spatialPointwiseKernelLaunchParams(state, Y);
80    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
81            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
82    [encoder endEncoding];
83    auto output = makeTensor(std::move(mt), input.options());
84    return output;
85  }
86}
87
88static Tensor t(const Tensor& input) {
89  TORCH_CHECK(input.is_metal());
90  TORCH_CHECK(input.dim() == 2);
91  return metal::transpose(input, 0, input.dim() < 2 ? 0 : 1);
92}
93
94TORCH_LIBRARY_IMPL(aten, Metal, m) {
95  m.impl(TORCH_SELECTIVE_NAME("aten::t"), TORCH_FN(t));
96  m.impl(TORCH_SELECTIVE_NAME("aten::transpose.int"), TORCH_FN(transpose));
97};
98
99} // namespace at::native::metal
100