xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalAddmm.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/Tensor.h>
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalPrepackOpContext.h>
4#import <ATen/native/metal/MetalTensorImpl.h>
5#import <ATen/native/metal/MetalTensorImplStorage.h>
6#import <ATen/native/metal/MetalTensorUtils.h>
7#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
8#import <ATen/native/metal/MetalContext.h>
9#import <ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.h>
10#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
11#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
12
13#include <torch/library.h>
14
15namespace at::native::metal {
16
17API_AVAILABLE(ios(11.0), macos(10.13))
18static Tensor addmm(
19    const Tensor& bias,
20    const Tensor& input,
21    const Tensor& weight,
22    const Scalar& beta,
23    const Scalar& alpha) {
24  TORCH_CHECK(input.is_metal());
25  TORCH_CHECK(weight.device() == kCPU && weight.dim() == 2);
26  TORCH_CHECK(bias.device() == kCPU);
27  TORCH_CHECK(beta.toFloat() == 1.0f);
28  TORCH_CHECK(alpha.toFloat() == 1.0f);
29  if(input.numel() == 0 || weight.numel() == 0){
30    return makeTensor({{input.size(0), weight.size(0)}}, input.options());
31  }
32  // Here we treat the matrix multiplication as convolution
33  auto weight_ =
34      weight.t().view({weight.size(1), weight.size(0), 1, 1}).contiguous();
35  // Reshape the input tensor to {N, C, 1, 1}
36  auto input_ = input.view({input.size(0), input.size(1), 1, 1});
37  MPSImage* X = imageFromTensor(input_);
38  Conv2DParams params;
39  params.N = X.numberOfImages;
40  params.OC = weight_.size(0);
41  params.IC = weight_.size(1);
42  params.KH = params.KW = 1, params.H = params.W = 1;
43  auto packedWeights = weight_.contiguous(c10::MemoryFormat::ChannelsLast);
44  MetalTensorImplStorage mt{{params.N, params.OC}};
45  SmallVector<int64_t, 4> textureSize = {params.N, params.OC, 1, 1};
46  MetalCommandBuffer* commandBuffer = getCommandBuffer(input_);
47  mt.texture()->allocateTemporaryStorage(textureSize, commandBuffer);
48  MPSImage* Y = mt.texture()->image();
49  float* w = packedWeights.data_ptr<float>();
50  float* b = bias.data_ptr<float>();
51  MPSCNNFullyConnectedOp* fc = [MPSCNNFullyConnectedOp linear:params
52                                                      weights:w
53                                                         bias:b
54                                                 neuronFilter:NeuronType::None];
55  [fc encode:commandBuffer.buffer sourceImage:X destinationImage:Y];
56  // The output texture becomes {N, oC, 1, 1}. Reshape it to {N, oC}
57  auto output =
58      makeTensor(std::move(mt), input.options()).view({params.N, params.OC});
59  return output;
60}
61
62namespace prepack {
63
64static Tensor linear(const Tensor& input, LinearOpContext& context) {
65  TORCH_CHECK(input.is_metal());
66  TORCH_CHECK(context.get_weight().device() == kCPU);
67  TORCH_CHECK(context.get_weight().dim() == 4);
68  if(input.numel() == 0 || context.get_weight().numel() == 0){
69    return makeTensor({{input.size(0), context.get_weight().size(0)}}, input.options());
70  }
71  // Reshape the input tensor to {N, C, 1, 1}
72  auto input_ = input.view({input.size(0), input.size(1), 1, 1});
73  MPSImage* X = imageFromTensor(input_);
74  Conv2DParams params;
75  params.N = X.numberOfImages;
76  params.OC = context.get_weight().size(0);
77  params.IC = context.get_weight().size(1);
78  params.KH = params.KW = 1;
79  params.H = params.W = 1;
80  MPSCNNFullyConnectedOp* op =
81      (__bridge MPSCNNFullyConnectedOp*)(context.get_opaqueOpPtr());
82  NeuronType nt =
83      neuronType(context.get_output_min(), context.get_output_max());
84  if (!op) {
85    float* w = context.get_weight().data_ptr<float>();
86    float* b = context.get_bias().has_value()
87        ? ((*context.get_bias()).data_ptr<float>())
88        : nullptr;
89    op = [MPSCNNFullyConnectedOp linear:params
90                                weights:w
91                                   bias:b
92                           neuronFilter:nt];
93    context.set_opaqueOpPtr((void*)CFBridgingRetain(op));
94    context.set_releaseCallback(^(void* res) {
95      if (res) {
96        CFBridgingRelease(res);
97      }
98    });
99  }
100  MetalTensorImplStorage mt{{params.N, params.OC}};
101  SmallVector<int64_t, 4> textureSize = {params.N, params.OC, 1, 1};
102  MetalCommandBuffer* commandBuffer = getCommandBuffer(input_);
103  mt.texture()->allocateTemporaryStorage(textureSize, commandBuffer);
104  MPSImage* Y1 = mt.texture()->image();
105  // HACK alert:
106  // Here we force X to become static before encoding.
107  // We've seen weird crashes in the MaskRCNN model complaining about
108  // a "sub-image" was released before its readCount was zero.
109  // TODO[T93395421]: Figure out the root cause and remove this line.
110  X = createStaticImage((MPSTemporaryImage* )X, commandBuffer, NO);
111  [op encode:commandBuffer.buffer sourceImage:X destinationImage:Y1];
112  if (nt == NeuronType::Clamp) {
113    MPSImage* Y2 = createTemporaryImage(commandBuffer, [Y1 sizes]);
114    float min = context.get_output_min().value().toFloat();
115    float max = context.get_output_max().value().toFloat();
116    MPSCNNClampOp* clampOp =
117        [MPSCNNClampOp newWithTextures:@[ Y1, Y2 ] Args:@[ @(min), @(max) ]];
118    [clampOp encode:commandBuffer.buffer];
119    mt.texture()->setImage(Y2);
120  }
121  // The output texture becomes {N, oC, 1, 1}. Reshape it to {N, oC}
122  auto output =
123      makeTensor(std::move(mt), input.options()).view({params.N, params.OC});
124  return output;
125}
126
127static Tensor linear_run(
128    const Tensor& input,
129    const c10::intrusive_ptr<LinearOpContext>& op_context) {
130  return linear(input, *op_context);
131}
132
133}
134
135TORCH_LIBRARY_IMPL(aten, Metal, m) {
136  m.impl(TORCH_SELECTIVE_NAME("aten::addmm"), TORCH_FN(addmm));
137};
138
139TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) {
140  m.impl(TORCH_SELECTIVE_NAME("metal_prepack::linear_run"), TORCH_FN(prepack::linear_run));
141}
142
143} // namespace at::native::metal
144