xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/operations/Linear.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3#include <ATen/ExpandUtils.h>
4#include <ATen/native/mps/OperationUtils.h>
5#include <ATen/ops/linear_backward_native.h>
6#include <ATen/ops/linear_native.h>
7
8namespace at::native {
9
10using namespace mps;
11
12Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::optional<Tensor>& bias_opt) {
13  // wT = transpose(weight);
14  // y=x*wT+b
15
16  auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg;
17
18  TORCH_CHECK(supportedFloatingOrComplexType(input), "MPS device does not support linear for non-float inputs");
19  TORCH_CHECK(input.is_mps(), "Tensor for argument input is on ", input.device(), " but expected on mps");
20  TORCH_CHECK(supportedFloatingOrComplexType(weight_arg), "MPS device does not support linear for non-float weights");
21  TORCH_CHECK(weight_arg.is_mps(), "Tensor for argument weight is on ", weight_arg.device(), " but expected on mps");
22
23  const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt));
24  const bool is_bias_defined = bias.defined();
25  if (is_bias_defined) {
26    TORCH_CHECK(bias.is_mps(), "Tensor for argument bias is on ", bias.device(), " but expected on mps");
27    TORCH_CHECK(supportedFloatingOrComplexType(bias), "MPS device does not support linear for non-float bias");
28  }
29
30  auto input_size = input.sizes();
31  std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
32  output_size.push_back(weight.size(0));
33
34  TORCH_CHECK(input.size(-1) == weight_arg.size(-1),
35              "linear(): input and weight.T shapes cannot be multiplied (",
36              input.size(-2),
37              "x",
38              input.size(-1),
39              " and ",
40              weight_arg.size(-1),
41              "x",
42              weight_arg.size(-2),
43              ")");
44
45  if (is_bias_defined) {
46    // Check bias and output shapes compatibility only.
47    inferExpandGeometry_dimvector(bias.sizes(), bias.strides(), output_size);
48  }
49
50  Tensor output =
51      at::empty(output_size, input.scalar_type(), std::nullopt, kMPS, std::nullopt, input.suggest_memory_format());
52
53  if (output.numel() == 0) {
54    return output;
55  }
56
57  MPSStream* stream = getCurrentMPSStream();
58
59  struct CachedGraph : public MPSCachedGraph {
60    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
61    MPSGraphTensor* inputTensor_ = nil;
62    MPSGraphTensor* weightTensor_ = nil;
63    MPSGraphTensor* biasTensor_ = nil;
64    MPSGraphTensor* outputTensor_ = nil;
65  };
66
67  @autoreleasepool {
68    string key = "mps_linear" + getTensorsStringKey({input, weight, bias});
69    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
70      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
71      MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
72
73      MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
74                                                              dimension:-1
75                                                          withDimension:-2
76                                                                   name:nil];
77      // matrixMultiplicationWithPrimary crashes for 5D tensors, see https://github.com/pytorch/pytorch/issues/114942
78      bool doReshape = input.dim() > 4;
79      if (!doReshape && is_bias_defined) {
80        // workaround to improve the performance with 3D+ inputs
81        doReshape =
82            input_size.size() > 2 && input_size[0] > 1 && input_size[1] >= 1 && input_size[1] <= 32 && bias.dim() <= 1;
83      }
84      auto inputFlattened = doReshape ? [mpsGraph flatten2DTensor:inputTensor axis:-1 name:nil] : inputTensor;
85      auto outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputFlattened
86                                                          secondaryTensor:weightTransposeTensor
87                                                                     name:nil];
88
89      if (is_bias_defined) {
90        newCachedGraph->biasTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, bias);
91        outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor
92                                           secondaryTensor:newCachedGraph->biasTensor_
93                                                      name:nil];
94      }
95      if (doReshape) {
96        outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:getMPSShape(output_size) name:nil];
97      }
98
99      newCachedGraph->inputTensor_ = inputTensor;
100      newCachedGraph->weightTensor_ = weightTensor;
101      newCachedGraph->outputTensor_ = outputTensor;
102    });
103
104    Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
105    Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight);
106    Placeholder biasPlaceholder = Placeholder();
107    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
108
109    NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
110    feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
111    feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData();
112    if (is_bias_defined) {
113      biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias);
114      feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
115    }
116    runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
117  }
118
119  // Shave off '1' present at the end of the shape
120  if (weight_arg.dim() == 1) {
121    // Number of elements in new output shape
122    auto output_sizes = output.sizes();
123    std::vector<int64_t> out_shape(output_sizes.begin(), output_sizes.end() - 1);
124    return output.view(IntArrayRef(out_shape));
125  }
126  return output;
127}
128
129static Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight) {
130  TORCH_CHECK(grad_output.is_mps(), "mps_linear_backward: grad_output needs to be mps layout");
131  TORCH_CHECK(weight.device().is_mps() && supportedFloatingOrComplexType(weight),
132              "mps_linear_backward: unsupported weights data type: ",
133              weight.scalar_type());
134
135  TORCH_CHECK(supportedFloatingOrComplexType(grad_output),
136              "MPS device does not support linear backward for non-float inputs");
137
138  const Tensor weight_reshaped = weight.is_contiguous() ? weight : weight.contiguous();
139
140  struct CachedGraph : public MPSCachedGraph {
141    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
142    MPSGraphTensor* weightTensor_ = nil;
143    MPSGraphTensor* gradOutputTensor_ = nil;
144    MPSGraphTensor* outputTensor_ = nil;
145  };
146
147  Tensor output = at::empty(
148      input_size, grad_output.scalar_type(), std::nullopt, kMPS, std::nullopt, grad_output.suggest_memory_format());
149  TORCH_CHECK(output.is_mps());
150  if (grad_output.numel() == 0) {
151    return output;
152  }
153
154  MPSStream* stream = getCurrentMPSStream();
155
156  @autoreleasepool {
157    string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped});
158    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
159      newCachedGraph->weightTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped);
160      newCachedGraph->gradOutputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
161
162      // MPS matrixMultiplication crashes for 5D+ tensors on 14.2.1 with `New volume should match old volume`
163      // See https://github.com/pytorch/pytorch/issues/114942 for more details
164      bool needReshape = grad_output.dim() > 4;
165      auto gradOutputTensor = needReshape
166          ? [mpsGraph flatten2DTensor:newCachedGraph->gradOutputTensor_ axis:-1 name:nil]
167          : newCachedGraph->gradOutputTensor_;
168
169      auto outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTensor
170                                                          secondaryTensor:newCachedGraph->weightTensor_
171                                                                     name:nil];
172      if (needReshape) {
173        outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:getMPSShape(output) name:nil];
174      }
175
176      newCachedGraph->outputTensor_ = outputTensor;
177    });
178
179    Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_reshaped);
180    Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
181    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
182
183    auto feeds = dictionaryFromPlaceholders(weightPlaceholder, gradOutputPlaceholder);
184    runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
185
186    return output;
187  }
188}
189
190static std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& grad_output,
191                                                               const Tensor& input,
192                                                               const Tensor& weight,
193                                                               bool bias_defined) {
194  TORCH_CHECK(grad_output.is_mps() && input.is_mps(),
195              "_mps_linear_backward: grad_output and input needs to be mps layout");
196
197  TORCH_CHECK(supportedFloatingOrComplexType(grad_output),
198              "MPS device does not support linear backward for non-float inputs");
199
200  struct CachedGraph : public MPSCachedGraph {
201    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
202    MPSGraphTensor* inputTensor_ = nil;
203    MPSGraphTensor* weightTensor_ = nil;
204    MPSGraphTensor* gradOutputTensor_ = nil;
205    MPSGraphTensor* outputTensor_ = nil;
206    MPSGraphTensor* biasTensor_ = nil;
207  };
208
209  auto grad_output_reshaped =
210      grad_output.dim() != 2 ? grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output;
211  auto input_reshaped = input.dim() != 2 ? input.reshape({-1, input.size(input.dim() - 1)}) : input;
212
213  TORCH_CHECK(grad_output_reshaped.is_mps());
214  TORCH_CHECK(input_reshaped.is_mps());
215
216  Tensor output = at::empty({grad_output_reshaped.size(1), input_reshaped.size(1)},
217                            grad_output.scalar_type(),
218                            std::nullopt,
219                            kMPS,
220                            std::nullopt,
221                            grad_output.suggest_memory_format());
222  Tensor bias = at::empty({grad_output_reshaped.size(1)},
223                          grad_output.scalar_type(),
224                          std::nullopt,
225                          kMPS,
226                          std::nullopt,
227                          grad_output.suggest_memory_format());
228  TORCH_CHECK(output.is_mps());
229  TORCH_CHECK(bias.is_mps());
230
231  if (grad_output.numel() == 0) {
232    output.zero_();
233    bias.zero_();
234    return std::tuple<Tensor, Tensor>{output, bias};
235  }
236  MPSStream* stream = getCurrentMPSStream();
237
238  @autoreleasepool {
239    string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" +
240        getTensorsStringKey({input_reshaped, weight, grad_output_reshaped});
241    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
242      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped);
243      MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
244      MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped);
245
246      MPSGraphTensor* gradOutputTransposeTensor = [mpsGraph transposeTensor:gradOutputTensor
247                                                                  dimension:-1
248                                                              withDimension:-2
249                                                                       name:nil];
250
251      // grad_weight
252      MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTransposeTensor
253                                                                     secondaryTensor:inputTensor
254                                                                                name:nil];
255      MPSGraphTensor* biasTensor = nil;
256      if (bias_defined) {
257        // grad_bias
258        biasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor axis:0 name:nil];
259      }
260
261      newCachedGraph->inputTensor_ = inputTensor;
262      newCachedGraph->weightTensor_ = weightTensor;
263      newCachedGraph->gradOutputTensor_ = gradOutputTensor;
264      newCachedGraph->outputTensor_ = outputTensor;
265      newCachedGraph->biasTensor_ = biasTensor;
266    });
267
268    Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_reshaped);
269    Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight);
270    Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_reshaped);
271    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
272    Placeholder biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias);
273
274    auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, inputPlaceholder, weightPlaceholder);
275    auto results = bias_defined ? dictionaryFromPlaceholders(outputPlaceholder, biasPlaceholder)
276                                : dictionaryFromPlaceholders(outputPlaceholder);
277    runMPSGraph(stream, cachedGraph->graph(), feeds, results);
278
279    return std::tuple<Tensor, Tensor>{output, bias};
280  }
281}
282
283std::tuple<Tensor, Tensor, Tensor> mps_linear_backward(const Tensor& input,
284                                                       const Tensor& grad_output,
285                                                       const Tensor& weight,
286                                                       std::array<bool, 3> output_mask) {
287  Tensor grad_input, grad_weight, grad_bias;
288  if (output_mask[0]) {
289    grad_input = _mps_linear_backward_input(input.sizes(), grad_output, weight);
290  }
291  if (output_mask[1] || output_mask[2]) {
292    std::tie(grad_weight, grad_bias) = _mps_linear_backward_weights(grad_output, input, weight, output_mask[2]);
293  }
294  return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
295}
296
297} // namespace at::native
298