1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/DispatchStub.h>
5
6 namespace at::native {
7
8 using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
9 const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
10 using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&);
11 using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&,
12 const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
13
14 DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub);
15 DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub);
16 DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub);
17
18 // TensorAccessor when it is defined to work around undefined...
19 template <typename scalar_t>
conditional_accessor_1d(const Tensor & t)20 static TensorAccessor<scalar_t, 1> conditional_accessor_1d(const Tensor& t) {
21 if (! t.defined()) {
22 return TensorAccessor<scalar_t, 1>(nullptr, nullptr, nullptr);
23 }
24 return t.accessor<scalar_t, 1>();
25 }
26
27 template <typename scalar_t>
conditional_data_ptr(const Tensor & t)28 static scalar_t* conditional_data_ptr(const Tensor& t) {
29 if constexpr (std::is_const_v<scalar_t>) {
30 return t.defined() ? t.contiguous().const_data_ptr<scalar_t>()
31 : nullptr;
32 } else {
33 return t.defined() ? t.contiguous().data_ptr<scalar_t>()
34 : nullptr;
35 }
36 }
37
38 } // namespace at::native
39