1 #include <torch/nn/functional/batchnorm.h>
2 #include <torch/nn/modules/batchnorm.h>
3
4 #include <torch/cuda.h>
5 #include <torch/types.h>
6
7 #include <c10/util/Exception.h>
8
9 #include <cstddef>
10 #include <ostream>
11 #include <utility>
12 #include <vector>
13
14 namespace torch {
15 namespace nn {
16
_check_input_dim(const Tensor & input)17 void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {
18 TORCH_CHECK(
19 input.dim() == 2 || input.dim() == 3,
20 "expected 2D or 3D input (got ",
21 input.dim(),
22 "D input)");
23 }
24
_check_input_dim(const Tensor & input)25 void BatchNorm2dImpl::_check_input_dim(const Tensor& input) {
26 TORCH_CHECK(
27 input.dim() == 4, "expected 4D input (got ", input.dim(), "D input)");
28 }
29
_check_input_dim(const Tensor & input)30 void BatchNorm3dImpl::_check_input_dim(const Tensor& input) {
31 TORCH_CHECK(
32 input.dim() == 5, "expected 5D input (got ", input.dim(), "D input)");
33 }
34
35 template class BatchNormImplBase<1, BatchNorm1dImpl>;
36 template class BatchNormImplBase<2, BatchNorm2dImpl>;
37 template class BatchNormImplBase<3, BatchNorm3dImpl>;
38
39 } // namespace nn
40 } // namespace torch
41