xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/norm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
2 #include <torch/csrc/jit/tensorexpr/operators/norm.h>
3 
4 namespace torch::jit::tensorexpr {
5 
computeBatchNorm(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)6 Tensor computeBatchNorm(
7     const std::vector<ArgValue>& inputs,
8     const std::vector<ExprHandle>& outputShape,
9     const std::vector<ExprHandle>& outputStrides,
10     const std::optional<ScalarType>& outputType,
11     at::Device device) {
12   bool hasWeight = true;
13   bool hasBias = true;
14 
15   if (std::holds_alternative<ArgNone>(inputs[1])) {
16     hasWeight = false;
17   }
18 
19   if (std::holds_alternative<ArgNone>(inputs[2])) {
20     hasBias = false;
21   }
22 
23   return Compute(
24       "aten_batch_norm",
25       outputShape,
26       outputStrides,
27       [&](const std::vector<VarHandle>& axes) {
28         TORCH_INTERNAL_ASSERT(axes.size() >= 2);
29         // axes: N, C, H, W
30         std::vector<ExprHandle> indices(axes.begin(), axes.end());
31         ExprHandle c = indices[1];
32 
33         // Parameter list:
34         // input, weight, bias, mean, var, training, momentum, eps,
35         // cudnn_enabled
36         std::vector<ExprHandle> exprInputs = {
37             tensorOrConstant(inputs[0], indices), // input
38             tensorOrConstant(inputs[3], {c}), // mean
39             tensorOrConstant(inputs[4], {c}), // var
40             constant(inputs[7]) // eps
41         };
42 
43         ExprHandle weight = FloatImm::make(1);
44         ExprHandle bias = FloatImm::make(0);
45         if (hasWeight) {
46           weight = tensorOrConstant(inputs[1], {c});
47           exprInputs.push_back(weight);
48         }
49         if (hasBias) {
50           bias = tensorOrConstant(inputs[2], {c});
51           exprInputs.push_back(bias);
52         }
53         promoteInputs(exprInputs);
54 
55         ExprHandle input = exprInputs[0];
56         ExprHandle mean = exprInputs[1];
57         ExprHandle var = exprInputs[2];
58         ExprHandle eps = exprInputs[3];
59 
60         auto inv_var = rsqrt(var + eps);
61         auto alpha = inv_var * weight;
62         auto beta = bias - mean * alpha;
63         auto output = input * alpha + beta;
64         return demoteOutput(output, outputType);
65       });
66 }
67 
68 } // namespace torch::jit::tensorexpr
69