1// Copyright © 2022 Apple Inc. 2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 3#include <ATen/ExpandUtils.h> 4#include <ATen/native/mps/OperationUtils.h> 5#include <ATen/ops/linear_backward_native.h> 6#include <ATen/ops/linear_native.h> 7 8namespace at::native { 9 10using namespace mps; 11 12Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::optional<Tensor>& bias_opt) { 13 // wT = transpose(weight); 14 // y=x*wT+b 15 16 auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg; 17 18 TORCH_CHECK(supportedFloatingOrComplexType(input), "MPS device does not support linear for non-float inputs"); 19 TORCH_CHECK(input.is_mps(), "Tensor for argument input is on ", input.device(), " but expected on mps"); 20 TORCH_CHECK(supportedFloatingOrComplexType(weight_arg), "MPS device does not support linear for non-float weights"); 21 TORCH_CHECK(weight_arg.is_mps(), "Tensor for argument weight is on ", weight_arg.device(), " but expected on mps"); 22 23 const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt)); 24 const bool is_bias_defined = bias.defined(); 25 if (is_bias_defined) { 26 TORCH_CHECK(bias.is_mps(), "Tensor for argument bias is on ", bias.device(), " but expected on mps"); 27 TORCH_CHECK(supportedFloatingOrComplexType(bias), "MPS device does not support linear for non-float bias"); 28 } 29 30 auto input_size = input.sizes(); 31 std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1); 32 output_size.push_back(weight.size(0)); 33 34 TORCH_CHECK(input.size(-1) == weight_arg.size(-1), 35 "linear(): input and weight.T shapes cannot be multiplied (", 36 input.size(-2), 37 "x", 38 input.size(-1), 39 " and ", 40 weight_arg.size(-1), 41 "x", 42 weight_arg.size(-2), 43 ")"); 44 45 if (is_bias_defined) { 46 // Check bias and output shapes compatibility only. 47 inferExpandGeometry_dimvector(bias.sizes(), bias.strides(), output_size); 48 } 49 50 Tensor output = 51 at::empty(output_size, input.scalar_type(), std::nullopt, kMPS, std::nullopt, input.suggest_memory_format()); 52 53 if (output.numel() == 0) { 54 return output; 55 } 56 57 MPSStream* stream = getCurrentMPSStream(); 58 59 struct CachedGraph : public MPSCachedGraph { 60 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 61 MPSGraphTensor* inputTensor_ = nil; 62 MPSGraphTensor* weightTensor_ = nil; 63 MPSGraphTensor* biasTensor_ = nil; 64 MPSGraphTensor* outputTensor_ = nil; 65 }; 66 67 @autoreleasepool { 68 string key = "mps_linear" + getTensorsStringKey({input, weight, bias}); 69 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) { 70 MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); 71 MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); 72 73 MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor 74 dimension:-1 75 withDimension:-2 76 name:nil]; 77 // matrixMultiplicationWithPrimary crashes for 5D tensors, see https://github.com/pytorch/pytorch/issues/114942 78 bool doReshape = input.dim() > 4; 79 if (!doReshape && is_bias_defined) { 80 // workaround to improve the performance with 3D+ inputs 81 doReshape = 82 input_size.size() > 2 && input_size[0] > 1 && input_size[1] >= 1 && input_size[1] <= 32 && bias.dim() <= 1; 83 } 84 auto inputFlattened = doReshape ? [mpsGraph flatten2DTensor:inputTensor axis:-1 name:nil] : inputTensor; 85 auto outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputFlattened 86 secondaryTensor:weightTransposeTensor 87 name:nil]; 88 89 if (is_bias_defined) { 90 newCachedGraph->biasTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, bias); 91 outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor 92 secondaryTensor:newCachedGraph->biasTensor_ 93 name:nil]; 94 } 95 if (doReshape) { 96 outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:getMPSShape(output_size) name:nil]; 97 } 98 99 newCachedGraph->inputTensor_ = inputTensor; 100 newCachedGraph->weightTensor_ = weightTensor; 101 newCachedGraph->outputTensor_ = outputTensor; 102 }); 103 104 Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); 105 Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); 106 Placeholder biasPlaceholder = Placeholder(); 107 Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); 108 109 NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary]; 110 feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); 111 feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); 112 if (is_bias_defined) { 113 biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias); 114 feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData(); 115 } 116 runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); 117 } 118 119 // Shave off '1' present at the end of the shape 120 if (weight_arg.dim() == 1) { 121 // Number of elements in new output shape 122 auto output_sizes = output.sizes(); 123 std::vector<int64_t> out_shape(output_sizes.begin(), output_sizes.end() - 1); 124 return output.view(IntArrayRef(out_shape)); 125 } 126 return output; 127} 128 129static Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight) { 130 TORCH_CHECK(grad_output.is_mps(), "mps_linear_backward: grad_output needs to be mps layout"); 131 TORCH_CHECK(weight.device().is_mps() && supportedFloatingOrComplexType(weight), 132 "mps_linear_backward: unsupported weights data type: ", 133 weight.scalar_type()); 134 135 TORCH_CHECK(supportedFloatingOrComplexType(grad_output), 136 "MPS device does not support linear backward for non-float inputs"); 137 138 const Tensor weight_reshaped = weight.is_contiguous() ? weight : weight.contiguous(); 139 140 struct CachedGraph : public MPSCachedGraph { 141 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 142 MPSGraphTensor* weightTensor_ = nil; 143 MPSGraphTensor* gradOutputTensor_ = nil; 144 MPSGraphTensor* outputTensor_ = nil; 145 }; 146 147 Tensor output = at::empty( 148 input_size, grad_output.scalar_type(), std::nullopt, kMPS, std::nullopt, grad_output.suggest_memory_format()); 149 TORCH_CHECK(output.is_mps()); 150 if (grad_output.numel() == 0) { 151 return output; 152 } 153 154 MPSStream* stream = getCurrentMPSStream(); 155 156 @autoreleasepool { 157 string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped}); 158 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) { 159 newCachedGraph->weightTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped); 160 newCachedGraph->gradOutputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); 161 162 // MPS matrixMultiplication crashes for 5D+ tensors on 14.2.1 with `New volume should match old volume` 163 // See https://github.com/pytorch/pytorch/issues/114942 for more details 164 bool needReshape = grad_output.dim() > 4; 165 auto gradOutputTensor = needReshape 166 ? [mpsGraph flatten2DTensor:newCachedGraph->gradOutputTensor_ axis:-1 name:nil] 167 : newCachedGraph->gradOutputTensor_; 168 169 auto outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTensor 170 secondaryTensor:newCachedGraph->weightTensor_ 171 name:nil]; 172 if (needReshape) { 173 outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:getMPSShape(output) name:nil]; 174 } 175 176 newCachedGraph->outputTensor_ = outputTensor; 177 }); 178 179 Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_reshaped); 180 Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); 181 Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); 182 183 auto feeds = dictionaryFromPlaceholders(weightPlaceholder, gradOutputPlaceholder); 184 runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); 185 186 return output; 187 } 188} 189 190static std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& grad_output, 191 const Tensor& input, 192 const Tensor& weight, 193 bool bias_defined) { 194 TORCH_CHECK(grad_output.is_mps() && input.is_mps(), 195 "_mps_linear_backward: grad_output and input needs to be mps layout"); 196 197 TORCH_CHECK(supportedFloatingOrComplexType(grad_output), 198 "MPS device does not support linear backward for non-float inputs"); 199 200 struct CachedGraph : public MPSCachedGraph { 201 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 202 MPSGraphTensor* inputTensor_ = nil; 203 MPSGraphTensor* weightTensor_ = nil; 204 MPSGraphTensor* gradOutputTensor_ = nil; 205 MPSGraphTensor* outputTensor_ = nil; 206 MPSGraphTensor* biasTensor_ = nil; 207 }; 208 209 auto grad_output_reshaped = 210 grad_output.dim() != 2 ? grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output; 211 auto input_reshaped = input.dim() != 2 ? input.reshape({-1, input.size(input.dim() - 1)}) : input; 212 213 TORCH_CHECK(grad_output_reshaped.is_mps()); 214 TORCH_CHECK(input_reshaped.is_mps()); 215 216 Tensor output = at::empty({grad_output_reshaped.size(1), input_reshaped.size(1)}, 217 grad_output.scalar_type(), 218 std::nullopt, 219 kMPS, 220 std::nullopt, 221 grad_output.suggest_memory_format()); 222 Tensor bias = at::empty({grad_output_reshaped.size(1)}, 223 grad_output.scalar_type(), 224 std::nullopt, 225 kMPS, 226 std::nullopt, 227 grad_output.suggest_memory_format()); 228 TORCH_CHECK(output.is_mps()); 229 TORCH_CHECK(bias.is_mps()); 230 231 if (grad_output.numel() == 0) { 232 output.zero_(); 233 bias.zero_(); 234 return std::tuple<Tensor, Tensor>{output, bias}; 235 } 236 MPSStream* stream = getCurrentMPSStream(); 237 238 @autoreleasepool { 239 string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" + 240 getTensorsStringKey({input_reshaped, weight, grad_output_reshaped}); 241 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 242 MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped); 243 MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); 244 MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped); 245 246 MPSGraphTensor* gradOutputTransposeTensor = [mpsGraph transposeTensor:gradOutputTensor 247 dimension:-1 248 withDimension:-2 249 name:nil]; 250 251 // grad_weight 252 MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTransposeTensor 253 secondaryTensor:inputTensor 254 name:nil]; 255 MPSGraphTensor* biasTensor = nil; 256 if (bias_defined) { 257 // grad_bias 258 biasTensor = [mpsGraph reductionSumWithTensor:gradOutputTensor axis:0 name:nil]; 259 } 260 261 newCachedGraph->inputTensor_ = inputTensor; 262 newCachedGraph->weightTensor_ = weightTensor; 263 newCachedGraph->gradOutputTensor_ = gradOutputTensor; 264 newCachedGraph->outputTensor_ = outputTensor; 265 newCachedGraph->biasTensor_ = biasTensor; 266 }); 267 268 Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_reshaped); 269 Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); 270 Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_reshaped); 271 Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); 272 Placeholder biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias); 273 274 auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, inputPlaceholder, weightPlaceholder); 275 auto results = bias_defined ? dictionaryFromPlaceholders(outputPlaceholder, biasPlaceholder) 276 : dictionaryFromPlaceholders(outputPlaceholder); 277 runMPSGraph(stream, cachedGraph->graph(), feeds, results); 278 279 return std::tuple<Tensor, Tensor>{output, bias}; 280 } 281} 282 283std::tuple<Tensor, Tensor, Tensor> mps_linear_backward(const Tensor& input, 284 const Tensor& grad_output, 285 const Tensor& weight, 286 std::array<bool, 3> output_mask) { 287 Tensor grad_input, grad_weight, grad_bias; 288 if (output_mask[0]) { 289 grad_input = _mps_linear_backward_input(input.sizes(), grad_output, weight); 290 } 291 if (output_mask[1] || output_mask[2]) { 292 std::tie(grad_weight, grad_bias) = _mps_linear_backward_weights(grad_output, input, weight, output_mask[2]); 293 } 294 return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias}; 295} 296 297} // namespace at::native 298