xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalTensorImpl.h>
3#import <ATen/native/metal/MetalTensorImplStorage.h>
4#import <ATen/native/metal/MetalTensorUtils.h>
5#import <ATen/native/metal/MetalContext.h>
6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9
10#include <ATen/Tensor.h>
11#include <torch/library.h>
12
13namespace at::native::metal {
14
15using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
16
17static inline bool broadCastFirstInput(MPSImage* X1, MPSImage* X2) {
18  if ((X2.height > 1 && X1.height == 1) ||
19      (X2.width > 1 && X1.width == 1)) {
20    return true;
21  }
22  return false;
23}
24
25static inline void checkInputs(const Tensor& input1, const Tensor& input2) {
26  TORCH_CHECK(
27      channelsSize(input1) == channelsSize(input2),
28      "Metal binary elementwise ops require channel dimension to be equal!");
29  if (batchSize(input1) != batchSize(input2)) {
30    TORCH_CHECK(
31        channelsSize(input1) % 4 == 0,
32        "Metal binary elementwise ops require channel to be a multiple of 4 to broadcast along batch dimension!")
33  }
34
35  const uint32_t input1_h = heightSize(input1);
36  const uint32_t input1_w = widthSize(input1);
37  const uint32_t input2_h = heightSize(input2);
38  const uint32_t input2_w = widthSize(input2);
39
40  const std::string broadcast_error_msg =
41      "Incompatible input dimensions for broadcasting for Metal binary elementwise op!";
42  if (input1_h != input2_h) {
43    if (input1_h > input2_h) {
44      TORCH_CHECK(input2_h == 1, broadcast_error_msg);
45      TORCH_CHECK(input2_w == input1_w || input2_w == 1, broadcast_error_msg);
46    } else if (input2_h > input1_h) {
47      TORCH_CHECK(input1_h == 1, broadcast_error_msg);
48      TORCH_CHECK(input1_w == input2_w || input1_w == 1, broadcast_error_msg);
49    }
50  } else if (input1_w != input2_w) {
51    if (input1_w > input2_w) {
52      TORCH_CHECK(input2_w == 1, broadcast_error_msg);
53    } else if (input2_w > input1_w) {
54      TORCH_CHECK(input1_h == 1, broadcast_error_msg);
55    }
56  }
57}
58
59static Tensor binaryElementwiseShaderKernel(
60    const Tensor& input1,
61    const Tensor& input2,
62    const std::string& arrayKernel,
63    const std::string& nonarrayKernel) {
64  checkInputs(input1, input2);
65  MPSImage* X1 = imageFromTensor(input1);
66  MPSImage* X2 = imageFromTensor(input2);
67  TORCH_CHECK(X1.numberOfImages == X2.numberOfImages &&
68              X1.featureChannels == X2.featureChannels)
69  IntArrayRef outputSize = input1.sizes();
70  if (broadCastFirstInput(X1, X2)) {
71    outputSize = input2.sizes();
72  }
73  if(c10::multiply_integers(outputSize) == 0){
74    return makeTensor({outputSize.vec()}, input1.options());
75  }
76  MetalTensorImplStorage mt{outputSize.vec()};
77  MetalCommandBuffer* cb1 = getCommandBuffer(input1);
78  MetalCommandBuffer* cb2 = getCommandBuffer(input2);
79  TORCH_CHECK(
80      [cb1 isEqual:cb2], @"inputs have different Metal command buffers");
81  mt.texture()->allocateTemporaryStorage(outputSize, cb1);
82  MPSImage* Y = mt.texture()->image();
83  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
84      pipelineState:mpscnn::kernelFor(X1, arrayKernel, nonarrayKernel)];
85  id<MTLComputeCommandEncoder> encoder = [cb1.buffer computeCommandEncoder];
86  [encoder setComputePipelineState:state];
87  [encoder setTexture:[X1 texture] atIndex:0];
88  [encoder setTexture:[X2 texture] atIndex:1];
89  [encoder setTexture:[Y texture] atIndex:2];
90  const auto& launchParams =
91      mpscnn::spatialPointwiseKernelLaunchParams(state, Y);
92  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
93          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
94  [encoder endEncoding];
95  auto output = makeTensor(std::move(mt), input1.options());
96  return output;
97}
98
99static Tensor& binaryElementwiseShaderKernel_(
100    Tensor& input1,
101    const Tensor& input2,
102    const std::string& arrayKernel,
103    const std::string& nonarrayKernel) {
104  checkInputs(input1, input2);
105  MPSImage* X1 = imageFromTensor(input1);
106  MPSImage* X2 = imageFromTensor(input2);
107  TORCH_CHECK(X1.numberOfImages == X2.numberOfImages &&
108              X1.featureChannels == X2.featureChannels)
109  IntArrayRef outputSize = input1.sizes();
110  if (broadCastFirstInput(X1, X2)) {
111    outputSize = input2.sizes();
112  }
113  if(c10::multiply_integers(outputSize) == 0){
114      return input1;
115  }
116  MetalCommandBuffer* cb1 = getCommandBuffer(input1);
117  MetalCommandBuffer* cb2 = getCommandBuffer(input2);
118  TORCH_CHECK(
119      [cb1 isEqual:cb2], @"inputs have different Metal command buffers");
120  MPSImage* Y = createTemporaryImage(cb1, outputSize.vec());
121  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
122      pipelineState:mpscnn::kernelFor(X1, arrayKernel, nonarrayKernel)];
123  id<MTLComputeCommandEncoder> encoder = [cb1.buffer computeCommandEncoder];
124  [encoder setComputePipelineState:state];
125  [encoder setTexture:[X1 texture] atIndex:0];
126  [encoder setTexture:[X2 texture] atIndex:1];
127  [encoder setTexture:[Y texture] atIndex:2];
128  const auto& launchParams =
129      mpscnn::spatialPointwiseKernelLaunchParams(state, Y);
130  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
131          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
132  [encoder endEncoding];
133  MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl();
134  MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
135  implStorage.texture()->setImage(Y);
136  return input1;
137}
138
139template <typename T>
140Tensor binaryElementwiseMPSCNNKernel(
141    const Tensor& input1,
142    const Tensor& input2) {
143  checkInputs(input1, input2);
144  MPSImage* X1 = imageFromTensor(input1);
145  MPSImage* X2 = imageFromTensor(input2);
146  TORCH_CHECK(X1.numberOfImages == X2.numberOfImages &&
147              X1.featureChannels == X2.featureChannels)
148  IntArrayRef outputSize = input1.sizes();
149  if (broadCastFirstInput(X1, X2)) {
150    outputSize = input2.sizes();
151  }
152  if(c10::multiply_integers(outputSize) == 0){
153      return makeTensor({outputSize.vec()}, input1.options());
154  }
155  MetalTensorImplStorage mt{outputSize.vec()};
156  MetalCommandBuffer* cb1 = getCommandBuffer(input1);
157  MetalCommandBuffer* cb2 = getCommandBuffer(input2);
158  TORCH_CHECK(
159      [cb1 isEqual:cb2], @"inputs have different Metal command buffers");
160  mt.texture()->allocateTemporaryStorage(outputSize, cb1);
161  MPSImage* Y = mt.texture()->image();
162  T* kernel = [[T alloc] initWithDevice:[MetalContext sharedInstance].device];
163  kernel.primaryStrideInPixelsY = X1.height == 1 ? 0 : 1;
164  kernel.primaryStrideInPixelsX = X1.width == 1 ? 0 : 1;
165  kernel.secondaryStrideInPixelsY = X2.height == 1 ? 0 : 1;
166  kernel.secondaryStrideInPixelsX = X2.width == 1 ? 0 : 1;
167  [kernel encodeToCommandBuffer:cb1.buffer
168                   primaryImage:X1
169                 secondaryImage:X2
170               destinationImage:Y];
171  auto output = makeTensor(std::move(mt), input1.options());
172  return output;
173}
174
175template <typename T>
176Tensor& binaryElementwiseMPSCNNKernel_(Tensor& input1, const Tensor& input2) {
177  checkInputs(input1, input2);
178  MPSImage* X1 = imageFromTensor(input1);
179  MPSImage* X2 = imageFromTensor(input2);
180  TORCH_CHECK(X1.numberOfImages == X2.numberOfImages &&
181              X1.featureChannels == X2.featureChannels)
182  IntArrayRef outputSize = input1.sizes();
183  if (broadCastFirstInput(X1, X2)) {
184    outputSize = input2.sizes();
185  }
186  if(c10::multiply_integers(outputSize) == 0){
187    return input1;
188  }
189  MetalCommandBuffer* cb1 = getCommandBuffer(input1);
190  MetalCommandBuffer* cb2 = getCommandBuffer(input2);
191  TORCH_CHECK(
192      [cb1 isEqual:cb2], @"inputs have different Metal command buffers");
193  MPSImage* Y = createTemporaryImage(cb1, outputSize.vec());
194  T* kernel = [[T alloc] initWithDevice:[MetalContext sharedInstance].device];
195  kernel.primaryStrideInPixelsY = X1.height == 1 ? 0 : 1;
196  kernel.primaryStrideInPixelsX = X1.width == 1 ? 0 : 1;
197  kernel.secondaryStrideInPixelsY = X2.height == 1 ? 0 : 1;
198  kernel.secondaryStrideInPixelsX = X2.width == 1 ? 0 : 1;
199  [kernel encodeToCommandBuffer:cb1.buffer
200                   primaryImage:X1
201                 secondaryImage:X2
202               destinationImage:Y];
203  MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl();
204  MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
205  implStorage.texture()->setImage(Y);
206  return input1;
207}
208
209static Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) {
210  TORCH_CHECK(input1.is_metal());
211  auto input2_ = input2.is_metal() ? input2 : input2.metal();
212  if (@available(iOS 11.3, *)) {
213    return binaryElementwiseMPSCNNKernel<MPSCNNAdd>(input1, input2_);
214  } else {
215    return binaryElementwiseShaderKernel(
216        input1, input2_, "elementwise_add", "elementwise_add_nonarray");
217  }
218}
219
220static Tensor& add__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) {
221  TORCH_CHECK(input1.is_metal());
222  auto input2_ = input2.is_metal() ? input2 : input2.metal();
223  if (@available(iOS 11.3, *)) {
224    return binaryElementwiseMPSCNNKernel_<MPSCNNAdd>(input1, input2_);
225  } else {
226    return binaryElementwiseShaderKernel_(
227        input1, input2_, "elementwise_add", "elementwise_add_nonarray");
228  }
229}
230
231static Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) {
232  TORCH_CHECK(input1.is_metal());
233  auto input2_ = input2.is_metal() ? input2 : input2.metal();
234  if (@available(iOS 11.3, *)) {
235    return binaryElementwiseMPSCNNKernel<MPSCNNSubtract>(input1, input2_);
236  } else {
237    return binaryElementwiseShaderKernel(
238        input1, input2_, "elementwise_sub", "elementwise_sub_nonarray");
239  }
240}
241
242static Tensor& sub__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) {
243  TORCH_CHECK(input1.is_metal());
244  auto input2_ = input2.is_metal() ? input2 : input2.metal();
245  if (@available(iOS 11.3, *)) {
246    return binaryElementwiseMPSCNNKernel_<MPSCNNSubtract>(input1, input2_);
247  } else {
248    return binaryElementwiseShaderKernel_(
249        input1, input2_, "elementwise_sub", "elementwise_sub_nonarray");
250  }
251}
252
253static Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) {
254  TORCH_CHECK(input1.is_metal());
255  auto input2_ = input2.is_metal() ? input2 : input2.metal();
256  if (@available(iOS 11.3, *)) {
257    return binaryElementwiseMPSCNNKernel<MPSCNNMultiply>(input1, input2_);
258  } else {
259    return binaryElementwiseShaderKernel(
260        input1, input2_, "elementwise_mul", "elementwise_mul_nonarray");
261  }
262}
263
264static Tensor& mul__Tensor(Tensor& input1, const Tensor& input2) {
265  TORCH_CHECK(input1.is_metal());
266  auto input2_ = input2.is_metal() ? input2 : input2.metal();
267  if (@available(iOS 11.3, *)) {
268    return binaryElementwiseMPSCNNKernel_<MPSCNNMultiply>(input1, input2_);
269  } else {
270    return binaryElementwiseShaderKernel_(
271        input1, input2_, "elementwise_mul", "elementwise_mul_nonarray");
272  }
273}
274
275static Tensor div_Tensor(const Tensor& input1, const Tensor& input2) {
276  TORCH_CHECK(input1.is_metal());
277  auto input2_ = input2.is_metal() ? input2 : input2.metal();
278  if (@available(iOS 11.3, *)) {
279    return binaryElementwiseMPSCNNKernel<MPSCNNDivide>(input1, input2_);
280  } else {
281    return binaryElementwiseShaderKernel(
282        input1, input2_, "elementwise_div", "elementwise_div_nonarray");
283  }
284}
285
286static Tensor& div__Tensor(Tensor& input1, const Tensor& input2) {
287  TORCH_CHECK(input1.is_metal());
288  auto input2_ = input2.is_metal() ? input2 : input2.metal();
289  if (@available(iOS 11.3, *)) {
290    return binaryElementwiseMPSCNNKernel_<MPSCNNDivide>(input1, input2_);
291  } else {
292    return binaryElementwiseShaderKernel_(
293        input1, input2_, "elementwise_div", "elementwise_div_nonarray");
294  }
295}
296
297TORCH_LIBRARY_IMPL(aten, Metal, m) {
298  m.impl(TORCH_SELECTIVE_NAME("aten::add.Tensor"), TORCH_FN(add_Tensor));
299  m.impl(TORCH_SELECTIVE_NAME("aten::add_.Tensor"), TORCH_FN(add__Tensor));
300  m.impl(TORCH_SELECTIVE_NAME("aten::mul.Tensor"), TORCH_FN(mul_Tensor));
301  m.impl(TORCH_SELECTIVE_NAME("aten::mul_.Tensor"), TORCH_FN(mul__Tensor));
302  m.impl(TORCH_SELECTIVE_NAME("aten::sub.Tensor"), TORCH_FN(sub_Tensor));
303  m.impl(TORCH_SELECTIVE_NAME("aten::sub_.Tensor"), TORCH_FN(sub__Tensor));
304  m.impl(TORCH_SELECTIVE_NAME("aten::div.Tensor"), TORCH_FN(div_Tensor));
305  m.impl(TORCH_SELECTIVE_NAME("aten::div_.Tensor"), TORCH_FN(div__Tensor));
306}
307
308} // namespace at::native::metal
309