xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Utils.h>
2 
3 #include <torch/csrc/jit/ir/constants.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/ir/subgraph_matcher.h>
6 #include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
7 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
8 #include <torch/csrc/jit/passes/remove_mutation.h>
9 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
10 #ifdef USE_CUDA
11 #include <ATen/cuda/CUDAConfig.h>
12 #endif
13 
14 namespace torch::jit {
15 
getFuseFrozenConvAddReluImpl()16 std::function<void(std::shared_ptr<Graph>&)>& getFuseFrozenConvAddReluImpl() {
17   static std::function<void(std::shared_ptr<Graph>&)> impl;
18   return impl;
19 }
20 
21 // Implementation is in frozen_conv_add_relu_fusion.cpp; at runtime the
22 // implementation is registered in _fuseFrozenConvAddReluImpl. This allows
23 // the GPU code to be built separately from CPU-only code. If you're
24 // expecting conv-add-relu fusion to occur but it's not happening, it's
25 // possible that the GPU code isn't being built or linked properly.
FuseFrozenConvAddRelu(std::shared_ptr<Graph> & graph)26 void FuseFrozenConvAddRelu(std::shared_ptr<Graph>& graph) {
27   if (getFuseFrozenConvAddReluImpl()) {
28     getFuseFrozenConvAddReluImpl()(graph);
29   }
30 }
31 
32 } // namespace torch::jit
33