1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/FusedAdam.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_fused_adam.h>
11 #include <ATen/ops/_fused_adam_native.h>
12 #include <ATen/ops/_fused_adamw.h>
13 #include <ATen/ops/_fused_adamw_native.h>
14 #endif
15
16
17 namespace at::native {
18
_fused_adam_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)19 void _fused_adam_kernel_cpu_(
20 at::TensorList params,
21 at::TensorList grads,
22 at::TensorList exp_avgs,
23 at::TensorList exp_avg_sqs,
24 at::TensorList max_exp_avg_sqs,
25 at::TensorList state_steps,
26 const double lr,
27 const double beta1,
28 const double beta2,
29 const double weight_decay,
30 const double eps,
31 const bool amsgrad,
32 const bool maximize,
33 const std::optional<at::Tensor>& grad_scale,
34 const std::optional<at::Tensor>& found_inf) {
35 const float* grad_scale_ptr =
36 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
37 const float* found_inf_ptr =
38 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
39 if (found_inf_ptr && *found_inf_ptr == 1.0) {
40 return;
41 }
42 size_t n_tensors = params.size();
43 TORCH_CHECK(grads.size() == n_tensors);
44 TORCH_CHECK(exp_avgs.size() == n_tensors);
45 TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
46 if (amsgrad) {
47 TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
48 } else {
49 TORCH_CHECK(max_exp_avg_sqs.empty());
50 }
51 TORCH_CHECK(state_steps.size() == n_tensors);
52 at::Tensor max_exp_avg_sq = at::Tensor();
53 for (size_t i = 0; i < n_tensors; i++){
54 if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
55 fused_adam_stub(
56 kCPU,
57 params[i],
58 grads[i],
59 exp_avgs[i],
60 exp_avg_sqs[i],
61 max_exp_avg_sq,
62 state_steps[i],
63 lr,
64 beta1,
65 beta2,
66 weight_decay,
67 eps,
68 amsgrad,
69 maximize,
70 grad_scale_ptr,
71 ADAM_MODE::ORIGINAL);
72 }
73 }
74
75 // The following overload simply has a Tensor lr
_fused_adam_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)76 void _fused_adam_kernel_cpu_(
77 at::TensorList params,
78 at::TensorList grads,
79 at::TensorList exp_avgs,
80 at::TensorList exp_avg_sqs,
81 at::TensorList max_exp_avg_sqs,
82 at::TensorList state_steps,
83 const at::Tensor& lr,
84 const double beta1,
85 const double beta2,
86 const double weight_decay,
87 const double eps,
88 const bool amsgrad,
89 const bool maximize,
90 const std::optional<at::Tensor>& grad_scale,
91 const std::optional<at::Tensor>& found_inf) {
92 _fused_adam_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
93 }
94
_fused_adamw_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)95 void _fused_adamw_kernel_cpu_(
96 at::TensorList params,
97 at::TensorList grads,
98 at::TensorList exp_avgs,
99 at::TensorList exp_avg_sqs,
100 at::TensorList max_exp_avg_sqs,
101 at::TensorList state_steps,
102 const double lr,
103 const double beta1,
104 const double beta2,
105 const double weight_decay,
106 const double eps,
107 const bool amsgrad,
108 const bool maximize,
109 const std::optional<at::Tensor>& grad_scale,
110 const std::optional<at::Tensor>& found_inf) {
111 const float* grad_scale_ptr =
112 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
113 const float* found_inf_ptr =
114 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
115 if (found_inf_ptr && *found_inf_ptr == 1.0) {
116 return;
117 }
118 size_t n_tensors = params.size();
119 TORCH_CHECK(grads.size() == n_tensors);
120 TORCH_CHECK(exp_avgs.size() == n_tensors);
121 TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
122 if (amsgrad) {
123 TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
124 } else {
125 TORCH_CHECK(max_exp_avg_sqs.empty());
126 }
127 TORCH_CHECK(state_steps.size() == n_tensors);
128 at::Tensor max_exp_avg_sq = at::Tensor();
129 for (size_t i = 0; i < n_tensors; i++){
130 if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
131 fused_adam_stub(
132 kCPU,
133 params[i],
134 grads[i],
135 exp_avgs[i],
136 exp_avg_sqs[i],
137 max_exp_avg_sq,
138 state_steps[i],
139 lr,
140 beta1,
141 beta2,
142 weight_decay,
143 eps,
144 amsgrad,
145 maximize,
146 grad_scale_ptr,
147 ADAM_MODE::ADAMW);
148 }
149 }
150
151 // The following overload simply has a Tensor lr
_fused_adamw_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList max_exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool amsgrad,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)152 void _fused_adamw_kernel_cpu_(
153 at::TensorList params,
154 at::TensorList grads,
155 at::TensorList exp_avgs,
156 at::TensorList exp_avg_sqs,
157 at::TensorList max_exp_avg_sqs,
158 at::TensorList state_steps,
159 const at::Tensor& lr,
160 const double beta1,
161 const double beta2,
162 const double weight_decay,
163 const double eps,
164 const bool amsgrad,
165 const bool maximize,
166 const std::optional<at::Tensor>& grad_scale,
167 const std::optional<at::Tensor>& found_inf) {
168 _fused_adamw_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
169 }
170
171
172 DEFINE_DISPATCH(fused_adam_stub);
173
174 }
175