1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_vectorize_convolutions.h"
17
18 #include <optional>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/service/gpu/cudnn_support_utils.h"
23 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26
27 namespace xla {
28 namespace gpu {
29
30 // Finds convolutions that this pass may be able to transform, namely int8_t
31 // cudnn forward or forward-bias-activation convolutions
32 //
33 // cudnn as of v8.2 supports the following data type combinations for forward
34 // and forward-bias-activation convolutions. We have to make sure we only
35 // vectorize to one of these supported configs.
36 //
37 // in out
38 // int8x1 int8x1
39 // int8x1 float
40 // int8x1 int32_t
41 //
42 // int8x4 int8x4
43 // int8x4 float
44 //
45 // int8x32 int8x32
46 //
47 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionForward
48 //
49 // For now we restrict ourselves to only the int8xN -> int8xN cases. We could
50 // allow the int8x4 -> float case in the future if desirable.
GetRelevantConvs(HloComputation * comp)51 static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
52 HloComputation* comp) {
53 std::vector<HloCustomCallInstruction*> convs;
54 for (HloInstruction* instr : comp->instructions()) {
55 if (instr->opcode() != HloOpcode::kCustomCall ||
56 (instr->custom_call_target() != kCudnnConvForwardCallTarget &&
57 instr->custom_call_target() !=
58 kCudnnConvBiasActivationForwardCallTarget) ||
59 instr->operand_count() < 2) {
60 continue;
61 }
62
63 PrimitiveType input_ty = instr->operand(0)->shape().element_type();
64 PrimitiveType output_ty = instr->shape().tuple_shapes(0).element_type();
65 if (input_ty == output_ty && (input_ty == S8 || input_ty == U8)) {
66 convs.push_back(Cast<HloCustomCallInstruction>(instr));
67 }
68 }
69 return convs;
70 }
71
72 // Converts an XlaBuilder into an HloComputation in the same module as
73 // `sibling_computation`.
74 //
75 // Yes, we serialize/deserialize as a proto. :)
BuilderToHloComputation(XlaBuilder & b,XlaOp root,HloComputation * sibling_computation)76 static StatusOr<HloComputation*> BuilderToHloComputation(
77 XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) {
78 TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
79 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape());
80 HloModuleConfig config(program_shape);
81 TF_ASSIGN_OR_RETURN(auto new_module,
82 HloModule::CreateFromProto(comp.proto(), config));
83
84 HloModule* dest_module = sibling_computation->parent();
85 HloCloneContext context(dest_module);
86 return dest_module->DeepCloneComputation(new_module->entry_computation(),
87 &context);
88 }
89
90 // Reshapes `instr` so that it has an extra dimension of size `vect_size` right
91 // after `dim`.
SplitAtDim(XlaOp instr,int64_t dim,int64_t vect_size)92 static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) {
93 XlaBuilder& b = *instr.builder();
94 Shape shape = b.GetShape(instr).ValueOrDie();
95 DimensionVector new_dims(shape.dimensions().begin(),
96 shape.dimensions().end());
97 CHECK_EQ(new_dims[dim] % vect_size, 0);
98 new_dims[dim] /= vect_size;
99 new_dims.insert(new_dims.begin() + dim + 1, vect_size);
100 return Reshape(instr, new_dims);
101 }
102
103 // Reshapes `shape` so that there's an extra dimension of size `vect_size` right
104 // after `dim`.
105 //
106 // For example given shape=s8[10, 32, 20], dim=1, vect_size=4, returns
107 // s8[10, 8, 4, 20].
SplitShapeAtDim(Shape shape,int64_t dim,int64_t vect_size)108 static Shape SplitShapeAtDim(Shape shape, int64_t dim, int64_t vect_size) {
109 DimensionVector new_dims(shape.dimensions().begin(),
110 shape.dimensions().end());
111 CHECK_EQ(new_dims[dim] % vect_size, 0);
112 new_dims[dim] /= vect_size;
113 new_dims.insert(new_dims.begin() + dim + 1, vect_size);
114 return ShapeUtil::MakeShape(shape.element_type(), new_dims);
115 }
116
117 // Transposes dimension `src` to right before `dst`.
MoveDim(XlaOp instr,int64_t src,int64_t dst)118 static XlaOp MoveDim(XlaOp instr, int64_t src, int64_t dst) {
119 XlaBuilder& b = *instr.builder();
120 int64_t rank = b.GetShape(instr)->dimensions_size();
121
122 DimensionVector idxs(rank);
123 absl::c_iota(idxs, 0);
124 if (src < dst) {
125 idxs.insert(idxs.begin() + dst, src);
126 idxs.erase(idxs.begin() + src);
127 } else {
128 idxs.erase(idxs.begin() + src);
129 idxs.insert(idxs.begin() + dst, src);
130 }
131 return Transpose(instr, idxs);
132 }
133
134 // Reshapes instr so that dimension `vect_dim` has size `vect_size`, by stealing
135 // elements from `dim`.
136 //
137 // Requires that this is possible without merging and re-splitting the two
138 // dimensions. I.e. there should be some amount of dim that we can "split off"
139 // and add to vect_dim to get it to have size vect_size.
RevectorizeInstr(XlaOp instr,int64_t dim,int64_t vect_dim,int64_t vect_size)140 static XlaOp RevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
141 int64_t vect_size) {
142 XlaBuilder& b = *instr.builder();
143 Shape shape = b.GetShape(instr).ValueOrDie();
144 auto size = [&](int64_t d) { return shape.dimensions(d); };
145
146 CHECK_LE(size(vect_dim), vect_size);
147 CHECK_EQ(vect_size % size(vect_dim), 0);
148
149 int64_t split_factor = vect_size / size(vect_dim);
150 CHECK_EQ(size(dim) % split_factor, 0);
151
152 // Split dim into [C, split_factor].
153 instr = SplitAtDim(instr, dim, split_factor);
154
155 // SplitAtDim may have added a dimension before vect_dim.
156 if (vect_dim > dim) {
157 vect_dim++;
158 }
159
160 // Move the split_factor dimension to right before vect_dim.
161 instr = MoveDim(instr, dim + 1, vect_dim);
162
163 // Moving the split_factor dimension may have *removed* a dimension before
164 // vect_dim.
165 if (vect_dim > dim) {
166 vect_dim--;
167 }
168
169 // Collapse the split_factor dimension into vect_dim.
170 return Collapse(instr, {vect_dim, vect_dim + 1});
171 }
172
173 // Inverse of RevectorizeInstr. Reshapes instr so that dimension `vect_dim` has
174 // size `vect_size`, moving excess elements into `dim`.
UnrevectorizeInstr(XlaOp instr,int64_t dim,int64_t vect_dim,int64_t orig_vect_size)175 static XlaOp UnrevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
176 int64_t orig_vect_size) {
177 XlaBuilder& b = *instr.builder();
178 Shape shape = b.GetShape(instr).ValueOrDie();
179 auto size = [&](int64_t d) { return shape.dimensions(d); };
180
181 CHECK_GE(size(vect_dim), orig_vect_size);
182 CHECK_EQ(size(vect_dim) % orig_vect_size, 0);
183
184 // Split vect_dim into [C, orig_vect_size].
185 instr = SplitAtDim(instr, vect_dim, orig_vect_size);
186
187 // SplitAtDim may have added a dimension before dim.
188 if (dim > vect_dim) {
189 dim++;
190 }
191
192 // Move the `C` dimension to right after `dim`. Take into account that
193 // SplitAtDim may have added a dimension before dim.
194 instr = MoveDim(instr, vect_dim, dim + 1);
195
196 // MoveDim may have *removed* a dimension before dim.
197 if (dim > vect_dim) {
198 dim--;
199 }
200
201 // Collapse the `C` and `dim` dimensions.
202 return Collapse(instr, {dim, dim + 1});
203 }
204
205 // Adds a vectorized-feature dimension to dnums right after the current feature
206 // dimension.
207 //
208 // ConvolutionDimensionNumbers doesn't represent the vectorized-feature
209 // dimension explicitly, because the whole concept of a vectorized-feature
210 // dimension is specific to cudnn. Rather, the vectorized-feature dimension is
211 // implicit; it's the first dimension that *doesn't* appear in the dnums.
212 //
213 // This function "makes room" in dnums for the new vectorized dimension by
214 // incrementing any dimensions which appear after the feature dim. The implicit
215 // vector dim is then in this "empty" spot.
VectorizeDnums(ConvolutionDimensionNumbers dnums)216 static ConvolutionDimensionNumbers VectorizeDnums(
217 ConvolutionDimensionNumbers dnums) {
218 int64_t input_vect_dim = dnums.input_feature_dimension();
219 if (dnums.input_batch_dimension() > input_vect_dim) {
220 dnums.set_input_batch_dimension(dnums.input_batch_dimension() + 1);
221 }
222 for (int64_t& d : *dnums.mutable_input_spatial_dimensions()) {
223 if (d > input_vect_dim) {
224 ++d;
225 }
226 }
227
228 int64_t kernel_vect_dim = dnums.kernel_input_feature_dimension();
229 if (dnums.kernel_output_feature_dimension() > kernel_vect_dim) {
230 dnums.set_kernel_output_feature_dimension(
231 dnums.kernel_output_feature_dimension() + 1);
232 }
233 for (int64_t& d : *dnums.mutable_kernel_spatial_dimensions()) {
234 if (d > kernel_vect_dim) {
235 ++d;
236 }
237 }
238
239 int64_t output_vect_dim = dnums.output_feature_dimension();
240 if (dnums.output_batch_dimension() > output_vect_dim) {
241 dnums.set_output_batch_dimension(dnums.output_batch_dimension() + 1);
242 }
243 for (int64_t& d : *dnums.mutable_output_spatial_dimensions()) {
244 if (d > output_vect_dim) {
245 ++d;
246 }
247 }
248
249 return dnums;
250 }
251
252 // Tries to vectorize an already-vectorized convolution.
253 //
254 // That is, given a convolution of shape [N, C/k, H, W, k], changes it to have
255 // shape [N, C/vect_size, H, W, vect_size]. Similarly changes the filter from
256 // [H, W, I/k, O] to [H, W, I/vect_size, vect_size, O].
257 //
258 // (The dimensions can appear in any order; which is N/C/etc is determined by
259 // the convolutions' dnums.)
TryRevectorizeConv(const se::CudaComputeCapability & compute_capability,HloCustomCallInstruction * conv,int vect_size)260 static StatusOr<bool> TryRevectorizeConv(
261 const se::CudaComputeCapability& compute_capability,
262 HloCustomCallInstruction* conv, int vect_size) {
263 const Shape& input_shape = conv->operand(0)->shape();
264 const Shape& kernel_shape = conv->operand(1)->shape();
265 const Shape& output_shape = conv->shape().tuple_shapes(0);
266 const auto& dnums = conv->convolution_dimension_numbers();
267
268 // Find the vectorized-features dim in the input/kernel/output.
269 std::optional<int64_t> input_vect_dim;
270 std::optional<int64_t> kernel_vect_dim;
271 std::optional<int64_t> output_vect_dim;
272 std::tie(input_vect_dim, kernel_vect_dim, output_vect_dim) =
273 FindVectorizedFeatureDims(dnums, input_shape, kernel_shape, output_shape);
274
275 if (!input_vect_dim.has_value() || !kernel_vect_dim.has_value() ||
276 !output_vect_dim.has_value()) {
277 return false;
278 }
279
280 int64_t input_feat_size =
281 input_shape.dimensions(dnums.input_feature_dimension());
282 int64_t output_feat_size =
283 output_shape.dimensions(dnums.output_feature_dimension());
284 int64_t input_vect_size = input_shape.dimensions(*input_vect_dim);
285 int64_t output_vect_size = output_shape.dimensions(*output_vect_dim);
286 if (vect_size % input_vect_size != 0 || vect_size % output_vect_size != 0 ||
287 input_feat_size % (vect_size / input_vect_size) != 0 ||
288 output_feat_size % (vect_size / output_vect_size) != 0) {
289 return false;
290 }
291
292 // If this is an integer convolution check that we only vectorize when cuDNN
293 // supports the vectorized implementation.
294 if (primitive_util::IsIntegralType(input_shape.element_type())) {
295 TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
296 CudnnSupportsOptimizedIntegerConvolution(
297 compute_capability, *conv, vect_size));
298 if (!supported_target_vectorization) {
299 VLOG(3) << "Skipping re-vectorization of conv to vector size: "
300 << vect_size << ": " << conv->ToString();
301 return false;
302 }
303 }
304
305 VLOG(1) << "Re-vectorizing conv channels from "
306 << input_shape.dimensions(*input_vect_dim) << " to " << vect_size
307 << ": " << conv->ToString();
308
309 // We use XlaBuilder because it's a lot easier to get these tricky
310 // reshape/transposes correct using that API.
311 XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
312 b.SetOpMetadata(conv->metadata());
313
314 absl::InlinedVector<XlaOp, 4> new_operands = {
315 RevectorizeInstr(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
316 dnums.input_feature_dimension(), *input_vect_dim,
317 vect_size),
318 RevectorizeInstr(Parameter(&b, 1, conv->operand(1)->shape(), "filter"),
319 dnums.kernel_input_feature_dimension(), *kernel_vect_dim,
320 vect_size),
321 };
322 if (conv->operand_count() > 2) {
323 // Bias, if present. This is passed through unmodified.
324 new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
325 }
326 if (conv->operand_count() > 3) {
327 new_operands.push_back(RevectorizeInstr(
328 Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
329 dnums.input_feature_dimension(), *input_vect_dim, vect_size));
330 }
331
332 if (conv->operand_count() > 4) {
333 return InvalidArgument(
334 "Don't understand a conv with more than 4 arguments: %s",
335 conv->ToString());
336 }
337
338 // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
339 // value in the tuple represents the convolution's scratch memory.
340 DimensionVector new_output_dims(output_shape.dimensions().begin(),
341 output_shape.dimensions().end());
342 new_output_dims[dnums.output_feature_dimension()] /=
343 (vect_size / output_vect_size);
344 new_output_dims[*output_vect_dim] = vect_size;
345 XlaOp new_conv = CustomCallWithConvDnums(
346 &b, conv->custom_call_target(), new_operands,
347 ShapeUtil::MakeTupleShape(
348 {ShapeUtil::MakeShape(output_shape.element_type(), new_output_dims),
349 ShapeUtil::MakeShape(U8, {0})}),
350 /*operand_shapes_with_layout=*/{},
351 /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
352 /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
353 /*window=*/conv->window(),
354 /*dnums=*/conv->convolution_dimension_numbers());
355
356 XlaOp new_conv_result = GetTupleElement(new_conv, 0);
357 XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
358
359 XlaOp new_conv_result_unrevectorized = UnrevectorizeInstr(
360 new_conv_result, dnums.output_feature_dimension(), *output_vect_dim,
361 /*orig_vect_size=*/output_shape.dimensions(*output_vect_dim));
362
363 TF_ASSIGN_OR_RETURN(
364 HloComputation * new_conv_comp,
365 BuilderToHloComputation(
366 b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}),
367 conv->parent()));
368
369 // Set the name on the new conv. This is purely cosmetic, but we attempt to
370 // preserve e.g. "cudnn-conv.42" instead of "custom-call.42".
371 auto new_conv_comp_instrs = new_conv_comp->instructions();
372 auto new_conv_it =
373 absl::c_find_if(new_conv_comp_instrs, [](HloInstruction* instr) {
374 return instr->opcode() == HloOpcode::kCustomCall;
375 });
376 if (new_conv_it != new_conv_comp_instrs.end()) {
377 new_conv_comp->parent()->SetAndUniquifyInstrName(*new_conv_it,
378 conv->name());
379 }
380
381 // Replace the old conv with a call to the computation we just created.
382 VLOG(1) << "Re-vectorized conv to " << new_conv_comp->ToString();
383 TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
384 conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
385 new_conv_comp)));
386
387 return true;
388 }
389
390 // Tries to vectorize a convolution.
391 //
392 // Given a convolution of dimensions [N, C, H, W], tries to convert it to have
393 // shape [N, C/vect_size, H, W, vect_size]. Similarly, given a kernel of shape
394 // [H, W, I, O], tries to conver it to [H, W, I/vect_size, vect_size, O].
395 //
396 // This requires that C be a multiple of vect_size. CudnnPadForConvolutions can
397 // add padding to make this true.
TryVectorizeConv(const se::CudaComputeCapability & compute_capability,HloCustomCallInstruction * conv,int64_t vect_size)398 static StatusOr<bool> TryVectorizeConv(
399 const se::CudaComputeCapability& compute_capability,
400 HloCustomCallInstruction* conv, int64_t vect_size) {
401 const Shape& input_shape = conv->operand(0)->shape();
402 const Shape& output_shape = conv->shape().tuple_shapes(0);
403 const auto& dnums = conv->convolution_dimension_numbers();
404 int64_t in_channels = input_shape.dimensions(dnums.input_feature_dimension());
405 int64_t out_channels =
406 output_shape.dimensions(dnums.output_feature_dimension());
407
408 if (in_channels % vect_size != 0 || out_channels % vect_size != 0) {
409 return false;
410 }
411
412 if (input_shape.dimensions_size() >
413 2 + dnums.input_spatial_dimensions_size()) {
414 // Conv already has an extra dimension, which we assume is the vectorized
415 // features dim.
416 return false;
417 }
418
419 // If this is an integer convolution check that we only vectorize when cuDNN
420 // supports the vectorized implementation.
421 if (primitive_util::IsIntegralType(input_shape.element_type())) {
422 TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
423 CudnnSupportsOptimizedIntegerConvolution(
424 compute_capability, *conv, vect_size));
425 if (!supported_target_vectorization) {
426 VLOG(3) << "Skipping vectorization of conv to vector size: " << vect_size
427 << ": " << conv->ToString();
428 return false;
429 }
430 }
431
432 VLOG(1) << "Vectorizing conv channels by " << vect_size << ": "
433 << conv->ToString();
434
435 // We use XlaBuilder because it's a lot easier to get these tricky
436 // reshape/transposes correct using that API.
437 XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
438 b.SetOpMetadata(conv->metadata());
439
440 absl::InlinedVector<XlaOp, 4> new_operands = {
441 SplitAtDim(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
442 dnums.input_feature_dimension(), vect_size),
443 SplitAtDim(Parameter(&b, 1, conv->operand(1)->shape(), "filter"),
444 dnums.kernel_input_feature_dimension(), vect_size),
445 };
446 if (conv->operand_count() > 2) {
447 // Bias, if present. This is passed through unmodified.
448 new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
449 }
450 if (conv->operand_count() > 3) {
451 // Handle side input, which has same shape as the input.
452 new_operands.push_back(
453 SplitAtDim(Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
454 dnums.input_feature_dimension(), vect_size));
455 }
456 if (conv->operand_count() > 4) {
457 return InvalidArgument(
458 "Don't understand a conv with more than 4 arguments: %s",
459 conv->ToString());
460 }
461
462 // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
463 // value in the tuple represents the convolution's scratch memory.
464 Shape new_output_shape = SplitShapeAtDim(
465 output_shape, dnums.output_feature_dimension(), vect_size);
466 XlaOp new_conv = CustomCallWithConvDnums(
467 &b, conv->custom_call_target(), new_operands,
468 ShapeUtil::MakeTupleShape(
469 {new_output_shape, ShapeUtil::MakeShape(U8, {0})}),
470 /*operand_shapes_with_layout=*/{},
471 /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
472 /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
473 /*window=*/conv->window(),
474 /*dnums=*/VectorizeDnums(dnums));
475
476 XlaOp new_conv_result = GetTupleElement(new_conv, 0);
477 XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
478
479 // Reshape back to the original shape.
480 XlaOp conv_result_collapsed = Collapse(
481 new_conv_result,
482 {dnums.output_feature_dimension(), dnums.output_feature_dimension() + 1});
483
484 TF_ASSIGN_OR_RETURN(
485 HloComputation * new_conv_comp,
486 BuilderToHloComputation(
487 b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}),
488 conv->parent()));
489
490 // Create a tuple and replace the old conv with it!
491 VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString();
492 TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
493 conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
494 new_conv_comp)));
495 return true;
496 }
497
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)498 StatusOr<bool> CudnnVectorizeConvolutions::Run(
499 HloModule* module,
500 const absl::flat_hash_set<absl::string_view>& execution_threads) {
501 bool changed = false;
502 for (HloComputation* comp :
503 module->MakeNonfusionComputations(execution_threads)) {
504 for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
505 // Try to (re)vectorize to int8x32 if this is an sm75+ GPU. If we can't,
506 // fall back to int8x4.
507 bool local_changed = false;
508 if (compute_capability_.IsAtLeast(7, 5)) {
509 TF_ASSIGN_OR_RETURN(local_changed,
510 TryRevectorizeConv(compute_capability_, conv, 32));
511 if (!local_changed) {
512 TF_ASSIGN_OR_RETURN(local_changed,
513 TryVectorizeConv(compute_capability_, conv, 32));
514 }
515 }
516 if (!local_changed) {
517 TF_ASSIGN_OR_RETURN(local_changed,
518 TryVectorizeConv(compute_capability_, conv, 4));
519 }
520 changed |= local_changed;
521 }
522 }
523 return changed;
524 }
525
526 } // namespace gpu
527 } // namespace xla
528