1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/FusedAdagrad.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_fused_adagrad.h>
11 #include <ATen/ops/_fused_adagrad_native.h>
12 #endif
13
14
15 namespace at::native {
16
_fused_adagrad_kernel_cpu_(at::TensorList params,at::TensorList grads,at::TensorList state_sums,at::TensorList state_steps,const double lr,const double lr_decay,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)17 void _fused_adagrad_kernel_cpu_(
18 at::TensorList params,
19 at::TensorList grads,
20 at::TensorList state_sums,
21 at::TensorList state_steps,
22 const double lr,
23 const double lr_decay,
24 const double weight_decay,
25 const double eps,
26 const bool maximize,
27 const std::optional<at::Tensor>& grad_scale,
28 const std::optional<at::Tensor>& found_inf) {
29 const float* grad_scale_ptr =
30 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
31 const float* found_inf_ptr =
32 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
33 if (found_inf_ptr && *found_inf_ptr == 1.0) {
34 return;
35 }
36 size_t n_tensors = params.size();
37 TORCH_CHECK(grads.size() == n_tensors);
38 TORCH_CHECK(state_sums.size() == n_tensors);
39 TORCH_CHECK(state_steps.size() == n_tensors);
40 for (size_t i = 0; i < n_tensors; i++){
41 fused_adagrad_stub(
42 kCPU,
43 params[i],
44 grads[i],
45 state_sums[i],
46 state_steps[i],
47 lr,
48 lr_decay,
49 weight_decay,
50 eps,
51 maximize,
52 grad_scale_ptr);
53 }
54 }
55
56 DEFINE_DISPATCH(fused_adagrad_stub);
57
58 }
59