1 #include <ATen/Utils.h> 2 3 #include <torch/csrc/jit/ir/constants.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <torch/csrc/jit/ir/subgraph_matcher.h> 6 #include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h> 7 #include <torch/csrc/jit/passes/graph_rewrite_helper.h> 8 #include <torch/csrc/jit/passes/remove_mutation.h> 9 #include <torch/csrc/jit/passes/subgraph_rewrite.h> 10 #ifdef USE_CUDA 11 #include <ATen/cuda/CUDAConfig.h> 12 #endif 13 14 namespace torch::jit { 15 getFuseFrozenConvAddReluImpl()16std::function<void(std::shared_ptr<Graph>&)>& getFuseFrozenConvAddReluImpl() { 17 static std::function<void(std::shared_ptr<Graph>&)> impl; 18 return impl; 19 } 20 21 // Implementation is in frozen_conv_add_relu_fusion.cpp; at runtime the 22 // implementation is registered in _fuseFrozenConvAddReluImpl. This allows 23 // the GPU code to be built separately from CPU-only code. If you're 24 // expecting conv-add-relu fusion to occur but it's not happening, it's 25 // possible that the GPU code isn't being built or linked properly. FuseFrozenConvAddRelu(std::shared_ptr<Graph> & graph)26void FuseFrozenConvAddRelu(std::shared_ptr<Graph>& graph) { 27 if (getFuseFrozenConvAddReluImpl()) { 28 getFuseFrozenConvAddReluImpl()(graph); 29 } 30 } 31 32 } // namespace torch::jit 33