xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/AmpKernels.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/core/ATen_fwd.h>
5 
6 namespace at {
7 class Tensor;
8 
9 namespace native {
10 
11 using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
12     TensorList,
13     Tensor&,
14     const Tensor&);
15 
16 using _amp_update_scale_cpu__fn = Tensor& (*)(
17     Tensor&,
18     Tensor&,
19     const Tensor&,
20     double,
21     double,
22     int64_t);
23 
24 DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
25 DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
26 
27 } // namespace native
28 } // namespace at
29