xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/batch_norm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/DispatchStub.h>
5 
6 namespace at::native {
7 
8 using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
9     const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
10 using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&);
11 using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&,
12         const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
13 
14 DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub);
15 DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub);
16 DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub);
17 
18 // TensorAccessor when it is defined to work around undefined...
19 template <typename scalar_t>
conditional_accessor_1d(const Tensor & t)20 static TensorAccessor<scalar_t, 1> conditional_accessor_1d(const Tensor& t) {
21   if (! t.defined()) {
22     return TensorAccessor<scalar_t, 1>(nullptr, nullptr, nullptr);
23   }
24   return t.accessor<scalar_t, 1>();
25 }
26 
27 template <typename scalar_t>
conditional_data_ptr(const Tensor & t)28 static scalar_t* conditional_data_ptr(const Tensor& t) {
29   if constexpr (std::is_const_v<scalar_t>) {
30     return t.defined() ? t.contiguous().const_data_ptr<scalar_t>()
31                       : nullptr;
32   } else {
33     return t.defined() ? t.contiguous().data_ptr<scalar_t>()
34                       : nullptr;
35   }
36 }
37 
38 } // namespace at::native
39