xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FusedSGD.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <ATen/native/DispatchStub.h>
3 
4 namespace at::native {
5 
6 using fused_sgd_fn = void (*)(
7     const at::Tensor& param,
8     const at::Tensor& grad,
9     const at::Tensor& momentum_buffer,
10     const double weight_decay,
11     const double momentum,
12     const double lr,
13     const double dampening,
14     const bool nesterov,
15     const bool maximize,
16     const bool is_first_step,
17     const float* grad_scale_ptr);
18 
19 DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub);
20 
21 } // namespace at::native
22