xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalConvolution.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalTensorImpl.h>
3#import <ATen/native/metal/MetalTensorUtils.h>
4#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
5#import <ATen/native/metal/mpscnn/MPSCNNConvOp.h>
6#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
7#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
8#import <ATen/native/metal/ops/MetalConvolution.h>
9
10#import <ATen/ATen.h>
11
12namespace at::native::metal {
13
14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
15Tensor conv2d(
16    const Tensor& input,
17    const Tensor& weight,
18    const std::optional<at::Tensor>& bias,
19    IntArrayRef stride,
20    IntArrayRef padding,
21    IntArrayRef dilation,
22    int64_t groups) {
23  TORCH_CHECK(input.is_metal());
24  Conv2DParams params{
25      input.sizes(), weight.sizes(), padding, stride, dilation, groups};
26  TORCH_INTERNAL_ASSERT(input.dim() == 4, "Expected 4-dimensional input");
27  TORCH_INTERNAL_ASSERT(weight.dim() == 4, "Expected 4-dimensional weight");
28  TORCH_CHECK(weight.device().type() == kCPU);
29  auto outputSize = params.output_sizes();
30  if(c10::multiply_integers(outputSize) == 0){
31      return makeTensor({outputSize}, input.options());
32  }
33  MPSImage* X = imageFromTensor(input);
34  auto packedWeights = weight.contiguous(c10::MemoryFormat::ChannelsLast);
35  // MPSCNN Convolution
36  float* w = packedWeights.data_ptr<float>();
37  float* b = bias.has_value() ? bias->data_ptr<float>() : nullptr;
38  MPSCNNConvOp* op = [MPSCNNConvOp conv2d:params
39                                  weights:w
40                                     bias:b
41                             neuronFilter:NeuronType::None];
42  MetalTensorImplStorage mt{outputSize};
43  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
44  mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
45  MPSImage* Y = mt.texture()->image();
46  [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y];
47  auto output = makeTensor(std::move(mt), input.options());
48  return output;
49}
50
51namespace prepack {
52
53Tensor conv2d(const Tensor& input, Conv2dOpContext& context) {
54  MPSImage* X = imageFromTensor(input);
55  Conv2DParams params{input.sizes(),
56                      context.get_weight().sizes(),
57                      context.get_padding(),
58                      context.get_stride(),
59                      context.get_dilation(),
60                      context.get_groups()};
61  auto outputSize = params.output_sizes();
62  if(c10::multiply_integers(outputSize) == 0){
63    return makeTensor({outputSize}, input.options());
64  }
65  MPSCNNConvOp* op = (__bridge MPSCNNConvOp*)(context.get_conv2dOpPtr());
66  NeuronType nt = neuronType(context.get_output_min(), context.get_output_max());
67  if (!op) {
68    float* w = context.get_weight().data_ptr<float>();
69    float* b = context.get_bias().has_value() ? ((*context.get_bias()).data_ptr<float>())
70                                        : nullptr;
71    op = [MPSCNNConvOp conv2d:params weights:w bias:b neuronFilter:nt];
72    context.set_conv2dOpPtr((void*)CFBridgingRetain(op));
73    context.set_releaseCallback(^(void* res) {
74      if (res) {
75        CFBridgingRelease(res);
76      }
77    });
78  }
79  MetalTensorImplStorage mt{outputSize};
80  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
81  mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
82  MPSImage* Y1 = mt.texture()->image();
83  [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y1];
84  // fuse hardtanh with convolution
85  if (nt == NeuronType::Clamp) {
86    MPSImage* Y2 = createTemporaryImage(commandBuffer, [Y1 sizes]);
87    float min = context.get_output_min().value().toFloat();
88    float max = context.get_output_max().value().toFloat();
89    MPSCNNClampOp* clampOp =
90        [MPSCNNClampOp newWithTextures:@[ Y1, Y2 ] Args:@[ @(min), @(max) ]];
91    [clampOp encode:commandBuffer.buffer];
92    mt.texture()->setImage(Y2);
93  }
94  auto output = makeTensor(std::move(mt), input.options());
95  return output;
96}
97
98static Tensor conv2d_prepack_run(
99    const Tensor& input,
100    const c10::intrusive_ptr<Conv2dOpContext>& op_context) {
101  return conv2d(input, *op_context);
102}
103
104} // namespace prepack
105
106TORCH_LIBRARY_IMPL(aten, Metal, m) {
107  // NB: this didn't actually do anything; need to generalize this to
108  // work for general convolution and register to aten::convolution
109  // m.impl(TORCH_SELECTIVE_NAME("aten::conv2d"), TORCH_FN(conv2d));
110};
111
112TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) {
113  m.impl(TORCH_SELECTIVE_NAME("metal_prepack::conv2d_run"), prepack::conv2d_prepack_run);
114}
115
116} // namespace at::native::metal
117