1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/FusedSGD.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_fused_sgd.h>
11 #include <ATen/ops/_fused_sgd_native.h>
12 #endif
13
14
15 namespace at::native {
16
17
_fused_sgd_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)18 void _fused_sgd_kernel_cpu_(
19 at::TensorList params,
20 at::TensorList grads,
21 at::TensorList momentum_buffer_list,
22 const double weight_decay,
23 const double momentum,
24 const double lr,
25 const double dampening,
26 const bool nesterov,
27 const bool maximize,
28 const bool is_first_step,
29 const std::optional<at::Tensor>& grad_scale,
30 const std::optional<at::Tensor>& found_inf) {
31 const float* grad_scale_ptr =
32 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
33 const float* found_inf_ptr =
34 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
35 if (found_inf_ptr && *found_inf_ptr == 1.0) {
36 return;
37 }
38 size_t n_tensors = params.size();
39 TORCH_CHECK(grads.size() == n_tensors);
40 bool no_momentum_buffer = momentum == 0.0;
41 if (no_momentum_buffer) {
42 TORCH_CHECK(momentum_buffer_list.empty());
43 } else {
44 TORCH_CHECK(momentum_buffer_list.size() == n_tensors);
45 }
46 for (size_t i = 0; i < n_tensors; i++){
47 fused_sgd_stub(
48 kCPU,
49 params[i],
50 grads[i],
51 no_momentum_buffer ? Tensor() : momentum_buffer_list[i],
52 weight_decay,
53 momentum,
54 lr,
55 dampening,
56 nesterov,
57 maximize,
58 is_first_step,
59 grad_scale_ptr);
60 }
61 }
62
_fused_sgd_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const at::Tensor & lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)63 void _fused_sgd_kernel_cpu_(
64 at::TensorList params,
65 at::TensorList grads,
66 at::TensorList momentum_buffer_list,
67 const double weight_decay,
68 const double momentum,
69 const at::Tensor& lr,
70 const double dampening,
71 const bool nesterov,
72 const bool maximize,
73 const bool is_first_step,
74 const std::optional<at::Tensor>& grad_scale,
75 const std::optional<at::Tensor>& found_inf) {
76 _fused_sgd_kernel_cpu_(
77 params, grads, momentum_buffer_list, weight_decay,
78 momentum, lr.item<double>(), dampening, nesterov,
79 maximize, is_first_step, grad_scale, found_inf
80 );
81 }
82
83 DEFINE_DISPATCH(fused_sgd_stub);
84
85 }
86