xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FusedAdam.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <ATen/native/DispatchStub.h>
3 
4 namespace at::native {
5 
6 enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
7 
8 using fused_adam_fn = void (*)(
9     const at::Tensor& param,
10     const at::Tensor& grad,
11     const at::Tensor& exp_avg,
12     const at::Tensor& exp_avg_sq,
13     const at::Tensor& max_exp_avg_sq,
14     const at::Tensor& state_step,
15     const double lr,
16     const double beta1,
17     const double beta2,
18     const double weight_decay,
19     const double eps,
20     const bool amsgrad,
21     const bool maximize,
22     const float* grad_scale_ptr,
23     const ADAM_MODE);
24 
25 DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);
26 
27 } // namespace at::native
28