1 #include <ATen/core/Tensor.h> 2 #include <ATen/native/DispatchStub.h> 3 4 namespace at::native { 5 6 using fused_sgd_fn = void (*)( 7 const at::Tensor& param, 8 const at::Tensor& grad, 9 const at::Tensor& momentum_buffer, 10 const double weight_decay, 11 const double momentum, 12 const double lr, 13 const double dampening, 14 const bool nesterov, 15 const bool maximize, 16 const bool is_first_step, 17 const float* grad_scale_ptr); 18 19 DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub); 20 21 } // namespace at::native 22