xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FusedAdam.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/FusedAdam.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_fused_adam.h>
11 #include <ATen/ops/_fused_adam_native.h>
12 #include <ATen/ops/_fused_adamw.h>
13 #include <ATen/ops/_fused_adamw_native.h>
14 #endif
15 
16 
17 namespace at::native {
18 
_fused_adam_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)19 void _fused_adam_kernel_cpu_(
20     at::TensorList params,
21     at::TensorList grads,
22     at::TensorList exp_avgs,
23     at::TensorList exp_avg_sqs,
24     at::TensorList max_exp_avg_sqs,
25     at::TensorList state_steps,
26     const double lr,
27     const double beta1,
28     const double beta2,
29     const double weight_decay,
30     const double eps,
31     const bool amsgrad,
32     const bool maximize,
33     const std::optional<at::Tensor>& grad_scale,
34     const std::optional<at::Tensor>& found_inf) {
35   const float* grad_scale_ptr =
36       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
37   const float* found_inf_ptr =
38       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
39   if (found_inf_ptr && *found_inf_ptr == 1.0) {
40       return;
41   }
42   size_t n_tensors = params.size();
43   TORCH_CHECK(grads.size() == n_tensors);
44   TORCH_CHECK(exp_avgs.size() == n_tensors);
45   TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
46   if (amsgrad) {
47     TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
48   } else {
49     TORCH_CHECK(max_exp_avg_sqs.empty());
50   }
51   TORCH_CHECK(state_steps.size() == n_tensors);
52   at::Tensor max_exp_avg_sq = at::Tensor();
53   for (size_t i = 0; i < n_tensors; i++){
54     if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
55     fused_adam_stub(
56       kCPU,
57       params[i],
58       grads[i],
59       exp_avgs[i],
60       exp_avg_sqs[i],
61       max_exp_avg_sq,
62       state_steps[i],
63       lr,
64       beta1,
65       beta2,
66       weight_decay,
67       eps,
68       amsgrad,
69       maximize,
70       grad_scale_ptr,
71       ADAM_MODE::ORIGINAL);
72   }
73 }
74 
75 // The following overload simply has a Tensor lr
_fused_adam_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)76 void _fused_adam_kernel_cpu_(
77     at::TensorList params,
78     at::TensorList grads,
79     at::TensorList exp_avgs,
80     at::TensorList exp_avg_sqs,
81     at::TensorList max_exp_avg_sqs,
82     at::TensorList state_steps,
83     const at::Tensor& lr,
84     const double beta1,
85     const double beta2,
86     const double weight_decay,
87     const double eps,
88     const bool amsgrad,
89     const bool maximize,
90     const std::optional<at::Tensor>& grad_scale,
91     const std::optional<at::Tensor>& found_inf) {
92   _fused_adam_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
93 }
94 
_fused_adamw_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)95 void _fused_adamw_kernel_cpu_(
96     at::TensorList params,
97     at::TensorList grads,
98     at::TensorList exp_avgs,
99     at::TensorList exp_avg_sqs,
100     at::TensorList max_exp_avg_sqs,
101     at::TensorList state_steps,
102     const double lr,
103     const double beta1,
104     const double beta2,
105     const double weight_decay,
106     const double eps,
107     const bool amsgrad,
108     const bool maximize,
109     const std::optional<at::Tensor>& grad_scale,
110     const std::optional<at::Tensor>& found_inf) {
111   const float* grad_scale_ptr =
112       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
113   const float* found_inf_ptr =
114       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
115   if (found_inf_ptr && *found_inf_ptr == 1.0) {
116       return;
117   }
118   size_t n_tensors = params.size();
119   TORCH_CHECK(grads.size() == n_tensors);
120   TORCH_CHECK(exp_avgs.size() == n_tensors);
121   TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
122   if (amsgrad) {
123     TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
124   } else {
125     TORCH_CHECK(max_exp_avg_sqs.empty());
126   }
127   TORCH_CHECK(state_steps.size() == n_tensors);
128   at::Tensor max_exp_avg_sq = at::Tensor();
129   for (size_t i = 0; i < n_tensors; i++){
130     if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
131     fused_adam_stub(
132       kCPU,
133       params[i],
134       grads[i],
135       exp_avgs[i],
136       exp_avg_sqs[i],
137       max_exp_avg_sq,
138       state_steps[i],
139       lr,
140       beta1,
141       beta2,
142       weight_decay,
143       eps,
144       amsgrad,
145       maximize,
146       grad_scale_ptr,
147       ADAM_MODE::ADAMW);
148   }
149 }
150 
151 // The following overload simply has a Tensor lr
_fused_adamw_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)152 void _fused_adamw_kernel_cpu_(
153     at::TensorList params,
154     at::TensorList grads,
155     at::TensorList exp_avgs,
156     at::TensorList exp_avg_sqs,
157     at::TensorList max_exp_avg_sqs,
158     at::TensorList state_steps,
159     const at::Tensor& lr,
160     const double beta1,
161     const double beta2,
162     const double weight_decay,
163     const double eps,
164     const bool amsgrad,
165     const bool maximize,
166     const std::optional<at::Tensor>& grad_scale,
167     const std::optional<at::Tensor>& found_inf) {
168   _fused_adamw_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
169 }
170 
171 
172 DEFINE_DISPATCH(fused_adam_stub);
173 
174 }
175