1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
25 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
26 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
27 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/window_util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/errors.h"
37
38 namespace xla {
39 namespace gpu {
40
41 using se::dnn::DataLayout;
42 using se::dnn::FilterLayout;
43
44 // Returns (input, filter, output) layouts.
45 static std::tuple<DataLayout, FilterLayout, DataLayout>
HeuristicLayoutAssignment(const HloInstruction * instr,se::StreamExecutor * stream_executor)46 HeuristicLayoutAssignment(const HloInstruction* instr,
47 se::StreamExecutor* stream_executor) {
48 // DataLayout and FilterLayout uses weird enum names. Translations:
49 // N <=> Batch or Output
50 // C <=> Depth or Input
51 // H <=> Y
52 // W <=> X
53 //
54 // Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
55 //
56 // If you have trouble keeping these straight, consider that all that matters
57 // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
58
59 constexpr auto kAllNCHW =
60 std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
61 DataLayout::kBatchDepthYX);
62 // kBatchDepthYX4 has the same layout as kBatchDepthYX32; they're both VECT_C
63 // layouts as far as cudnn is concerned.
64 constexpr auto kAllNCHW_VECT_C =
65 std::make_tuple(DataLayout::kBatchDepthYX4, FilterLayout::kOutputInputYX4,
66 DataLayout::kBatchDepthYX4);
67 constexpr auto kAllNHWC =
68 std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
69 DataLayout::kBatchYXDepth);
70
71 // Integer convolution must use NHWC or NCHW_VECT_C.
72 //
73 // TODO(jlebar): Do non-VECT_C int8_t convs still require NHWC with new
74 // versions of cudnn?
75 const ConvolutionDimensionNumbers& dnums =
76 instr->convolution_dimension_numbers();
77 Shape input_shape = instr->operand(0)->shape();
78 PrimitiveType input_ty = instr->operand(0)->shape().element_type();
79 if (primitive_util::IsIntegralType(input_ty)) {
80 if (input_ty == S8 && dnums.input_spatial_dimensions_size() == 2 &&
81 input_shape.dimensions_size() == 5) {
82 VLOG(2) << "Using NCHW_VECT_C for int8_t conv " << instr->ToString();
83 return kAllNCHW_VECT_C;
84 }
85 VLOG(2) << "Using NHWC for int8_t conv " << instr->ToString();
86 return kAllNHWC;
87 }
88
89 const DebugOptions& debug_options =
90 instr->GetModule()->config().debug_options();
91
92 if (debug_options.xla_gpu_force_conv_nchw()) {
93 VLOG(2) << "Overriding layout to NCHW for " << instr->ToString();
94 return kAllNCHW;
95 }
96
97 if (debug_options.xla_gpu_force_conv_nhwc()) {
98 VLOG(2) << "Overriding layout to NHWC for " << instr->ToString();
99 return kAllNHWC;
100 }
101
102 // If we're not Volta or not fp16, or not conv2D, the decision is easy: Use
103 // NCHW.
104 if (input_ty != F16 ||
105 !stream_executor->GetDeviceDescription()
106 .cuda_compute_capability()
107 .IsAtLeast(se::CudaComputeCapability::VOLTA) ||
108 instr->shape().tuple_shapes(0).dimensions_size() != 4) {
109 return kAllNCHW;
110 }
111
112 VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
113
114 // Empirically we've found with Volta and cudnn <= 7.3 that backward-input
115 // convs with stride are significantly faster with NCHW layouts.
116 //
117 // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW),
118 // which on paper gives good performance. However, there are two observations:
119 // * a mixed layout combination is more cuDNN-bug prone, based on empirical
120 // evidence.
121 // * we've also observed that for mixed layouts, cuDNN transposes data back
122 // and forth from a different layout combination. If we end up with
123 // transposes anyway, we prefer to have them in XLA, as they can be fused.
124 if (auto* dnn = stream_executor->AsDnn()) {
125 auto version_status = dnn->GetVersion();
126 if (version_status.ok()) {
127 auto version = std::move(version_status).value();
128 if (std::make_tuple(version.major_version(), version.minor_version()) <=
129 std::make_tuple(7, 3) &&
130 instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
131 window_util::HasStride(instr->window())) {
132 return kAllNCHW;
133 }
134 }
135 }
136
137 // For other Volta f16 convolutions, use NHWC.
138 return kAllNHWC;
139 }
140
141 // Adds layout constraints on the cudnn custom-call instruction. The layout
142 // constraints are represented in terms of minor_to_major fields of both
143 // operands and the output shape. Depending on the underlying algorithm, one of
144 // { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
AddBackendConstraintsToDnnConvCustomCall(HloCustomCallInstruction * instr,LayoutConstraints * constraints)145 Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
146 HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
147 Shape lhs_shape = instr->operand(0)->shape();
148 Shape rhs_shape = instr->operand(1)->shape();
149 Shape result_shape = instr->shape().tuple_shapes(0);
150
151 Shape* input_shape;
152 Shape* filter_shape;
153 Shape* output_shape;
154
155 TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
156 switch (kind) {
157 case CudnnConvKind::kForward:
158 case CudnnConvKind::kForwardActivation:
159 input_shape = &lhs_shape;
160 filter_shape = &rhs_shape;
161 output_shape = &result_shape;
162 break;
163 case CudnnConvKind::kBackwardInput:
164 input_shape = &result_shape;
165 filter_shape = &rhs_shape;
166 output_shape = &lhs_shape;
167 break;
168 case CudnnConvKind::kBackwardFilter:
169 input_shape = &lhs_shape;
170 filter_shape = &result_shape;
171 output_shape = &rhs_shape;
172 break;
173 }
174
175 {
176 DataLayout input;
177 FilterLayout filter;
178 DataLayout output;
179 std::tie(input, filter, output) =
180 HeuristicLayoutAssignment(instr, stream_executor_);
181
182 TF_ASSIGN_OR_RETURN(
183 std::tie(*input_shape->mutable_layout(),
184 *filter_shape->mutable_layout(),
185 *output_shape->mutable_layout()),
186 StreamExecutorConvLayoutsToXlaLayouts(
187 instr->convolution_dimension_numbers(), input, filter, output));
188 }
189
190 // The custom call returns a tuple of (actual_result, scratch_buffer);
191 // call_result_buf is the logical buffer for actual_result, the thing that
192 // contains the result of the conv call.
193 TF_ASSIGN_OR_RETURN(
194 const LogicalBuffer* call_result_buf,
195 points_to_analysis_->GetBufferDefinedAt(instr, /*index=*/{0}));
196
197 // Set layouts of the instructions' shapes.
198 TF_RETURN_IF_ERROR(SetOperandLayout(lhs_shape, instr, 0));
199 TF_RETURN_IF_ERROR(SetOperandLayout(rhs_shape, instr, 1));
200 TF_RETURN_IF_ERROR(SetBufferLayout(result_shape.layout(), *call_result_buf));
201 // instr->operand(2), if exists, is the bias buffer. There is no need to
202 // assign layout to it, as it has only one dimension.
203
204 // instr->operand(3), if exists, is the side input buffer.
205 if (instr->operand_count() == 4) {
206 if (kind != CudnnConvKind::kForwardActivation) {
207 return InternalError(
208 "Invalid convolution. Conv has a side input, but kind is not fused "
209 "conv forward: %s",
210 instr->ToString());
211 }
212 // The side input layout must match the output layout.
213 TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, 3));
214 }
215 return OkStatus();
216 }
217
218 namespace {
219
220 // Imposes the default layout with first two dimensions swapped on input
221 // `shape`.
SetFortranLayout(Shape * shape)222 void SetFortranLayout(Shape* shape) {
223 LayoutUtil::SetToDefaultLayout(shape);
224 int n = shape->mutable_layout()->minor_to_major_size();
225 CHECK_GE(n, 2);
226 std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
227 shape->mutable_layout()->mutable_minor_to_major()->at(1));
228 }
229
DotCanSupportShapeWithLayout(const HloInstruction * dot,const Shape & shape)230 bool DotCanSupportShapeWithLayout(const HloInstruction* dot,
231 const Shape& shape) {
232 const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers();
233 // If we are able to construct a `MatrixLayout` then the dot can support
234 // this layout.
235 return MatrixLayout::For(shape, dot_dims.lhs_batch_dimensions().size(),
236 dot_dims.lhs_contracting_dimensions().size(),
237 dot_dims.rhs_batch_dimensions().size(),
238 dot_dims.rhs_contracting_dimensions().size())
239 .ok();
240 }
241
242 } // namespace
243
AddBackendConstraints(LayoutConstraints * constraints)244 Status GpuLayoutAssignment::AddBackendConstraints(
245 LayoutConstraints* constraints) {
246 // Add convolution constraints in reverse postorder that the earliest
247 // convolution layout propagates first. This reduces the likelihood of fusion
248 // nodes with copies.
249 auto post_order = constraints->computation()->MakeInstructionPostOrder();
250 for (auto iterator = post_order.rbegin(); iterator != post_order.rend();
251 ++iterator) {
252 HloInstruction* instruction = *iterator;
253 if (IsCustomCallToDnnConvolution(*instruction)) {
254 TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
255 Cast<HloCustomCallInstruction>(instruction), constraints));
256 }
257
258 CHECK(!IsCublasGemm(*instruction))
259 << "Gemm rewriting should run after layout assignment";
260
261 if (IsMatrixMultiplication(*instruction)) {
262 const Shape& output_shape = instruction->shape();
263 const Shape& lhs_shape = instruction->operand(0)->shape();
264 const Shape& rhs_shape = instruction->operand(1)->shape();
265 const DotDimensionNumbers& dot_dims =
266 instruction->dot_dimension_numbers();
267
268 // Matmuls require the batch dimensions to be in consecutive physical
269 // dimensions and likewise for the contracting and non-contracting
270 // dimensions. Additionally, no batch dimension can be in the most
271 // minor physical dimension for inputs or the output.
272 absl::Span<const int64_t> lhs_batch_dims =
273 dot_dims.lhs_batch_dimensions();
274 absl::Span<const int64_t> lhs_col_dims =
275 dot_dims.lhs_contracting_dimensions();
276 TF_ASSIGN_OR_RETURN(
277 std::vector<int64_t> lhs_row_dims,
278 GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_col_dims));
279
280 absl::Span<const int64_t> rhs_batch_dims =
281 dot_dims.rhs_batch_dimensions();
282 absl::Span<const int64_t> rhs_row_dims =
283 dot_dims.rhs_contracting_dimensions();
284 TF_ASSIGN_OR_RETURN(
285 std::vector<int64_t> rhs_col_dims,
286 GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_row_dims));
287
288 // For unbatched S8xS8->S32 matrix multiplication enforce a TN layout,
289 // which will allow the NVidia GPUs to use TensorCores.
290 bool is_s8_to_s32 = (output_shape.element_type() == PrimitiveType::S32 &&
291 lhs_shape.element_type() == PrimitiveType::S8 &&
292 rhs_shape.element_type() == PrimitiveType::S8 &&
293 output_shape.dimensions_size() == 2 &&
294 lhs_shape.dimensions_size() == 2 &&
295 rhs_shape.dimensions_size() == 2);
296
297 if (is_s8_to_s32) {
298 TF_RETURN_IF_ERROR(SetOperandBatchRowsColsLayout(
299 instruction, 0, lhs_batch_dims, lhs_row_dims, lhs_col_dims));
300 TF_RETURN_IF_ERROR(SetOperandBatchRowsColsLayout(
301 instruction, 1, rhs_batch_dims, rhs_col_dims, rhs_row_dims));
302 TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
303 } else if (!lhs_batch_dims.empty()) {
304 TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 0, lhs_batch_dims,
305 lhs_row_dims, lhs_col_dims));
306 TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 1, rhs_batch_dims,
307 rhs_row_dims, rhs_col_dims));
308 TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
309 }
310 } else if (instruction->opcode() == HloOpcode::kTranspose) {
311 const HloInstruction* operand = instruction->operand(0);
312 if ((operand->opcode() != HloOpcode::kDot) ||
313 (operand->user_count() > 1)) {
314 continue;
315 }
316
317 // If possible, set layout of the dot operation such that the output of
318 // the transpose (as a bitcast) has the default layout.
319 Shape shape = operand->shape();
320 *shape.mutable_layout() =
321 LayoutUtil::MakeLayoutFromMajorToMinor(instruction->dimensions());
322
323 if (DotCanSupportShapeWithLayout(operand, shape)) {
324 TF_RETURN_IF_ERROR(
325 SetOperandLayout(shape, instruction, /*operand_no=*/0));
326 }
327 } else if (instruction->opcode() == HloOpcode::kFft) {
328 // cuFFT requires a dim0 major layout.
329 Shape op0_shape = instruction->operand(0)->shape();
330 LayoutUtil::SetToDefaultLayout(&op0_shape);
331 Shape output_shape = instruction->shape();
332 LayoutUtil::SetToDefaultLayout(&output_shape);
333 TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
334 TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
335 } else if (instruction->opcode() == HloOpcode::kSort &&
336 instruction->operand(0)->shape().rank() > 1) {
337 // Make sure that all the operands and the output(s) have the same layout.
338 Shape keys_shape = instruction->operand(0)->shape();
339 Layout keys_layout =
340 LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank());
341 for (int64_t i = 0; i < instruction->operand_count(); ++i) {
342 Shape shape = instruction->operand(i)->shape();
343 *shape.mutable_layout() = keys_layout;
344 TF_RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i));
345 const LogicalBuffer* output_buffer;
346 if (instruction->shape().IsArray()) {
347 TF_ASSIGN_OR_RETURN(
348 output_buffer,
349 points_to_analysis_->GetBufferDefinedAt(instruction, {}));
350 } else {
351 TF_ASSIGN_OR_RETURN(
352 output_buffer,
353 points_to_analysis_->GetBufferDefinedAt(instruction, {i}));
354 }
355 TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer));
356 }
357 } else if (instruction->opcode() == HloOpcode::kTriangularSolve) {
358 // TODO(phawkins): Ideally we would relax this constraint. What we
359 // actually want is that:
360 // a) the batch dimensions are major, in no particular order.
361 // b) the two minor dimensions are in fortran (column-major) order,
362 // although for the 'a' argument we could potentially accept row-major
363 // order and fold the transpose into the operator.
364 Shape op0_shape = instruction->operand(0)->shape();
365 Shape op1_shape = instruction->operand(1)->shape();
366 Shape output_shape = instruction->shape();
367 SetFortranLayout(&op0_shape);
368 SetFortranLayout(&op1_shape);
369 SetFortranLayout(&output_shape);
370 TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
371 TF_RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1));
372 TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
373 } else if (instruction->opcode() == HloOpcode::kReduceScatter) {
374 // XLA:GPU can only support reduce-scatter where the scatter dimension
375 // is the most major dimension in the layout.
376 auto ars = Cast<HloReduceScatterInstruction>(instruction);
377 TF_RETURN_IF_ERROR(SetInstructionLayout(
378 ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()),
379 ars));
380 } else if (instruction->opcode() == HloOpcode::kAllGather) {
381 // XLA:GPU can only support all-gathers where the gather dimension is the
382 // most major dimension in the layout.
383 auto ag = Cast<HloAllGatherInstruction>(instruction);
384 TF_RETURN_IF_ERROR(SetInstructionLayout(
385 ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()),
386 ag));
387 } else if (instruction->opcode() == HloOpcode::kAllToAll &&
388 instruction->shape().IsArray()) {
389 // XLA:GPU can only support all-to-all with split dimensions where the
390 // split dimension is the most major dimension in the layout.
391 auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
392 TF_RETURN_IF_ERROR(SetInstructionLayout(
393 ShapeUtil::MoveDimToMajor(all_to_all->shape(),
394 *all_to_all->split_dimension()),
395 all_to_all));
396 }
397 }
398 return OkStatus();
399 }
400
SetDotOperandLayout(const HloInstruction * instruction,int64_t operand,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> row_dims,absl::Span<const int64_t> col_dims)401 Status GpuLayoutAssignment::SetDotOperandLayout(
402 const HloInstruction* instruction, int64_t operand,
403 absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
404 absl::Span<const int64_t> col_dims) {
405 Shape shape = instruction->operand(operand)->shape();
406
407 // First, try to use the existing layout, if present.
408 if (shape.has_layout() &&
409 MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
410 // Re-set the operand layout, so it becomes mandatory.
411 return SetOperandLayout(shape, instruction, operand);
412
413 // Next, try the default layout (for the sake of everybody's sanity).
414 LayoutUtil::SetToDefaultLayout(&shape);
415 if (MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
416 return SetOperandLayout(shape, instruction, operand);
417
418 // Otherwise, fallback to forcing a (batch, rows, cols) layout.
419 return SetOperandBatchRowsColsLayout(instruction, operand, batch_dims,
420 row_dims, col_dims);
421 }
422
SetOperandBatchRowsColsLayout(const HloInstruction * instruction,int64_t operand,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> row_dims,absl::Span<const int64_t> col_dims)423 Status GpuLayoutAssignment::SetOperandBatchRowsColsLayout(
424 const HloInstruction* instruction, int64_t operand,
425 absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
426 absl::Span<const int64_t> col_dims) {
427 std::vector<int64_t> major_to_minor;
428 major_to_minor.reserve(batch_dims.size() + row_dims.size() + col_dims.size());
429 major_to_minor.insert(major_to_minor.end(), batch_dims.begin(),
430 batch_dims.end());
431 major_to_minor.insert(major_to_minor.end(), row_dims.begin(), row_dims.end());
432 major_to_minor.insert(major_to_minor.end(), col_dims.begin(), col_dims.end());
433
434 Shape shape = instruction->operand(operand)->shape();
435 *shape.mutable_layout() =
436 LayoutUtil::MakeLayoutFromMajorToMinor(major_to_minor);
437 return SetOperandLayout(shape, instruction, operand);
438 }
439
SetDotLayout(const HloInstruction * instruction,LayoutConstraints * constraints)440 Status GpuLayoutAssignment::SetDotLayout(const HloInstruction* instruction,
441 LayoutConstraints* constraints) {
442 // If a user has requested a layout that we can support, use that.
443 for (const HloInstruction* user : instruction->users()) {
444 for (int64_t i = 0; i < user->operand_count(); ++i) {
445 if (user->operand(i) != instruction) {
446 continue;
447 }
448
449 const ShapeLayout* constraint = constraints->OperandLayout(user, i);
450 if ((constraint != nullptr) &&
451 DotCanSupportShapeWithLayout(instruction, constraint->shape())) {
452 return SetInstructionLayout(constraint->shape(), instruction);
453 }
454 }
455 }
456
457 // Otherwise, use the default layout.
458 return SetInstructionLayout(
459 LayoutUtil::GetWithDefaultLayout(instruction->shape()), instruction);
460 }
461
462 } // namespace gpu
463 } // namespace xla
464