xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalConcat.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalTensorImpl.h>
4#import <ATen/native/metal/MetalTensorImplStorage.h>
5#import <ATen/native/metal/MetalTensorUtils.h>
6#import <ATen/native/metal/MetalContext.h>
7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
9#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
10
11#include <ATen/Tensor.h>
12#include <ATen/native/UpSample.h>
13#include <torch/library.h>
14
15namespace at::native::metal {
16
17static Tensor cat_batch(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) {
18  MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor);
19  MPSImage* Y = mt.texture()->image();
20  ushort cat_dim4_pointer = 0;
21  for (const auto& t : tensors) {
22    MPSImage* X = imageFromTensor(t);
23    MetalCommandBuffer* Xcb = getCommandBuffer(t);
24    TORCH_CHECK(
25        [commandBuffer isEqual:Xcb],
26        @"inputs have different Metal command buffers");
27    id<MTLComputeCommandEncoder> encoder =
28        [commandBuffer.buffer computeCommandEncoder];
29    id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
30        pipelineState:mpscnn::kernelFor(
31                          X, "copy_offset", "copy_offset_nonarray")];
32    id<MTLBuffer> offsetBuffer = [[MetalContext sharedInstance].device
33        newBufferWithLength:1 * sizeof(ushort)
34                    options:MTLResourceCPUCacheModeWriteCombined];
35    ushort* offsetBufferPtr = (ushort*)[offsetBuffer contents];
36    offsetBufferPtr[0] = cat_dim4_pointer;
37
38    [encoder setComputePipelineState:state];
39    [encoder setTexture:[X texture] atIndex:0];
40    [encoder setTexture:[Y texture] atIndex:1];
41    [encoder setBuffer:offsetBuffer offset:0 atIndex:0];
42
43    const auto& launchParams =
44        mpscnn::spatialPointwiseKernelLaunchParams(state, X);
45    [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
46            threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
47    [encoder endEncoding];
48    cat_dim4_pointer += t.size(0) * ((t.size(1) + 3) / 4);
49  }
50  auto output = makeTensor(std::move(mt), tensor.options());
51  return output;
52}
53
54static Tensor cat_feature(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) {
55  MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor);
56  MPSImage* Y = mt.texture()->image();
57  ushort channel_offset = 0;
58
59  auto temp_size = tensor.sizes().vec();
60  temp_size[1] = 4;
61  MetalTensorImplStorage tt{temp_size};
62  tt.texture()->setCommandBuffer(commandBuffer);
63  tt.texture()->allocateTemporaryStorage(temp_size, commandBuffer);
64  MPSImage* T = tt.texture()->image();
65
66  for (const auto& t : tensors) {
67    MPSImage* X = imageFromTensor(t);
68    MetalCommandBuffer* Xcb = getCommandBuffer(t);
69    TORCH_CHECK(
70        [commandBuffer isEqual:Xcb],
71        @"inputs have different Metal command buffers");
72    ushort tex_offset = channel_offset % 4;
73    std::string kernelString = tex_offset == 0 ? "append_features" : "append_features_off";
74
75    {
76      id<MTLComputeCommandEncoder> encoder =
77          [commandBuffer.buffer computeCommandEncoder];
78      id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
79          specializedPipelineState:kernelString
80                         Constants:@[
81                           @(T.height),
82                           @(T.width),
83                           @(T.featureChannels),
84                           @(T.numberOfImages),
85                           @(X.height),
86                           @(X.width),
87                           @(X.featureChannels),
88                           @(X.numberOfImages),
89                         ]];
90      id<MTLBuffer> offsetBuffer = [[MetalContext sharedInstance].device
91          newBufferWithLength:6 * sizeof(ushort)
92                      options:MTLResourceCPUCacheModeWriteCombined];
93      ushort* offsetBufferPtr = (ushort*)[offsetBuffer contents];
94      offsetBufferPtr[0] = (X.featureChannels + tex_offset + 3) / 4;
95      offsetBufferPtr[1] = (Y.featureChannels + 3) / 4;
96      offsetBufferPtr[2] = channel_offset / 4;
97      offsetBufferPtr[3] = (X.featureChannels + 3) / 4;
98      offsetBufferPtr[4] = X.numberOfImages * offsetBufferPtr[0];
99      offsetBufferPtr[5] = tex_offset;
100
101      [encoder setComputePipelineState:state];
102      if (tex_offset == 0) {
103        [encoder setTexture:[X texture] atIndex:0];
104        [encoder setTexture:[Y texture] atIndex:1];
105        [encoder setBuffer:offsetBuffer offset:0 atIndex:0];
106      }
107      else {
108        [encoder setTexture:[X texture] atIndex:0];
109        [encoder setTexture:[T texture] atIndex:1];
110        [encoder setTexture:[Y texture] atIndex:2];
111        [encoder setBuffer:offsetBuffer offset:0 atIndex:0];
112      }
113
114      ushort featureChannels = X.featureChannels;
115      if (channel_offset % 4 > 0) {
116        featureChannels += tex_offset;
117      }
118      const auto& launchParams =
119          metal::mpscnn::spatialPointwiseKernelLaunchParams(
120              state, X.numberOfImages, featureChannels, X.height, X.width);
121      [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
122              threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
123      [encoder endEncoding];
124    }
125
126    channel_offset += X.featureChannels;
127
128    {
129      id<MTLComputeCommandEncoder> encoder =
130          [commandBuffer.buffer computeCommandEncoder];
131
132      id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
133          specializedPipelineState:"store_features"
134                         Constants:@[
135                           @(T.height),
136                           @(T.width),
137                           @(T.featureChannels),
138                           @(T.numberOfImages),
139                         ]];
140      id<MTLBuffer> offsetBuffer = [[MetalContext sharedInstance].device
141          newBufferWithLength:2 * sizeof(ushort)
142                      options:MTLResourceCPUCacheModeWriteCombined];
143      ushort* offsetBufferPtr = (ushort*)[offsetBuffer contents];
144      offsetBufferPtr[0] = channel_offset / 4;
145      offsetBufferPtr[1] = (Y.featureChannels + 3) / 4;
146
147      [encoder setComputePipelineState:state];
148      [encoder setTexture:[Y texture] atIndex:0];
149      [encoder setTexture:[T texture] atIndex:1];
150      [encoder setBuffer:offsetBuffer offset:0 atIndex:0];
151
152      const auto& launchParams =
153          metal::mpscnn::spatialPointwiseKernelLaunchParams(state, T);
154      [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
155              threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
156      [encoder endEncoding];
157    }
158  }
159  auto output = makeTensor(std::move(mt), tensor.options());
160  return output;
161}
162
163static Tensor cat(const ITensorListRef& tensors, int64_t dim) {
164  TORCH_CHECK(
165      dim == 0 || dim == 1,
166      "Metal cat is implemented only for batch dimension");
167  int64_t cat_dim_size = 0;
168  TORCH_CHECK(!tensors.empty(), "cat expected a non-empty list of Tensor");
169  at::Tensor tensor = *tensors.begin();
170  MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor);
171  for (const auto& t : tensors) {
172    TORCH_CHECK(t.dim() == 4, "Metal cat expects 4 dimensional inputs");
173    TORCH_CHECK(t.is_metal(), "Metal cat expects metal tensors");
174
175    for (int d = 0; d < 4; ++d) {
176      if (d == dim) {
177        continue;
178      }
179      TORCH_CHECK(
180          t.size(d) == tensor.size(d),
181          "Metal cat inputs must have matching sizes except concatenated dimension");
182    }
183    cat_dim_size += t.size(dim);
184  }
185  auto result_size = tensor.sizes().vec();
186  result_size[dim] = cat_dim_size;
187  TORCH_CHECK(
188      result_size[0] * ((result_size[1] + 3) / 4) > 1,
189      "Output tensor must be a texture array");
190  MetalTensorImplStorage mt{result_size};
191  mt.texture()->setCommandBuffer(commandBuffer);
192  mt.texture()->allocateTemporaryStorage(result_size, commandBuffer);
193
194  if (dim == 1) {
195    return cat_feature(tensor, tensors, mt);
196  }
197  return cat_batch(tensor, tensors, mt);
198}
199
200TORCH_LIBRARY_IMPL(aten, Metal, m) {
201  m.impl(TORCH_SELECTIVE_NAME("aten::cat"), TORCH_FN(cat));
202}
203
204} // namespace at::native::metal
205