1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_normalization.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "tensorflow/compiler/xla/permutation_util.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/shape.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33
34 namespace xla {
35 namespace {
36
37 // Layout normalization visitor. Aims to achieve the global postcondition that
38 // every layout is strictly descending (the layout permutation is effectively
39 // applied to the shape itself).
40 //
41 // Local precondition for every call:
42 // -> Input is a bitcast from a normalized layout.
43 //
44 // Local postcondition:
45 // -> Input and output of a processed operation have descending layout*
46 //
47 // *: For current fusion limitations this is currently not applicable to
48 // unnested reductions only.
49 class LayoutNormalizationVisitor : public DfsHloRewriteVisitor {
50 public:
51 // Default action: ensure local postcondition that any input is always a
52 // bitcast from canonical layout for any rewrites of the HLO users.
53 //
54 // Bitcast to descending layout and then bitcast back to make sure that shapes
55 // match.
DefaultAction(HloInstruction * hlo)56 Status DefaultAction(HloInstruction* hlo) override {
57 if (!hlo->user_count()) {
58 // The local postcondition does not have to apply to the case when there
59 // are no users.
60 return OkStatus();
61 }
62 auto users = hlo->users();
63 auto shape = hlo->shape();
64 if (shape.IsTuple() || shape.IsToken()) {
65 // GTEs will be transformed individually, tokens should be skipped.
66 return OkStatus();
67 }
68
69 auto normalized_shape = Normalize(shape);
70 auto bc_to_normalized = MakeBitcastHlo(hlo, normalized_shape);
71 auto bc_to_orig = MakeBitcastHlo(bc_to_normalized, shape);
72 TF_RETURN_IF_ERROR(hlo->ReplaceUsesWith(users, bc_to_orig));
73 MarkAsChanged();
74 return OkStatus();
75 }
76
77 // Converts concatenation to normalized layout.
78 //
79 // With respect to layouts, concatenations are simple, as they are
80 // layout-preserving. However, there are some complications with respect to
81 // degenerate dimensions: since our normalized form drops degenerate
82 // dimensions, that might make the concatenation impossible, as the
83 // corresponding concatenated dimension might not exist in the normalized
84 // form.
85 //
86 // So we drop all degenerate dimensions EXCEPT for the one being concatenated.
HandleConcatenate(HloInstruction * hlo)87 Status HandleConcatenate(HloInstruction* hlo) override {
88 auto s = hlo->shape();
89 auto orig_concat_dim = hlo->dimensions(0);
90
91 std::vector<HloInstruction*> normalized_inputs;
92 for (HloInstruction* operand : hlo->mutable_operands()) {
93 TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand));
94 auto normalized_input_s = normalized_input->shape();
95 auto operand_s = operand->shape();
96
97 // Drop all degenerate dimensions, unless it is being concatenated.
98 auto operand_s_filtered = ShapeUtil::FilterDimensions(
99 [&](int dim) {
100 return operand_s.dimensions(dim) != 1 || dim == orig_concat_dim;
101 },
102 operand_s);
103
104 auto operand_s_normalized =
105 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
106 operand_s_filtered);
107 auto new_operand =
108 operand_s_normalized == normalized_input_s
109 ? normalized_input
110 : MakeBitcastHlo(normalized_input, operand_s_normalized);
111 normalized_inputs.push_back(new_operand);
112 }
113
114 auto out_shape_degen_dropped = ShapeUtil::FilterDimensions(
115 [&](int dim) {
116 return s.dimensions(dim) != 1 || dim == orig_concat_dim;
117 },
118 s);
119 auto normalized_w_degen =
120 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(s);
121 auto normalized_shape =
122 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
123 out_shape_degen_dropped);
124
125 auto l = ToTransposeDimensions(s.layout());
126 int64_t normalized_concat_dim = FindIndex(l, orig_concat_dim);
127 auto degen_delta = absl::c_count_if(
128 normalized_w_degen.dimensions().subspan(0, normalized_concat_dim),
129 [&](int dim) { return dim == 1; });
130 auto normalized_concat = hlo->AddInstruction(
131 HloInstruction::CreateConcatenate(normalized_shape, normalized_inputs,
132 normalized_concat_dim - degen_delta));
133 auto bc_to_orig = MakeBitcastHlo(normalized_concat, hlo->shape());
134 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
135 return OkStatus();
136 }
137
138 // Converts broadcast input and output to normalized layout.
139 //
140 // Converts:
141 //
142 // A{I} -> bitcast{L} -> broadcast[S]{L'}
143 //
144 // Into:
145 //
146 // A{I} -> broadcast[S']{I} -> bitcast[S]{L'}
HandleBroadcast(HloInstruction * hlo)147 Status HandleBroadcast(HloInstruction* hlo) override {
148 VLOG(3) << "Input broadcast: " << hlo->ToString();
149 auto s = hlo->shape();
150 auto operand = hlo->mutable_operand(0);
151 TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand));
152 auto normalized_shape = Normalize(s);
153 auto orig_br_dimensions =
154 NoDegenerateDims(hlo->dimensions(), operand->shape(), s);
155 auto layout_as_permutation = ToTransposeDimensions(
156 ShapeUtil::DropDegenerateDimensions(operand->shape()).layout());
157 auto orig_output_layout_as_permutation =
158 ToTransposeDimensions(ShapeUtil::DropDegenerateDimensions(s).layout());
159 std::vector<int64_t> br_dimensions;
160 if (!hlo->dimensions().empty()) {
161 br_dimensions = Permute(orig_br_dimensions, layout_as_permutation);
162 }
163 for (int64_t& d : br_dimensions) {
164 d = FindIndex(orig_output_layout_as_permutation, d);
165 }
166 absl::c_sort(br_dimensions);
167 auto normalized_broadcast =
168 MakeBroadcastHlo(normalized_input, br_dimensions, normalized_shape);
169 VLOG(3) << "Generated broadcast: " << normalized_broadcast->ToString();
170 auto bc_to_orig = MakeBitcastHlo(normalized_broadcast, s);
171 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
172 return OkStatus();
173 }
174
175 // Pushes down the bitcast across the unary.
176 // That is, converts:
177 //
178 // H_0{I} -> B{L} -> U{L}
179 //
180 // into
181 //
182 // H_0{I} -> U{I} -> B{L}
183 //
184 // where {I} denotes default layout.
HandleElementwiseUnary(HloInstruction * hlo)185 Status HandleElementwiseUnary(HloInstruction* hlo) override {
186 auto s = hlo->shape();
187 auto operand = hlo->mutable_operand(0);
188 auto operand_shape = operand->shape();
189
190 // Precondition: elementwise unary leaves layout intact.
191 TF_RET_CHECK(s.layout() == operand_shape.layout())
192 << "Unexpected non-layout preserving elementwise unary: "
193 << hlo->ToString();
194 TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand));
195
196 PrimitiveType to_element_type = s.element_type();
197 HloInstruction* new_unary;
198 if (hlo->opcode() == HloOpcode::kConvert) {
199 new_unary = MakeConvertToHlo(normalized_input, to_element_type);
200 } else if (hlo->opcode() == HloOpcode::kReducePrecision) {
201 new_unary = MakeReducePrecisionHlo(normalized_input, hlo->exponent_bits(),
202 hlo->mantissa_bits());
203 } else if (hlo->opcode() == HloOpcode::kBitcastConvert) {
204 new_unary = MakeBitcastConvertToHlo(normalized_input, to_element_type);
205 } else {
206 TF_ASSIGN_OR_RETURN(new_unary,
207 MakeUnaryHlo(hlo->opcode(), normalized_input));
208 }
209 auto bc_to_orig = MakeBitcastHlo(new_unary, s);
210 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
211 return OkStatus();
212 }
213
214 // Pushes down the bitcast across the binary. Converts:
215 //
216 // A1{I} -> bitcast{L}
217 // \
218 // B{L}
219 // /
220 // A2{I} -> bitcast{L}
221 //
222 // Into:
223 //
224 // A1{I}
225 // \
226 // B{I} - bitcast{L}
227 // /
228 // A2{I}
HandleElementwiseBinary(HloInstruction * hlo)229 Status HandleElementwiseBinary(HloInstruction* hlo) override {
230 auto s = hlo->shape();
231 auto a = hlo->mutable_operand(0);
232 auto b = hlo->mutable_operand(1);
233 TF_RET_CHECK(a->shape().layout() == s.layout());
234 TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(a));
235 TF_ASSIGN_OR_RETURN(auto b0, GetNormalizedInput(b));
236
237 HloInstruction* new_binary;
238 if (hlo->opcode() == HloOpcode::kCompare) {
239 TF_ASSIGN_OR_RETURN(new_binary,
240 MakeCompareHlo(hlo->comparison_direction(), a0, b0));
241 } else {
242 TF_ASSIGN_OR_RETURN(new_binary, MakeBinaryHlo(hlo->opcode(), a0, b0));
243 }
244 auto bc_to_orig = MakeBitcastHlo(new_binary, s);
245 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
246 return OkStatus();
247 }
248
249 // The ReshapeDecomposer already gives us a precondition that a reshape is
250 // bitcast. Converts:
251 //
252 // A{I} -> bitcast [S0]{L1} -> R [S]{L2}
253 //
254 // Into:
255 //
256 // A{I} -> R [S']{I} -> bitcast[S]{L2}
257 //
HandleReshape(HloInstruction * hlo)258 Status HandleReshape(HloInstruction* hlo) override {
259 auto s = hlo->shape();
260 auto operand = hlo->mutable_operand(0);
261 TF_RET_CHECK(ShapeUtil::ReshapeIsBitcast(s, operand->shape()));
262 TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
263 auto normalized_reshape_s = Normalize(s);
264 TF_ASSIGN_OR_RETURN(auto new_reshape,
265 MakeReshapeHlo(normalized_reshape_s, a0));
266 auto bc_to_orig = MakeBitcastHlo(new_reshape, s);
267 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
268 return OkStatus();
269 }
270
271 // For bitcasting transposes, converts:
272 //
273 // A{I} -> bitcast[S]{L} -> transpose{L2}
274 //
275 // Into:
276 //
277 // A{I} -> bitcast{L2}
278 //
279 // For non-bitcasting ones, converts:
280 //
281 // A{I} -> bitcast[S0]{L} -> transpose[S]{L2}
282 //
283 // Into:
284 //
285 // A{I} -> transpose[S']{I} -> bitcast{L2}
286 //
287 // Where S' is the normalization of [S]{L2}, and `dimensions` attribute is
288 //
289 // The `dimensions` of the new transposition is given by:
290 //
291 // L^-1 o `dim_0` o L2
292 //
293 // where dim_0 is dimensions of the original transposition, and `o` denotes
294 // permutation composition.
HandleTranspose(HloInstruction * hlo)295 Status HandleTranspose(HloInstruction* hlo) override {
296 auto s = hlo->shape();
297 auto operand = hlo->mutable_operand(0);
298 auto operand_s = operand->shape();
299 TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
300 auto normalized_shape = Normalize(s);
301 VLOG(3) << "Input transpose: " << hlo->ToString();
302
303 if (!ShapeUtil::TransposeIsBitcast(s, operand_s, hlo->dimensions())) {
304 auto l0_perm = InversePermutation(ToTransposeDimensions(
305 ShapeUtil::DropDegenerateDimensions(operand_s).layout()));
306 auto l_perm = ToTransposeDimensions(
307 ShapeUtil::DropDegenerateDimensions(s).layout());
308
309 auto dims = NoDegenerateDims(hlo->dimensions(), s, operand_s);
310 auto t = ComposePermutations(l0_perm, dims);
311 auto dimensions = ComposePermutations(t, l_perm);
312 auto normalized_transpose = hlo->AddInstruction(
313 HloInstruction::CreateTranspose(normalized_shape, a0, dimensions));
314 VLOG(3) << "Generated normalized physical transpose: "
315 << normalized_transpose->ToString();
316 auto bc_to_orig = MakeBitcastHlo(normalized_transpose, s);
317 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
318 } else {
319 auto bc_to_orig = MakeBitcastHlo(a0, s);
320 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
321 }
322 return OkStatus();
323 }
324
325 // Converts a purely physical copy into a physical+logical transposition.
326 //
327 // Converts:
328 //
329 // A{I} -> bitcast{L} -> copy[S]{L'}
330 //
331 // Into:
332 //
333 // A{I} -> transpose[S']{I} -> bitcast[S]{L'}
334 //
335 // Where S' is normalization of [S]{L'}, and transposition dimensions are
336 // given by L'.
HandleCopy(HloInstruction * hlo)337 Status HandleCopy(HloInstruction* hlo) override {
338 VLOG(3) << "Processing copy: " << hlo->ToString();
339 auto s = hlo->shape();
340 auto operand = hlo->mutable_operand(0);
341 TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
342 auto s_normalized = Normalize(s);
343 auto l0_perm = InversePermutation(ToTransposeDimensions(
344 ShapeUtil::DropDegenerateDimensions(operand->shape()).layout()));
345 auto l_perm =
346 ToTransposeDimensions(ShapeUtil::DropDegenerateDimensions(s).layout());
347 auto dimensions = ComposePermutations(l0_perm, l_perm);
348 auto t = hlo->AddInstruction(
349 HloInstruction::CreateTranspose(s_normalized, a0, dimensions));
350 auto bc_to_orig = MakeBitcastHlo(t, s);
351 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
352 return OkStatus();
353 }
354
355 // The reverse HLO has a list of dimensions it reverses, which again becomes
356 // pretty interesting in the presence of degenerate dimensions: we need to
357 // drop those from the list.
358 //
359 // Luckily, reverse is layout-preserving.
HandleReverse(HloInstruction * hlo)360 Status HandleReverse(HloInstruction* hlo) override {
361 auto s = hlo->shape();
362 auto operand = hlo->mutable_operand(0);
363 TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
364 auto s_normalized = Normalize(s);
365 auto normalized_w_degen =
366 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(s);
367
368 std::vector<int64_t> new_dimensions =
369 TransformDimensionsForLayoutPreservingHlo(hlo, normalized_w_degen,
370 s_normalized);
371 auto normalized_reverse = hlo->AddInstruction(
372 HloInstruction::CreateReverse(a0->shape(), a0, new_dimensions));
373 auto bc_to_orig = MakeBitcastHlo(normalized_reverse, s);
374 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
375 return OkStatus();
376 }
377
378 // Padding is layout-preserving, so we only have to permute values inside the
379 // padding config.
380 //
381 // Like in broadcast, we have to be mindful that we can't remove degenerate
382 // dimensions if they are padded.
HandlePad(HloInstruction * hlo)383 Status HandlePad(HloInstruction* hlo) override {
384 auto s = hlo->shape();
385 auto operand = hlo->mutable_operand(0);
386 const auto& operand_s = operand->shape();
387 auto padded_by = hlo->mutable_operand(1);
388 TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
389 auto padded_config = hlo->padding_config();
390
391 auto operand_s_filtered = ShapeUtil::FilterDimensions(
392 [&](int dim) {
393 return operand_s.dimensions(dim) != 1 ||
394 !IsZeroPadding(hlo->padding_config().dimensions(dim));
395 },
396 operand->shape());
397 auto operand_s_normalized =
398 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
399 operand_s_filtered);
400 auto new_operand = operand_s_normalized == a0->shape()
401 ? a0
402 : MakeBitcastHlo(a0, operand_s_normalized);
403
404 auto s_normalized = Normalize(s);
405 auto l = ToTransposeDimensions(s.layout());
406
407 auto normalized_w_degen =
408 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(s);
409
410 PaddingConfig new_padding;
411 new_padding.mutable_dimensions()->Reserve(s_normalized.dimensions_size());
412 for (int dim = 0; dim < s_normalized.dimensions_size(); dim++) {
413 new_padding.add_dimensions();
414 }
415
416 for (int dim = 0; dim < s.dimensions_size(); dim++) {
417 if (s.dimensions(dim) == 1) {
418 continue;
419 }
420 int tr_dim = static_cast<int>(FindIndex(l, dim));
421 int out_dim = tr_dim - DegenDimsUpTo(normalized_w_degen, tr_dim);
422 *new_padding.mutable_dimensions(out_dim) = padded_config.dimensions(dim);
423 }
424
425 auto padded_normalized = hlo->AddInstruction(HloInstruction::CreatePad(
426 s_normalized, new_operand, padded_by, new_padding));
427 auto bc_to_orig = MakeBitcastHlo(padded_normalized, s);
428 TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
429 return OkStatus();
430 }
431
432 private:
IsZeroPadding(const PaddingConfig::PaddingConfigDimension & c)433 bool IsZeroPadding(const PaddingConfig::PaddingConfigDimension& c) {
434 return c.edge_padding_high() == 0 && c.edge_padding_low() == 0 &&
435 c.interior_padding() == 0;
436 }
437
438 // Returns a list of dimensions associated with `hlo` after layout
439 // normalization.
TransformDimensionsForLayoutPreservingHlo(HloInstruction * hlo,const Shape & normalized_shape_w_degen,const Shape & normalized_out_shape)440 std::vector<int64_t> TransformDimensionsForLayoutPreservingHlo(
441 HloInstruction* hlo, const Shape& normalized_shape_w_degen,
442 const Shape& normalized_out_shape) {
443 bool skip_degen_dims = normalized_shape_w_degen != normalized_out_shape;
444 std::vector<int64_t> new_dimensions;
445 const auto& s = hlo->shape();
446 auto l = ToTransposeDimensions(s.layout());
447
448 for (int64_t dim : hlo->dimensions()) {
449 if (s.dimensions(dim) == 1 && skip_degen_dims) {
450 continue;
451 }
452
453 auto tr_dim = FindIndex(l, dim);
454 auto degen_delta =
455 skip_degen_dims ? DegenDimsUpTo(normalized_shape_w_degen, tr_dim) : 0;
456 new_dimensions.push_back(tr_dim - degen_delta);
457 }
458 absl::c_sort(new_dimensions);
459 return new_dimensions;
460 }
461
462 // Returns number of degenerate dimensions in `shape` up to (exclusive) a
463 // `dim`.
DegenDimsUpTo(const Shape & shape,int dim)464 int DegenDimsUpTo(const Shape& shape, int dim) {
465 return absl::c_count_if(shape.dimensions().subspan(0, dim),
466 [&](int d) { return d == 1; });
467 }
468
469 // Drops items from `dimensions` corresponding to degenerate dimensions in
470 // `input_shape`.
NoDegenerateDims(absl::Span<int64_t const> dimensions,const Shape & input_shape,const Shape & output_shape)471 std::vector<int64_t> NoDegenerateDims(absl::Span<int64_t const> dimensions,
472 const Shape& input_shape,
473 const Shape& output_shape) {
474 std::vector<int64_t> out;
475 for (int i = 0; i < dimensions.size(); i++) {
476 if (input_shape.dimensions(i) != 1) {
477 int64_t val = dimensions[i];
478
479 // Count all preceding 1-sized dimensions.
480 int64_t delta = 0;
481 for (int o = 0; o < val; o++) {
482 if (output_shape.dimensions(o) == static_cast<int64_t>(1)) {
483 delta++;
484 }
485 }
486
487 out.push_back(val - delta);
488 }
489 }
490 return out;
491 }
492
493 // Converts a layout to a dimensions transposition necessary to get to that
494 // layout from identity.
ToTransposeDimensions(const Layout & l)495 std::vector<int64_t> ToTransposeDimensions(const Layout& l) {
496 std::vector<int64_t> out(l.minor_to_major().begin(),
497 l.minor_to_major().end());
498 absl::c_reverse(out);
499 return out;
500 }
501
502 // Due to Local Precondition we have, the input to all processed ops should
503 // be HLO in descending layout piped through bitcast.
GetNormalizedInput(HloInstruction * hlo)504 StatusOr<HloInstruction*> GetNormalizedInput(HloInstruction* hlo) {
505 TF_RET_CHECK(hlo->opcode() == HloOpcode::kBitcast)
506 << "Unexpected HLO input: " << hlo->ToString();
507 auto input = hlo->mutable_operand(0);
508 auto input_shape = input->shape();
509 TF_RET_CHECK(input_shape.layout() ==
510 LayoutUtil::GetDefaultLayoutForShape(input_shape));
511 return input;
512 }
513
514 // Forces the layout to be descending and removes degenerate dimensions
515 // without altering physical layout.
Normalize(const Shape & s)516 Shape Normalize(const Shape& s) {
517 return ShapeUtil::DropDegenerateDimensions(
518 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(s));
519 }
520 };
521
522 } // end namespace
523
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)524 StatusOr<bool> LayoutNormalization::Run(
525 HloModule* module,
526 const absl::flat_hash_set<absl::string_view>& execution_threads) {
527 return LayoutNormalizationVisitor{}.RunOnModule(module, execution_threads);
528 }
529
530 } // end namespace xla
531