1 #include <torch/nn/functional/instancenorm.h>
2 #include <torch/nn/modules/instancenorm.h>
3
4 namespace torch {
5 namespace nn {
6
_check_input_dim(const Tensor & input)7 void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) {
8 if (input.dim() != 3 && input.dim() != 2) {
9 TORCH_CHECK(
10 false, "expected 2D or 3D input (got ", input.dim(), "D input)");
11 }
12 }
13
_check_input_dim(const Tensor & input)14 void InstanceNorm2dImpl::_check_input_dim(const Tensor& input) {
15 if (input.dim() != 4 && input.dim() != 3) {
16 TORCH_CHECK(
17 false, "expected 3D or 4D input (got ", input.dim(), "D input)");
18 }
19 }
20
_check_input_dim(const Tensor & input)21 void InstanceNorm3dImpl::_check_input_dim(const Tensor& input) {
22 if (input.dim() != 5 &&
23 input.dim() != 4) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
24 TORCH_CHECK(
25 false, "expected 4D or 5D input (got ", input.dim(), "D input)");
26 }
27 }
28
29 template class InstanceNormImpl<1, InstanceNorm1dImpl>;
30 template class InstanceNormImpl<2, InstanceNorm2dImpl>;
31 template class InstanceNormImpl<3, InstanceNorm3dImpl>;
32
33 } // namespace nn
34 } // namespace torch
35