1
2 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #include "tensorflow/compiler/xla/service/reduce_decomposer.h"
18
19 #include <functional>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/status.h"
28
29 namespace xla {
30
31 namespace {
32
33 // Enforces property that all inputs to variadic reduction have same layout.
34 class VariadicReductionLayoutEqualizer : public DfsHloRewriteVisitor {
35 public:
HandleReduce(HloInstruction * hlo)36 Status HandleReduce(HloInstruction* hlo) override {
37 auto reduce = Cast<HloReduceInstruction>(hlo);
38 std::vector<HloInstruction*> new_inputs;
39 bool changed = false;
40 for (HloInstruction* input : reduce->inputs()) {
41 auto first_input = reduce->inputs()[0];
42 auto first_input_s = first_input->shape();
43 auto input_s = input->shape();
44 if (first_input_s.layout() != input_s.layout()) {
45 Shape new_input_s = ShapeUtil::MakeShapeWithLayout(
46 input_s.element_type(), input_s.dimensions(),
47 first_input_s.layout().minor_to_major());
48 auto copy = MakeCopyHlo(input, new_input_s);
49 changed = true;
50 new_inputs.push_back(copy);
51 } else {
52 new_inputs.push_back(input);
53 }
54 }
55
56 if (changed) {
57 TF_ASSIGN_OR_RETURN(
58 auto new_reduce,
59 MakeReduceHlo(new_inputs, reduce->init_values(), reduce->dimensions(),
60 reduce->called_computations()[0]));
61 TF_RETURN_IF_ERROR(ReplaceInstruction(reduce, new_reduce));
62 }
63
64 return Status::OK();
65 }
66 };
67
68 class ReduceDecomposerVisitor : public DfsHloRewriteVisitor {
69 public:
ReduceDecomposerVisitor(HloPredicate custom_layout_allowed)70 explicit ReduceDecomposerVisitor(HloPredicate custom_layout_allowed)
71 : custom_layout_allowed_(std::move(custom_layout_allowed)) {}
72
HandleReduce(HloInstruction * hlo)73 Status HandleReduce(HloInstruction* hlo) override {
74 auto reduce = Cast<HloReduceInstruction>(hlo);
75 auto shape = reduce->shape();
76 if (custom_layout_allowed_ && custom_layout_allowed_(reduce)) {
77 return OkStatus();
78 }
79
80 std::vector<Shape> expected_shapes(reduce->input_count());
81 for (int i = 0; i < reduce->input_count(); i++) {
82 expected_shapes[i] = ExpectedOutputShape(reduce, i);
83 TF_RET_CHECK(reduce->inputs()[i]->shape().layout() ==
84 reduce->inputs()[0]->shape().layout());
85 }
86
87 std::vector<Shape> output_shapes;
88 if (shape.IsTuple()) {
89 for (int i = 0; i < shape.tuple_shapes_size(); i++) {
90 output_shapes.push_back(ShapeUtil::GetTupleElementShape(shape, i));
91 TF_RET_CHECK(output_shapes[i].layout() == output_shapes[0].layout());
92 }
93 } else {
94 output_shapes.push_back(shape);
95 }
96
97 TF_RET_CHECK(!output_shapes.empty());
98 if (ShapeUtil::MakeMaybeTupleShape(expected_shapes) !=
99 ShapeUtil::MakeMaybeTupleShape(output_shapes)) {
100 TF_ASSIGN_OR_RETURN(auto r_prime,
101 MakeReduceHlo(reduce->inputs(), reduce->init_values(),
102 reduce->dimensions(),
103 reduce->called_computations()[0]));
104 TF_RET_CHECK(r_prime->shape() ==
105 ShapeUtil::MakeMaybeTupleShape(expected_shapes));
106
107 if (!shape.IsTuple()) {
108 auto copy = MakeCopyHlo(r_prime, shape);
109 TF_RETURN_IF_ERROR(ReplaceInstruction(reduce, copy));
110 return OkStatus();
111 }
112
113 std::vector<HloInstruction*> copies;
114 for (int i = 0; i < reduce->input_count(); i++) {
115 TF_ASSIGN_OR_RETURN(auto from, GetOutput(r_prime, i));
116 auto copy = MakeCopyHlo(from, output_shapes[i]);
117 copies.push_back(copy);
118 }
119 auto out = MaybeMakeTuple(copies);
120 TF_RETURN_IF_ERROR(ReplaceInstruction(reduce, out));
121 }
122 return OkStatus();
123 }
124
125 private:
GetOutput(HloInstruction * instr,int idx)126 StatusOr<HloInstruction*> GetOutput(HloInstruction* instr, int idx) {
127 if (instr->shape().IsTuple()) {
128 return MakeGetTupleElementHlo(instr, idx);
129 } else {
130 TF_RET_CHECK(idx == 0);
131 return instr;
132 }
133 }
134
ExpectedOutputShape(HloReduceInstruction * reduce,int input_idx)135 Shape ExpectedOutputShape(HloReduceInstruction* reduce, int input_idx) {
136 Shape reduce_shape = reduce->shape();
137 auto output_shape = reduce_shape.IsTuple()
138 ? reduce_shape.tuple_shapes(input_idx)
139 : reduce_shape;
140 auto* operand = reduce->inputs()[input_idx];
141 auto operand_shape = operand->shape();
142 return ShapeUtil::DeleteDimensions(reduce->dimensions(), operand_shape);
143 }
144
145 HloPredicate custom_layout_allowed_;
146 };
147
148 } // namespace
149
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)150 StatusOr<bool> ReduceDecomposer::Run(
151 HloModule* module,
152 const absl::flat_hash_set<absl::string_view>& execution_threads) {
153 TF_ASSIGN_OR_RETURN(bool changed1,
154 VariadicReductionLayoutEqualizer{}.RunOnModule(
155 module, execution_threads));
156 TF_ASSIGN_OR_RETURN(
157 bool changed2,
158 ReduceDecomposerVisitor{custom_layout_allowed_}.RunOnModule(
159 module, execution_threads));
160 return changed1 || changed2;
161 }
162
163 } // namespace xla
164