1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <utility>
21
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/Value.h"
26 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
27 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
32 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
33 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/status_macros.h"
37 #include "tensorflow/compiler/xla/statusor.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/statusor.h"
42
43 namespace xla {
44
45 using llvm_ir::IrArray;
46
DefaultAction(const HloInstruction & instruction)47 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::DefaultAction(
48 const HloInstruction& instruction) {
49 IndexedGenerator generator = elemental_emitter_.MakeElementGenerator(
50 &instruction, indexed_generators_);
51
52 return StatusOr<IndexedGenerator>([&, generator = std::move(generator)](
53 const IrArray::Index& index)
54 -> StatusOr<llvm::Value*> {
55 ValueCacheKey key{&instruction, index.multidim()};
56 llvm::Value* value = value_cache_.insert({key, nullptr}).first->second;
57
58 if (value != nullptr) {
59 if (const auto* generated_instruction =
60 llvm::dyn_cast<llvm::Instruction>(value)) {
61 const llvm::BasicBlock* bb = generated_instruction->getParent();
62
63 // Ideally, we should be able to reuse the cached generated value if it
64 // dominates the current insertion block. However, the check for
65 // dominance can be expensive and unreliable when the function is being
66 // constructed.
67 //
68 // It's also worth experimenting what if we don't do caching at all.
69 // LLVM's CSE or GVN should be able to easily merge common
70 // subexpressions that would be regenerated without caching. But this
71 // might increase the JIT compilation time.
72 llvm::IRBuilder<>* b = elemental_emitter_.b();
73
74 if (bb == b->GetInsertBlock()) {
75 VLOG(3) << "The cached generated value is reused.";
76 return value;
77 }
78
79 VLOG(3)
80 << "The cached generated value can't be reused, because it is in "
81 "a different BB ("
82 << bb->getName().str() << ") from the current insertion block ("
83 << b->GetInsertBlock()->getName().str() << ").";
84 }
85 }
86
87 TF_ASSIGN_OR_RETURN(value, generator(index));
88 value_cache_[std::move(key)] = value;
89 return value;
90 });
91 }
92
HandleConstant(const HloInstruction & constant)93 FusedIrEmitter::IndexedGenerator FusedIrEmitter::HandleConstant(
94 const HloInstruction& constant) {
95 llvm::Module* module = elemental_emitter_.module();
96 llvm::IRBuilder<>* b = elemental_emitter_.b();
97
98 llvm::Constant* initializer =
99 llvm_ir::ConvertLiteralToIrConstant(constant.literal(), module);
100 llvm::GlobalVariable* global = new llvm::GlobalVariable(
101 *b->GetInsertBlock()->getModule(), initializer->getType(),
102 /*isConstant=*/true,
103 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
104 /*Initializer=*/initializer,
105 /*Name=*/"", /*InsertBefore=*/nullptr,
106 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
107 /*AddressSpace=*/0,
108 /*isExternallyInitialized=*/false);
109 global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
110
111 llvm::Type* shape_type = llvm_ir::ShapeToIrType(constant.shape(), module);
112 llvm::Constant* global_with_shape =
113 llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
114 global, shape_type->getPointerTo());
115
116 IrArray array(global_with_shape, shape_type, constant.shape());
117
118 return [&, b, array = std::move(array)](const IrArray::Index& index) {
119 return array.EmitReadArrayElement(index, b, constant.name());
120 };
121 }
122
HandleTuple(const HloInstruction & tuple)123 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::HandleTuple(
124 const HloInstruction& tuple) {
125 std::vector<llvm::Type*> element_ir_types;
126 element_ir_types.reserve(tuple.operand_count());
127 for (const HloInstruction* operand : tuple.operands()) {
128 element_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
129 operand->shape().element_type(), elemental_emitter_.module()));
130 }
131
132 llvm::IRBuilder<>* b = elemental_emitter_.b();
133 llvm::Type* type = llvm::StructType::get(b->getContext(), element_ir_types);
134
135 return StatusOr<IndexedGenerator>(
136 [&, b, type](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
137 llvm::Value* ret = llvm::UndefValue::get(type);
138 for (size_t i = 0; i < tuple.operand_count(); ++i) {
139 TF_ASSIGN_OR_RETURN(llvm::Value * value,
140 indexed_generators_.at(tuple.operand(i))(index));
141 ret = b->CreateInsertValue(ret, value, i);
142 }
143 return ret;
144 });
145 }
146
IsFusedIrEmitterInefficient(const HloInstruction & consumer,const HloInstruction & producer)147 bool FusedIrEmitter::IsFusedIrEmitterInefficient(
148 const HloInstruction& consumer, const HloInstruction& producer) {
149 if (consumer.opcode() != HloOpcode::kFusion) {
150 return false;
151 }
152 FusionNodeIndexingEvaluation eval_consumer(&consumer);
153 if (producer.opcode() != HloOpcode::kFusion) {
154 return eval_consumer.CodeDuplicationTooHigh(&producer);
155 }
156 // If 'producer' is a fusion node as well, also evaluate it. Pass the
157 // evaluated duplication of the fusion node if it is merged into consumer.
158 FusionNodeIndexingEvaluation eval_producer(
159 &producer, eval_consumer.EvaluateEmittedInstructions(&producer));
160 return eval_producer.MaxCodeDuplicationTooHigh();
161 }
162
CreateGenerator(const HloInstruction & instruction)163 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::CreateGenerator(
164 const HloInstruction& instruction) {
165 switch (instruction.opcode()) {
166 case HloOpcode::kConstant:
167 return HandleConstant(instruction);
168 case HloOpcode::kGetTupleElement:
169 return InternalError("Tuple parameters are not supported for fusion");
170 case HloOpcode::kParameter:
171 return InvalidArgument("Unbound parameter: %s", instruction.ToString());
172 case HloOpcode::kTuple:
173 return HandleTuple(instruction);
174 default:
175 return DefaultAction(instruction);
176 }
177 }
178
GetGenerator(const HloInstruction & instruction)179 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator(
180 const HloInstruction& instruction) {
181 std::vector<const HloInstruction*> stack = {&instruction};
182 while (!stack.empty()) {
183 const HloInstruction& instr = *stack.back();
184 stack.pop_back();
185
186 IndexedGenerator& indexed_generator = indexed_generators_[&instr];
187 if (indexed_generator != nullptr) continue;
188
189 stack.insert(stack.end(), instr.operands().begin(), instr.operands().end());
190 TF_ASSIGN_OR_RETURN(indexed_generator, CreateGenerator(instr));
191 }
192 return indexed_generators_[&instruction];
193 }
194
195 } // namespace xla
196