xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
17 
18 #include <cstdlib>
19 #include <numeric>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/permutation_util.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
27 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
28 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/window_util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace xla {
38 namespace gpu {
39 
40 namespace conv_matchers {
41 
CanImplementAsGpuForwardConv(HloInstruction * conv)42 bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
43   const ConvolutionDimensionNumbers& dnums =
44       conv->convolution_dimension_numbers();
45   if (dnums.input_spatial_dimensions_size() > 3) {
46     return false;
47   }
48 
49   // CuDNN does not accept zero-element arguments
50   if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
51       ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
52     return false;
53   }
54 
55   // CuDNN can perform either cross correlation (no reversal),
56   // or convolution (all dimensions reversed).
57   if (dnums.input_spatial_dimensions_size() == 2
58           ? !window_util::AllOrNoneReversed(conv->window())
59           : window_util::HasWindowReversal(conv->window())) {
60     return false;
61   }
62   return true;
63 }
64 
65 // Try to match a backward filter pattern that contains "conv".
66 // Precondition: "conv" is a kConvolution.
67 std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardFilter(HloInstruction * conv)68 MatchBackwardFilter(HloInstruction* conv) {
69   VLOG(2) << "Trying to match convolution backward filter.";
70   const auto no_match_result =
71       std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
72 
73   if (conv->feature_group_count() > 1) {
74     VLOG(1) << conv->ToString()
75             << " is a forward convolution. All grouped backward filters are "
76                "mapped to batch grouped convolutions in tf2xla bridge. Hence "
77                "backward filter "
78                "convolutions cannot have feature groups greater than 1 at this "
79                "point. No need to fold to backward filter.";
80     return no_match_result;
81   }
82 
83   // Step 1: match the instruction pattern without considering the paddings and
84   // dimension numbers just yet. We may need some generic pattern matcher
85   // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
86   //
87   // Backward filter convolution is implemented in XLA as the forward
88   // convolution of padded activations and dilated gradients. Padding on
89   // activations and dilation on gradients are specified in the "window" field
90   // of the forward convolution.
91   //
92   //        activations  gradients
93   //              \         /
94   //               v       v
95   //              Convolution
96   //                 conv
97   CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
98 
99   // Step 2: match paddings and dimension numbers of the forward convolution.
100   const ConvolutionDimensionNumbers& conv_dnums =
101       conv->convolution_dimension_numbers();
102   auto input_batch_dim = conv_dnums.input_batch_dimension();
103   auto input_feature_dim = conv_dnums.input_feature_dimension();
104   auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
105   auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
106   auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
107   auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
108   auto output_batch_dim = conv_dnums.output_batch_dimension();
109   auto output_feature_dim = conv_dnums.output_feature_dimension();
110   auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
111   for (const WindowDimension& window_dim : conv->window().dimensions()) {
112     if (window_dim.stride() != 1) {
113       VLOG(1) << "Forward convolution's window "
114               << conv->window().ShortDebugString()
115               << " should have stride of 1.";
116       return no_match_result;
117     }
118     if (window_dim.base_dilation() != 1) {
119       VLOG(1) << "Forward convolution's window "
120               << conv->window().ShortDebugString()
121               << " should have no base (LHS) dilation.";
122       return no_match_result;
123     }
124     if (window_dim.padding_low() < 0) {
125       VLOG(1) << "Padding low should be non-negative.";
126       return no_match_result;
127     }
128     if (window_dim.window_reversal()) {
129       VLOG(1) << "Window reversal field not supported";
130       return no_match_result;
131     }
132     // Padding high will be checked in Step 3.
133   }
134   // Mathematically, there is no difference between convolution forward vs
135   // backward filter. A backward filter:
136   //   [N, O, H+h-1, W+w-1] x [N, C, H, W] -> [O, C, h, w]
137   // Can be treated as a forward convolution with `N` treated as the new
138   // contracting (feature) dimension, `O` treated as the new batch dimension,
139   // and `C` treated as the new output feature dimension. The only difference is
140   // layouts and performance.
141   //
142   // Since there is no way to precisely tell whether we want a foward conv or
143   // backward filter conv, we have to rely on heuristics. Empirically forward
144   // convolutions have very small kernel dimensions, while in the backward pass
145   // "kernel dimensions" are large. If kernel dimensions are smaller than the
146   // output dimensions, return foward conv; otherwise proceed with backward
147   // filter conv.
148   if ((kernel_spatial_dims.empty() ||
149        conv->operand(1)->shape().dimensions(kernel_spatial_dims[0]) <=
150            conv->shape().dimensions(output_spatial_dims[0])) &&
151       !window_util::HasWindowDilation(conv->window())) {
152     VLOG(1) << conv->ToString()
153             << " is a regular forward convolution. No need "
154                "to fold it to a backward filter convolution....";
155     return no_match_result;
156   }
157 
158   // Step 3: fuse the matched HLOs into a backward convolution instruction.
159   //
160   // Compute the window of the backward convolution.
161   Window backward_conv_window;
162   for (int i = 0; i < input_spatial_dims.size(); ++i) {
163     WindowDimension* dim = backward_conv_window.add_dimensions();
164     // The window size of the backward convolution equals the output size of the
165     // forward convolution.
166     int64_t filter_size = conv->shape().dimensions(output_spatial_dims[i]);
167     dim->set_size(filter_size);
168     // The window stride equals the window dilation of the forward convolution.
169     dim->set_stride(conv->window().dimensions(i).window_dilation());
170     // The window's low padding is the same as the low padding of the
171     // activations.
172     dim->set_padding_low(conv->window().dimensions(i).padding_low());
173     dim->set_base_dilation(1);
174     dim->set_window_dilation(1);
175 
176     int64_t input_size =
177         conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
178     int64_t output_size = conv->window().dimensions(i).size();
179     // Compute the range of the amount of valid high padding. We first compute
180     // min_padding_high, the amount of padding on the right/bottom to ensure the
181     // last patch ends at the border, i.e.,
182     //
183     //   input_size + dim->padding_low() + min_padding_high
184     //     = (output_size - 1) * stride + filter_size
185     //
186     // Because convolution ignores trailing incomplete windows, any amount of
187     // padding high from min_padding_high to min_padding_high+stride-1
188     // (max_padding_high) has the same effect.
189     int64_t padded_input_size = filter_size + (output_size - 1) * dim->stride();
190     int64_t min_padding_high =
191         padded_input_size - input_size - dim->padding_low();
192     int64_t max_padding_high = min_padding_high + dim->stride() - 1;
193     CHECK_GE(dim->padding_low(), 0);
194     // In practice, since cuDNN convolution only supports even padding, we make
195     // the amount of high padding the same as the amount of low padding as long
196     // as it is between min_padding_high and max_padding_high. If it is not in
197     // that range, we pick the one that's closest to dim->padding_low() and let
198     // GpuConvPaddingLegalization canonicalize the resultant backward
199     // convolution later. Picking the closest one minimizes the cost of the kPad
200     // instruction to be inserted by GpuConvPaddingLegalization.
201     if (dim->padding_low() >= min_padding_high &&
202         dim->padding_low() <= max_padding_high) {
203       dim->set_padding_high(dim->padding_low());
204     } else {
205       if (dim->padding_low() < min_padding_high) {
206         dim->set_padding_high(min_padding_high);
207       } else {
208         dim->set_padding_high(max_padding_high);
209       }
210     }
211     if (dim->padding_high() < 0) {
212       LOG(WARNING)
213           << "Fusing this pattern to backward filter convolution would cause "
214              "negative padding ("
215           << dim->padding_high()
216           << ") on right/bottom of the weight gradients, which is not "
217              "supported by GpuConvPaddingLegalization (b/32744257). "
218              "Falling back to "
219              "unfused convolution for instruction: "
220           << conv->ToString();
221       return no_match_result;
222     }
223   }
224 
225   // Restore the dimension numbers of the backward convolution from the forward
226   // convolution. The two activation dimensions are reversed (batch and
227   // feature).
228   ConvolutionDimensionNumbers backward_conv_dnums;
229   backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
230   backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
231   for (int i = 0; i < input_spatial_dims.size(); ++i) {
232     backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
233   }
234   backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
235   backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
236   for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
237     backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
238   }
239   // The dimension numbering of the output of the forward convolution (before
240   // transposition) is the same as that of the activations (according to the
241   // semantics of kConvolution). The batch dimension of the activations should
242   // be treated as the input feature dimension, and the feature dimension should
243   // be treated as the output feature.
244   backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
245   backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
246   for (int i = 0; i < output_spatial_dims.size(); ++i) {
247     backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
248   }
249 
250   HloInstruction* lhs = conv->mutable_operand(0);
251   return std::make_tuple(true, backward_conv_window, backward_conv_dnums, lhs);
252 }
253 
254 // Try to match a backward input pattern that contains "conv".
255 // Precondition: "conv" is a kConvolution.
256 std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardInput(HloInstruction * conv)257 MatchBackwardInput(HloInstruction* conv) {
258   VLOG(2) << "Trying to match convolution backward input.";
259   const auto no_match_result =
260       std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
261 
262   // TODO(timshen) Theoretically cuDNN supports grouped convolutions also
263   // for the backward input convolution, but based on the cudnn's current state
264   // there is not much performance improvement when using the
265   // cudnn backward input API for grouped conv.
266   // This needs to be re-evaluated for future cuDNN versions.
267   // Note that we already have the necessary code down below, the only thing to
268   // enable it is to remove the following early return.
269   if (conv->feature_group_count() > 1) {
270     return no_match_result;
271   }
272 
273   // Match instruction pattern.
274   CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
275   HloInstruction* reverse_filter = conv->mutable_operand(1);
276   ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
277 
278   // Match BackwardInput for a depthwise convolution and thunk it to forward
279   // convolution Output feature dimension and input feature dimension has been
280   // swapped in the bridge. Hence to get the actual input features we need to
281   // query the output feature dimension
282   auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
283   auto kernel_out_features =
284       reverse_filter->shape().dimensions(kernel_out_feature_dim);
285 
286   // For a depthwise convolution, the input features must be equal to the
287   // feature_group_count. We can leverage this property to match a depthwise
288   // convolution and thunk it to forward conv
289   if (conv->feature_group_count() > 1 &&
290       kernel_out_features == conv->feature_group_count()) {
291     return no_match_result;
292   }
293 
294   // We pattern-match to a backwards input conv if:
295   //
296   //  - all spatial dims of the filter are reversed
297   //
298   // OR
299   //
300   //  - filter is 1x1 or a constant AND
301   //  - conv has base dilation (otherwise this is just a regular forward conv).
302   //
303   // The final criterion above is just for canonicalization; cudnn seems to run
304   // just as fast if we canonicalize 1x1/constant filters without base dilation
305   // to forward or backward convs.  We canonicalize to forward conv because (a)
306   // it's more natural (constant filters usually show up when doing inference,
307   // and having backwards convolutions in inference graphs would be weird), and
308   // (b) cudnn has special fusions for forward conv plus bias and activation,
309   // and we want to pattern-match to that after running this pass.
310   bool is_reversed_filter =
311       reverse_filter->opcode() == HloOpcode::kReverse &&
312       absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
313                              reverse_filter->dimensions());
314   bool is_1x1_filter =
315       absl::c_all_of(conv->window().dimensions(),
316                      [](const WindowDimension& d) { return d.size() == 1; });
317   if (!is_reversed_filter &&
318       !(window_util::HasBaseDilation(conv->window()) &&
319         (reverse_filter->IsConstant() || is_1x1_filter))) {
320     VLOG(1) << "Can't match to backwards convolution. Either filter is not "
321                "kReverse, or it's not a base-dilated conv with a 1x1 or "
322                "constant filter.";
323     return no_match_result;
324   }
325 
326   // Match padding and dilation of the forward convolution.
327   for (const WindowDimension& window_dim : conv->window().dimensions()) {
328     if (window_dim.stride() != 1) {
329       VLOG(1) << "Forward convolution's window "
330               << conv->window().ShortDebugString()
331               << " should have stride of 1.";
332       return no_match_result;
333     }
334     if (window_dim.window_dilation() != 1) {
335       VLOG(1) << "Forward convolution's window "
336               << conv->window().ShortDebugString()
337               << " should have no window dilation.";
338       return no_match_result;
339     }
340     if (window_dim.window_reversal()) {
341       VLOG(1) << "Window reversal field not supported";
342       return no_match_result;
343     }
344   }
345 
346   const auto& input_spatial_dims = dnums.input_spatial_dimensions();
347   const auto& output_spatial_dims = dnums.output_spatial_dimensions();
348   CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
349   CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
350 
351   const Window& old_window = conv->window();
352   Window new_window = old_window;
353   for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
354     // Restore backward convolution's padding config from the matched pattern.
355     // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we
356     // convert backward input convolution to a variant of forward convolution.
357     //
358     // The stride of the backward convolution
359     // = the base dilation factor of the forward convolution
360     auto dim = new_window.mutable_dimensions(i);
361     dim->set_stride(old_window.dimensions(i).base_dilation());
362     dim->set_base_dilation(1);
363 
364     // The low padding = kernel_size - 1 - low padding on the gradients
365     // Make sure the low padding is not negative.
366     auto kernel_size = old_window.dimensions(i).size();
367     auto backward_padding_low =
368         kernel_size - 1 - old_window.dimensions(i).padding_low();
369     if (backward_padding_low < 0) {
370       LOG(WARNING)
371           << "The low padding of the backward convolution would be negative ("
372           << backward_padding_low
373           << "), which isn't supported by GpuConvPaddingLegalization "
374              "for now (b/32744257).";
375       return no_match_result;
376     }
377     dim->set_padding_low(backward_padding_low);
378 
379     // Compute the range of the amount of padding on the right/bottom of the
380     // activations. XLA's convolution requires all patches to be within the
381     // padded base. This gives us flexiblity to choose the amount of high
382     // padding from a set of values without changing the result of the backward
383     // convolution. The minimum amount (min_padding_high) makes the last patch
384     // end at the border. The maximum amount (max_padding_high) equals
385     // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
386     // size to change.
387     auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
388     auto output_size =
389         conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
390     auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
391     auto total_pad_size = padded_input_size - unpadded_input_size;
392     auto min_padding_high = total_pad_size - backward_padding_low;
393     auto max_padding_high = min_padding_high + dim->stride() - 1;
394 
395     if (backward_padding_low >= min_padding_high &&
396         backward_padding_low <= max_padding_high) {
397       // In the best case (most likely), if backward_padding_low is in the range
398       // of the amounts of valid high padding, we choose backward_padding_low
399       // because cuDNN supports even padding only.
400       dim->set_padding_high(backward_padding_low);
401     } else {
402       // Otherwise, we choose the amount that's closest to backward_padding_low,
403       // and GpuConvPaddingLegalization will later insert kSlice
404       // instructions to enforce even padding.
405       //
406       // For example, consider the backward convolution pattern
407       //
408       //   ab     xy
409       //   | pad  | reverse
410       //  .a.b    yx
411       //     \   /
412       //      ABC
413       //
414       // The amount of low padding on activations (in backward convolution) is
415       //   backward_padding_low = kernel_size - 1 - forward_padding_low
416       //                        = 2 - 1 - 1 = 0
417       //
418       // The amount of padding high must be between 1 and 2, in order to make
419       // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
420       // the range of [1,2], so we pick the closest valid amount of padding
421       // high, which is 1 in this case. Therefore, we fuse the above pattern to
422       //
423       //   ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
424       if (backward_padding_low < min_padding_high) {
425         dim->set_padding_high(min_padding_high);
426       } else {
427         dim->set_padding_high(max_padding_high);
428       }
429     }
430     // GpuConvPaddingLegalization doesn't handle backward input
431     // convolution with negative padding for now. So fall back to unfused
432     // convolution in case of negative padding. For example,
433     //   ABCD = Conv(abc, reverse(xy), padding_high=2)
434     // could be fused to
435     //   ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
436     // with positive padding low but negative padding high.
437     if (dim->padding_high() < 0) {
438       LOG(WARNING) << "Fusing this pattern to backward convolution would cause "
439                       "negative padding ("
440                    << dim->padding_high()
441                    << ") on right/bottom of the activations, which is not "
442                       "supported by GpuConvPaddingLegalization (b/32744257). "
443                       "Falling back to unfused convolution for instruction: "
444                    << conv->ToString();
445       return no_match_result;
446     }
447   }
448 
449   // OK, it's a match! Switch the input feature dimension with the output
450   // feature dimension. Also switch the output with the input. This is the way
451   // cuDNN expects it to be.
452   auto conv_dnums = conv->convolution_dimension_numbers();
453   dnums.set_kernel_input_feature_dimension(
454       conv_dnums.kernel_output_feature_dimension());
455   dnums.set_kernel_output_feature_dimension(
456       conv_dnums.kernel_input_feature_dimension());
457   for (int i = 0; i < input_spatial_dims.size(); ++i) {
458     dnums.set_input_spatial_dimensions(i,
459                                        conv_dnums.output_spatial_dimensions(i));
460     dnums.set_output_spatial_dimensions(i,
461                                         conv_dnums.input_spatial_dimensions(i));
462   }
463   dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension());
464   dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension());
465   dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension());
466   dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension());
467 
468   // If we matched against a constant, we need to add a reverse op that can be
469   // subsumed by the cuDNN call. algebraic-simplifier will later remove any
470   // unnecessary reverses.
471   if (reverse_filter->opcode() != HloOpcode::kReverse &&
472       reverse_filter->IsConstant()) {
473     // Create a double-reverse, which is a nop.
474     HloComputation* c = conv->parent();
475     reverse_filter = c->AddInstruction(
476         HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
477                                       dnums.kernel_spatial_dimensions()));
478     reverse_filter = c->AddInstruction(
479         HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
480                                       dnums.kernel_spatial_dimensions()));
481     TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter));
482   }
483 
484   // Calculate the 'rhs' that goes into the backward input convolution.
485   HloInstruction* rhs = reverse_filter;
486   // One reverse is subsumed by the cuDNN call.
487   if (rhs->opcode() == HloOpcode::kReverse) {
488     rhs = rhs->mutable_operand(0);
489   }
490   if (conv->feature_group_count() == 1) {
491     return std::make_tuple(true, new_window, dnums, rhs);
492   }
493 
494   // Handle grouped convolutions. Because we swapped the input feature dimension
495   // with the output feature dimension, we need to also reshape the kernel so
496   // that the 'feature_group_count' parameter still makes sense. The
497   // 'feature_group_count' parameter essentially specifies how often the
498   // 'kernel_input_feature_dimension' is repeated. So when we swap these
499   // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
500   // 'feature_group_count' and multiply the new
501   // 'kernel_output_feature_dimension' by 'feature_group_count'.
502   int64_t input_feature_dimension = dnums.kernel_input_feature_dimension();
503   int64_t output_feature_dimension = dnums.kernel_output_feature_dimension();
504   // The following code assumes that input_feature_dimension and
505   // output_feature_dimension are adjacent.
506   if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
507     return no_match_result;
508   }
509 
510   int64_t input_features = rhs->shape().dimensions(input_feature_dimension);
511   int64_t output_features = rhs->shape().dimensions(output_feature_dimension);
512 
513   // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G,
514   // out_depth / G]
515   std::vector<int64_t> reshape_dims = SpanToVector(rhs->shape().dimensions());
516   auto num_groups = conv->feature_group_count();
517   CHECK_EQ(input_features % num_groups, 0)
518       << "Input feature count should be an exact multiple of feature group "
519          "count";
520   reshape_dims[input_feature_dimension] =
521       reshape_dims[input_feature_dimension] / num_groups;
522   reshape_dims.insert(reshape_dims.begin() + input_feature_dimension,
523                       num_groups);
524 
525   HloComputation* c = conv->parent();
526   rhs = c->AddInstruction(HloInstruction::CreateReshape(
527       ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs));
528 
529   // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ...,
530   // in_depth/G, G, out_depth / G]
531   std::vector<int64_t> transpose_dims(rhs->shape().dimensions_size());
532   std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
533   transpose_dims.erase(transpose_dims.begin() + input_feature_dimension);
534   transpose_dims.insert(transpose_dims.begin() + output_feature_dimension,
535                         input_feature_dimension);
536   std::vector<int64_t> transpose_reshape_dims =
537       SpanToVector(rhs->shape().dimensions());
538   transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
539                                input_feature_dimension);
540   transpose_reshape_dims.insert(
541       transpose_reshape_dims.begin() + output_feature_dimension, num_groups);
542   rhs = c->AddInstruction(HloInstruction::CreateTranspose(
543       ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims),
544       rhs, transpose_dims));
545 
546   // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ...,
547   // in_depth/G, out_depth]
548   Shape new_shape = rhs->shape();
549   new_shape.DeleteDimension(output_feature_dimension);
550   new_shape.set_dimensions(output_feature_dimension,
551                            output_features * num_groups);
552   rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
553   return std::make_tuple(true, new_window, dnums, rhs);
554 }
555 
556 }  // namespace conv_matchers
557 
558 namespace {
559 
CreateGpuConv(absl::string_view call_target,const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const Window & window,const ConvolutionDimensionNumbers & dnums,int64_t feature_group_count,const OpMetadata & metadata)560 HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape,
561                               HloInstruction* lhs, HloInstruction* rhs,
562                               const Window& window,
563                               const ConvolutionDimensionNumbers& dnums,
564                               int64_t feature_group_count,
565                               const OpMetadata& metadata) {
566   HloComputation* computation = lhs->parent();
567 
568   // This call returns a tuple of (conv_result, scratch_memory), where
569   // conv_result is the actual result of the convolution, and scratch_memory is
570   // temporary memory used by cudnn.
571   //
572   // At the moment, we don't know how much scratch memory this conv is going to
573   // use, so we put u8[0] in this place.  Later on another pass will choose
574   // which conv algorithm to use, and at that point we'll modify the shape of
575   // this second tuple element.
576   Shape call_shape =
577       ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
578 
579   HloInstruction* custom_call = computation->AddInstruction(
580       HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
581   custom_call->set_window(window);
582   custom_call->set_convolution_dimension_numbers(dnums);
583   custom_call->set_feature_group_count(feature_group_count);
584   custom_call->set_metadata(metadata);
585 
586   // Give the customcall a user-friendly name.
587   std::optional<std::string> name;
588   if (call_target == kCudnnConvForwardCallTarget) {
589     name = "cudnn-conv";
590   } else if (call_target == kCudnnConvBackwardInputCallTarget) {
591     name = "cudnn-conv-bw-input";
592   } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
593     name = "cudnn-conv-bw-filter";
594   } else if (call_target == kCudnnConvBiasActivationForwardCallTarget) {
595     name = "cudnn-conv-bias-activation";
596   }
597   if (name.has_value()) {
598     computation->parent()->SetAndUniquifyInstrName(custom_call, *name);
599   }
600 
601   return custom_call;
602 }
603 
ConvertBatchGroupedToFeatureGroupedConvolution(HloInstruction * conv)604 HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(
605     HloInstruction* conv) {
606   CHECK_EQ(conv->feature_group_count(), 1);
607   int64_t num_groups = conv->batch_group_count();
608   auto dim_numbers = conv->convolution_dimension_numbers();
609   auto lhs = conv->mutable_operand(0);
610   auto rhs = conv->mutable_operand(1);
611 
612   int64_t input_batch_dimension = dim_numbers.input_batch_dimension();
613 
614   Shape output_shape = conv->shape();
615   int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
616   int64_t input_feature = lhs->shape().dimensions(input_feature_dimension);
617 
618   HloComputation* computation = lhs->parent();
619   auto add = [&](std::unique_ptr<HloInstruction> inst) {
620     return computation->AddInstruction(std::move(inst));
621   };
622   // Reshape batch_dim N -> [G, N/G]
623   std::vector<int64_t> reshape_dims = SpanToVector(lhs->shape().dimensions());
624   reshape_dims[input_batch_dimension] =
625       reshape_dims[input_batch_dimension] / num_groups;
626   reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
627   lhs = add(HloInstruction::CreateReshape(
628       ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
629 
630   // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
631   // W, G, C]
632   std::vector<int64_t> transpose_dims(lhs->shape().dimensions_size());
633   std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
634   transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
635   transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
636                         input_batch_dimension);
637   std::vector<int64_t> transpose_reshape_dims =
638       ComposePermutations(lhs->shape().dimensions(), transpose_dims);
639   lhs = add(HloInstruction::CreateTranspose(
640       ShapeUtil::MakeShape(lhs->shape().element_type(), transpose_reshape_dims),
641       lhs, transpose_dims));
642 
643   // Merge [G,C] -> [C*G]
644   Shape new_shape = lhs->shape();
645   new_shape.DeleteDimension(input_feature_dimension);
646   new_shape.set_dimensions(input_feature_dimension, input_feature * num_groups);
647   lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
648 
649   std::vector<HloInstruction*> new_operands = {lhs, rhs};
650   auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
651   new_conv->set_feature_group_count(num_groups);
652   new_conv->set_batch_group_count(1);
653   new_conv->set_convolution_dimension_numbers(dim_numbers);
654   return computation->AddInstruction(std::move(new_conv));
655 }
656 
GetDefaultBackendConfig()657 CudnnConvBackendConfig GetDefaultBackendConfig() {
658   CudnnConvBackendConfig config;
659   config.set_conv_result_scale(1);
660   return config;
661 }
662 
663 // Helper function to create a custom_call instruction to replace the given
664 // conv instruction
CreateCustomCallHelper(HloInstruction * conv)665 static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
666   bool match;
667   Window window;
668   ConvolutionDimensionNumbers dnums;
669   HloInstruction* rhs;
670   HloInstruction* lhs;
671 
672   std::tie(match, window, dnums, rhs) = conv_matchers::MatchBackwardInput(conv);
673   if (match) {
674     return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
675                          conv->mutable_operand(0), rhs, window, dnums,
676                          conv->feature_group_count(), conv->metadata());
677   }
678 
679   std::tie(match, window, dnums, lhs) =
680       conv_matchers::MatchBackwardFilter(conv);
681   if (match) {
682     return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
683                          conv->mutable_operand(1), window, dnums,
684                          conv->batch_group_count(), conv->metadata());
685   }
686 
687   // If all else fails, try a forward convolution.
688   if (conv_matchers::CanImplementAsGpuForwardConv(conv)) {
689     if (conv->batch_group_count() > 1) {
690       conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
691     }
692 
693     return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(),
694                          conv->mutable_operand(0), conv->mutable_operand(1),
695                          conv->window(), conv->convolution_dimension_numbers(),
696                          conv->feature_group_count(), conv->metadata());
697   }
698 
699   return nullptr;
700 }
701 
702 // Tries to rewrite a single convolution into a call to cudnn/miopen.
RunOnInstruction(HloInstruction * conv)703 StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
704   CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
705 
706   TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
707                       CreateCustomCallHelper(conv));
708   if (custom_call == nullptr) {
709     return false;
710   }
711 
712   TF_RETURN_IF_ERROR(
713       custom_call->set_backend_config(GetDefaultBackendConfig()));
714 
715   VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
716           << custom_call->ToString();
717 
718   // The CustomCall returns a tuple (conv_result, scratch_memory).  Extract
719   // out the conv result and replace `conv` with it.
720   TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
721       conv,
722       HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
723   return true;
724 }
725 
726 // Rewrites the convolutions in the given computation into calls to
727 // cudnn/miopen.
728 // Returns true if it made any changes.
RunOnComputation(HloComputation * computation)729 StatusOr<bool> RunOnComputation(HloComputation* computation) {
730   std::vector<HloInstruction*> convs;
731   for (auto* hlo : computation->instructions()) {
732     if (hlo->opcode() == HloOpcode::kConvolution) {
733       convs.push_back(hlo);
734     }
735   }
736 
737   bool changed = false;
738   for (HloInstruction* conv : convs) {
739     TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv));
740     changed |= result;
741   }
742   return changed;
743 }
744 }  // namespace
745 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)746 StatusOr<bool> GpuConvRewriter::Run(
747     HloModule* module,
748     const absl::flat_hash_set<absl::string_view>& execution_threads) {
749   XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString());
750   bool changed = false;
751   for (HloComputation* computation :
752        module->MakeNonfusionComputations(execution_threads)) {
753     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
754     changed |= result;
755   }
756   XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString());
757   return changed;
758 }
759 
760 }  // namespace gpu
761 }  // namespace xla
762