1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
17
18 #include <cstdlib>
19 #include <numeric>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/permutation_util.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
27 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
28 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/window_util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace xla {
38 namespace gpu {
39
40 namespace conv_matchers {
41
CanImplementAsGpuForwardConv(HloInstruction * conv)42 bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
43 const ConvolutionDimensionNumbers& dnums =
44 conv->convolution_dimension_numbers();
45 if (dnums.input_spatial_dimensions_size() > 3) {
46 return false;
47 }
48
49 // CuDNN does not accept zero-element arguments
50 if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
51 ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
52 return false;
53 }
54
55 // CuDNN can perform either cross correlation (no reversal),
56 // or convolution (all dimensions reversed).
57 if (dnums.input_spatial_dimensions_size() == 2
58 ? !window_util::AllOrNoneReversed(conv->window())
59 : window_util::HasWindowReversal(conv->window())) {
60 return false;
61 }
62 return true;
63 }
64
65 // Try to match a backward filter pattern that contains "conv".
66 // Precondition: "conv" is a kConvolution.
67 std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardFilter(HloInstruction * conv)68 MatchBackwardFilter(HloInstruction* conv) {
69 VLOG(2) << "Trying to match convolution backward filter.";
70 const auto no_match_result =
71 std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
72
73 if (conv->feature_group_count() > 1) {
74 VLOG(1) << conv->ToString()
75 << " is a forward convolution. All grouped backward filters are "
76 "mapped to batch grouped convolutions in tf2xla bridge. Hence "
77 "backward filter "
78 "convolutions cannot have feature groups greater than 1 at this "
79 "point. No need to fold to backward filter.";
80 return no_match_result;
81 }
82
83 // Step 1: match the instruction pattern without considering the paddings and
84 // dimension numbers just yet. We may need some generic pattern matcher
85 // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
86 //
87 // Backward filter convolution is implemented in XLA as the forward
88 // convolution of padded activations and dilated gradients. Padding on
89 // activations and dilation on gradients are specified in the "window" field
90 // of the forward convolution.
91 //
92 // activations gradients
93 // \ /
94 // v v
95 // Convolution
96 // conv
97 CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
98
99 // Step 2: match paddings and dimension numbers of the forward convolution.
100 const ConvolutionDimensionNumbers& conv_dnums =
101 conv->convolution_dimension_numbers();
102 auto input_batch_dim = conv_dnums.input_batch_dimension();
103 auto input_feature_dim = conv_dnums.input_feature_dimension();
104 auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
105 auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
106 auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
107 auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
108 auto output_batch_dim = conv_dnums.output_batch_dimension();
109 auto output_feature_dim = conv_dnums.output_feature_dimension();
110 auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
111 for (const WindowDimension& window_dim : conv->window().dimensions()) {
112 if (window_dim.stride() != 1) {
113 VLOG(1) << "Forward convolution's window "
114 << conv->window().ShortDebugString()
115 << " should have stride of 1.";
116 return no_match_result;
117 }
118 if (window_dim.base_dilation() != 1) {
119 VLOG(1) << "Forward convolution's window "
120 << conv->window().ShortDebugString()
121 << " should have no base (LHS) dilation.";
122 return no_match_result;
123 }
124 if (window_dim.padding_low() < 0) {
125 VLOG(1) << "Padding low should be non-negative.";
126 return no_match_result;
127 }
128 if (window_dim.window_reversal()) {
129 VLOG(1) << "Window reversal field not supported";
130 return no_match_result;
131 }
132 // Padding high will be checked in Step 3.
133 }
134 // Mathematically, there is no difference between convolution forward vs
135 // backward filter. A backward filter:
136 // [N, O, H+h-1, W+w-1] x [N, C, H, W] -> [O, C, h, w]
137 // Can be treated as a forward convolution with `N` treated as the new
138 // contracting (feature) dimension, `O` treated as the new batch dimension,
139 // and `C` treated as the new output feature dimension. The only difference is
140 // layouts and performance.
141 //
142 // Since there is no way to precisely tell whether we want a foward conv or
143 // backward filter conv, we have to rely on heuristics. Empirically forward
144 // convolutions have very small kernel dimensions, while in the backward pass
145 // "kernel dimensions" are large. If kernel dimensions are smaller than the
146 // output dimensions, return foward conv; otherwise proceed with backward
147 // filter conv.
148 if ((kernel_spatial_dims.empty() ||
149 conv->operand(1)->shape().dimensions(kernel_spatial_dims[0]) <=
150 conv->shape().dimensions(output_spatial_dims[0])) &&
151 !window_util::HasWindowDilation(conv->window())) {
152 VLOG(1) << conv->ToString()
153 << " is a regular forward convolution. No need "
154 "to fold it to a backward filter convolution....";
155 return no_match_result;
156 }
157
158 // Step 3: fuse the matched HLOs into a backward convolution instruction.
159 //
160 // Compute the window of the backward convolution.
161 Window backward_conv_window;
162 for (int i = 0; i < input_spatial_dims.size(); ++i) {
163 WindowDimension* dim = backward_conv_window.add_dimensions();
164 // The window size of the backward convolution equals the output size of the
165 // forward convolution.
166 int64_t filter_size = conv->shape().dimensions(output_spatial_dims[i]);
167 dim->set_size(filter_size);
168 // The window stride equals the window dilation of the forward convolution.
169 dim->set_stride(conv->window().dimensions(i).window_dilation());
170 // The window's low padding is the same as the low padding of the
171 // activations.
172 dim->set_padding_low(conv->window().dimensions(i).padding_low());
173 dim->set_base_dilation(1);
174 dim->set_window_dilation(1);
175
176 int64_t input_size =
177 conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
178 int64_t output_size = conv->window().dimensions(i).size();
179 // Compute the range of the amount of valid high padding. We first compute
180 // min_padding_high, the amount of padding on the right/bottom to ensure the
181 // last patch ends at the border, i.e.,
182 //
183 // input_size + dim->padding_low() + min_padding_high
184 // = (output_size - 1) * stride + filter_size
185 //
186 // Because convolution ignores trailing incomplete windows, any amount of
187 // padding high from min_padding_high to min_padding_high+stride-1
188 // (max_padding_high) has the same effect.
189 int64_t padded_input_size = filter_size + (output_size - 1) * dim->stride();
190 int64_t min_padding_high =
191 padded_input_size - input_size - dim->padding_low();
192 int64_t max_padding_high = min_padding_high + dim->stride() - 1;
193 CHECK_GE(dim->padding_low(), 0);
194 // In practice, since cuDNN convolution only supports even padding, we make
195 // the amount of high padding the same as the amount of low padding as long
196 // as it is between min_padding_high and max_padding_high. If it is not in
197 // that range, we pick the one that's closest to dim->padding_low() and let
198 // GpuConvPaddingLegalization canonicalize the resultant backward
199 // convolution later. Picking the closest one minimizes the cost of the kPad
200 // instruction to be inserted by GpuConvPaddingLegalization.
201 if (dim->padding_low() >= min_padding_high &&
202 dim->padding_low() <= max_padding_high) {
203 dim->set_padding_high(dim->padding_low());
204 } else {
205 if (dim->padding_low() < min_padding_high) {
206 dim->set_padding_high(min_padding_high);
207 } else {
208 dim->set_padding_high(max_padding_high);
209 }
210 }
211 if (dim->padding_high() < 0) {
212 LOG(WARNING)
213 << "Fusing this pattern to backward filter convolution would cause "
214 "negative padding ("
215 << dim->padding_high()
216 << ") on right/bottom of the weight gradients, which is not "
217 "supported by GpuConvPaddingLegalization (b/32744257). "
218 "Falling back to "
219 "unfused convolution for instruction: "
220 << conv->ToString();
221 return no_match_result;
222 }
223 }
224
225 // Restore the dimension numbers of the backward convolution from the forward
226 // convolution. The two activation dimensions are reversed (batch and
227 // feature).
228 ConvolutionDimensionNumbers backward_conv_dnums;
229 backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
230 backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
231 for (int i = 0; i < input_spatial_dims.size(); ++i) {
232 backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
233 }
234 backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
235 backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
236 for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
237 backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
238 }
239 // The dimension numbering of the output of the forward convolution (before
240 // transposition) is the same as that of the activations (according to the
241 // semantics of kConvolution). The batch dimension of the activations should
242 // be treated as the input feature dimension, and the feature dimension should
243 // be treated as the output feature.
244 backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
245 backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
246 for (int i = 0; i < output_spatial_dims.size(); ++i) {
247 backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
248 }
249
250 HloInstruction* lhs = conv->mutable_operand(0);
251 return std::make_tuple(true, backward_conv_window, backward_conv_dnums, lhs);
252 }
253
254 // Try to match a backward input pattern that contains "conv".
255 // Precondition: "conv" is a kConvolution.
256 std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardInput(HloInstruction * conv)257 MatchBackwardInput(HloInstruction* conv) {
258 VLOG(2) << "Trying to match convolution backward input.";
259 const auto no_match_result =
260 std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
261
262 // TODO(timshen) Theoretically cuDNN supports grouped convolutions also
263 // for the backward input convolution, but based on the cudnn's current state
264 // there is not much performance improvement when using the
265 // cudnn backward input API for grouped conv.
266 // This needs to be re-evaluated for future cuDNN versions.
267 // Note that we already have the necessary code down below, the only thing to
268 // enable it is to remove the following early return.
269 if (conv->feature_group_count() > 1) {
270 return no_match_result;
271 }
272
273 // Match instruction pattern.
274 CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
275 HloInstruction* reverse_filter = conv->mutable_operand(1);
276 ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
277
278 // Match BackwardInput for a depthwise convolution and thunk it to forward
279 // convolution Output feature dimension and input feature dimension has been
280 // swapped in the bridge. Hence to get the actual input features we need to
281 // query the output feature dimension
282 auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
283 auto kernel_out_features =
284 reverse_filter->shape().dimensions(kernel_out_feature_dim);
285
286 // For a depthwise convolution, the input features must be equal to the
287 // feature_group_count. We can leverage this property to match a depthwise
288 // convolution and thunk it to forward conv
289 if (conv->feature_group_count() > 1 &&
290 kernel_out_features == conv->feature_group_count()) {
291 return no_match_result;
292 }
293
294 // We pattern-match to a backwards input conv if:
295 //
296 // - all spatial dims of the filter are reversed
297 //
298 // OR
299 //
300 // - filter is 1x1 or a constant AND
301 // - conv has base dilation (otherwise this is just a regular forward conv).
302 //
303 // The final criterion above is just for canonicalization; cudnn seems to run
304 // just as fast if we canonicalize 1x1/constant filters without base dilation
305 // to forward or backward convs. We canonicalize to forward conv because (a)
306 // it's more natural (constant filters usually show up when doing inference,
307 // and having backwards convolutions in inference graphs would be weird), and
308 // (b) cudnn has special fusions for forward conv plus bias and activation,
309 // and we want to pattern-match to that after running this pass.
310 bool is_reversed_filter =
311 reverse_filter->opcode() == HloOpcode::kReverse &&
312 absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
313 reverse_filter->dimensions());
314 bool is_1x1_filter =
315 absl::c_all_of(conv->window().dimensions(),
316 [](const WindowDimension& d) { return d.size() == 1; });
317 if (!is_reversed_filter &&
318 !(window_util::HasBaseDilation(conv->window()) &&
319 (reverse_filter->IsConstant() || is_1x1_filter))) {
320 VLOG(1) << "Can't match to backwards convolution. Either filter is not "
321 "kReverse, or it's not a base-dilated conv with a 1x1 or "
322 "constant filter.";
323 return no_match_result;
324 }
325
326 // Match padding and dilation of the forward convolution.
327 for (const WindowDimension& window_dim : conv->window().dimensions()) {
328 if (window_dim.stride() != 1) {
329 VLOG(1) << "Forward convolution's window "
330 << conv->window().ShortDebugString()
331 << " should have stride of 1.";
332 return no_match_result;
333 }
334 if (window_dim.window_dilation() != 1) {
335 VLOG(1) << "Forward convolution's window "
336 << conv->window().ShortDebugString()
337 << " should have no window dilation.";
338 return no_match_result;
339 }
340 if (window_dim.window_reversal()) {
341 VLOG(1) << "Window reversal field not supported";
342 return no_match_result;
343 }
344 }
345
346 const auto& input_spatial_dims = dnums.input_spatial_dimensions();
347 const auto& output_spatial_dims = dnums.output_spatial_dimensions();
348 CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
349 CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
350
351 const Window& old_window = conv->window();
352 Window new_window = old_window;
353 for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
354 // Restore backward convolution's padding config from the matched pattern.
355 // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we
356 // convert backward input convolution to a variant of forward convolution.
357 //
358 // The stride of the backward convolution
359 // = the base dilation factor of the forward convolution
360 auto dim = new_window.mutable_dimensions(i);
361 dim->set_stride(old_window.dimensions(i).base_dilation());
362 dim->set_base_dilation(1);
363
364 // The low padding = kernel_size - 1 - low padding on the gradients
365 // Make sure the low padding is not negative.
366 auto kernel_size = old_window.dimensions(i).size();
367 auto backward_padding_low =
368 kernel_size - 1 - old_window.dimensions(i).padding_low();
369 if (backward_padding_low < 0) {
370 LOG(WARNING)
371 << "The low padding of the backward convolution would be negative ("
372 << backward_padding_low
373 << "), which isn't supported by GpuConvPaddingLegalization "
374 "for now (b/32744257).";
375 return no_match_result;
376 }
377 dim->set_padding_low(backward_padding_low);
378
379 // Compute the range of the amount of padding on the right/bottom of the
380 // activations. XLA's convolution requires all patches to be within the
381 // padded base. This gives us flexiblity to choose the amount of high
382 // padding from a set of values without changing the result of the backward
383 // convolution. The minimum amount (min_padding_high) makes the last patch
384 // end at the border. The maximum amount (max_padding_high) equals
385 // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
386 // size to change.
387 auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
388 auto output_size =
389 conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
390 auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
391 auto total_pad_size = padded_input_size - unpadded_input_size;
392 auto min_padding_high = total_pad_size - backward_padding_low;
393 auto max_padding_high = min_padding_high + dim->stride() - 1;
394
395 if (backward_padding_low >= min_padding_high &&
396 backward_padding_low <= max_padding_high) {
397 // In the best case (most likely), if backward_padding_low is in the range
398 // of the amounts of valid high padding, we choose backward_padding_low
399 // because cuDNN supports even padding only.
400 dim->set_padding_high(backward_padding_low);
401 } else {
402 // Otherwise, we choose the amount that's closest to backward_padding_low,
403 // and GpuConvPaddingLegalization will later insert kSlice
404 // instructions to enforce even padding.
405 //
406 // For example, consider the backward convolution pattern
407 //
408 // ab xy
409 // | pad | reverse
410 // .a.b yx
411 // \ /
412 // ABC
413 //
414 // The amount of low padding on activations (in backward convolution) is
415 // backward_padding_low = kernel_size - 1 - forward_padding_low
416 // = 2 - 1 - 1 = 0
417 //
418 // The amount of padding high must be between 1 and 2, in order to make
419 // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
420 // the range of [1,2], so we pick the closest valid amount of padding
421 // high, which is 1 in this case. Therefore, we fuse the above pattern to
422 //
423 // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
424 if (backward_padding_low < min_padding_high) {
425 dim->set_padding_high(min_padding_high);
426 } else {
427 dim->set_padding_high(max_padding_high);
428 }
429 }
430 // GpuConvPaddingLegalization doesn't handle backward input
431 // convolution with negative padding for now. So fall back to unfused
432 // convolution in case of negative padding. For example,
433 // ABCD = Conv(abc, reverse(xy), padding_high=2)
434 // could be fused to
435 // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
436 // with positive padding low but negative padding high.
437 if (dim->padding_high() < 0) {
438 LOG(WARNING) << "Fusing this pattern to backward convolution would cause "
439 "negative padding ("
440 << dim->padding_high()
441 << ") on right/bottom of the activations, which is not "
442 "supported by GpuConvPaddingLegalization (b/32744257). "
443 "Falling back to unfused convolution for instruction: "
444 << conv->ToString();
445 return no_match_result;
446 }
447 }
448
449 // OK, it's a match! Switch the input feature dimension with the output
450 // feature dimension. Also switch the output with the input. This is the way
451 // cuDNN expects it to be.
452 auto conv_dnums = conv->convolution_dimension_numbers();
453 dnums.set_kernel_input_feature_dimension(
454 conv_dnums.kernel_output_feature_dimension());
455 dnums.set_kernel_output_feature_dimension(
456 conv_dnums.kernel_input_feature_dimension());
457 for (int i = 0; i < input_spatial_dims.size(); ++i) {
458 dnums.set_input_spatial_dimensions(i,
459 conv_dnums.output_spatial_dimensions(i));
460 dnums.set_output_spatial_dimensions(i,
461 conv_dnums.input_spatial_dimensions(i));
462 }
463 dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension());
464 dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension());
465 dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension());
466 dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension());
467
468 // If we matched against a constant, we need to add a reverse op that can be
469 // subsumed by the cuDNN call. algebraic-simplifier will later remove any
470 // unnecessary reverses.
471 if (reverse_filter->opcode() != HloOpcode::kReverse &&
472 reverse_filter->IsConstant()) {
473 // Create a double-reverse, which is a nop.
474 HloComputation* c = conv->parent();
475 reverse_filter = c->AddInstruction(
476 HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
477 dnums.kernel_spatial_dimensions()));
478 reverse_filter = c->AddInstruction(
479 HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
480 dnums.kernel_spatial_dimensions()));
481 TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter));
482 }
483
484 // Calculate the 'rhs' that goes into the backward input convolution.
485 HloInstruction* rhs = reverse_filter;
486 // One reverse is subsumed by the cuDNN call.
487 if (rhs->opcode() == HloOpcode::kReverse) {
488 rhs = rhs->mutable_operand(0);
489 }
490 if (conv->feature_group_count() == 1) {
491 return std::make_tuple(true, new_window, dnums, rhs);
492 }
493
494 // Handle grouped convolutions. Because we swapped the input feature dimension
495 // with the output feature dimension, we need to also reshape the kernel so
496 // that the 'feature_group_count' parameter still makes sense. The
497 // 'feature_group_count' parameter essentially specifies how often the
498 // 'kernel_input_feature_dimension' is repeated. So when we swap these
499 // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
500 // 'feature_group_count' and multiply the new
501 // 'kernel_output_feature_dimension' by 'feature_group_count'.
502 int64_t input_feature_dimension = dnums.kernel_input_feature_dimension();
503 int64_t output_feature_dimension = dnums.kernel_output_feature_dimension();
504 // The following code assumes that input_feature_dimension and
505 // output_feature_dimension are adjacent.
506 if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
507 return no_match_result;
508 }
509
510 int64_t input_features = rhs->shape().dimensions(input_feature_dimension);
511 int64_t output_features = rhs->shape().dimensions(output_feature_dimension);
512
513 // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G,
514 // out_depth / G]
515 std::vector<int64_t> reshape_dims = SpanToVector(rhs->shape().dimensions());
516 auto num_groups = conv->feature_group_count();
517 CHECK_EQ(input_features % num_groups, 0)
518 << "Input feature count should be an exact multiple of feature group "
519 "count";
520 reshape_dims[input_feature_dimension] =
521 reshape_dims[input_feature_dimension] / num_groups;
522 reshape_dims.insert(reshape_dims.begin() + input_feature_dimension,
523 num_groups);
524
525 HloComputation* c = conv->parent();
526 rhs = c->AddInstruction(HloInstruction::CreateReshape(
527 ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs));
528
529 // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ...,
530 // in_depth/G, G, out_depth / G]
531 std::vector<int64_t> transpose_dims(rhs->shape().dimensions_size());
532 std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
533 transpose_dims.erase(transpose_dims.begin() + input_feature_dimension);
534 transpose_dims.insert(transpose_dims.begin() + output_feature_dimension,
535 input_feature_dimension);
536 std::vector<int64_t> transpose_reshape_dims =
537 SpanToVector(rhs->shape().dimensions());
538 transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
539 input_feature_dimension);
540 transpose_reshape_dims.insert(
541 transpose_reshape_dims.begin() + output_feature_dimension, num_groups);
542 rhs = c->AddInstruction(HloInstruction::CreateTranspose(
543 ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims),
544 rhs, transpose_dims));
545
546 // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ...,
547 // in_depth/G, out_depth]
548 Shape new_shape = rhs->shape();
549 new_shape.DeleteDimension(output_feature_dimension);
550 new_shape.set_dimensions(output_feature_dimension,
551 output_features * num_groups);
552 rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
553 return std::make_tuple(true, new_window, dnums, rhs);
554 }
555
556 } // namespace conv_matchers
557
558 namespace {
559
CreateGpuConv(absl::string_view call_target,const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const Window & window,const ConvolutionDimensionNumbers & dnums,int64_t feature_group_count,const OpMetadata & metadata)560 HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape,
561 HloInstruction* lhs, HloInstruction* rhs,
562 const Window& window,
563 const ConvolutionDimensionNumbers& dnums,
564 int64_t feature_group_count,
565 const OpMetadata& metadata) {
566 HloComputation* computation = lhs->parent();
567
568 // This call returns a tuple of (conv_result, scratch_memory), where
569 // conv_result is the actual result of the convolution, and scratch_memory is
570 // temporary memory used by cudnn.
571 //
572 // At the moment, we don't know how much scratch memory this conv is going to
573 // use, so we put u8[0] in this place. Later on another pass will choose
574 // which conv algorithm to use, and at that point we'll modify the shape of
575 // this second tuple element.
576 Shape call_shape =
577 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
578
579 HloInstruction* custom_call = computation->AddInstruction(
580 HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
581 custom_call->set_window(window);
582 custom_call->set_convolution_dimension_numbers(dnums);
583 custom_call->set_feature_group_count(feature_group_count);
584 custom_call->set_metadata(metadata);
585
586 // Give the customcall a user-friendly name.
587 std::optional<std::string> name;
588 if (call_target == kCudnnConvForwardCallTarget) {
589 name = "cudnn-conv";
590 } else if (call_target == kCudnnConvBackwardInputCallTarget) {
591 name = "cudnn-conv-bw-input";
592 } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
593 name = "cudnn-conv-bw-filter";
594 } else if (call_target == kCudnnConvBiasActivationForwardCallTarget) {
595 name = "cudnn-conv-bias-activation";
596 }
597 if (name.has_value()) {
598 computation->parent()->SetAndUniquifyInstrName(custom_call, *name);
599 }
600
601 return custom_call;
602 }
603
ConvertBatchGroupedToFeatureGroupedConvolution(HloInstruction * conv)604 HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(
605 HloInstruction* conv) {
606 CHECK_EQ(conv->feature_group_count(), 1);
607 int64_t num_groups = conv->batch_group_count();
608 auto dim_numbers = conv->convolution_dimension_numbers();
609 auto lhs = conv->mutable_operand(0);
610 auto rhs = conv->mutable_operand(1);
611
612 int64_t input_batch_dimension = dim_numbers.input_batch_dimension();
613
614 Shape output_shape = conv->shape();
615 int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
616 int64_t input_feature = lhs->shape().dimensions(input_feature_dimension);
617
618 HloComputation* computation = lhs->parent();
619 auto add = [&](std::unique_ptr<HloInstruction> inst) {
620 return computation->AddInstruction(std::move(inst));
621 };
622 // Reshape batch_dim N -> [G, N/G]
623 std::vector<int64_t> reshape_dims = SpanToVector(lhs->shape().dimensions());
624 reshape_dims[input_batch_dimension] =
625 reshape_dims[input_batch_dimension] / num_groups;
626 reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
627 lhs = add(HloInstruction::CreateReshape(
628 ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
629
630 // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
631 // W, G, C]
632 std::vector<int64_t> transpose_dims(lhs->shape().dimensions_size());
633 std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
634 transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
635 transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
636 input_batch_dimension);
637 std::vector<int64_t> transpose_reshape_dims =
638 ComposePermutations(lhs->shape().dimensions(), transpose_dims);
639 lhs = add(HloInstruction::CreateTranspose(
640 ShapeUtil::MakeShape(lhs->shape().element_type(), transpose_reshape_dims),
641 lhs, transpose_dims));
642
643 // Merge [G,C] -> [C*G]
644 Shape new_shape = lhs->shape();
645 new_shape.DeleteDimension(input_feature_dimension);
646 new_shape.set_dimensions(input_feature_dimension, input_feature * num_groups);
647 lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
648
649 std::vector<HloInstruction*> new_operands = {lhs, rhs};
650 auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
651 new_conv->set_feature_group_count(num_groups);
652 new_conv->set_batch_group_count(1);
653 new_conv->set_convolution_dimension_numbers(dim_numbers);
654 return computation->AddInstruction(std::move(new_conv));
655 }
656
GetDefaultBackendConfig()657 CudnnConvBackendConfig GetDefaultBackendConfig() {
658 CudnnConvBackendConfig config;
659 config.set_conv_result_scale(1);
660 return config;
661 }
662
663 // Helper function to create a custom_call instruction to replace the given
664 // conv instruction
CreateCustomCallHelper(HloInstruction * conv)665 static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
666 bool match;
667 Window window;
668 ConvolutionDimensionNumbers dnums;
669 HloInstruction* rhs;
670 HloInstruction* lhs;
671
672 std::tie(match, window, dnums, rhs) = conv_matchers::MatchBackwardInput(conv);
673 if (match) {
674 return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
675 conv->mutable_operand(0), rhs, window, dnums,
676 conv->feature_group_count(), conv->metadata());
677 }
678
679 std::tie(match, window, dnums, lhs) =
680 conv_matchers::MatchBackwardFilter(conv);
681 if (match) {
682 return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
683 conv->mutable_operand(1), window, dnums,
684 conv->batch_group_count(), conv->metadata());
685 }
686
687 // If all else fails, try a forward convolution.
688 if (conv_matchers::CanImplementAsGpuForwardConv(conv)) {
689 if (conv->batch_group_count() > 1) {
690 conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
691 }
692
693 return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(),
694 conv->mutable_operand(0), conv->mutable_operand(1),
695 conv->window(), conv->convolution_dimension_numbers(),
696 conv->feature_group_count(), conv->metadata());
697 }
698
699 return nullptr;
700 }
701
702 // Tries to rewrite a single convolution into a call to cudnn/miopen.
RunOnInstruction(HloInstruction * conv)703 StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
704 CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
705
706 TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
707 CreateCustomCallHelper(conv));
708 if (custom_call == nullptr) {
709 return false;
710 }
711
712 TF_RETURN_IF_ERROR(
713 custom_call->set_backend_config(GetDefaultBackendConfig()));
714
715 VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
716 << custom_call->ToString();
717
718 // The CustomCall returns a tuple (conv_result, scratch_memory). Extract
719 // out the conv result and replace `conv` with it.
720 TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
721 conv,
722 HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
723 return true;
724 }
725
726 // Rewrites the convolutions in the given computation into calls to
727 // cudnn/miopen.
728 // Returns true if it made any changes.
RunOnComputation(HloComputation * computation)729 StatusOr<bool> RunOnComputation(HloComputation* computation) {
730 std::vector<HloInstruction*> convs;
731 for (auto* hlo : computation->instructions()) {
732 if (hlo->opcode() == HloOpcode::kConvolution) {
733 convs.push_back(hlo);
734 }
735 }
736
737 bool changed = false;
738 for (HloInstruction* conv : convs) {
739 TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv));
740 changed |= result;
741 }
742 return changed;
743 }
744 } // namespace
745
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)746 StatusOr<bool> GpuConvRewriter::Run(
747 HloModule* module,
748 const absl::flat_hash_set<absl::string_view>& execution_threads) {
749 XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString());
750 bool changed = false;
751 for (HloComputation* computation :
752 module->MakeNonfusionComputations(execution_threads)) {
753 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
754 changed |= result;
755 }
756 XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString());
757 return changed;
758 }
759
760 } // namespace gpu
761 } // namespace xla
762