xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <utility>
21 
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
27 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
32 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
33 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/status_macros.h"
37 #include "tensorflow/compiler/xla/statusor.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/statusor.h"
42 
43 namespace xla {
44 
45 using llvm_ir::IrArray;
46 
DefaultAction(const HloInstruction & instruction)47 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::DefaultAction(
48     const HloInstruction& instruction) {
49   IndexedGenerator generator = elemental_emitter_.MakeElementGenerator(
50       &instruction, indexed_generators_);
51 
52   return StatusOr<IndexedGenerator>([&, generator = std::move(generator)](
53                                         const IrArray::Index& index)
54                                         -> StatusOr<llvm::Value*> {
55     ValueCacheKey key{&instruction, index.multidim()};
56     llvm::Value* value = value_cache_.insert({key, nullptr}).first->second;
57 
58     if (value != nullptr) {
59       if (const auto* generated_instruction =
60               llvm::dyn_cast<llvm::Instruction>(value)) {
61         const llvm::BasicBlock* bb = generated_instruction->getParent();
62 
63         // Ideally, we should be able to reuse the cached generated value if it
64         // dominates the current insertion block. However, the check for
65         // dominance can be expensive and unreliable when the function is being
66         // constructed.
67         //
68         // It's also worth experimenting what if we don't do caching at all.
69         // LLVM's CSE or GVN should be able to easily merge common
70         // subexpressions that would be regenerated without caching. But this
71         // might increase the JIT compilation time.
72         llvm::IRBuilder<>* b = elemental_emitter_.b();
73 
74         if (bb == b->GetInsertBlock()) {
75           VLOG(3) << "The cached generated value is reused.";
76           return value;
77         }
78 
79         VLOG(3)
80             << "The cached generated value can't be reused, because it is in "
81                "a different BB ("
82             << bb->getName().str() << ") from the current insertion block ("
83             << b->GetInsertBlock()->getName().str() << ").";
84       }
85     }
86 
87     TF_ASSIGN_OR_RETURN(value, generator(index));
88     value_cache_[std::move(key)] = value;
89     return value;
90   });
91 }
92 
HandleConstant(const HloInstruction & constant)93 FusedIrEmitter::IndexedGenerator FusedIrEmitter::HandleConstant(
94     const HloInstruction& constant) {
95   llvm::Module* module = elemental_emitter_.module();
96   llvm::IRBuilder<>* b = elemental_emitter_.b();
97 
98   llvm::Constant* initializer =
99       llvm_ir::ConvertLiteralToIrConstant(constant.literal(), module);
100   llvm::GlobalVariable* global = new llvm::GlobalVariable(
101       *b->GetInsertBlock()->getModule(), initializer->getType(),
102       /*isConstant=*/true,
103       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
104       /*Initializer=*/initializer,
105       /*Name=*/"", /*InsertBefore=*/nullptr,
106       /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
107       /*AddressSpace=*/0,
108       /*isExternallyInitialized=*/false);
109   global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
110 
111   llvm::Type* shape_type = llvm_ir::ShapeToIrType(constant.shape(), module);
112   llvm::Constant* global_with_shape =
113       llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
114           global, shape_type->getPointerTo());
115 
116   IrArray array(global_with_shape, shape_type, constant.shape());
117 
118   return [&, b, array = std::move(array)](const IrArray::Index& index) {
119     return array.EmitReadArrayElement(index, b, constant.name());
120   };
121 }
122 
HandleTuple(const HloInstruction & tuple)123 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::HandleTuple(
124     const HloInstruction& tuple) {
125   std::vector<llvm::Type*> element_ir_types;
126   element_ir_types.reserve(tuple.operand_count());
127   for (const HloInstruction* operand : tuple.operands()) {
128     element_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
129         operand->shape().element_type(), elemental_emitter_.module()));
130   }
131 
132   llvm::IRBuilder<>* b = elemental_emitter_.b();
133   llvm::Type* type = llvm::StructType::get(b->getContext(), element_ir_types);
134 
135   return StatusOr<IndexedGenerator>(
136       [&, b, type](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
137         llvm::Value* ret = llvm::UndefValue::get(type);
138         for (size_t i = 0; i < tuple.operand_count(); ++i) {
139           TF_ASSIGN_OR_RETURN(llvm::Value * value,
140                               indexed_generators_.at(tuple.operand(i))(index));
141           ret = b->CreateInsertValue(ret, value, i);
142         }
143         return ret;
144       });
145 }
146 
IsFusedIrEmitterInefficient(const HloInstruction & consumer,const HloInstruction & producer)147 bool FusedIrEmitter::IsFusedIrEmitterInefficient(
148     const HloInstruction& consumer, const HloInstruction& producer) {
149   if (consumer.opcode() != HloOpcode::kFusion) {
150     return false;
151   }
152   FusionNodeIndexingEvaluation eval_consumer(&consumer);
153   if (producer.opcode() != HloOpcode::kFusion) {
154     return eval_consumer.CodeDuplicationTooHigh(&producer);
155   }
156   // If 'producer' is a fusion node as well, also evaluate it. Pass the
157   // evaluated duplication of the fusion node if it is merged into consumer.
158   FusionNodeIndexingEvaluation eval_producer(
159       &producer, eval_consumer.EvaluateEmittedInstructions(&producer));
160   return eval_producer.MaxCodeDuplicationTooHigh();
161 }
162 
CreateGenerator(const HloInstruction & instruction)163 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::CreateGenerator(
164     const HloInstruction& instruction) {
165   switch (instruction.opcode()) {
166     case HloOpcode::kConstant:
167       return HandleConstant(instruction);
168     case HloOpcode::kGetTupleElement:
169       return InternalError("Tuple parameters are not supported for fusion");
170     case HloOpcode::kParameter:
171       return InvalidArgument("Unbound parameter: %s", instruction.ToString());
172     case HloOpcode::kTuple:
173       return HandleTuple(instruction);
174     default:
175       return DefaultAction(instruction);
176   }
177 }
178 
GetGenerator(const HloInstruction & instruction)179 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator(
180     const HloInstruction& instruction) {
181   std::vector<const HloInstruction*> stack = {&instruction};
182   while (!stack.empty()) {
183     const HloInstruction& instr = *stack.back();
184     stack.pop_back();
185 
186     IndexedGenerator& indexed_generator = indexed_generators_[&instr];
187     if (indexed_generator != nullptr) continue;
188 
189     stack.insert(stack.end(), instr.operands().begin(), instr.operands().end());
190     TF_ASSIGN_OR_RETURN(indexed_generator, CreateGenerator(instr));
191   }
192   return indexed_generators_[&instruction];
193 }
194 
195 }  // namespace xla
196