xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
25 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
26 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
27 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/window_util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/errors.h"
37 
38 namespace xla {
39 namespace gpu {
40 
41 using se::dnn::DataLayout;
42 using se::dnn::FilterLayout;
43 
44 // Returns (input, filter, output) layouts.
45 static std::tuple<DataLayout, FilterLayout, DataLayout>
HeuristicLayoutAssignment(const HloInstruction * instr,se::StreamExecutor * stream_executor)46 HeuristicLayoutAssignment(const HloInstruction* instr,
47                           se::StreamExecutor* stream_executor) {
48   // DataLayout and FilterLayout uses weird enum names. Translations:
49   //   N <=> Batch or Output
50   //   C <=> Depth or Input
51   //   H <=> Y
52   //   W <=> X
53   //
54   // Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
55   //
56   // If you have trouble keeping these straight, consider that all that matters
57   // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
58 
59   constexpr auto kAllNCHW =
60       std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
61                       DataLayout::kBatchDepthYX);
62   // kBatchDepthYX4 has the same layout as kBatchDepthYX32; they're both VECT_C
63   // layouts as far as cudnn is concerned.
64   constexpr auto kAllNCHW_VECT_C =
65       std::make_tuple(DataLayout::kBatchDepthYX4, FilterLayout::kOutputInputYX4,
66                       DataLayout::kBatchDepthYX4);
67   constexpr auto kAllNHWC =
68       std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
69                       DataLayout::kBatchYXDepth);
70 
71   // Integer convolution must use NHWC or NCHW_VECT_C.
72   //
73   // TODO(jlebar): Do non-VECT_C int8_t convs still require NHWC with new
74   // versions of cudnn?
75   const ConvolutionDimensionNumbers& dnums =
76       instr->convolution_dimension_numbers();
77   Shape input_shape = instr->operand(0)->shape();
78   PrimitiveType input_ty = instr->operand(0)->shape().element_type();
79   if (primitive_util::IsIntegralType(input_ty)) {
80     if (input_ty == S8 && dnums.input_spatial_dimensions_size() == 2 &&
81         input_shape.dimensions_size() == 5) {
82       VLOG(2) << "Using NCHW_VECT_C for int8_t conv " << instr->ToString();
83       return kAllNCHW_VECT_C;
84     }
85     VLOG(2) << "Using NHWC for int8_t conv " << instr->ToString();
86     return kAllNHWC;
87   }
88 
89   const DebugOptions& debug_options =
90       instr->GetModule()->config().debug_options();
91 
92   if (debug_options.xla_gpu_force_conv_nchw()) {
93     VLOG(2) << "Overriding layout to NCHW for " << instr->ToString();
94     return kAllNCHW;
95   }
96 
97   if (debug_options.xla_gpu_force_conv_nhwc()) {
98     VLOG(2) << "Overriding layout to NHWC for " << instr->ToString();
99     return kAllNHWC;
100   }
101 
102   // If we're not Volta or not fp16, or not conv2D, the decision is easy: Use
103   // NCHW.
104   if (input_ty != F16 ||
105       !stream_executor->GetDeviceDescription()
106            .cuda_compute_capability()
107            .IsAtLeast(se::CudaComputeCapability::VOLTA) ||
108       instr->shape().tuple_shapes(0).dimensions_size() != 4) {
109     return kAllNCHW;
110   }
111 
112   VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
113 
114   // Empirically we've found with Volta and cudnn <= 7.3 that backward-input
115   // convs with stride are significantly faster with NCHW layouts.
116   //
117   // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW),
118   // which on paper gives good performance. However, there are two observations:
119   // * a mixed layout combination is more cuDNN-bug prone, based on empirical
120   //   evidence.
121   // * we've also observed that for mixed layouts, cuDNN transposes data back
122   //   and forth from a different layout combination. If we end up with
123   //   transposes anyway, we prefer to have them in XLA, as they can be fused.
124   if (auto* dnn = stream_executor->AsDnn()) {
125     auto version_status = dnn->GetVersion();
126     if (version_status.ok()) {
127       auto version = std::move(version_status).value();
128       if (std::make_tuple(version.major_version(), version.minor_version()) <=
129               std::make_tuple(7, 3) &&
130           instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
131           window_util::HasStride(instr->window())) {
132         return kAllNCHW;
133       }
134     }
135   }
136 
137   // For other Volta f16 convolutions, use NHWC.
138   return kAllNHWC;
139 }
140 
141 // Adds layout constraints on the cudnn custom-call instruction. The layout
142 // constraints are represented in terms of minor_to_major fields of both
143 // operands and the output shape. Depending on the underlying algorithm, one of
144 // { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
AddBackendConstraintsToDnnConvCustomCall(HloCustomCallInstruction * instr,LayoutConstraints * constraints)145 Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
146     HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
147   Shape lhs_shape = instr->operand(0)->shape();
148   Shape rhs_shape = instr->operand(1)->shape();
149   Shape result_shape = instr->shape().tuple_shapes(0);
150 
151   Shape* input_shape;
152   Shape* filter_shape;
153   Shape* output_shape;
154 
155   TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
156   switch (kind) {
157     case CudnnConvKind::kForward:
158     case CudnnConvKind::kForwardActivation:
159       input_shape = &lhs_shape;
160       filter_shape = &rhs_shape;
161       output_shape = &result_shape;
162       break;
163     case CudnnConvKind::kBackwardInput:
164       input_shape = &result_shape;
165       filter_shape = &rhs_shape;
166       output_shape = &lhs_shape;
167       break;
168     case CudnnConvKind::kBackwardFilter:
169       input_shape = &lhs_shape;
170       filter_shape = &result_shape;
171       output_shape = &rhs_shape;
172       break;
173   }
174 
175   {
176     DataLayout input;
177     FilterLayout filter;
178     DataLayout output;
179     std::tie(input, filter, output) =
180         HeuristicLayoutAssignment(instr, stream_executor_);
181 
182     TF_ASSIGN_OR_RETURN(
183         std::tie(*input_shape->mutable_layout(),
184                  *filter_shape->mutable_layout(),
185                  *output_shape->mutable_layout()),
186         StreamExecutorConvLayoutsToXlaLayouts(
187             instr->convolution_dimension_numbers(), input, filter, output));
188   }
189 
190   // The custom call returns a tuple of (actual_result, scratch_buffer);
191   // call_result_buf is the logical buffer for actual_result, the thing that
192   // contains the result of the conv call.
193   TF_ASSIGN_OR_RETURN(
194       const LogicalBuffer* call_result_buf,
195       points_to_analysis_->GetBufferDefinedAt(instr, /*index=*/{0}));
196 
197   // Set layouts of the instructions' shapes.
198   TF_RETURN_IF_ERROR(SetOperandLayout(lhs_shape, instr, 0));
199   TF_RETURN_IF_ERROR(SetOperandLayout(rhs_shape, instr, 1));
200   TF_RETURN_IF_ERROR(SetBufferLayout(result_shape.layout(), *call_result_buf));
201   // instr->operand(2), if exists, is the bias buffer. There is no need to
202   // assign layout to it, as it has only one dimension.
203 
204   // instr->operand(3), if exists, is the side input buffer.
205   if (instr->operand_count() == 4) {
206     if (kind != CudnnConvKind::kForwardActivation) {
207       return InternalError(
208           "Invalid convolution. Conv has a side input, but kind is not fused "
209           "conv forward: %s",
210           instr->ToString());
211     }
212     // The side input layout must match the output layout.
213     TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, 3));
214   }
215   return OkStatus();
216 }
217 
218 namespace {
219 
220 // Imposes the default layout with first two dimensions swapped on input
221 // `shape`.
SetFortranLayout(Shape * shape)222 void SetFortranLayout(Shape* shape) {
223   LayoutUtil::SetToDefaultLayout(shape);
224   int n = shape->mutable_layout()->minor_to_major_size();
225   CHECK_GE(n, 2);
226   std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0),
227             shape->mutable_layout()->mutable_minor_to_major()->at(1));
228 }
229 
DotCanSupportShapeWithLayout(const HloInstruction * dot,const Shape & shape)230 bool DotCanSupportShapeWithLayout(const HloInstruction* dot,
231                                   const Shape& shape) {
232   const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers();
233   // If we are able to construct a `MatrixLayout` then the dot can support
234   // this layout.
235   return MatrixLayout::For(shape, dot_dims.lhs_batch_dimensions().size(),
236                            dot_dims.lhs_contracting_dimensions().size(),
237                            dot_dims.rhs_batch_dimensions().size(),
238                            dot_dims.rhs_contracting_dimensions().size())
239       .ok();
240 }
241 
242 }  // namespace
243 
AddBackendConstraints(LayoutConstraints * constraints)244 Status GpuLayoutAssignment::AddBackendConstraints(
245     LayoutConstraints* constraints) {
246   // Add convolution constraints in reverse postorder that the earliest
247   // convolution layout propagates first. This reduces the likelihood of fusion
248   // nodes with copies.
249   auto post_order = constraints->computation()->MakeInstructionPostOrder();
250   for (auto iterator = post_order.rbegin(); iterator != post_order.rend();
251        ++iterator) {
252     HloInstruction* instruction = *iterator;
253     if (IsCustomCallToDnnConvolution(*instruction)) {
254       TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
255           Cast<HloCustomCallInstruction>(instruction), constraints));
256     }
257 
258     CHECK(!IsCublasGemm(*instruction))
259         << "Gemm rewriting should run after layout assignment";
260 
261     if (IsMatrixMultiplication(*instruction)) {
262       const Shape& output_shape = instruction->shape();
263       const Shape& lhs_shape = instruction->operand(0)->shape();
264       const Shape& rhs_shape = instruction->operand(1)->shape();
265       const DotDimensionNumbers& dot_dims =
266           instruction->dot_dimension_numbers();
267 
268       // Matmuls require the batch dimensions to be in consecutive physical
269       // dimensions and likewise for the contracting and non-contracting
270       // dimensions. Additionally, no batch dimension can be in the most
271       // minor physical dimension for inputs or the output.
272       absl::Span<const int64_t> lhs_batch_dims =
273           dot_dims.lhs_batch_dimensions();
274       absl::Span<const int64_t> lhs_col_dims =
275           dot_dims.lhs_contracting_dimensions();
276       TF_ASSIGN_OR_RETURN(
277           std::vector<int64_t> lhs_row_dims,
278           GetNonContractingDims(lhs_shape, lhs_batch_dims, lhs_col_dims));
279 
280       absl::Span<const int64_t> rhs_batch_dims =
281           dot_dims.rhs_batch_dimensions();
282       absl::Span<const int64_t> rhs_row_dims =
283           dot_dims.rhs_contracting_dimensions();
284       TF_ASSIGN_OR_RETURN(
285           std::vector<int64_t> rhs_col_dims,
286           GetNonContractingDims(rhs_shape, rhs_batch_dims, rhs_row_dims));
287 
288       // For unbatched S8xS8->S32 matrix multiplication enforce a TN layout,
289       // which will allow the NVidia GPUs to use TensorCores.
290       bool is_s8_to_s32 = (output_shape.element_type() == PrimitiveType::S32 &&
291                            lhs_shape.element_type() == PrimitiveType::S8 &&
292                            rhs_shape.element_type() == PrimitiveType::S8 &&
293                            output_shape.dimensions_size() == 2 &&
294                            lhs_shape.dimensions_size() == 2 &&
295                            rhs_shape.dimensions_size() == 2);
296 
297       if (is_s8_to_s32) {
298         TF_RETURN_IF_ERROR(SetOperandBatchRowsColsLayout(
299             instruction, 0, lhs_batch_dims, lhs_row_dims, lhs_col_dims));
300         TF_RETURN_IF_ERROR(SetOperandBatchRowsColsLayout(
301             instruction, 1, rhs_batch_dims, rhs_col_dims, rhs_row_dims));
302         TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
303       } else if (!lhs_batch_dims.empty()) {
304         TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 0, lhs_batch_dims,
305                                                lhs_row_dims, lhs_col_dims));
306         TF_RETURN_IF_ERROR(SetDotOperandLayout(instruction, 1, rhs_batch_dims,
307                                                rhs_row_dims, rhs_col_dims));
308         TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints));
309       }
310     } else if (instruction->opcode() == HloOpcode::kTranspose) {
311       const HloInstruction* operand = instruction->operand(0);
312       if ((operand->opcode() != HloOpcode::kDot) ||
313           (operand->user_count() > 1)) {
314         continue;
315       }
316 
317       // If possible, set layout of the dot operation such that the output of
318       // the transpose (as a bitcast) has the default layout.
319       Shape shape = operand->shape();
320       *shape.mutable_layout() =
321           LayoutUtil::MakeLayoutFromMajorToMinor(instruction->dimensions());
322 
323       if (DotCanSupportShapeWithLayout(operand, shape)) {
324         TF_RETURN_IF_ERROR(
325             SetOperandLayout(shape, instruction, /*operand_no=*/0));
326       }
327     } else if (instruction->opcode() == HloOpcode::kFft) {
328       // cuFFT requires a dim0 major layout.
329       Shape op0_shape = instruction->operand(0)->shape();
330       LayoutUtil::SetToDefaultLayout(&op0_shape);
331       Shape output_shape = instruction->shape();
332       LayoutUtil::SetToDefaultLayout(&output_shape);
333       TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
334       TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
335     } else if (instruction->opcode() == HloOpcode::kSort &&
336                instruction->operand(0)->shape().rank() > 1) {
337       // Make sure that all the operands and the output(s) have the same layout.
338       Shape keys_shape = instruction->operand(0)->shape();
339       Layout keys_layout =
340           LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank());
341       for (int64_t i = 0; i < instruction->operand_count(); ++i) {
342         Shape shape = instruction->operand(i)->shape();
343         *shape.mutable_layout() = keys_layout;
344         TF_RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i));
345         const LogicalBuffer* output_buffer;
346         if (instruction->shape().IsArray()) {
347           TF_ASSIGN_OR_RETURN(
348               output_buffer,
349               points_to_analysis_->GetBufferDefinedAt(instruction, {}));
350         } else {
351           TF_ASSIGN_OR_RETURN(
352               output_buffer,
353               points_to_analysis_->GetBufferDefinedAt(instruction, {i}));
354         }
355         TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer));
356       }
357     } else if (instruction->opcode() == HloOpcode::kTriangularSolve) {
358       // TODO(phawkins): Ideally we would relax this constraint. What we
359       // actually want is that:
360       // a) the batch dimensions are major, in no particular order.
361       // b) the two minor dimensions are in fortran (column-major) order,
362       // although for the 'a' argument we could potentially accept row-major
363       // order and fold the transpose into the operator.
364       Shape op0_shape = instruction->operand(0)->shape();
365       Shape op1_shape = instruction->operand(1)->shape();
366       Shape output_shape = instruction->shape();
367       SetFortranLayout(&op0_shape);
368       SetFortranLayout(&op1_shape);
369       SetFortranLayout(&output_shape);
370       TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0));
371       TF_RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1));
372       TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction));
373     } else if (instruction->opcode() == HloOpcode::kReduceScatter) {
374       // XLA:GPU can only support reduce-scatter where the scatter dimension
375       // is the most major dimension in the layout.
376       auto ars = Cast<HloReduceScatterInstruction>(instruction);
377       TF_RETURN_IF_ERROR(SetInstructionLayout(
378           ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()),
379           ars));
380     } else if (instruction->opcode() == HloOpcode::kAllGather) {
381       // XLA:GPU can only support all-gathers where the gather dimension is the
382       // most major dimension in the layout.
383       auto ag = Cast<HloAllGatherInstruction>(instruction);
384       TF_RETURN_IF_ERROR(SetInstructionLayout(
385           ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()),
386           ag));
387     } else if (instruction->opcode() == HloOpcode::kAllToAll &&
388                instruction->shape().IsArray()) {
389       // XLA:GPU can only support all-to-all with split dimensions where the
390       // split dimension is the most major dimension in the layout.
391       auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
392       TF_RETURN_IF_ERROR(SetInstructionLayout(
393           ShapeUtil::MoveDimToMajor(all_to_all->shape(),
394                                     *all_to_all->split_dimension()),
395           all_to_all));
396     }
397   }
398   return OkStatus();
399 }
400 
SetDotOperandLayout(const HloInstruction * instruction,int64_t operand,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> row_dims,absl::Span<const int64_t> col_dims)401 Status GpuLayoutAssignment::SetDotOperandLayout(
402     const HloInstruction* instruction, int64_t operand,
403     absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
404     absl::Span<const int64_t> col_dims) {
405   Shape shape = instruction->operand(operand)->shape();
406 
407   // First, try to use the existing layout, if present.
408   if (shape.has_layout() &&
409       MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
410     // Re-set the operand layout, so it becomes mandatory.
411     return SetOperandLayout(shape, instruction, operand);
412 
413   // Next, try the default layout (for the sake of everybody's sanity).
414   LayoutUtil::SetToDefaultLayout(&shape);
415   if (MatrixLayout::For(shape, batch_dims, row_dims, col_dims).ok())
416     return SetOperandLayout(shape, instruction, operand);
417 
418   // Otherwise, fallback to forcing a (batch, rows, cols) layout.
419   return SetOperandBatchRowsColsLayout(instruction, operand, batch_dims,
420                                        row_dims, col_dims);
421 }
422 
SetOperandBatchRowsColsLayout(const HloInstruction * instruction,int64_t operand,absl::Span<const int64_t> batch_dims,absl::Span<const int64_t> row_dims,absl::Span<const int64_t> col_dims)423 Status GpuLayoutAssignment::SetOperandBatchRowsColsLayout(
424     const HloInstruction* instruction, int64_t operand,
425     absl::Span<const int64_t> batch_dims, absl::Span<const int64_t> row_dims,
426     absl::Span<const int64_t> col_dims) {
427   std::vector<int64_t> major_to_minor;
428   major_to_minor.reserve(batch_dims.size() + row_dims.size() + col_dims.size());
429   major_to_minor.insert(major_to_minor.end(), batch_dims.begin(),
430                         batch_dims.end());
431   major_to_minor.insert(major_to_minor.end(), row_dims.begin(), row_dims.end());
432   major_to_minor.insert(major_to_minor.end(), col_dims.begin(), col_dims.end());
433 
434   Shape shape = instruction->operand(operand)->shape();
435   *shape.mutable_layout() =
436       LayoutUtil::MakeLayoutFromMajorToMinor(major_to_minor);
437   return SetOperandLayout(shape, instruction, operand);
438 }
439 
SetDotLayout(const HloInstruction * instruction,LayoutConstraints * constraints)440 Status GpuLayoutAssignment::SetDotLayout(const HloInstruction* instruction,
441                                          LayoutConstraints* constraints) {
442   // If a user has requested a layout that we can support, use that.
443   for (const HloInstruction* user : instruction->users()) {
444     for (int64_t i = 0; i < user->operand_count(); ++i) {
445       if (user->operand(i) != instruction) {
446         continue;
447       }
448 
449       const ShapeLayout* constraint = constraints->OperandLayout(user, i);
450       if ((constraint != nullptr) &&
451           DotCanSupportShapeWithLayout(instruction, constraint->shape())) {
452         return SetInstructionLayout(constraint->shape(), instruction);
453       }
454     }
455   }
456 
457   // Otherwise, use the default layout.
458   return SetInstructionLayout(
459       LayoutUtil::GetWithDefaultLayout(instruction->shape()), instruction);
460 }
461 
462 }  // namespace gpu
463 }  // namespace xla
464