1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalTensorImpl.h> 3#import <ATen/native/metal/MetalTensorUtils.h> 4#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h> 5#import <ATen/native/metal/mpscnn/MPSCNNConvOp.h> 6#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 7#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 8#import <ATen/native/metal/ops/MetalConvolution.h> 9 10#import <ATen/ATen.h> 11 12namespace at::native::metal { 13 14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>; 15Tensor conv2d( 16 const Tensor& input, 17 const Tensor& weight, 18 const std::optional<at::Tensor>& bias, 19 IntArrayRef stride, 20 IntArrayRef padding, 21 IntArrayRef dilation, 22 int64_t groups) { 23 TORCH_CHECK(input.is_metal()); 24 Conv2DParams params{ 25 input.sizes(), weight.sizes(), padding, stride, dilation, groups}; 26 TORCH_INTERNAL_ASSERT(input.dim() == 4, "Expected 4-dimensional input"); 27 TORCH_INTERNAL_ASSERT(weight.dim() == 4, "Expected 4-dimensional weight"); 28 TORCH_CHECK(weight.device().type() == kCPU); 29 auto outputSize = params.output_sizes(); 30 if(c10::multiply_integers(outputSize) == 0){ 31 return makeTensor({outputSize}, input.options()); 32 } 33 MPSImage* X = imageFromTensor(input); 34 auto packedWeights = weight.contiguous(c10::MemoryFormat::ChannelsLast); 35 // MPSCNN Convolution 36 float* w = packedWeights.data_ptr<float>(); 37 float* b = bias.has_value() ? bias->data_ptr<float>() : nullptr; 38 MPSCNNConvOp* op = [MPSCNNConvOp conv2d:params 39 weights:w 40 bias:b 41 neuronFilter:NeuronType::None]; 42 MetalTensorImplStorage mt{outputSize}; 43 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 44 mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); 45 MPSImage* Y = mt.texture()->image(); 46 [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y]; 47 auto output = makeTensor(std::move(mt), input.options()); 48 return output; 49} 50 51namespace prepack { 52 53Tensor conv2d(const Tensor& input, Conv2dOpContext& context) { 54 MPSImage* X = imageFromTensor(input); 55 Conv2DParams params{input.sizes(), 56 context.get_weight().sizes(), 57 context.get_padding(), 58 context.get_stride(), 59 context.get_dilation(), 60 context.get_groups()}; 61 auto outputSize = params.output_sizes(); 62 if(c10::multiply_integers(outputSize) == 0){ 63 return makeTensor({outputSize}, input.options()); 64 } 65 MPSCNNConvOp* op = (__bridge MPSCNNConvOp*)(context.get_conv2dOpPtr()); 66 NeuronType nt = neuronType(context.get_output_min(), context.get_output_max()); 67 if (!op) { 68 float* w = context.get_weight().data_ptr<float>(); 69 float* b = context.get_bias().has_value() ? ((*context.get_bias()).data_ptr<float>()) 70 : nullptr; 71 op = [MPSCNNConvOp conv2d:params weights:w bias:b neuronFilter:nt]; 72 context.set_conv2dOpPtr((void*)CFBridgingRetain(op)); 73 context.set_releaseCallback(^(void* res) { 74 if (res) { 75 CFBridgingRelease(res); 76 } 77 }); 78 } 79 MetalTensorImplStorage mt{outputSize}; 80 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 81 mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); 82 MPSImage* Y1 = mt.texture()->image(); 83 [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y1]; 84 // fuse hardtanh with convolution 85 if (nt == NeuronType::Clamp) { 86 MPSImage* Y2 = createTemporaryImage(commandBuffer, [Y1 sizes]); 87 float min = context.get_output_min().value().toFloat(); 88 float max = context.get_output_max().value().toFloat(); 89 MPSCNNClampOp* clampOp = 90 [MPSCNNClampOp newWithTextures:@[ Y1, Y2 ] Args:@[ @(min), @(max) ]]; 91 [clampOp encode:commandBuffer.buffer]; 92 mt.texture()->setImage(Y2); 93 } 94 auto output = makeTensor(std::move(mt), input.options()); 95 return output; 96} 97 98static Tensor conv2d_prepack_run( 99 const Tensor& input, 100 const c10::intrusive_ptr<Conv2dOpContext>& op_context) { 101 return conv2d(input, *op_context); 102} 103 104} // namespace prepack 105 106TORCH_LIBRARY_IMPL(aten, Metal, m) { 107 // NB: this didn't actually do anything; need to generalize this to 108 // work for general convolution and register to aten::convolution 109 // m.impl(TORCH_SELECTIVE_NAME("aten::conv2d"), TORCH_FN(conv2d)); 110}; 111 112TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) { 113 m.impl(TORCH_SELECTIVE_NAME("metal_prepack::conv2d_run"), prepack::conv2d_prepack_run); 114} 115 116} // namespace at::native::metal 117