xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/frozen_linear_folding.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/constants.h>
2 #include <torch/csrc/jit/ir/ir.h>
3 #include <torch/csrc/jit/passes/dead_code_elimination.h>
4 #include <torch/csrc/jit/passes/fold_linear_bn.h>
5 #include <torch/csrc/jit/passes/frozen_linear_folding.h>
6 #include <torch/csrc/jit/passes/utils/optimization_utils.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #else
11 #include <ATen/ops/ones_like.h>
12 #include <ATen/ops/zeros_like.h>
13 #endif
14 
15 namespace torch::jit {
16 
17 namespace {
18 
19 using Tensor = at::Tensor;
20 
supportedLinearNode(Node * n)21 bool supportedLinearNode(Node* n) {
22   if (n->kind() == aten::linear) {
23     return true;
24   } else {
25     return false;
26   }
27 }
28 
FoldFrozenLinearBatchnorm(Block * b)29 bool FoldFrozenLinearBatchnorm(Block* b) {
30   bool graph_modified = false;
31   for (Node* n : b->nodes()) {
32     for (Block* block : n->blocks()) {
33       graph_modified |= FoldFrozenLinearBatchnorm(block);
34     }
35 
36     if (n->kind() == aten::batch_norm &&
37         supportedLinearNode(n->inputs().at(0)->node())) {
38       auto linear = n->inputs().at(0)->node();
39       auto bn = n;
40 
41       if (nonConstantParameters(linear) || nonConstantParameters(bn)) {
42         continue;
43       }
44 
45       auto bn_rm_ivalue = bn->namedInput("running_mean");
46       auto bn_rv_ivalue = bn->namedInput("running_var");
47 
48       // check running_mean and running_var has value, if they are
49       // None(track_running_stats=False), skipping the folding path.
50       if (bn_rm_ivalue->type() == NoneType::get() &&
51           bn_rv_ivalue->type() == NoneType::get()) {
52         continue;
53       }
54 
55       auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
56       auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
57       auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
58       auto linear_w = constant_as<Tensor>(linear->namedInput("weight")).value();
59 
60       int64_t linear_out_features = linear_w.size(0);
61       int64_t bn_num_features = bn_rm.size(0);
62 
63       // Linear-BN needs to be fused while preserving the shapes of linear
64       // weight/bias. To preserve the shapes of linear weight/bias, the channel
65       // dim of bn needs to be broadcastable with the last dim of linear,
66       // because bn operates over the channel dim, (N, C_in, H, W) while linear
67       // operates over the last dim, (*, H_in). To be broadcastable, the number
68       // of features in bn and the number of output features from linear must
69       // satisfy the following condition:
70       // 1. they are equal, or
71       // 2. the number of features in bn is 1
72       // Otherwise, skip the folding path
73       if (!(linear_out_features == bn_num_features || bn_num_features == 1)) {
74         continue;
75       }
76 
77       // implementation taken from torch/nn/utils/fusion.py
78       Tensor linear_b;
79       if (linear->namedInput("bias")->type() == NoneType::get()) {
80         at::ScalarType bias_dtype = bn_rm.scalar_type();
81         at::ScalarType weight_dtype = linear_w.scalar_type();
82         at::DeviceType weight_device = linear_w.device().type();
83         if (weight_device == at::kCUDA &&
84             (weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
85             bias_dtype == at::kFloat) {
86           bias_dtype = weight_dtype;
87         }
88         linear_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
89       } else {
90         linear_b = constant_as<Tensor>(linear->namedInput("bias")).value();
91       }
92       Tensor bn_w;
93       if (bn->namedInput("weight")->type() == NoneType::get()) {
94         bn_w = at::ones_like(bn_rm);
95       } else {
96         bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
97       }
98       Tensor bn_b;
99       if (n->namedInput("bias")->type() == NoneType::get()) {
100         bn_b = at::zeros_like(bn_rm);
101       } else {
102         bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
103       }
104 
105       LinearBNParameters params;
106       params.linear_w = linear_w;
107       params.linear_b = linear_b;
108       params.bn_rm = bn_rm;
109       params.bn_rv = bn_rv;
110       params.bn_eps = bn_eps;
111       params.bn_w = bn_w;
112       params.bn_b = bn_b;
113       std::tuple<Tensor, Tensor> out =
114           computeUpdatedLinearWeightAndBias(params);
115       WithInsertPoint guard(linear);
116       auto fused_linear_w = b->owningGraph()->insertConstant(std::get<0>(out));
117       auto fused_linear_b = b->owningGraph()->insertConstant(std::get<1>(out));
118       auto linear_w_value = linear->namedInput("weight");
119       auto linear_b_value = linear->namedInput("bias");
120 
121       fused_linear_w->setDebugName(linear_w_value->debugName() + "_fused_bn");
122       fused_linear_b->setDebugName(linear_b_value->debugName() + "_fused_bn");
123 
124       linear->replaceInputWith(linear_w_value, fused_linear_w);
125       linear->replaceInputWith(linear_b_value, fused_linear_b);
126 
127       bn->output()->replaceAllUsesWith(linear->output());
128       graph_modified = true;
129     }
130   }
131   return graph_modified;
132 }
133 
134 } // namespace
135 
FoldFrozenLinearBatchnorm(std::shared_ptr<Graph> & graph)136 bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph) {
137   bool graph_modified = FoldFrozenLinearBatchnorm(graph->block());
138   EliminateDeadCode(graph);
139   return graph_modified;
140 }
141 
142 } // namespace torch::jit
143