1 #include <ATen/Utils.h>
2
3 #include <ATen/code_template.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 #include <torch/csrc/jit/ir/constants.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/ir/subgraph_matcher.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/remove_mutation.h>
12 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
13
14 namespace torch::jit {
15
16 namespace {
fuseFrozenConvAddReluImpl(std::shared_ptr<Graph> & graph)17 void fuseFrozenConvAddReluImpl(std::shared_ptr<Graph>& graph) {
18 #if AT_CUDNN_ENABLED() || AT_ROCM_ENABLED()
19 GRAPH_DEBUG("Before fuseFrozenConvAddReluImpl: ", *graph);
20 SubgraphRewriter rewriter;
21
22 // CUDNN does not support conv1d
23 std::array<std::string, 2> conv_operators = {"conv2d", "conv3d"};
24 std::array<std::string, 2> add_operators = {"add", "add_"};
25 std::array<std::string, 2> relu_operators = {"relu", "relu_"};
26
27 auto conv_relu_rstring = at::jit::CodeTemplate(R"(
28 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
29 %x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
30 %res = aten::${relu}(%x)
31 return (%res))");
32
33 #ifdef USE_ROCM
34 std::string conv_relu_fused = R"(
35 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
36 %res = aten::miopen_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
37 return (%res))";
38 #else
39 std::string conv_relu_fused = R"(
40 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
41 %res = aten::cudnn_convolution_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
42 return (%res))";
43 #endif
44
45 auto conv_add_relu_rstring = at::jit::CodeTemplate(R"(
46 graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
47 %x = aten::${conv}(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
48 %y = aten::${add}(%x, %z, %alpha)
49 %res = aten::${relu}(%y)
50 return (%res))");
51
52 #ifdef USE_ROCM
53 std::string conv_add_relu_fused = R"(
54 graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
55 %res = aten::miopen_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups)
56 return (%res))";
57 #else
58 std::string conv_add_relu_fused = R"(
59 graph(%input, %weight, %bias, %z, %alpha, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
60 %res = aten::cudnn_convolution_add_relu(%input, %weight, %z, %alpha, %bias, %stride, %padding, %dilation, %groups)
61 return (%res))";
62 #endif
63
64 for (const auto& conv : conv_operators) {
65 for (const auto& relu : relu_operators) {
66 at::jit::TemplateEnv env;
67 env.s("conv", conv);
68 env.s("relu", relu);
69 rewriter.RegisterRewritePattern(
70 conv_relu_rstring.format(env), conv_relu_fused);
71 for (const auto& add : add_operators) {
72 env.s("add", add);
73 rewriter.RegisterRewritePattern(
74 conv_add_relu_rstring.format(env), conv_add_relu_fused);
75 }
76 }
77 }
78
79 auto filter = [](const Match& match,
80 const std::unordered_map<std::string, Value*>& vmap) {
81 auto weight = toIValue(match.values_map.at(vmap.at("weight")));
82 if (!weight.has_value() || !weight.value().isTensor()) {
83 return false;
84 }
85 const at::Tensor& weight_t = weight.value().toTensor();
86 if (!weight_t.device().is_cuda() || !weight_t.is_contiguous()) {
87 return false;
88 }
89
90 // bias is optional
91 if (vmap.find("bias") != vmap.end()) {
92 auto bias = toIValue(match.values_map.at(vmap.at("bias")));
93 if (bias.has_value() && bias.value().isTensor()) {
94 const at::Tensor& bias_t = bias.value().toTensor();
95 if (bias_t.dtype() != weight_t.dtype() || bias_t.ndimension() != 1 ||
96 bias_t.size(0) != weight_t.size(0) || !bias_t.device().is_cuda()) {
97 return false;
98 }
99 }
100 }
101
102 // z is optional
103 if (vmap.find("z") != vmap.end()) {
104 auto z = toIValue(match.values_map.at(vmap.at("z")));
105 if (z.has_value() && z.value().isTensor()) {
106 const at::Tensor& z_t = z.value().toTensor();
107 if (z_t.dtype() != weight_t.dtype() ||
108 z_t.size(0) != weight_t.size(0) || !z_t.is_contiguous() ||
109 !z_t.device().is_cuda()) {
110 return false;
111 }
112 }
113 }
114 return true;
115 };
116
117 // Convert _convolution and in-place operators for simpler replacement pattern
118 // matching
119 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
120
121 rewriter.runOnGraph(graph, filter);
122 GRAPH_DEBUG("After fuseFrozenConvAddReluImpl: ", *graph);
123 #endif
124 }
125
__anon09a46e040302() 126 auto dummyInitializer = []() {
127 getFuseFrozenConvAddReluImpl() = fuseFrozenConvAddReluImpl;
128 return true;
129 }();
130
131 } // namespace
132
133 } // namespace torch::jit
134