xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/tuple_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_util.h"
17 #include "absl/types/span.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 
20 namespace xla {
21 
ExtractPrefix(HloInstruction * input_tuple,int64_t elements)22 /*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple,
23                                                     int64_t elements) {
24   CHECK(input_tuple->shape().IsTuple());
25 
26   HloComputation* computation = input_tuple->parent();
27   const Shape& input_shape = input_tuple->shape();
28 
29   std::vector<HloInstruction*> tuple_elements;
30   tuple_elements.reserve(elements);
31   for (int i = 0; i < elements; i++) {
32     tuple_elements.push_back(
33         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
34             input_shape.tuple_shapes(i), input_tuple, i)));
35   }
36 
37   return computation->AddInstruction(
38       HloInstruction::CreateTuple(tuple_elements));
39 }
40 
AppendSuffix(HloInstruction * input_tuple,absl::Span<HloInstruction * const> trailing_values)41 /*static*/ HloInstruction* TupleUtil::AppendSuffix(
42     HloInstruction* input_tuple,
43     absl::Span<HloInstruction* const> trailing_values) {
44   CHECK(input_tuple->shape().IsTuple());
45 
46   HloComputation* computation = input_tuple->parent();
47   const Shape& input_shape = input_tuple->shape();
48   std::vector<HloInstruction*> tuple_elements;
49   tuple_elements.reserve(input_shape.tuple_shapes_size());
50   for (int i = 0; i < input_shape.tuple_shapes_size(); i++) {
51     tuple_elements.push_back(
52         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
53             input_shape.tuple_shapes(i), input_tuple, i)));
54   }
55   tuple_elements.insert(tuple_elements.end(), trailing_values.begin(),
56                         trailing_values.end());
57   return computation->AddInstruction(
58       HloInstruction::CreateTuple(tuple_elements));
59 }
60 
ReplaceTupleWith(HloInstruction * new_instruction,HloInstruction * tuple,ShapeIndex shape_index,bool insert_bitcast_if_different_shape)61 /*static*/ StatusOr<HloInstruction*> TupleUtil::ReplaceTupleWith(
62     HloInstruction* new_instruction, HloInstruction* tuple,
63     ShapeIndex shape_index, bool insert_bitcast_if_different_shape) {
64   const Shape& tuple_shape = tuple->shape();
65   CHECK(tuple->shape().IsTuple())
66       << "ReplaceTupleWith was called for a non-tuple. Tuple = "
67       << tuple->ToString()
68       << ", new_instruction = " << new_instruction->ToString()
69       << ", shape_index = " << shape_index.ToString();
70   // Check if the new instruction is a get-tuple-element of the correct index of
71   // the tuple, and if so, simply return tuple.
72   const HloInstruction* instruction = new_instruction;
73   bool equivalent = true;
74   for (int i = shape_index.size() - 1; i >= 0; --i) {
75     int index = shape_index[i];
76     if (instruction->opcode() != HloOpcode::kGetTupleElement ||
77         instruction->tuple_index() != index) {
78       equivalent = false;
79       break;
80     }
81     instruction = instruction->operand(0);
82   }
83   if (equivalent && instruction == tuple) {
84     VLOG(4) << "Instruction " << new_instruction->ToShortString()
85             << " already exists at index " << shape_index.ToString() << " of "
86             << tuple->ToShortString();
87     return tuple;
88   }
89 
90   HloComputation* computation = new_instruction->parent();
91   std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size());
92   CHECK_GE(tuple_shape.tuple_shapes_size(), shape_index[0]);
93   for (int i = 0; i < tuple_shape.tuple_shapes_size(); ++i) {
94     const Shape& subshape = tuple_shape.tuple_shapes(i);
95     // If tuple is a tuple instruction, we can get the tuple instruction's
96     // operand to construct the new tuple to improve compilation time
97     // performance.
98     auto get_operand = [&]() {
99       if (tuple->opcode() == HloOpcode::kTuple) {
100         return tuple->mutable_operand(i);
101       } else {
102         return computation->AddInstruction(
103             HloInstruction::CreateGetTupleElement(subshape, tuple, i));
104       }
105     };
106     if (i == shape_index[0]) {
107       // If the subshape is still a tuple, recurse and pass a new shape index
108       // for the one level deeper.
109       if (subshape.IsTuple()) {
110         TF_ASSIGN_OR_RETURN(tuple_args[i],
111                             ReplaceTupleWith(new_instruction, get_operand(),
112                                              ShapeIndex(shape_index.begin() + 1,
113                                                         shape_index.end())));
114       } else {
115         if (subshape != new_instruction->shape() &&
116             insert_bitcast_if_different_shape) {
117           VLOG(4) << "Old shape = " << subshape.ToString()
118                   << ", new shape = " << new_instruction->shape().ToString()
119                   << "; inserting a bitcast.";
120           new_instruction = computation->AddInstruction(
121               HloInstruction::CreateBitcast(subshape, new_instruction));
122         } else if (tuple->opcode() == HloOpcode::kTuple &&
123                    tuple->operand(i) == new_instruction) {
124           // If the tuple element is the same as the new instruction, we
125           // actually don't have to create a new tuple, just return the original
126           // tuple.
127           VLOG(4) << "Tuple already contains the new instruction = "
128                   << new_instruction->ToShortString()
129                   << " tuple = " << tuple->ToShortString();
130           return tuple;
131         }
132         tuple_args[i] = new_instruction;
133       }
134     } else {
135       tuple_args[i] = get_operand();
136     }
137   }
138   if (shape_index[0] == tuple_shape.tuple_shapes_size()) {
139     // If shape_index[0] is equal to the tuple shape size, add the new
140     // instruction as an additional argument.
141     tuple_args.push_back(new_instruction);
142   }
143   return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args));
144 }
145 
AddGetTupleElements(const HloPosition & position)146 /*static*/ HloInstruction* TupleUtil::AddGetTupleElements(
147     const HloPosition& position) {
148   HloInstruction* instruction = position.instruction;
149   HloComputation* computation = instruction->parent();
150 
151   // If the instruction we're processing is a tuple, we (recursively) search or
152   // create kGetTupleElement instructions and copy that value.
153   for (int64_t index : position.index) {
154     // We first search if there already is a get-tuple-element with the correct
155     // index. If there is no such get-tuple-element, we create one.
156     auto gte_it = absl::c_find_if(
157         instruction->users(), [index](const HloInstruction* use) {
158           return use != use->parent()->root_instruction() &&
159                  use->opcode() == HloOpcode::kGetTupleElement &&
160                  use->tuple_index() == index;
161         });
162     if (gte_it != instruction->users().end()) {
163       instruction = *gte_it;
164     } else {
165       instruction =
166           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
167               instruction->shape().tuple_shapes(index), instruction, index));
168     }
169   }
170   return instruction;
171 }
172 
173 }  // namespace xla
174