xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/frozen_conv_add_relu_fusion_cuda.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Utils.h>
2 
3 #include <ATen/code_template.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 #include <torch/csrc/jit/ir/constants.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/ir/subgraph_matcher.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/remove_mutation.h>
12 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
13 
14 namespace torch::jit {
15 
16 namespace {
fuseFrozenConvAddReluImpl(std::shared_ptr<Graph> & graph)17 void fuseFrozenConvAddReluImpl(std::shared_ptr<Graph>& graph) {
18 #if AT_CUDNN_ENABLED() || AT_ROCM_ENABLED()
19   GRAPH_DEBUG("Before fuseFrozenConvAddReluImpl: ", *graph);
20   SubgraphRewriter rewriter;
21 
22   // CUDNN does not support conv1d
23   std::array<std::string, 2> conv_operators = {"conv2d", "conv3d"};
24   std::array<std::string, 2> add_operators = {"add", "add_"};
25   std::array<std::string, 2> relu_operators = {"relu", "relu_"};
26 
27   auto conv_relu_rstring = at::jit::CodeTemplate(R"(
28     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
29       %x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
30       %res = aten::${relu}(%x)
31       return (%res))");
32 
33 #ifdef USE_ROCM
34   std::string conv_relu_fused = R"(
35     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
36         %res = aten::miopen_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
37         return (%res))";
38 #else
39   std::string conv_relu_fused = R"(
40     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
41         %res = aten::cudnn_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
42         return (%res))";
43 #endif
44 
45   auto conv_add_relu_rstring = at::jit::CodeTemplate(R"(
46     graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
47       %x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
48       %y = aten::${add}(%x, %z, %alpha)
49       %res = aten::${relu}(%y)
50       return (%res))");
51 
52 #ifdef USE_ROCM
53   std::string conv_add_relu_fused = R"(
54     graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
55         %res = aten::miopen_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups)
56         return (%res))";
57 #else
58   std::string conv_add_relu_fused = R"(
59     graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
60         %res = aten::cudnn_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups)
61         return (%res))";
62 #endif
63 
64   for (const auto& conv : conv_operators) {
65     for (const auto& relu : relu_operators) {
66       at::jit::TemplateEnv env;
67       env.s("conv", conv);
68       env.s("relu", relu);
69       rewriter.RegisterRewritePattern(
70           conv_relu_rstring.format(env), conv_relu_fused);
71       for (const auto& add : add_operators) {
72         env.s("add", add);
73         rewriter.RegisterRewritePattern(
74             conv_add_relu_rstring.format(env), conv_add_relu_fused);
75       }
76     }
77   }
78 
79   auto filter = [](const Match& match,
80                    const std::unordered_map<std::string, Value*>& vmap) {
81     auto weight = toIValue(match.values_map.at(vmap.at("weight")));
82     if (!weight.has_value() || !weight.value().isTensor()) {
83       return false;
84     }
85     const at::Tensor& weight_t = weight.value().toTensor();
86     if (!weight_t.device().is_cuda() || !weight_t.is_contiguous()) {
87       return false;
88     }
89 
90     // bias is optional
91     if (vmap.find("bias") != vmap.end()) {
92       auto bias = toIValue(match.values_map.at(vmap.at("bias")));
93       if (bias.has_value() && bias.value().isTensor()) {
94         const at::Tensor& bias_t = bias.value().toTensor();
95         if (bias_t.dtype() != weight_t.dtype() || bias_t.ndimension() != 1 ||
96             bias_t.size(0) != weight_t.size(0) || !bias_t.device().is_cuda()) {
97           return false;
98         }
99       }
100     }
101 
102     // z is optional
103     if (vmap.find("z") != vmap.end()) {
104       auto z = toIValue(match.values_map.at(vmap.at("z")));
105       if (z.has_value() && z.value().isTensor()) {
106         const at::Tensor& z_t = z.value().toTensor();
107         if (z_t.dtype() != weight_t.dtype() ||
108             z_t.size(0) != weight_t.size(0) || !z_t.is_contiguous() ||
109             !z_t.device().is_cuda()) {
110           return false;
111         }
112       }
113     }
114     return true;
115   };
116 
117   // Convert _convolution and in-place operators for simpler replacement pattern
118   // matching
119   graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
120 
121   rewriter.runOnGraph(graph, filter);
122   GRAPH_DEBUG("After fuseFrozenConvAddReluImpl: ", *graph);
123 #endif
124 }
125 
__anon09a46e040302() 126 auto dummyInitializer = []() {
127   getFuseFrozenConvAddReluImpl() = fuseFrozenConvAddReluImpl;
128   return true;
129 }();
130 
131 } // namespace
132 
133 } // namespace torch::jit
134