1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/tuple_util.h"
17 #include "absl/types/span.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19
20 namespace xla {
21
ExtractPrefix(HloInstruction * input_tuple,int64_t elements)22 /*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple,
23 int64_t elements) {
24 CHECK(input_tuple->shape().IsTuple());
25
26 HloComputation* computation = input_tuple->parent();
27 const Shape& input_shape = input_tuple->shape();
28
29 std::vector<HloInstruction*> tuple_elements;
30 tuple_elements.reserve(elements);
31 for (int i = 0; i < elements; i++) {
32 tuple_elements.push_back(
33 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
34 input_shape.tuple_shapes(i), input_tuple, i)));
35 }
36
37 return computation->AddInstruction(
38 HloInstruction::CreateTuple(tuple_elements));
39 }
40
AppendSuffix(HloInstruction * input_tuple,absl::Span<HloInstruction * const> trailing_values)41 /*static*/ HloInstruction* TupleUtil::AppendSuffix(
42 HloInstruction* input_tuple,
43 absl::Span<HloInstruction* const> trailing_values) {
44 CHECK(input_tuple->shape().IsTuple());
45
46 HloComputation* computation = input_tuple->parent();
47 const Shape& input_shape = input_tuple->shape();
48 std::vector<HloInstruction*> tuple_elements;
49 tuple_elements.reserve(input_shape.tuple_shapes_size());
50 for (int i = 0; i < input_shape.tuple_shapes_size(); i++) {
51 tuple_elements.push_back(
52 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
53 input_shape.tuple_shapes(i), input_tuple, i)));
54 }
55 tuple_elements.insert(tuple_elements.end(), trailing_values.begin(),
56 trailing_values.end());
57 return computation->AddInstruction(
58 HloInstruction::CreateTuple(tuple_elements));
59 }
60
ReplaceTupleWith(HloInstruction * new_instruction,HloInstruction * tuple,ShapeIndex shape_index,bool insert_bitcast_if_different_shape)61 /*static*/ StatusOr<HloInstruction*> TupleUtil::ReplaceTupleWith(
62 HloInstruction* new_instruction, HloInstruction* tuple,
63 ShapeIndex shape_index, bool insert_bitcast_if_different_shape) {
64 const Shape& tuple_shape = tuple->shape();
65 CHECK(tuple->shape().IsTuple())
66 << "ReplaceTupleWith was called for a non-tuple. Tuple = "
67 << tuple->ToString()
68 << ", new_instruction = " << new_instruction->ToString()
69 << ", shape_index = " << shape_index.ToString();
70 // Check if the new instruction is a get-tuple-element of the correct index of
71 // the tuple, and if so, simply return tuple.
72 const HloInstruction* instruction = new_instruction;
73 bool equivalent = true;
74 for (int i = shape_index.size() - 1; i >= 0; --i) {
75 int index = shape_index[i];
76 if (instruction->opcode() != HloOpcode::kGetTupleElement ||
77 instruction->tuple_index() != index) {
78 equivalent = false;
79 break;
80 }
81 instruction = instruction->operand(0);
82 }
83 if (equivalent && instruction == tuple) {
84 VLOG(4) << "Instruction " << new_instruction->ToShortString()
85 << " already exists at index " << shape_index.ToString() << " of "
86 << tuple->ToShortString();
87 return tuple;
88 }
89
90 HloComputation* computation = new_instruction->parent();
91 std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size());
92 CHECK_GE(tuple_shape.tuple_shapes_size(), shape_index[0]);
93 for (int i = 0; i < tuple_shape.tuple_shapes_size(); ++i) {
94 const Shape& subshape = tuple_shape.tuple_shapes(i);
95 // If tuple is a tuple instruction, we can get the tuple instruction's
96 // operand to construct the new tuple to improve compilation time
97 // performance.
98 auto get_operand = [&]() {
99 if (tuple->opcode() == HloOpcode::kTuple) {
100 return tuple->mutable_operand(i);
101 } else {
102 return computation->AddInstruction(
103 HloInstruction::CreateGetTupleElement(subshape, tuple, i));
104 }
105 };
106 if (i == shape_index[0]) {
107 // If the subshape is still a tuple, recurse and pass a new shape index
108 // for the one level deeper.
109 if (subshape.IsTuple()) {
110 TF_ASSIGN_OR_RETURN(tuple_args[i],
111 ReplaceTupleWith(new_instruction, get_operand(),
112 ShapeIndex(shape_index.begin() + 1,
113 shape_index.end())));
114 } else {
115 if (subshape != new_instruction->shape() &&
116 insert_bitcast_if_different_shape) {
117 VLOG(4) << "Old shape = " << subshape.ToString()
118 << ", new shape = " << new_instruction->shape().ToString()
119 << "; inserting a bitcast.";
120 new_instruction = computation->AddInstruction(
121 HloInstruction::CreateBitcast(subshape, new_instruction));
122 } else if (tuple->opcode() == HloOpcode::kTuple &&
123 tuple->operand(i) == new_instruction) {
124 // If the tuple element is the same as the new instruction, we
125 // actually don't have to create a new tuple, just return the original
126 // tuple.
127 VLOG(4) << "Tuple already contains the new instruction = "
128 << new_instruction->ToShortString()
129 << " tuple = " << tuple->ToShortString();
130 return tuple;
131 }
132 tuple_args[i] = new_instruction;
133 }
134 } else {
135 tuple_args[i] = get_operand();
136 }
137 }
138 if (shape_index[0] == tuple_shape.tuple_shapes_size()) {
139 // If shape_index[0] is equal to the tuple shape size, add the new
140 // instruction as an additional argument.
141 tuple_args.push_back(new_instruction);
142 }
143 return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args));
144 }
145
AddGetTupleElements(const HloPosition & position)146 /*static*/ HloInstruction* TupleUtil::AddGetTupleElements(
147 const HloPosition& position) {
148 HloInstruction* instruction = position.instruction;
149 HloComputation* computation = instruction->parent();
150
151 // If the instruction we're processing is a tuple, we (recursively) search or
152 // create kGetTupleElement instructions and copy that value.
153 for (int64_t index : position.index) {
154 // We first search if there already is a get-tuple-element with the correct
155 // index. If there is no such get-tuple-element, we create one.
156 auto gte_it = absl::c_find_if(
157 instruction->users(), [index](const HloInstruction* use) {
158 return use != use->parent()->root_instruction() &&
159 use->opcode() == HloOpcode::kGetTupleElement &&
160 use->tuple_index() == index;
161 });
162 if (gte_it != instruction->users().end()) {
163 instruction = *gte_it;
164 } else {
165 instruction =
166 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
167 instruction->shape().tuple_shapes(index), instruction, index));
168 }
169 }
170 return instruction;
171 }
172
173 } // namespace xla
174