1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalTensorImpl.h> 3#import <ATen/native/metal/MetalTensorImplStorage.h> 4#import <ATen/native/metal/MetalTensorUtils.h> 5#import <ATen/native/metal/MetalContext.h> 6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9 10#include <ATen/Tensor.h> 11#include <torch/library.h> 12 13namespace at::native::metal { 14 15using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>; 16 17static inline bool broadCastFirstInput(MPSImage* X1, MPSImage* X2) { 18 if ((X2.height > 1 && X1.height == 1) || 19 (X2.width > 1 && X1.width == 1)) { 20 return true; 21 } 22 return false; 23} 24 25static inline void checkInputs(const Tensor& input1, const Tensor& input2) { 26 TORCH_CHECK( 27 channelsSize(input1) == channelsSize(input2), 28 "Metal binary elementwise ops require channel dimension to be equal!"); 29 if (batchSize(input1) != batchSize(input2)) { 30 TORCH_CHECK( 31 channelsSize(input1) % 4 == 0, 32 "Metal binary elementwise ops require channel to be a multiple of 4 to broadcast along batch dimension!") 33 } 34 35 const uint32_t input1_h = heightSize(input1); 36 const uint32_t input1_w = widthSize(input1); 37 const uint32_t input2_h = heightSize(input2); 38 const uint32_t input2_w = widthSize(input2); 39 40 const std::string broadcast_error_msg = 41 "Incompatible input dimensions for broadcasting for Metal binary elementwise op!"; 42 if (input1_h != input2_h) { 43 if (input1_h > input2_h) { 44 TORCH_CHECK(input2_h == 1, broadcast_error_msg); 45 TORCH_CHECK(input2_w == input1_w || input2_w == 1, broadcast_error_msg); 46 } else if (input2_h > input1_h) { 47 TORCH_CHECK(input1_h == 1, broadcast_error_msg); 48 TORCH_CHECK(input1_w == input2_w || input1_w == 1, broadcast_error_msg); 49 } 50 } else if (input1_w != input2_w) { 51 if (input1_w > input2_w) { 52 TORCH_CHECK(input2_w == 1, broadcast_error_msg); 53 } else if (input2_w > input1_w) { 54 TORCH_CHECK(input1_h == 1, broadcast_error_msg); 55 } 56 } 57} 58 59static Tensor binaryElementwiseShaderKernel( 60 const Tensor& input1, 61 const Tensor& input2, 62 const std::string& arrayKernel, 63 const std::string& nonarrayKernel) { 64 checkInputs(input1, input2); 65 MPSImage* X1 = imageFromTensor(input1); 66 MPSImage* X2 = imageFromTensor(input2); 67 TORCH_CHECK(X1.numberOfImages == X2.numberOfImages && 68 X1.featureChannels == X2.featureChannels) 69 IntArrayRef outputSize = input1.sizes(); 70 if (broadCastFirstInput(X1, X2)) { 71 outputSize = input2.sizes(); 72 } 73 if(c10::multiply_integers(outputSize) == 0){ 74 return makeTensor({outputSize.vec()}, input1.options()); 75 } 76 MetalTensorImplStorage mt{outputSize.vec()}; 77 MetalCommandBuffer* cb1 = getCommandBuffer(input1); 78 MetalCommandBuffer* cb2 = getCommandBuffer(input2); 79 TORCH_CHECK( 80 [cb1 isEqual:cb2], @"inputs have different Metal command buffers"); 81 mt.texture()->allocateTemporaryStorage(outputSize, cb1); 82 MPSImage* Y = mt.texture()->image(); 83 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 84 pipelineState:mpscnn::kernelFor(X1, arrayKernel, nonarrayKernel)]; 85 id<MTLComputeCommandEncoder> encoder = [cb1.buffer computeCommandEncoder]; 86 [encoder setComputePipelineState:state]; 87 [encoder setTexture:[X1 texture] atIndex:0]; 88 [encoder setTexture:[X2 texture] atIndex:1]; 89 [encoder setTexture:[Y texture] atIndex:2]; 90 const auto& launchParams = 91 mpscnn::spatialPointwiseKernelLaunchParams(state, Y); 92 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 93 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 94 [encoder endEncoding]; 95 auto output = makeTensor(std::move(mt), input1.options()); 96 return output; 97} 98 99static Tensor& binaryElementwiseShaderKernel_( 100 Tensor& input1, 101 const Tensor& input2, 102 const std::string& arrayKernel, 103 const std::string& nonarrayKernel) { 104 checkInputs(input1, input2); 105 MPSImage* X1 = imageFromTensor(input1); 106 MPSImage* X2 = imageFromTensor(input2); 107 TORCH_CHECK(X1.numberOfImages == X2.numberOfImages && 108 X1.featureChannels == X2.featureChannels) 109 IntArrayRef outputSize = input1.sizes(); 110 if (broadCastFirstInput(X1, X2)) { 111 outputSize = input2.sizes(); 112 } 113 if(c10::multiply_integers(outputSize) == 0){ 114 return input1; 115 } 116 MetalCommandBuffer* cb1 = getCommandBuffer(input1); 117 MetalCommandBuffer* cb2 = getCommandBuffer(input2); 118 TORCH_CHECK( 119 [cb1 isEqual:cb2], @"inputs have different Metal command buffers"); 120 MPSImage* Y = createTemporaryImage(cb1, outputSize.vec()); 121 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 122 pipelineState:mpscnn::kernelFor(X1, arrayKernel, nonarrayKernel)]; 123 id<MTLComputeCommandEncoder> encoder = [cb1.buffer computeCommandEncoder]; 124 [encoder setComputePipelineState:state]; 125 [encoder setTexture:[X1 texture] atIndex:0]; 126 [encoder setTexture:[X2 texture] atIndex:1]; 127 [encoder setTexture:[Y texture] atIndex:2]; 128 const auto& launchParams = 129 mpscnn::spatialPointwiseKernelLaunchParams(state, Y); 130 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 131 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 132 [encoder endEncoding]; 133 MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl(); 134 MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle(); 135 implStorage.texture()->setImage(Y); 136 return input1; 137} 138 139template <typename T> 140Tensor binaryElementwiseMPSCNNKernel( 141 const Tensor& input1, 142 const Tensor& input2) { 143 checkInputs(input1, input2); 144 MPSImage* X1 = imageFromTensor(input1); 145 MPSImage* X2 = imageFromTensor(input2); 146 TORCH_CHECK(X1.numberOfImages == X2.numberOfImages && 147 X1.featureChannels == X2.featureChannels) 148 IntArrayRef outputSize = input1.sizes(); 149 if (broadCastFirstInput(X1, X2)) { 150 outputSize = input2.sizes(); 151 } 152 if(c10::multiply_integers(outputSize) == 0){ 153 return makeTensor({outputSize.vec()}, input1.options()); 154 } 155 MetalTensorImplStorage mt{outputSize.vec()}; 156 MetalCommandBuffer* cb1 = getCommandBuffer(input1); 157 MetalCommandBuffer* cb2 = getCommandBuffer(input2); 158 TORCH_CHECK( 159 [cb1 isEqual:cb2], @"inputs have different Metal command buffers"); 160 mt.texture()->allocateTemporaryStorage(outputSize, cb1); 161 MPSImage* Y = mt.texture()->image(); 162 T* kernel = [[T alloc] initWithDevice:[MetalContext sharedInstance].device]; 163 kernel.primaryStrideInPixelsY = X1.height == 1 ? 0 : 1; 164 kernel.primaryStrideInPixelsX = X1.width == 1 ? 0 : 1; 165 kernel.secondaryStrideInPixelsY = X2.height == 1 ? 0 : 1; 166 kernel.secondaryStrideInPixelsX = X2.width == 1 ? 0 : 1; 167 [kernel encodeToCommandBuffer:cb1.buffer 168 primaryImage:X1 169 secondaryImage:X2 170 destinationImage:Y]; 171 auto output = makeTensor(std::move(mt), input1.options()); 172 return output; 173} 174 175template <typename T> 176Tensor& binaryElementwiseMPSCNNKernel_(Tensor& input1, const Tensor& input2) { 177 checkInputs(input1, input2); 178 MPSImage* X1 = imageFromTensor(input1); 179 MPSImage* X2 = imageFromTensor(input2); 180 TORCH_CHECK(X1.numberOfImages == X2.numberOfImages && 181 X1.featureChannels == X2.featureChannels) 182 IntArrayRef outputSize = input1.sizes(); 183 if (broadCastFirstInput(X1, X2)) { 184 outputSize = input2.sizes(); 185 } 186 if(c10::multiply_integers(outputSize) == 0){ 187 return input1; 188 } 189 MetalCommandBuffer* cb1 = getCommandBuffer(input1); 190 MetalCommandBuffer* cb2 = getCommandBuffer(input2); 191 TORCH_CHECK( 192 [cb1 isEqual:cb2], @"inputs have different Metal command buffers"); 193 MPSImage* Y = createTemporaryImage(cb1, outputSize.vec()); 194 T* kernel = [[T alloc] initWithDevice:[MetalContext sharedInstance].device]; 195 kernel.primaryStrideInPixelsY = X1.height == 1 ? 0 : 1; 196 kernel.primaryStrideInPixelsX = X1.width == 1 ? 0 : 1; 197 kernel.secondaryStrideInPixelsY = X2.height == 1 ? 0 : 1; 198 kernel.secondaryStrideInPixelsX = X2.width == 1 ? 0 : 1; 199 [kernel encodeToCommandBuffer:cb1.buffer 200 primaryImage:X1 201 secondaryImage:X2 202 destinationImage:Y]; 203 MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl(); 204 MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle(); 205 implStorage.texture()->setImage(Y); 206 return input1; 207} 208 209static Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { 210 TORCH_CHECK(input1.is_metal()); 211 auto input2_ = input2.is_metal() ? input2 : input2.metal(); 212 if (@available(iOS 11.3, *)) { 213 return binaryElementwiseMPSCNNKernel<MPSCNNAdd>(input1, input2_); 214 } else { 215 return binaryElementwiseShaderKernel( 216 input1, input2_, "elementwise_add", "elementwise_add_nonarray"); 217 } 218} 219 220static Tensor& add__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { 221 TORCH_CHECK(input1.is_metal()); 222 auto input2_ = input2.is_metal() ? input2 : input2.metal(); 223 if (@available(iOS 11.3, *)) { 224 return binaryElementwiseMPSCNNKernel_<MPSCNNAdd>(input1, input2_); 225 } else { 226 return binaryElementwiseShaderKernel_( 227 input1, input2_, "elementwise_add", "elementwise_add_nonarray"); 228 } 229} 230 231static Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { 232 TORCH_CHECK(input1.is_metal()); 233 auto input2_ = input2.is_metal() ? input2 : input2.metal(); 234 if (@available(iOS 11.3, *)) { 235 return binaryElementwiseMPSCNNKernel<MPSCNNSubtract>(input1, input2_); 236 } else { 237 return binaryElementwiseShaderKernel( 238 input1, input2_, "elementwise_sub", "elementwise_sub_nonarray"); 239 } 240} 241 242static Tensor& sub__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { 243 TORCH_CHECK(input1.is_metal()); 244 auto input2_ = input2.is_metal() ? input2 : input2.metal(); 245 if (@available(iOS 11.3, *)) { 246 return binaryElementwiseMPSCNNKernel_<MPSCNNSubtract>(input1, input2_); 247 } else { 248 return binaryElementwiseShaderKernel_( 249 input1, input2_, "elementwise_sub", "elementwise_sub_nonarray"); 250 } 251} 252 253static Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { 254 TORCH_CHECK(input1.is_metal()); 255 auto input2_ = input2.is_metal() ? input2 : input2.metal(); 256 if (@available(iOS 11.3, *)) { 257 return binaryElementwiseMPSCNNKernel<MPSCNNMultiply>(input1, input2_); 258 } else { 259 return binaryElementwiseShaderKernel( 260 input1, input2_, "elementwise_mul", "elementwise_mul_nonarray"); 261 } 262} 263 264static Tensor& mul__Tensor(Tensor& input1, const Tensor& input2) { 265 TORCH_CHECK(input1.is_metal()); 266 auto input2_ = input2.is_metal() ? input2 : input2.metal(); 267 if (@available(iOS 11.3, *)) { 268 return binaryElementwiseMPSCNNKernel_<MPSCNNMultiply>(input1, input2_); 269 } else { 270 return binaryElementwiseShaderKernel_( 271 input1, input2_, "elementwise_mul", "elementwise_mul_nonarray"); 272 } 273} 274 275static Tensor div_Tensor(const Tensor& input1, const Tensor& input2) { 276 TORCH_CHECK(input1.is_metal()); 277 auto input2_ = input2.is_metal() ? input2 : input2.metal(); 278 if (@available(iOS 11.3, *)) { 279 return binaryElementwiseMPSCNNKernel<MPSCNNDivide>(input1, input2_); 280 } else { 281 return binaryElementwiseShaderKernel( 282 input1, input2_, "elementwise_div", "elementwise_div_nonarray"); 283 } 284} 285 286static Tensor& div__Tensor(Tensor& input1, const Tensor& input2) { 287 TORCH_CHECK(input1.is_metal()); 288 auto input2_ = input2.is_metal() ? input2 : input2.metal(); 289 if (@available(iOS 11.3, *)) { 290 return binaryElementwiseMPSCNNKernel_<MPSCNNDivide>(input1, input2_); 291 } else { 292 return binaryElementwiseShaderKernel_( 293 input1, input2_, "elementwise_div", "elementwise_div_nonarray"); 294 } 295} 296 297TORCH_LIBRARY_IMPL(aten, Metal, m) { 298 m.impl(TORCH_SELECTIVE_NAME("aten::add.Tensor"), TORCH_FN(add_Tensor)); 299 m.impl(TORCH_SELECTIVE_NAME("aten::add_.Tensor"), TORCH_FN(add__Tensor)); 300 m.impl(TORCH_SELECTIVE_NAME("aten::mul.Tensor"), TORCH_FN(mul_Tensor)); 301 m.impl(TORCH_SELECTIVE_NAME("aten::mul_.Tensor"), TORCH_FN(mul__Tensor)); 302 m.impl(TORCH_SELECTIVE_NAME("aten::sub.Tensor"), TORCH_FN(sub_Tensor)); 303 m.impl(TORCH_SELECTIVE_NAME("aten::sub_.Tensor"), TORCH_FN(sub__Tensor)); 304 m.impl(TORCH_SELECTIVE_NAME("aten::div.Tensor"), TORCH_FN(div_Tensor)); 305 m.impl(TORCH_SELECTIVE_NAME("aten::div_.Tensor"), TORCH_FN(div__Tensor)); 306} 307 308} // namespace at::native::metal 309