1 #include <c10/util/irange.h> 2 #include <torch/csrc/jit/ir/alias_analysis.h> 3 #include <torch/csrc/jit/ir/ir_views.h> 4 #include <torch/csrc/jit/passes/frozen_concat_linear.h> 5 #include <torch/csrc/jit/passes/frozen_conv_folding.h> 6 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h> 7 #include <torch/csrc/jit/passes/frozen_linear_folding.h> 8 #include <torch/csrc/jit/passes/remove_dropout.h> 9 #include <torch/csrc/jit/runtime/graph_executor.h> 10 11 namespace torch::jit { 12 OptimizeFrozenGraph(std::shared_ptr<Graph> & graph,bool optimize_numerics)13void OptimizeFrozenGraph( 14 std::shared_ptr<Graph>& graph, 15 bool optimize_numerics) { 16 removeDropout(graph); 17 FrozenConcatLinear(graph); 18 // run a couple times to capture Conv -> Mul -> Add etc 19 if (optimize_numerics) { 20 bool changed = false; 21 do { 22 changed = false; 23 changed |= FoldFrozenConvBatchnorm(graph); 24 changed |= FoldFrozenConvAddOrSub(graph); 25 changed |= FoldFrozenConvMulOrDiv(graph); 26 changed |= FoldFrozenLinearBatchnorm(graph); 27 } while (changed); 28 } 29 } 30 31 } // namespace torch::jit 32