1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/kernel.h> 4 5 namespace torch { 6 namespace jit { 7 namespace tensorexpr { 8 9 Tensor computeBatchNorm( 10 const std::vector<ArgValue>& inputs, 11 const std::vector<ExprHandle>& outputShape, 12 const std::vector<ExprHandle>& outputStrides, 13 const std::optional<ScalarType>& outputType, 14 at::Device device); 15 16 } // namespace tensorexpr 17 } // namespace jit 18 } // namespace torch 19