1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
17
18 #include <algorithm>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/stream_executor/lib/statusor.h"
33
34 namespace xla {
35 namespace gpu {
36
37 class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor {
38 public:
HandleReduce(HloInstruction * hlo)39 Status HandleReduce(HloInstruction *hlo) override {
40 auto instr = Cast<HloReduceInstruction>(hlo);
41 absl::InlinedVector<HloInstruction *, 2> input_reshapes;
42 absl::InlinedVector<Shape, 2> canonical_reduce_shapes;
43
44 int idx = -1;
45 std::vector<int64_t> updated_reduced_dimensions;
46 for (HloInstruction *reduced_op : instr->inputs()) {
47 idx++;
48 const Shape &input_shape = reduced_op->shape();
49 const Shape &reduce_shape = instr->shape().IsTuple()
50 ? instr->shape().tuple_shapes(idx)
51 : instr->shape();
52
53 if (!ShapeUtil::HasDegenerateDimensions(reduced_op->shape())) {
54 return OkStatus();
55 }
56 Shape canonical_input_shape =
57 ShapeUtil::DropDegenerateDimensions(input_shape);
58
59 Shape canonical_reduce_shape =
60 ShapeUtil::DropDegenerateDimensions(reduce_shape);
61
62 auto reduced_dimensions = instr->dimensions();
63 int64_t shift = 0;
64
65 for (int dim = 0; dim < input_shape.rank(); dim++) {
66 if (input_shape.dimensions(dim) == 1) {
67 shift++;
68 } else {
69 if (absl::c_linear_search(reduced_dimensions, dim) && idx == 0) {
70 // Only populate on first iteration.
71 updated_reduced_dimensions.push_back(dim - shift);
72 }
73 }
74 }
75
76 if (updated_reduced_dimensions.empty()) {
77 std::unique_ptr<HloInstruction> reshape =
78 HloInstruction::CreateBitcast(reduce_shape, reduced_op);
79 return ReplaceWithNewInstruction(instr, std::move(reshape));
80 }
81
82 input_reshapes.push_back(instr->parent()->AddInstruction(
83 HloInstruction::CreateBitcast(canonical_input_shape, reduced_op)));
84 canonical_reduce_shapes.push_back(canonical_reduce_shape);
85 }
86
87 Shape canonical_reduce_shape =
88 ShapeUtil::MakeMaybeTupleShape(canonical_reduce_shapes);
89 const Shape &orig_reduce_shape = instr->shape();
90 std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
91 canonical_reduce_shape, input_reshapes, instr->init_values(),
92 updated_reduced_dimensions, instr->to_apply());
93
94 if (canonical_reduce_shape != instr->shape()) {
95 HloInstruction *wrapped_reduce =
96 instr->parent()->AddInstruction(std::move(new_reduce));
97 absl::InlinedVector<HloInstruction *, 2> out;
98 if (!canonical_reduce_shape.IsTuple()) {
99 new_reduce =
100 HloInstruction::CreateBitcast(orig_reduce_shape, wrapped_reduce);
101 } else {
102 for (int oidx = 0; oidx < instr->input_count(); oidx++) {
103 HloInstruction *gte = instr->parent()->AddInstruction(
104 HloInstruction::CreateGetTupleElement(wrapped_reduce, oidx));
105 out.push_back(
106 instr->parent()->AddInstruction(HloInstruction::CreateBitcast(
107 orig_reduce_shape.tuple_shapes(oidx), gte)));
108 }
109 new_reduce = HloInstruction::CreateTuple(out);
110 }
111 }
112
113 return ReplaceWithNewInstruction(instr, std::move(new_reduce));
114 }
115 };
116
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)117 StatusOr<bool> ReductionDegenerateDimRemover::Run(
118 HloModule *module,
119 const absl::flat_hash_set<absl::string_view> &execution_threads) {
120 TF_ASSIGN_OR_RETURN(bool changed,
121 ReductionDegenerateDimRemoverVisitor().RunOnModule(
122 module, execution_threads));
123 return changed;
124 }
125
126 } // namespace gpu
127 } // namespace xla
128