1 #include <torch/csrc/jit/passes/fold_linear_bn.h> 2 3 #include <ATen/TensorOperators.h> 4 5 #ifndef AT_PER_OPERATOR_HEADERS 6 #include <ATen/Functions.h> 7 #else 8 #include <ATen/ops/rsqrt.h> 9 #endif 10 11 namespace torch::jit { 12 computeUpdatedLinearWeightAndBias(const LinearBNParameters & p)13std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias( 14 const LinearBNParameters& p) { 15 at::Tensor bn_scale = p.bn_w * at::rsqrt(p.bn_rv + p.bn_eps); 16 at::Tensor fused_w = p.linear_w * bn_scale.unsqueeze(-1); 17 at::Tensor fused_b = (p.linear_b - p.bn_rm) * bn_scale + p.bn_b; 18 19 auto linear_w_dtype = p.linear_w.dtype(); 20 auto linear_b_dtype = p.linear_b.dtype(); 21 22 return std::make_tuple( 23 fused_w.to(linear_w_dtype), fused_b.to(linear_b_dtype)); 24 } 25 26 } // namespace torch::jit 27