xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/instancenorm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/functional/instancenorm.h>
2 #include <torch/nn/modules/instancenorm.h>
3 
4 namespace torch {
5 namespace nn {
6 
_check_input_dim(const Tensor & input)7 void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) {
8   if (input.dim() != 3 && input.dim() != 2) {
9     TORCH_CHECK(
10         false, "expected 2D or 3D input (got ", input.dim(), "D input)");
11   }
12 }
13 
_check_input_dim(const Tensor & input)14 void InstanceNorm2dImpl::_check_input_dim(const Tensor& input) {
15   if (input.dim() != 4 && input.dim() != 3) {
16     TORCH_CHECK(
17         false, "expected 3D or 4D input (got ", input.dim(), "D input)");
18   }
19 }
20 
_check_input_dim(const Tensor & input)21 void InstanceNorm3dImpl::_check_input_dim(const Tensor& input) {
22   if (input.dim() != 5 &&
23       input.dim() != 4) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
24     TORCH_CHECK(
25         false, "expected 4D or 5D input (got ", input.dim(), "D input)");
26   }
27 }
28 
29 template class InstanceNormImpl<1, InstanceNorm1dImpl>;
30 template class InstanceNormImpl<2, InstanceNorm2dImpl>;
31 template class InstanceNormImpl<3, InstanceNorm3dImpl>;
32 
33 } // namespace nn
34 } // namespace torch
35