xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/layout_normalization.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/layout_normalization.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "tensorflow/compiler/xla/permutation_util.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/shape.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 
34 namespace xla {
35 namespace {
36 
37 // Layout normalization visitor. Aims to achieve the global postcondition that
38 // every layout is strictly descending (the layout permutation is effectively
39 // applied to the shape itself).
40 //
41 // Local precondition for every call:
42 //    -> Input is a bitcast from a normalized layout.
43 //
44 // Local postcondition:
45 //    -> Input and output of a processed operation have descending layout*
46 //
47 // *: For current fusion limitations this is currently not applicable to
48 // unnested reductions only.
49 class LayoutNormalizationVisitor : public DfsHloRewriteVisitor {
50  public:
51   // Default action: ensure local postcondition that any input is always a
52   // bitcast from canonical layout for any rewrites of the HLO users.
53   //
54   // Bitcast to descending layout and then bitcast back to make sure that shapes
55   // match.
DefaultAction(HloInstruction * hlo)56   Status DefaultAction(HloInstruction* hlo) override {
57     if (!hlo->user_count()) {
58       // The local postcondition does not have to apply to the case when there
59       // are no users.
60       return OkStatus();
61     }
62     auto users = hlo->users();
63     auto shape = hlo->shape();
64     if (shape.IsTuple() || shape.IsToken()) {
65       // GTEs will be transformed individually, tokens should be skipped.
66       return OkStatus();
67     }
68 
69     auto normalized_shape = Normalize(shape);
70     auto bc_to_normalized = MakeBitcastHlo(hlo, normalized_shape);
71     auto bc_to_orig = MakeBitcastHlo(bc_to_normalized, shape);
72     TF_RETURN_IF_ERROR(hlo->ReplaceUsesWith(users, bc_to_orig));
73     MarkAsChanged();
74     return OkStatus();
75   }
76 
77   // Converts concatenation to normalized layout.
78   //
79   // With respect to layouts, concatenations are simple, as they are
80   // layout-preserving. However, there are some complications with respect to
81   // degenerate dimensions: since our normalized form drops degenerate
82   // dimensions, that might make the concatenation impossible, as the
83   // corresponding concatenated dimension might not exist in the normalized
84   // form.
85   //
86   // So we drop all degenerate dimensions EXCEPT for the one being concatenated.
HandleConcatenate(HloInstruction * hlo)87   Status HandleConcatenate(HloInstruction* hlo) override {
88     auto s = hlo->shape();
89     auto orig_concat_dim = hlo->dimensions(0);
90 
91     std::vector<HloInstruction*> normalized_inputs;
92     for (HloInstruction* operand : hlo->mutable_operands()) {
93       TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand));
94       auto normalized_input_s = normalized_input->shape();
95       auto operand_s = operand->shape();
96 
97       // Drop all degenerate dimensions, unless it is being concatenated.
98       auto operand_s_filtered = ShapeUtil::FilterDimensions(
99           [&](int dim) {
100             return operand_s.dimensions(dim) != 1 || dim == orig_concat_dim;
101           },
102           operand_s);
103 
104       auto operand_s_normalized =
105           ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
106               operand_s_filtered);
107       auto new_operand =
108           operand_s_normalized == normalized_input_s
109               ? normalized_input
110               : MakeBitcastHlo(normalized_input, operand_s_normalized);
111       normalized_inputs.push_back(new_operand);
112     }
113 
114     auto out_shape_degen_dropped = ShapeUtil::FilterDimensions(
115         [&](int dim) {
116           return s.dimensions(dim) != 1 || dim == orig_concat_dim;
117         },
118         s);
119     auto normalized_w_degen =
120         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(s);
121     auto normalized_shape =
122         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
123             out_shape_degen_dropped);
124 
125     auto l = ToTransposeDimensions(s.layout());
126     int64_t normalized_concat_dim = FindIndex(l, orig_concat_dim);
127     auto degen_delta = absl::c_count_if(
128         normalized_w_degen.dimensions().subspan(0, normalized_concat_dim),
129         [&](int dim) { return dim == 1; });
130     auto normalized_concat = hlo->AddInstruction(
131         HloInstruction::CreateConcatenate(normalized_shape, normalized_inputs,
132                                           normalized_concat_dim - degen_delta));
133     auto bc_to_orig = MakeBitcastHlo(normalized_concat, hlo->shape());
134     TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
135     return OkStatus();
136   }
137 
138   // Converts broadcast input and output to normalized layout.
139   //
140   // Converts:
141   //
142   //  A{I} -> bitcast{L} -> broadcast[S]{L'}
143   //
144   // Into:
145   //
146   //  A{I} -> broadcast[S']{I} -> bitcast[S]{L'}
HandleBroadcast(HloInstruction * hlo)147   Status HandleBroadcast(HloInstruction* hlo) override {
148     VLOG(3) << "Input broadcast: " << hlo->ToString();
149     auto s = hlo->shape();
150     auto operand = hlo->mutable_operand(0);
151     TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand));
152     auto normalized_shape = Normalize(s);
153     auto orig_br_dimensions =
154         NoDegenerateDims(hlo->dimensions(), operand->shape(), s);
155     auto layout_as_permutation = ToTransposeDimensions(
156         ShapeUtil::DropDegenerateDimensions(operand->shape()).layout());
157     auto orig_output_layout_as_permutation =
158         ToTransposeDimensions(ShapeUtil::DropDegenerateDimensions(s).layout());
159     std::vector<int64_t> br_dimensions;
160     if (!hlo->dimensions().empty()) {
161       br_dimensions = Permute(orig_br_dimensions, layout_as_permutation);
162     }
163     for (int64_t& d : br_dimensions) {
164       d = FindIndex(orig_output_layout_as_permutation, d);
165     }
166     absl::c_sort(br_dimensions);
167     auto normalized_broadcast =
168         MakeBroadcastHlo(normalized_input, br_dimensions, normalized_shape);
169     VLOG(3) << "Generated broadcast: " << normalized_broadcast->ToString();
170     auto bc_to_orig = MakeBitcastHlo(normalized_broadcast, s);
171     TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
172     return OkStatus();
173   }
174 
175   // Pushes down the bitcast across the unary.
176   // That is, converts:
177   //
178   //    H_0{I} -> B{L} -> U{L}
179   //
180   // into
181   //
182   //    H_0{I} -> U{I} -> B{L}
183   //
184   // where {I} denotes default layout.
HandleElementwiseUnary(HloInstruction * hlo)185   Status HandleElementwiseUnary(HloInstruction* hlo) override {
186     auto s = hlo->shape();
187     auto operand = hlo->mutable_operand(0);
188     auto operand_shape = operand->shape();
189 
190     // Precondition: elementwise unary leaves layout intact.
191     TF_RET_CHECK(s.layout() == operand_shape.layout())
192         << "Unexpected non-layout preserving elementwise unary: "
193         << hlo->ToString();
194     TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand));
195 
196     PrimitiveType to_element_type = s.element_type();
197     HloInstruction* new_unary;
198     if (hlo->opcode() == HloOpcode::kConvert) {
199       new_unary = MakeConvertToHlo(normalized_input, to_element_type);
200     } else if (hlo->opcode() == HloOpcode::kReducePrecision) {
201       new_unary = MakeReducePrecisionHlo(normalized_input, hlo->exponent_bits(),
202                                          hlo->mantissa_bits());
203     } else if (hlo->opcode() == HloOpcode::kBitcastConvert) {
204       new_unary = MakeBitcastConvertToHlo(normalized_input, to_element_type);
205     } else {
206       TF_ASSIGN_OR_RETURN(new_unary,
207                           MakeUnaryHlo(hlo->opcode(), normalized_input));
208     }
209     auto bc_to_orig = MakeBitcastHlo(new_unary, s);
210     TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
211     return OkStatus();
212   }
213 
214   // Pushes down the bitcast across the binary. Converts:
215   //
216   //  A1{I} -> bitcast{L}
217   //            \
218   //            B{L}
219   //            /
220   //  A2{I} -> bitcast{L}
221   //
222   // Into:
223   //
224   //  A1{I}
225   //        \
226   //         B{I} - bitcast{L}
227   //        /
228   //  A2{I}
HandleElementwiseBinary(HloInstruction * hlo)229   Status HandleElementwiseBinary(HloInstruction* hlo) override {
230     auto s = hlo->shape();
231     auto a = hlo->mutable_operand(0);
232     auto b = hlo->mutable_operand(1);
233     TF_RET_CHECK(a->shape().layout() == s.layout());
234     TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(a));
235     TF_ASSIGN_OR_RETURN(auto b0, GetNormalizedInput(b));
236 
237     HloInstruction* new_binary;
238     if (hlo->opcode() == HloOpcode::kCompare) {
239       TF_ASSIGN_OR_RETURN(new_binary,
240                           MakeCompareHlo(hlo->comparison_direction(), a0, b0));
241     } else {
242       TF_ASSIGN_OR_RETURN(new_binary, MakeBinaryHlo(hlo->opcode(), a0, b0));
243     }
244     auto bc_to_orig = MakeBitcastHlo(new_binary, s);
245     TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
246     return OkStatus();
247   }
248 
249   // The ReshapeDecomposer already gives us a precondition that a reshape is
250   // bitcast. Converts:
251   //
252   // A{I} -> bitcast [S0]{L1} -> R [S]{L2}
253   //
254   // Into:
255   //
256   // A{I} -> R [S']{I} -> bitcast[S]{L2}
257   //
HandleReshape(HloInstruction * hlo)258   Status HandleReshape(HloInstruction* hlo) override {
259     auto s = hlo->shape();
260     auto operand = hlo->mutable_operand(0);
261     TF_RET_CHECK(ShapeUtil::ReshapeIsBitcast(s, operand->shape()));
262     TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
263     auto normalized_reshape_s = Normalize(s);
264     TF_ASSIGN_OR_RETURN(auto new_reshape,
265                         MakeReshapeHlo(normalized_reshape_s, a0));
266     auto bc_to_orig = MakeBitcastHlo(new_reshape, s);
267     TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
268     return OkStatus();
269   }
270 
271   // For bitcasting transposes, converts:
272   //
273   // A{I} -> bitcast[S]{L} -> transpose{L2}
274   //
275   // Into:
276   //
277   // A{I} -> bitcast{L2}
278   //
279   // For non-bitcasting ones, converts:
280   //
281   // A{I} -> bitcast[S0]{L} -> transpose[S]{L2}
282   //
283   // Into:
284   //
285   // A{I} -> transpose[S']{I} -> bitcast{L2}
286   //
287   // Where S' is the normalization of [S]{L2}, and `dimensions` attribute is
288   //
289   // The `dimensions` of the new transposition is given by:
290   //
291   //  L^-1 o `dim_0` o L2
292   //
293   // where dim_0 is dimensions of the original transposition, and `o` denotes
294   // permutation composition.
HandleTranspose(HloInstruction * hlo)295   Status HandleTranspose(HloInstruction* hlo) override {
296     auto s = hlo->shape();
297     auto operand = hlo->mutable_operand(0);
298     auto operand_s = operand->shape();
299     TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
300     auto normalized_shape = Normalize(s);
301     VLOG(3) << "Input transpose: " << hlo->ToString();
302 
303     if (!ShapeUtil::TransposeIsBitcast(s, operand_s, hlo->dimensions())) {
304       auto l0_perm = InversePermutation(ToTransposeDimensions(
305           ShapeUtil::DropDegenerateDimensions(operand_s).layout()));
306       auto l_perm = ToTransposeDimensions(
307           ShapeUtil::DropDegenerateDimensions(s).layout());
308 
309       auto dims = NoDegenerateDims(hlo->dimensions(), s, operand_s);
310       auto t = ComposePermutations(l0_perm, dims);
311       auto dimensions = ComposePermutations(t, l_perm);
312       auto normalized_transpose = hlo->AddInstruction(
313           HloInstruction::CreateTranspose(normalized_shape, a0, dimensions));
314       VLOG(3) << "Generated normalized physical transpose: "
315               << normalized_transpose->ToString();
316       auto bc_to_orig = MakeBitcastHlo(normalized_transpose, s);
317       TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
318     } else {
319       auto bc_to_orig = MakeBitcastHlo(a0, s);
320       TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
321     }
322     return OkStatus();
323   }
324 
325   // Converts a purely physical copy into a physical+logical transposition.
326   //
327   // Converts:
328   //
329   //  A{I} -> bitcast{L} -> copy[S]{L'}
330   //
331   // Into:
332   //
333   //  A{I} -> transpose[S']{I} -> bitcast[S]{L'}
334   //
335   // Where S' is normalization of [S]{L'}, and transposition dimensions are
336   // given by L'.
HandleCopy(HloInstruction * hlo)337   Status HandleCopy(HloInstruction* hlo) override {
338     VLOG(3) << "Processing copy: " << hlo->ToString();
339     auto s = hlo->shape();
340     auto operand = hlo->mutable_operand(0);
341     TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
342     auto s_normalized = Normalize(s);
343     auto l0_perm = InversePermutation(ToTransposeDimensions(
344         ShapeUtil::DropDegenerateDimensions(operand->shape()).layout()));
345     auto l_perm =
346         ToTransposeDimensions(ShapeUtil::DropDegenerateDimensions(s).layout());
347     auto dimensions = ComposePermutations(l0_perm, l_perm);
348     auto t = hlo->AddInstruction(
349         HloInstruction::CreateTranspose(s_normalized, a0, dimensions));
350     auto bc_to_orig = MakeBitcastHlo(t, s);
351     TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
352     return OkStatus();
353   }
354 
355   // The reverse HLO has a list of dimensions it reverses, which again becomes
356   // pretty interesting in the presence of degenerate dimensions: we need to
357   // drop those from the list.
358   //
359   // Luckily, reverse is layout-preserving.
HandleReverse(HloInstruction * hlo)360   Status HandleReverse(HloInstruction* hlo) override {
361     auto s = hlo->shape();
362     auto operand = hlo->mutable_operand(0);
363     TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
364     auto s_normalized = Normalize(s);
365     auto normalized_w_degen =
366         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(s);
367 
368     std::vector<int64_t> new_dimensions =
369         TransformDimensionsForLayoutPreservingHlo(hlo, normalized_w_degen,
370                                                   s_normalized);
371     auto normalized_reverse = hlo->AddInstruction(
372         HloInstruction::CreateReverse(a0->shape(), a0, new_dimensions));
373     auto bc_to_orig = MakeBitcastHlo(normalized_reverse, s);
374     TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
375     return OkStatus();
376   }
377 
378   // Padding is layout-preserving, so we only have to permute values inside the
379   // padding config.
380   //
381   // Like in broadcast, we have to be mindful that we can't remove degenerate
382   // dimensions if they are padded.
HandlePad(HloInstruction * hlo)383   Status HandlePad(HloInstruction* hlo) override {
384     auto s = hlo->shape();
385     auto operand = hlo->mutable_operand(0);
386     const auto& operand_s = operand->shape();
387     auto padded_by = hlo->mutable_operand(1);
388     TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand));
389     auto padded_config = hlo->padding_config();
390 
391     auto operand_s_filtered = ShapeUtil::FilterDimensions(
392         [&](int dim) {
393           return operand_s.dimensions(dim) != 1 ||
394                  !IsZeroPadding(hlo->padding_config().dimensions(dim));
395         },
396         operand->shape());
397     auto operand_s_normalized =
398         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
399             operand_s_filtered);
400     auto new_operand = operand_s_normalized == a0->shape()
401                            ? a0
402                            : MakeBitcastHlo(a0, operand_s_normalized);
403 
404     auto s_normalized = Normalize(s);
405     auto l = ToTransposeDimensions(s.layout());
406 
407     auto normalized_w_degen =
408         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(s);
409 
410     PaddingConfig new_padding;
411     new_padding.mutable_dimensions()->Reserve(s_normalized.dimensions_size());
412     for (int dim = 0; dim < s_normalized.dimensions_size(); dim++) {
413       new_padding.add_dimensions();
414     }
415 
416     for (int dim = 0; dim < s.dimensions_size(); dim++) {
417       if (s.dimensions(dim) == 1) {
418         continue;
419       }
420       int tr_dim = static_cast<int>(FindIndex(l, dim));
421       int out_dim = tr_dim - DegenDimsUpTo(normalized_w_degen, tr_dim);
422       *new_padding.mutable_dimensions(out_dim) = padded_config.dimensions(dim);
423     }
424 
425     auto padded_normalized = hlo->AddInstruction(HloInstruction::CreatePad(
426         s_normalized, new_operand, padded_by, new_padding));
427     auto bc_to_orig = MakeBitcastHlo(padded_normalized, s);
428     TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig));
429     return OkStatus();
430   }
431 
432  private:
IsZeroPadding(const PaddingConfig::PaddingConfigDimension & c)433   bool IsZeroPadding(const PaddingConfig::PaddingConfigDimension& c) {
434     return c.edge_padding_high() == 0 && c.edge_padding_low() == 0 &&
435            c.interior_padding() == 0;
436   }
437 
438   // Returns a list of dimensions associated with `hlo` after layout
439   // normalization.
TransformDimensionsForLayoutPreservingHlo(HloInstruction * hlo,const Shape & normalized_shape_w_degen,const Shape & normalized_out_shape)440   std::vector<int64_t> TransformDimensionsForLayoutPreservingHlo(
441       HloInstruction* hlo, const Shape& normalized_shape_w_degen,
442       const Shape& normalized_out_shape) {
443     bool skip_degen_dims = normalized_shape_w_degen != normalized_out_shape;
444     std::vector<int64_t> new_dimensions;
445     const auto& s = hlo->shape();
446     auto l = ToTransposeDimensions(s.layout());
447 
448     for (int64_t dim : hlo->dimensions()) {
449       if (s.dimensions(dim) == 1 && skip_degen_dims) {
450         continue;
451       }
452 
453       auto tr_dim = FindIndex(l, dim);
454       auto degen_delta =
455           skip_degen_dims ? DegenDimsUpTo(normalized_shape_w_degen, tr_dim) : 0;
456       new_dimensions.push_back(tr_dim - degen_delta);
457     }
458     absl::c_sort(new_dimensions);
459     return new_dimensions;
460   }
461 
462   // Returns number of degenerate dimensions in `shape` up to (exclusive) a
463   // `dim`.
DegenDimsUpTo(const Shape & shape,int dim)464   int DegenDimsUpTo(const Shape& shape, int dim) {
465     return absl::c_count_if(shape.dimensions().subspan(0, dim),
466                             [&](int d) { return d == 1; });
467   }
468 
469   // Drops items from `dimensions` corresponding to degenerate dimensions in
470   // `input_shape`.
NoDegenerateDims(absl::Span<int64_t const> dimensions,const Shape & input_shape,const Shape & output_shape)471   std::vector<int64_t> NoDegenerateDims(absl::Span<int64_t const> dimensions,
472                                         const Shape& input_shape,
473                                         const Shape& output_shape) {
474     std::vector<int64_t> out;
475     for (int i = 0; i < dimensions.size(); i++) {
476       if (input_shape.dimensions(i) != 1) {
477         int64_t val = dimensions[i];
478 
479         // Count all preceding 1-sized dimensions.
480         int64_t delta = 0;
481         for (int o = 0; o < val; o++) {
482           if (output_shape.dimensions(o) == static_cast<int64_t>(1)) {
483             delta++;
484           }
485         }
486 
487         out.push_back(val - delta);
488       }
489     }
490     return out;
491   }
492 
493   // Converts a layout to a dimensions transposition necessary to get to that
494   // layout from identity.
ToTransposeDimensions(const Layout & l)495   std::vector<int64_t> ToTransposeDimensions(const Layout& l) {
496     std::vector<int64_t> out(l.minor_to_major().begin(),
497                              l.minor_to_major().end());
498     absl::c_reverse(out);
499     return out;
500   }
501 
502   // Due to Local Precondition we have, the input to all processed ops should
503   // be HLO in descending layout piped through bitcast.
GetNormalizedInput(HloInstruction * hlo)504   StatusOr<HloInstruction*> GetNormalizedInput(HloInstruction* hlo) {
505     TF_RET_CHECK(hlo->opcode() == HloOpcode::kBitcast)
506         << "Unexpected HLO input: " << hlo->ToString();
507     auto input = hlo->mutable_operand(0);
508     auto input_shape = input->shape();
509     TF_RET_CHECK(input_shape.layout() ==
510                  LayoutUtil::GetDefaultLayoutForShape(input_shape));
511     return input;
512   }
513 
514   // Forces the layout to be descending and removes degenerate dimensions
515   // without altering physical layout.
Normalize(const Shape & s)516   Shape Normalize(const Shape& s) {
517     return ShapeUtil::DropDegenerateDimensions(
518         ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(s));
519   }
520 };
521 
522 }  // end namespace
523 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)524 StatusOr<bool> LayoutNormalization::Run(
525     HloModule* module,
526     const absl::flat_hash_set<absl::string_view>& execution_threads) {
527   return LayoutNormalizationVisitor{}.RunOnModule(module, execution_threads);
528 }
529 
530 }  // end namespace xla
531