1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ 18 19 #include <functional> 20 #include <map> 21 #include <memory> 22 #include <utility> 23 #include <vector> 24 25 #include "absl/strings/string_view.h" 26 #include "absl/types/span.h" 27 #include "llvm/IR/Function.h" 28 #include "llvm/IR/IRBuilder.h" 29 #include "llvm/IR/Value.h" 30 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" 33 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" 34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" 35 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" 36 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 37 #include "tensorflow/compiler/xla/service/hlo_computation.h" 38 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 39 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 40 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 41 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" 42 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 43 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 44 #include "tensorflow/compiler/xla/statusor.h" 45 #include "tensorflow/compiler/xla/types.h" 46 #include "tensorflow/compiler/xla/xla_data.pb.h" 47 48 namespace xla { 49 namespace gpu { 50 51 // Abstract base class for translating HLO graphs to LLVM IR for a GPU. 52 // 53 // There are two concrete subclasses of IrEmitter: IrEmitterNested and 54 // IrEmitterUnnested. In the unnested variety, each HLO gets its own kernel 55 // function, whereas in the nested version the whole computation is emitted as 56 // one *non-kernel* function. 57 // 58 // In XLA, kernel functions never call other kernel functions. This means that 59 // if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use 60 // an HLO computation as a "subroutine" -- e.g. the HLO computation that 61 // specifies how to reduce two elements -- then the subroutine computation must 62 // be emitted using IrEmitterNested. 63 // 64 // Fusion nodes are a special case. A fusion node is emitted using 65 // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is 66 // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR 67 // generator generator. See comments on that class. 68 class IrEmitter : public DfsHloVisitorWithDefault, 69 public IrBuilderMixin<IrEmitter> { 70 public: 71 IrEmitter(const IrEmitter&) = delete; 72 IrEmitter& operator=(const IrEmitter&) = delete; 73 74 Status DefaultAction(HloInstruction* hlo) override; 75 Status HandleConstant(HloInstruction* constant) override; 76 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 77 Status HandleConvolution(HloInstruction* convolution) override; 78 Status HandleFft(HloInstruction* fft) override; 79 Status HandleAllReduce(HloInstruction* crs) override; 80 Status HandleInfeed(HloInstruction* infeed) override; 81 Status HandleOutfeed(HloInstruction* outfeed) override; 82 Status HandleSend(HloInstruction* send) override; 83 Status HandleSendDone(HloInstruction* send_done) override; 84 Status HandleRecv(HloInstruction* recv) override; 85 Status HandleRecvDone(HloInstruction* recv_done) override; 86 Status HandleParameter(HloInstruction* parameter) override; 87 Status HandleTuple(HloInstruction* tuple) override; 88 Status HandleScatter(HloInstruction* scatter) override; 89 Status HandleFusion(HloInstruction* fusion) override; 90 Status HandleCall(HloInstruction* call) override; 91 Status HandleCustomCall(HloInstruction* custom_call) override; 92 Status HandleBatchNormInference(HloInstruction* batch_norm) override; 93 Status HandleBatchNormTraining(HloInstruction* batch_norm) override; 94 Status HandleBatchNormGrad(HloInstruction* batch_norm) override; 95 Status HandleAddDependency(HloInstruction* add_dependency) override; 96 FinishVisit(HloInstruction * root)97 Status FinishVisit(HloInstruction* root) override { return OkStatus(); } 98 builder()99 llvm::IRBuilder<>* builder() { return &b_; } 100 101 protected: 102 // Constructs an IrEmitter with the given IrEmitter context. 103 // ir_emitter_context is owned by the caller and should outlive the IrEmitter 104 // object. 105 explicit IrEmitter(const HloModuleConfig& hlo_module_config, 106 IrEmitterContext* ir_emitter_context, bool is_nested); 107 108 // Helper for calling HloToIrBindings::GetIrArray. 109 // 110 // Gets the IrArray which contains inst. This array has metadata that makes 111 // it valid only within the IR that implements consumer. If you are 112 // implementing an HLO and want to get its own output buffer, call 113 // GetIrArray(hlo, hlo). 114 llvm_ir::IrArray GetIrArray(const HloInstruction& inst, 115 const HloInstruction& consumer, 116 const ShapeIndex& shape_index = {}) { 117 return bindings_.GetIrArray(inst, consumer, shape_index); 118 } 119 // A convenient helper for calling HloToIrBindings::GetBasePointer. 120 llvm::Value* GetBasePointer(const HloInstruction& inst, 121 ShapeIndexView shape_index = {}) const { 122 return bindings_.GetBasePointer(inst, shape_index); 123 } 124 125 // Generates the IrArray for each output of an hlo instruction and returns 126 // a vector containing such IrArrays. 127 std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs( 128 const HloInstruction& hlo); 129 130 // Emit a singlethreaded or multithreaded loop that computes every element in 131 // the result of the given HLO instruction. This produces a series of nested 132 // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the 133 // inner-most loop is provided by the body_emitter function. 134 virtual Status EmitTargetElementLoop( 135 const HloInstruction& hlo, 136 const llvm_ir::ElementGenerator& body_emitter) = 0; 137 138 // Emits a call in IR to the given nested computation with the given operands 139 // and output. If no IR function has been previously emitted for the 140 // computation, also emits such a function. 141 Status EmitCallToNestedComputation(const HloComputation& nested_computation, 142 absl::Span<llvm::Value* const> operands, 143 llvm::Value* output); 144 145 // Emits an atomic operation that implements `nested_computation` in the 146 // sequentially consistent memory model. `output_address` and `source_address` 147 // are the arguments of the nested computation. For example, 148 // atomicAdd(output_address, *source_address). 149 Status EmitAtomicOperationForNestedComputation( 150 const HloComputation& nested_computation, llvm::Value* output_address, 151 llvm::Value* source_address, llvm::Type* element_type); 152 GetNestedComputer()153 GpuElementalIrEmitter::NestedComputer GetNestedComputer() { 154 return [&](const HloComputation& computation, 155 absl::Span<llvm::Value* const> parameter_elements) { 156 return ComputeNestedElement(computation, parameter_elements); 157 }; 158 } 159 160 StatusOr<std::vector<llvm::Value*>> ComputeNestedElement( 161 const HloComputation& computation, 162 absl::Span<llvm::Value* const> parameter_elements); 163 164 StatusOr<std::vector<llvm::Value*>> ComputeNestedElementFromAddrs( 165 const HloComputation& computation, 166 absl::Span<llvm::Value* const> parameter_elements_addrs); 167 168 IrEmitterContext* ir_emitter_context_; 169 llvm::Module* module_; 170 171 // The following fields track the IR emission state. According to LLVM memory 172 // management rules, their memory is owned by the module. 173 llvm::IRBuilder<> b_; 174 175 // Mapping from HLO to its underlying LLVM value. 176 HloToIrBindings bindings_; 177 178 // Hlo configuration data used during code generation. 179 const HloModuleConfig& hlo_module_config_; 180 181 // Bind all argument IrArrays of `fusion` to `fused_emitter`. 182 void BindFusionArguments(const HloInstruction* fusion, 183 FusedIrEmitter* fused_emitter); 184 185 private: 186 // A helper method for EmitAtomicOperationForNestedComputation. Certain 187 // computations, such as floating-point addition and integer maximization, can 188 // be simply implemented using an LLVM atomic instruction. If "computation" is 189 // one of this kind, emits code to do that and returns true; otherwise, 190 // returns false. 191 bool MaybeEmitDirectAtomicOperation(const HloComputation& computation, 192 llvm::Value* output_address, 193 llvm::Value* source_address); 194 195 // A helper method for EmitAtomicOperationForNestedComputation. It implements 196 // binary atomic operations using atomicCAS with special handling to support 197 // small data types. 198 Status EmitAtomicOperationUsingCAS(const HloComputation& computation, 199 llvm::Value* output_address, 200 llvm::Value* source_address, 201 llvm::Type* element_type); 202 203 // A helper method for HandleSort(). It adds the inner comparison loop where 204 // we compare elements pointed to by 'keys_index' and 'compare_keys_index'. 205 void EmitCompareLoop(int64_t dimension_to_sort, 206 const llvm_ir::IrArray::Index& keys_index, 207 const llvm_ir::IrArray::Index& compare_keys_index, 208 const llvm_ir::IrArray& keys_array); 209 210 // Emits an atomic operation that implements `nested_computation` in the 211 // sequentially consistent memory model. `output_address` and `source_address` 212 // are the arguments of the nested computation. For example, 213 // atomicAdd(output_address, *source_address). 214 StatusOr<llvm::Function*> EmitAtomicFunctionForNestedComputation( 215 const HloComputation& nested_computation, llvm::Type* element_ir_type); 216 217 // A convenience method to determine whether or not IR is emitted for AMDGPU. 218 bool IsEmittingForAMDGPU() const; 219 220 // Emits atomic add operation for AMD GPU. 221 void EmitAMDGPUAtomicAdd(llvm::Value* output_address, llvm::Value* source); 222 223 // A convenience method to determine the proper sync scope for an atomic op. 224 llvm::SyncScope::ID DetermineSyncScope() const; 225 226 // Map nested computations to emitted IR functions. This serves as a cache so 227 // that IrEmitter does not emit multiple functions for the same 228 // HloComputation. 229 std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_; 230 }; 231 232 } // namespace gpu 233 } // namespace xla 234 235 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ 236