xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/batchnorm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/functional/batchnorm.h>
2 #include <torch/nn/modules/batchnorm.h>
3 
4 #include <torch/cuda.h>
5 #include <torch/types.h>
6 
7 #include <c10/util/Exception.h>
8 
9 #include <cstddef>
10 #include <ostream>
11 #include <utility>
12 #include <vector>
13 
14 namespace torch {
15 namespace nn {
16 
_check_input_dim(const Tensor & input)17 void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {
18   TORCH_CHECK(
19       input.dim() == 2 || input.dim() == 3,
20       "expected 2D or 3D input (got ",
21       input.dim(),
22       "D input)");
23 }
24 
_check_input_dim(const Tensor & input)25 void BatchNorm2dImpl::_check_input_dim(const Tensor& input) {
26   TORCH_CHECK(
27       input.dim() == 4, "expected 4D input (got ", input.dim(), "D input)");
28 }
29 
_check_input_dim(const Tensor & input)30 void BatchNorm3dImpl::_check_input_dim(const Tensor& input) {
31   TORCH_CHECK(
32       input.dim() == 5, "expected 5D input (got ", input.dim(), "D input)");
33 }
34 
35 template class BatchNormImplBase<1, BatchNorm1dImpl>;
36 template class BatchNormImplBase<2, BatchNorm2dImpl>;
37 template class BatchNormImplBase<3, BatchNorm3dImpl>;
38 
39 } // namespace nn
40 } // namespace torch
41