1 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
2 #include <torch/csrc/jit/tensorexpr/operators/norm.h>
3
4 namespace torch::jit::tensorexpr {
5
computeBatchNorm(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)6 Tensor computeBatchNorm(
7 const std::vector<ArgValue>& inputs,
8 const std::vector<ExprHandle>& outputShape,
9 const std::vector<ExprHandle>& outputStrides,
10 const std::optional<ScalarType>& outputType,
11 at::Device device) {
12 bool hasWeight = true;
13 bool hasBias = true;
14
15 if (std::holds_alternative<ArgNone>(inputs[1])) {
16 hasWeight = false;
17 }
18
19 if (std::holds_alternative<ArgNone>(inputs[2])) {
20 hasBias = false;
21 }
22
23 return Compute(
24 "aten_batch_norm",
25 outputShape,
26 outputStrides,
27 [&](const std::vector<VarHandle>& axes) {
28 TORCH_INTERNAL_ASSERT(axes.size() >= 2);
29 // axes: N, C, H, W
30 std::vector<ExprHandle> indices(axes.begin(), axes.end());
31 ExprHandle c = indices[1];
32
33 // Parameter list:
34 // input, weight, bias, mean, var, training, momentum, eps,
35 // cudnn_enabled
36 std::vector<ExprHandle> exprInputs = {
37 tensorOrConstant(inputs[0], indices), // input
38 tensorOrConstant(inputs[3], {c}), // mean
39 tensorOrConstant(inputs[4], {c}), // var
40 constant(inputs[7]) // eps
41 };
42
43 ExprHandle weight = FloatImm::make(1);
44 ExprHandle bias = FloatImm::make(0);
45 if (hasWeight) {
46 weight = tensorOrConstant(inputs[1], {c});
47 exprInputs.push_back(weight);
48 }
49 if (hasBias) {
50 bias = tensorOrConstant(inputs[2], {c});
51 exprInputs.push_back(bias);
52 }
53 promoteInputs(exprInputs);
54
55 ExprHandle input = exprInputs[0];
56 ExprHandle mean = exprInputs[1];
57 ExprHandle var = exprInputs[2];
58 ExprHandle eps = exprInputs[3];
59
60 auto inv_var = rsqrt(var + eps);
61 auto alpha = inv_var * weight;
62 auto beta = bias - mean * alpha;
63 auto output = input * alpha + beta;
64 return demoteOutput(output, outputType);
65 });
66 }
67
68 } // namespace torch::jit::tensorexpr
69