1 #include <torch/csrc/jit/ir/constants.h>
2 #include <torch/csrc/jit/ir/ir.h>
3 #include <torch/csrc/jit/passes/dead_code_elimination.h>
4 #include <torch/csrc/jit/passes/fold_linear_bn.h>
5 #include <torch/csrc/jit/passes/frozen_linear_folding.h>
6 #include <torch/csrc/jit/passes/utils/optimization_utils.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #else
11 #include <ATen/ops/ones_like.h>
12 #include <ATen/ops/zeros_like.h>
13 #endif
14
15 namespace torch::jit {
16
17 namespace {
18
19 using Tensor = at::Tensor;
20
supportedLinearNode(Node * n)21 bool supportedLinearNode(Node* n) {
22 if (n->kind() == aten::linear) {
23 return true;
24 } else {
25 return false;
26 }
27 }
28
FoldFrozenLinearBatchnorm(Block * b)29 bool FoldFrozenLinearBatchnorm(Block* b) {
30 bool graph_modified = false;
31 for (Node* n : b->nodes()) {
32 for (Block* block : n->blocks()) {
33 graph_modified |= FoldFrozenLinearBatchnorm(block);
34 }
35
36 if (n->kind() == aten::batch_norm &&
37 supportedLinearNode(n->inputs().at(0)->node())) {
38 auto linear = n->inputs().at(0)->node();
39 auto bn = n;
40
41 if (nonConstantParameters(linear) || nonConstantParameters(bn)) {
42 continue;
43 }
44
45 auto bn_rm_ivalue = bn->namedInput("running_mean");
46 auto bn_rv_ivalue = bn->namedInput("running_var");
47
48 // check running_mean and running_var has value, if they are
49 // None(track_running_stats=False), skipping the folding path.
50 if (bn_rm_ivalue->type() == NoneType::get() &&
51 bn_rv_ivalue->type() == NoneType::get()) {
52 continue;
53 }
54
55 auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
56 auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
57 auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
58 auto linear_w = constant_as<Tensor>(linear->namedInput("weight")).value();
59
60 int64_t linear_out_features = linear_w.size(0);
61 int64_t bn_num_features = bn_rm.size(0);
62
63 // Linear-BN needs to be fused while preserving the shapes of linear
64 // weight/bias. To preserve the shapes of linear weight/bias, the channel
65 // dim of bn needs to be broadcastable with the last dim of linear,
66 // because bn operates over the channel dim, (N, C_in, H, W) while linear
67 // operates over the last dim, (*, H_in). To be broadcastable, the number
68 // of features in bn and the number of output features from linear must
69 // satisfy the following condition:
70 // 1. they are equal, or
71 // 2. the number of features in bn is 1
72 // Otherwise, skip the folding path
73 if (!(linear_out_features == bn_num_features || bn_num_features == 1)) {
74 continue;
75 }
76
77 // implementation taken from torch/nn/utils/fusion.py
78 Tensor linear_b;
79 if (linear->namedInput("bias")->type() == NoneType::get()) {
80 at::ScalarType bias_dtype = bn_rm.scalar_type();
81 at::ScalarType weight_dtype = linear_w.scalar_type();
82 at::DeviceType weight_device = linear_w.device().type();
83 if (weight_device == at::kCUDA &&
84 (weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
85 bias_dtype == at::kFloat) {
86 bias_dtype = weight_dtype;
87 }
88 linear_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
89 } else {
90 linear_b = constant_as<Tensor>(linear->namedInput("bias")).value();
91 }
92 Tensor bn_w;
93 if (bn->namedInput("weight")->type() == NoneType::get()) {
94 bn_w = at::ones_like(bn_rm);
95 } else {
96 bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
97 }
98 Tensor bn_b;
99 if (n->namedInput("bias")->type() == NoneType::get()) {
100 bn_b = at::zeros_like(bn_rm);
101 } else {
102 bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
103 }
104
105 LinearBNParameters params;
106 params.linear_w = linear_w;
107 params.linear_b = linear_b;
108 params.bn_rm = bn_rm;
109 params.bn_rv = bn_rv;
110 params.bn_eps = bn_eps;
111 params.bn_w = bn_w;
112 params.bn_b = bn_b;
113 std::tuple<Tensor, Tensor> out =
114 computeUpdatedLinearWeightAndBias(params);
115 WithInsertPoint guard(linear);
116 auto fused_linear_w = b->owningGraph()->insertConstant(std::get<0>(out));
117 auto fused_linear_b = b->owningGraph()->insertConstant(std::get<1>(out));
118 auto linear_w_value = linear->namedInput("weight");
119 auto linear_b_value = linear->namedInput("bias");
120
121 fused_linear_w->setDebugName(linear_w_value->debugName() + "_fused_bn");
122 fused_linear_b->setDebugName(linear_b_value->debugName() + "_fused_bn");
123
124 linear->replaceInputWith(linear_w_value, fused_linear_w);
125 linear->replaceInputWith(linear_b_value, fused_linear_b);
126
127 bn->output()->replaceAllUsesWith(linear->output());
128 graph_modified = true;
129 }
130 }
131 return graph_modified;
132 }
133
134 } // namespace
135
FoldFrozenLinearBatchnorm(std::shared_ptr<Graph> & graph)136 bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph) {
137 bool graph_modified = FoldFrozenLinearBatchnorm(graph->block());
138 EliminateDeadCode(graph);
139 return graph_modified;
140 }
141
142 } // namespace torch::jit
143