xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/ir_emitter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
18 
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Value.h"
30 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
34 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
35 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
36 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
37 #include "tensorflow/compiler/xla/service/hlo_computation.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 
48 namespace xla {
49 namespace gpu {
50 
51 // Abstract base class for translating HLO graphs to LLVM IR for a GPU.
52 //
53 // There are two concrete subclasses of IrEmitter: IrEmitterNested and
54 // IrEmitterUnnested.  In the unnested variety, each HLO gets its own kernel
55 // function, whereas in the nested version the whole computation is emitted as
56 // one *non-kernel* function.
57 //
58 // In XLA, kernel functions never call other kernel functions.  This means that
59 // if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use
60 // an HLO computation as a "subroutine" -- e.g. the HLO computation that
61 // specifies how to reduce two elements -- then the subroutine computation must
62 // be emitted using IrEmitterNested.
63 //
64 // Fusion nodes are a special case.  A fusion node is emitted using
65 // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is
66 // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR
67 // generator generator.  See comments on that class.
68 class IrEmitter : public DfsHloVisitorWithDefault,
69                   public IrBuilderMixin<IrEmitter> {
70  public:
71   IrEmitter(const IrEmitter&) = delete;
72   IrEmitter& operator=(const IrEmitter&) = delete;
73 
74   Status DefaultAction(HloInstruction* hlo) override;
75   Status HandleConstant(HloInstruction* constant) override;
76   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
77   Status HandleConvolution(HloInstruction* convolution) override;
78   Status HandleFft(HloInstruction* fft) override;
79   Status HandleAllReduce(HloInstruction* crs) override;
80   Status HandleInfeed(HloInstruction* infeed) override;
81   Status HandleOutfeed(HloInstruction* outfeed) override;
82   Status HandleSend(HloInstruction* send) override;
83   Status HandleSendDone(HloInstruction* send_done) override;
84   Status HandleRecv(HloInstruction* recv) override;
85   Status HandleRecvDone(HloInstruction* recv_done) override;
86   Status HandleParameter(HloInstruction* parameter) override;
87   Status HandleTuple(HloInstruction* tuple) override;
88   Status HandleScatter(HloInstruction* scatter) override;
89   Status HandleFusion(HloInstruction* fusion) override;
90   Status HandleCall(HloInstruction* call) override;
91   Status HandleCustomCall(HloInstruction* custom_call) override;
92   Status HandleBatchNormInference(HloInstruction* batch_norm) override;
93   Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
94   Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
95   Status HandleAddDependency(HloInstruction* add_dependency) override;
96 
FinishVisit(HloInstruction * root)97   Status FinishVisit(HloInstruction* root) override { return OkStatus(); }
98 
builder()99   llvm::IRBuilder<>* builder() { return &b_; }
100 
101  protected:
102   // Constructs an IrEmitter with the given IrEmitter context.
103   // ir_emitter_context is owned by the caller and should outlive the IrEmitter
104   // object.
105   explicit IrEmitter(const HloModuleConfig& hlo_module_config,
106                      IrEmitterContext* ir_emitter_context, bool is_nested);
107 
108   // Helper for calling HloToIrBindings::GetIrArray.
109   //
110   // Gets the IrArray which contains inst.  This array has metadata that makes
111   // it valid only within the IR that implements consumer.  If you are
112   // implementing an HLO and want to get its own output buffer, call
113   // GetIrArray(hlo, hlo).
114   llvm_ir::IrArray GetIrArray(const HloInstruction& inst,
115                               const HloInstruction& consumer,
116                               const ShapeIndex& shape_index = {}) {
117     return bindings_.GetIrArray(inst, consumer, shape_index);
118   }
119   // A convenient helper for calling HloToIrBindings::GetBasePointer.
120   llvm::Value* GetBasePointer(const HloInstruction& inst,
121                               ShapeIndexView shape_index = {}) const {
122     return bindings_.GetBasePointer(inst, shape_index);
123   }
124 
125   // Generates the IrArray for each output of an hlo instruction and returns
126   // a vector containing such IrArrays.
127   std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs(
128       const HloInstruction& hlo);
129 
130   // Emit a singlethreaded or multithreaded loop that computes every element in
131   // the result of the given HLO instruction. This produces a series of nested
132   // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the
133   // inner-most loop is provided by the body_emitter function.
134   virtual Status EmitTargetElementLoop(
135       const HloInstruction& hlo,
136       const llvm_ir::ElementGenerator& body_emitter) = 0;
137 
138   // Emits a call in IR to the given nested computation with the given operands
139   // and output. If no IR function has been previously emitted for the
140   // computation, also emits such a function.
141   Status EmitCallToNestedComputation(const HloComputation& nested_computation,
142                                      absl::Span<llvm::Value* const> operands,
143                                      llvm::Value* output);
144 
145   // Emits an atomic operation that implements `nested_computation` in the
146   // sequentially consistent memory model. `output_address` and `source_address`
147   // are the arguments of the nested computation. For example,
148   // atomicAdd(output_address, *source_address).
149   Status EmitAtomicOperationForNestedComputation(
150       const HloComputation& nested_computation, llvm::Value* output_address,
151       llvm::Value* source_address, llvm::Type* element_type);
152 
GetNestedComputer()153   GpuElementalIrEmitter::NestedComputer GetNestedComputer() {
154     return [&](const HloComputation& computation,
155                absl::Span<llvm::Value* const> parameter_elements) {
156       return ComputeNestedElement(computation, parameter_elements);
157     };
158   }
159 
160   StatusOr<std::vector<llvm::Value*>> ComputeNestedElement(
161       const HloComputation& computation,
162       absl::Span<llvm::Value* const> parameter_elements);
163 
164   StatusOr<std::vector<llvm::Value*>> ComputeNestedElementFromAddrs(
165       const HloComputation& computation,
166       absl::Span<llvm::Value* const> parameter_elements_addrs);
167 
168   IrEmitterContext* ir_emitter_context_;
169   llvm::Module* module_;
170 
171   // The following fields track the IR emission state. According to LLVM memory
172   // management rules, their memory is owned by the module.
173   llvm::IRBuilder<> b_;
174 
175   // Mapping from HLO to its underlying LLVM value.
176   HloToIrBindings bindings_;
177 
178   // Hlo configuration data used during code generation.
179   const HloModuleConfig& hlo_module_config_;
180 
181   // Bind all argument IrArrays of `fusion` to `fused_emitter`.
182   void BindFusionArguments(const HloInstruction* fusion,
183                            FusedIrEmitter* fused_emitter);
184 
185  private:
186   // A helper method for EmitAtomicOperationForNestedComputation. Certain
187   // computations, such as floating-point addition and integer maximization, can
188   // be simply implemented using an LLVM atomic instruction. If "computation" is
189   // one of this kind, emits code to do that and returns true; otherwise,
190   // returns false.
191   bool MaybeEmitDirectAtomicOperation(const HloComputation& computation,
192                                       llvm::Value* output_address,
193                                       llvm::Value* source_address);
194 
195   // A helper method for EmitAtomicOperationForNestedComputation. It implements
196   // binary atomic operations using atomicCAS with special handling to support
197   // small data types.
198   Status EmitAtomicOperationUsingCAS(const HloComputation& computation,
199                                      llvm::Value* output_address,
200                                      llvm::Value* source_address,
201                                      llvm::Type* element_type);
202 
203   // A helper method for HandleSort(). It adds the inner comparison loop where
204   // we compare elements pointed to by 'keys_index' and 'compare_keys_index'.
205   void EmitCompareLoop(int64_t dimension_to_sort,
206                        const llvm_ir::IrArray::Index& keys_index,
207                        const llvm_ir::IrArray::Index& compare_keys_index,
208                        const llvm_ir::IrArray& keys_array);
209 
210   // Emits an atomic operation that implements `nested_computation` in the
211   // sequentially consistent memory model. `output_address` and `source_address`
212   // are the arguments of the nested computation. For example,
213   // atomicAdd(output_address, *source_address).
214   StatusOr<llvm::Function*> EmitAtomicFunctionForNestedComputation(
215       const HloComputation& nested_computation, llvm::Type* element_ir_type);
216 
217   // A convenience method to determine whether or not IR is emitted for AMDGPU.
218   bool IsEmittingForAMDGPU() const;
219 
220   // Emits atomic add operation for AMD GPU.
221   void EmitAMDGPUAtomicAdd(llvm::Value* output_address, llvm::Value* source);
222 
223   // A convenience method to determine the proper sync scope for an atomic op.
224   llvm::SyncScope::ID DetermineSyncScope() const;
225 
226   // Map nested computations to emitted IR functions. This serves as a cache so
227   // that IrEmitter does not emit multiple functions for the same
228   // HloComputation.
229   std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_;
230 };
231 
232 }  // namespace gpu
233 }  // namespace xla
234 
235 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
236