xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/reduce_decomposer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 
2 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #include "tensorflow/compiler/xla/service/reduce_decomposer.h"
18 
19 #include <functional>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/status.h"
28 
29 namespace xla {
30 
31 namespace {
32 
33 // Enforces property that all inputs to variadic reduction have same layout.
34 class VariadicReductionLayoutEqualizer : public DfsHloRewriteVisitor {
35  public:
HandleReduce(HloInstruction * hlo)36   Status HandleReduce(HloInstruction* hlo) override {
37     auto reduce = Cast<HloReduceInstruction>(hlo);
38     std::vector<HloInstruction*> new_inputs;
39     bool changed = false;
40     for (HloInstruction* input : reduce->inputs()) {
41       auto first_input = reduce->inputs()[0];
42       auto first_input_s = first_input->shape();
43       auto input_s = input->shape();
44       if (first_input_s.layout() != input_s.layout()) {
45         Shape new_input_s = ShapeUtil::MakeShapeWithLayout(
46             input_s.element_type(), input_s.dimensions(),
47             first_input_s.layout().minor_to_major());
48         auto copy = MakeCopyHlo(input, new_input_s);
49         changed = true;
50         new_inputs.push_back(copy);
51       } else {
52         new_inputs.push_back(input);
53       }
54     }
55 
56     if (changed) {
57       TF_ASSIGN_OR_RETURN(
58           auto new_reduce,
59           MakeReduceHlo(new_inputs, reduce->init_values(), reduce->dimensions(),
60                         reduce->called_computations()[0]));
61       TF_RETURN_IF_ERROR(ReplaceInstruction(reduce, new_reduce));
62     }
63 
64     return Status::OK();
65   }
66 };
67 
68 class ReduceDecomposerVisitor : public DfsHloRewriteVisitor {
69  public:
ReduceDecomposerVisitor(HloPredicate custom_layout_allowed)70   explicit ReduceDecomposerVisitor(HloPredicate custom_layout_allowed)
71       : custom_layout_allowed_(std::move(custom_layout_allowed)) {}
72 
HandleReduce(HloInstruction * hlo)73   Status HandleReduce(HloInstruction* hlo) override {
74     auto reduce = Cast<HloReduceInstruction>(hlo);
75     auto shape = reduce->shape();
76     if (custom_layout_allowed_ && custom_layout_allowed_(reduce)) {
77       return OkStatus();
78     }
79 
80     std::vector<Shape> expected_shapes(reduce->input_count());
81     for (int i = 0; i < reduce->input_count(); i++) {
82       expected_shapes[i] = ExpectedOutputShape(reduce, i);
83       TF_RET_CHECK(reduce->inputs()[i]->shape().layout() ==
84                    reduce->inputs()[0]->shape().layout());
85     }
86 
87     std::vector<Shape> output_shapes;
88     if (shape.IsTuple()) {
89       for (int i = 0; i < shape.tuple_shapes_size(); i++) {
90         output_shapes.push_back(ShapeUtil::GetTupleElementShape(shape, i));
91         TF_RET_CHECK(output_shapes[i].layout() == output_shapes[0].layout());
92       }
93     } else {
94       output_shapes.push_back(shape);
95     }
96 
97     TF_RET_CHECK(!output_shapes.empty());
98     if (ShapeUtil::MakeMaybeTupleShape(expected_shapes) !=
99         ShapeUtil::MakeMaybeTupleShape(output_shapes)) {
100       TF_ASSIGN_OR_RETURN(auto r_prime,
101                           MakeReduceHlo(reduce->inputs(), reduce->init_values(),
102                                         reduce->dimensions(),
103                                         reduce->called_computations()[0]));
104       TF_RET_CHECK(r_prime->shape() ==
105                    ShapeUtil::MakeMaybeTupleShape(expected_shapes));
106 
107       if (!shape.IsTuple()) {
108         auto copy = MakeCopyHlo(r_prime, shape);
109         TF_RETURN_IF_ERROR(ReplaceInstruction(reduce, copy));
110         return OkStatus();
111       }
112 
113       std::vector<HloInstruction*> copies;
114       for (int i = 0; i < reduce->input_count(); i++) {
115         TF_ASSIGN_OR_RETURN(auto from, GetOutput(r_prime, i));
116         auto copy = MakeCopyHlo(from, output_shapes[i]);
117         copies.push_back(copy);
118       }
119       auto out = MaybeMakeTuple(copies);
120       TF_RETURN_IF_ERROR(ReplaceInstruction(reduce, out));
121     }
122     return OkStatus();
123   }
124 
125  private:
GetOutput(HloInstruction * instr,int idx)126   StatusOr<HloInstruction*> GetOutput(HloInstruction* instr, int idx) {
127     if (instr->shape().IsTuple()) {
128       return MakeGetTupleElementHlo(instr, idx);
129     } else {
130       TF_RET_CHECK(idx == 0);
131       return instr;
132     }
133   }
134 
ExpectedOutputShape(HloReduceInstruction * reduce,int input_idx)135   Shape ExpectedOutputShape(HloReduceInstruction* reduce, int input_idx) {
136     Shape reduce_shape = reduce->shape();
137     auto output_shape = reduce_shape.IsTuple()
138                             ? reduce_shape.tuple_shapes(input_idx)
139                             : reduce_shape;
140     auto* operand = reduce->inputs()[input_idx];
141     auto operand_shape = operand->shape();
142     return ShapeUtil::DeleteDimensions(reduce->dimensions(), operand_shape);
143   }
144 
145   HloPredicate custom_layout_allowed_;
146 };
147 
148 }  // namespace
149 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)150 StatusOr<bool> ReduceDecomposer::Run(
151     HloModule* module,
152     const absl::flat_hash_set<absl::string_view>& execution_threads) {
153   TF_ASSIGN_OR_RETURN(bool changed1,
154                       VariadicReductionLayoutEqualizer{}.RunOnModule(
155                           module, execution_threads));
156   TF_ASSIGN_OR_RETURN(
157       bool changed2,
158       ReduceDecomposerVisitor{custom_layout_allowed_}.RunOnModule(
159           module, execution_threads));
160   return changed1 || changed2;
161 }
162 
163 }  // namespace xla
164