xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FusedAdagrad.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/FusedAdagrad.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_fused_adagrad.h>
11 #include <ATen/ops/_fused_adagrad_native.h>
12 #endif
13 
14 
15 namespace at::native {
16 
_fused_adagrad_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList state_sums,at::TensorList state_steps,const double lr,const double lr_decay,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)17 void _fused_adagrad_kernel_cpu_(
18     at::TensorList params,
19     at::TensorList grads,
20     at::TensorList state_sums,
21     at::TensorList state_steps,
22     const double lr,
23     const double lr_decay,
24     const double weight_decay,
25     const double eps,
26     const bool maximize,
27     const std::optional<at::Tensor>& grad_scale,
28     const std::optional<at::Tensor>& found_inf) {
29   const float* grad_scale_ptr =
30       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
31   const float* found_inf_ptr =
32       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
33   if (found_inf_ptr && *found_inf_ptr == 1.0) {
34       return;
35   }
36   size_t n_tensors = params.size();
37   TORCH_CHECK(grads.size() == n_tensors);
38   TORCH_CHECK(state_sums.size() == n_tensors);
39   TORCH_CHECK(state_steps.size() == n_tensors);
40   for (size_t i = 0; i < n_tensors; i++){
41     fused_adagrad_stub(
42       kCPU,
43       params[i],
44       grads[i],
45       state_sums[i],
46       state_steps[i],
47       lr,
48       lr_decay,
49       weight_decay,
50       eps,
51       maximize,
52       grad_scale_ptr);
53   }
54 }
55 
56 DEFINE_DISPATCH(fused_adagrad_stub);
57 
58 }
59