xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FusedSGD.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/FusedSGD.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_fused_sgd.h>
11 #include <ATen/ops/_fused_sgd_native.h>
12 #endif
13 
14 
15 namespace at::native {
16 
17 
_fused_sgd_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const double lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)18 void _fused_sgd_kernel_cpu_(
19     at::TensorList params,
20     at::TensorList grads,
21     at::TensorList momentum_buffer_list,
22     const double weight_decay,
23     const double momentum,
24     const double lr,
25     const double dampening,
26     const bool nesterov,
27     const bool maximize,
28     const bool is_first_step,
29     const std::optional<at::Tensor>& grad_scale,
30     const std::optional<at::Tensor>& found_inf) {
31   const float* grad_scale_ptr =
32       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
33   const float* found_inf_ptr =
34       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
35   if (found_inf_ptr && *found_inf_ptr == 1.0) {
36       return;
37   }
38   size_t n_tensors = params.size();
39   TORCH_CHECK(grads.size() == n_tensors);
40   bool no_momentum_buffer = momentum == 0.0;
41   if (no_momentum_buffer) {
42     TORCH_CHECK(momentum_buffer_list.empty());
43   } else {
44     TORCH_CHECK(momentum_buffer_list.size() == n_tensors);
45   }
46   for (size_t i = 0; i < n_tensors; i++){
47     fused_sgd_stub(
48       kCPU,
49       params[i],
50       grads[i],
51       no_momentum_buffer ? Tensor() : momentum_buffer_list[i],
52       weight_decay,
53       momentum,
54       lr,
55       dampening,
56       nesterov,
57       maximize,
58       is_first_step,
59       grad_scale_ptr);
60   }
61 }
62 
_fused_sgd_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList momentum_buffer_list,const double weight_decay,const double momentum,const at::Tensor & lr,const double dampening,const bool nesterov,const bool maximize,const bool is_first_step,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)63 void _fused_sgd_kernel_cpu_(
64     at::TensorList params,
65     at::TensorList grads,
66     at::TensorList momentum_buffer_list,
67     const double weight_decay,
68     const double momentum,
69     const at::Tensor& lr,
70     const double dampening,
71     const bool nesterov,
72     const bool maximize,
73     const bool is_first_step,
74     const std::optional<at::Tensor>& grad_scale,
75     const std::optional<at::Tensor>& found_inf) {
76     _fused_sgd_kernel_cpu_(
77         params, grads, momentum_buffer_list, weight_decay,
78         momentum, lr.item<double>(), dampening, nesterov,
79         maximize, is_first_step, grad_scale, found_inf
80     );
81 }
82 
83 DEFINE_DISPATCH(fused_sgd_stub);
84 
85 }
86