xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 
30 namespace xla {
31 namespace llvm_ir {
32 
LoopEmitter(const BodyEmitter & body_emitter,const Shape & shape,llvm::IRBuilder<> * b)33 LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
34                          llvm::IRBuilder<>* b)
35     : body_emitter_(body_emitter), shape_(shape), b_(b) {}
36 
LoopEmitter(const BodyEmitter & body_emitter,const Shape & shape,std::vector<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b)37 LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
38                          std::vector<llvm::Value*> dynamic_dims,
39                          llvm::IRBuilder<>* b)
40     : LoopEmitter::LoopEmitter(body_emitter, shape, b) {
41   CHECK_EQ(dynamic_dims.size(), shape_.dimensions_size());
42   dynamic_dims_ = std::move(dynamic_dims);
43 }
44 
LoopEmitter(const ElementGenerator & target_element_generator,const IrArray & target_array,llvm::IRBuilder<> * b)45 LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
46                          const IrArray& target_array, llvm::IRBuilder<>* b)
47     : body_emitter_(MakeBodyEmitter(target_element_generator, {target_array}, b,
48                                     /*is_tuple=*/false)),
49       shape_(target_array.GetShape()),
50       b_(b) {}
51 
LoopEmitter(const ElementGenerator & target_element_generator,absl::Span<const IrArray> target_arrays,llvm::IRBuilder<> * b)52 LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
53                          absl::Span<const IrArray> target_arrays,
54                          llvm::IRBuilder<>* b)
55     : body_emitter_(MakeBodyEmitter(target_element_generator, target_arrays, b,
56                                     /*is_tuple=*/true)),
57       shape_(target_arrays[0].GetShape()),
58       b_(b) {
59   // Sanity check: In multi-output fusion, all shapes produced must have the
60   // same dimensions.
61   for (const IrArray& array : target_arrays) {
62     CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape()))
63         << ": '" << shape_.ShortDebugString() << "' does not match '"
64         << array.GetShape().ShortDebugString() << "'";
65   }
66 }
67 
MakeBodyEmitter(const ElementGenerator & target_element_generator,absl::Span<IrArray const> target_arrays,llvm::IRBuilder<> * b,bool is_tuple)68 BodyEmitter MakeBodyEmitter(const ElementGenerator& target_element_generator,
69                             absl::Span<IrArray const> target_arrays,
70                             llvm::IRBuilder<>* b, bool is_tuple) {
71   std::vector<IrArray> target_arrays_vec(target_arrays.begin(),
72                                          target_arrays.end());
73   if (!is_tuple) {
74     CHECK_EQ(target_arrays.size(), 1);
75     return [=](const llvm_ir::IrArray::Index array_index) -> Status {
76       // Convert target_element_generator to a BodyEmitter.
77       TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
78                           target_element_generator(array_index));
79       target_arrays_vec[0].EmitWriteArrayElement(array_index, target_element,
80                                                  b);
81       return OkStatus();
82     };
83   }
84 
85   return [=](const llvm_ir::IrArray::Index array_index) {
86     TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
87                         target_element_generator(array_index));
88     CHECK(target_element->getType()->isStructTy())
89         << "This BodyEmitter is for multi-output, but target element "
90            "generator does not produce values of struct type.";
91     CHECK_EQ(target_element->getType()->getStructNumElements(),
92              target_arrays_vec.size());
93 
94     for (int64_t i = 0; i < target_arrays_vec.size(); ++i) {
95       target_arrays_vec[i].EmitWriteArrayElement(
96           array_index, b->CreateExtractValue(target_element, i), b);
97     }
98     return OkStatus();
99   };
100 }
101 
EmitStaticIndex(ForLoopNest * loop_nest,llvm::Type * index_type)102 IrArray::Index LoopEmitter::EmitStaticIndex(ForLoopNest* loop_nest,
103                                             llvm::Type* index_type) {
104   // Create loop nest with one for-loop for each dimension of the target shape.
105   // Loops are added from outermost to innermost order with the ForLoopNest
106   // class so emit loops in order from most-major dimension down to most-minor
107   // dimension (of the target shape).
108   std::vector<llvm::Value*> array_multi_index(shape_.dimensions_size());
109   for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
110     int64_t dimension = LayoutUtil::Major(shape_.layout(), i);
111     std::unique_ptr<ForLoop> loop = loop_nest->AddLoop(
112         /*start_index=*/0,
113         /*end_index=*/shape_.dimensions(dimension),
114         /*suffix=*/absl::StrFormat("dim.%d", dimension));
115     array_multi_index[dimension] = loop->GetIndVarValue();
116   }
117   return IrArray::Index(array_multi_index, shape_, index_type);
118 }
119 
EmitDynamicIndex(ForLoopNest * loop_nest,llvm::Type * index_type)120 IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest,
121                                              llvm::Type* index_type) {
122   CHECK_EQ(shape_.is_dynamic(), true);
123   // Create loop nest with one for-loop for each dynamic dimensions.
124   // Loops are added from outermost to innermost order with the ForLoopNest
125   // class so emit loops in order from most-major dimension down to most-minor
126   // dimension (of the target shape).
127   std::vector<llvm::Value*> array_multi_index(shape_.dimensions_size());
128   for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
129     int64_t dimension = LayoutUtil::Major(shape_.layout(), i);
130     std::unique_ptr<ForLoop> loop = loop_nest->AddLoop(
131         /*suffix=*/absl::StrFormat("dim.%d", dimension),
132         /*start_index=*/llvm::ConstantInt::get(index_type, 0),
133         /*end_index=*/dynamic_dims_[dimension]);
134     array_multi_index[dimension] = loop->GetIndVarValue();
135   }
136   return IrArray::Index(array_multi_index, shape_, index_type);
137 }
138 
EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,llvm::Type * index_type,llvm::Value * base_index)139 std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
140     absl::string_view loop_name, llvm::Type* index_type,
141     llvm::Value* base_index) {
142   CHECK_NE(index_type, nullptr);
143   CHECK_EQ(base_index, nullptr)
144       << "XLA CPU implementation of"
145       << " LoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support"
146       << " base_index, but it was requested.";
147 
148   if (ShapeUtil::IsScalar(shape_)) {
149     // No loop needed, so set exit_bb_ to nullptr.
150     exit_bb_ = nullptr;
151     return {IrArray::Index(index_type)};
152   }
153 
154   ForLoopNest loop_nest(loop_name, b_);
155 
156   IrArray::Index array_index = dynamic_dims_.empty()
157                                    ? EmitStaticIndex(&loop_nest, index_type)
158                                    : EmitDynamicIndex(&loop_nest, index_type);
159 
160   // Set IR builder insertion point to the loop body basic block of the
161   // innermost loop.
162   llvm::BasicBlock* innermost_body_bb = loop_nest.GetInnerLoopBodyBasicBlock();
163   b_->SetInsertPoint(innermost_body_bb,
164                      innermost_body_bb->getFirstInsertionPt());
165 
166   // Set exit_bb_ to the exit block of the loop nest.
167   exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock();
168   CHECK_NOTNULL(exit_bb_);
169 
170   return {array_index};
171 }
172 
EmitLoop(absl::string_view loop_name,llvm::Type * index_type)173 Status LoopEmitter::EmitLoop(absl::string_view loop_name,
174                              llvm::Type* index_type) {
175   if (index_type == nullptr) {
176     index_type = b_->getInt64Ty();
177   }
178 
179   for (const IrArray::Index& array_index :
180        EmitIndexAndSetExitBasicBlock(loop_name, index_type,
181                                      /*base_index*/ nullptr)) {
182     TF_RETURN_IF_ERROR(body_emitter_(array_index));
183   }
184 
185   // Set the insertion point of b_ to the loop exit, so that
186   // code emitted for later instructions will be correctly placed.
187   if (exit_bb_ != nullptr) {
188     b_->SetInsertPoint(exit_bb_);
189   }
190   return OkStatus();
191 }
192 
193 }  // namespace llvm_ir
194 }  // namespace xla
195