1#include <ATen/Tensor.h> 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalPrepackOpContext.h> 4#import <ATen/native/metal/MetalTensorImpl.h> 5#import <ATen/native/metal/MetalTensorImplStorage.h> 6#import <ATen/native/metal/MetalTensorUtils.h> 7#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h> 8#import <ATen/native/metal/MetalContext.h> 9#import <ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.h> 10#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 11#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 12 13#include <torch/library.h> 14 15namespace at::native::metal { 16 17API_AVAILABLE(ios(11.0), macos(10.13)) 18static Tensor addmm( 19 const Tensor& bias, 20 const Tensor& input, 21 const Tensor& weight, 22 const Scalar& beta, 23 const Scalar& alpha) { 24 TORCH_CHECK(input.is_metal()); 25 TORCH_CHECK(weight.device() == kCPU && weight.dim() == 2); 26 TORCH_CHECK(bias.device() == kCPU); 27 TORCH_CHECK(beta.toFloat() == 1.0f); 28 TORCH_CHECK(alpha.toFloat() == 1.0f); 29 if(input.numel() == 0 || weight.numel() == 0){ 30 return makeTensor({{input.size(0), weight.size(0)}}, input.options()); 31 } 32 // Here we treat the matrix multiplication as convolution 33 auto weight_ = 34 weight.t().view({weight.size(1), weight.size(0), 1, 1}).contiguous(); 35 // Reshape the input tensor to {N, C, 1, 1} 36 auto input_ = input.view({input.size(0), input.size(1), 1, 1}); 37 MPSImage* X = imageFromTensor(input_); 38 Conv2DParams params; 39 params.N = X.numberOfImages; 40 params.OC = weight_.size(0); 41 params.IC = weight_.size(1); 42 params.KH = params.KW = 1, params.H = params.W = 1; 43 auto packedWeights = weight_.contiguous(c10::MemoryFormat::ChannelsLast); 44 MetalTensorImplStorage mt{{params.N, params.OC}}; 45 SmallVector<int64_t, 4> textureSize = {params.N, params.OC, 1, 1}; 46 MetalCommandBuffer* commandBuffer = getCommandBuffer(input_); 47 mt.texture()->allocateTemporaryStorage(textureSize, commandBuffer); 48 MPSImage* Y = mt.texture()->image(); 49 float* w = packedWeights.data_ptr<float>(); 50 float* b = bias.data_ptr<float>(); 51 MPSCNNFullyConnectedOp* fc = [MPSCNNFullyConnectedOp linear:params 52 weights:w 53 bias:b 54 neuronFilter:NeuronType::None]; 55 [fc encode:commandBuffer.buffer sourceImage:X destinationImage:Y]; 56 // The output texture becomes {N, oC, 1, 1}. Reshape it to {N, oC} 57 auto output = 58 makeTensor(std::move(mt), input.options()).view({params.N, params.OC}); 59 return output; 60} 61 62namespace prepack { 63 64static Tensor linear(const Tensor& input, LinearOpContext& context) { 65 TORCH_CHECK(input.is_metal()); 66 TORCH_CHECK(context.get_weight().device() == kCPU); 67 TORCH_CHECK(context.get_weight().dim() == 4); 68 if(input.numel() == 0 || context.get_weight().numel() == 0){ 69 return makeTensor({{input.size(0), context.get_weight().size(0)}}, input.options()); 70 } 71 // Reshape the input tensor to {N, C, 1, 1} 72 auto input_ = input.view({input.size(0), input.size(1), 1, 1}); 73 MPSImage* X = imageFromTensor(input_); 74 Conv2DParams params; 75 params.N = X.numberOfImages; 76 params.OC = context.get_weight().size(0); 77 params.IC = context.get_weight().size(1); 78 params.KH = params.KW = 1; 79 params.H = params.W = 1; 80 MPSCNNFullyConnectedOp* op = 81 (__bridge MPSCNNFullyConnectedOp*)(context.get_opaqueOpPtr()); 82 NeuronType nt = 83 neuronType(context.get_output_min(), context.get_output_max()); 84 if (!op) { 85 float* w = context.get_weight().data_ptr<float>(); 86 float* b = context.get_bias().has_value() 87 ? ((*context.get_bias()).data_ptr<float>()) 88 : nullptr; 89 op = [MPSCNNFullyConnectedOp linear:params 90 weights:w 91 bias:b 92 neuronFilter:nt]; 93 context.set_opaqueOpPtr((void*)CFBridgingRetain(op)); 94 context.set_releaseCallback(^(void* res) { 95 if (res) { 96 CFBridgingRelease(res); 97 } 98 }); 99 } 100 MetalTensorImplStorage mt{{params.N, params.OC}}; 101 SmallVector<int64_t, 4> textureSize = {params.N, params.OC, 1, 1}; 102 MetalCommandBuffer* commandBuffer = getCommandBuffer(input_); 103 mt.texture()->allocateTemporaryStorage(textureSize, commandBuffer); 104 MPSImage* Y1 = mt.texture()->image(); 105 // HACK alert: 106 // Here we force X to become static before encoding. 107 // We've seen weird crashes in the MaskRCNN model complaining about 108 // a "sub-image" was released before its readCount was zero. 109 // TODO[T93395421]: Figure out the root cause and remove this line. 110 X = createStaticImage((MPSTemporaryImage* )X, commandBuffer, NO); 111 [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y1]; 112 if (nt == NeuronType::Clamp) { 113 MPSImage* Y2 = createTemporaryImage(commandBuffer, [Y1 sizes]); 114 float min = context.get_output_min().value().toFloat(); 115 float max = context.get_output_max().value().toFloat(); 116 MPSCNNClampOp* clampOp = 117 [MPSCNNClampOp newWithTextures:@[ Y1, Y2 ] Args:@[ @(min), @(max) ]]; 118 [clampOp encode:commandBuffer.buffer]; 119 mt.texture()->setImage(Y2); 120 } 121 // The output texture becomes {N, oC, 1, 1}. Reshape it to {N, oC} 122 auto output = 123 makeTensor(std::move(mt), input.options()).view({params.N, params.OC}); 124 return output; 125} 126 127static Tensor linear_run( 128 const Tensor& input, 129 const c10::intrusive_ptr<LinearOpContext>& op_context) { 130 return linear(input, *op_context); 131} 132 133} 134 135TORCH_LIBRARY_IMPL(aten, Metal, m) { 136 m.impl(TORCH_SELECTIVE_NAME("aten::addmm"), TORCH_FN(addmm)); 137}; 138 139TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) { 140 m.impl(TORCH_SELECTIVE_NAME("metal_prepack::linear_run"), TORCH_FN(prepack::linear_run)); 141} 142 143} // namespace at::native::metal 144