xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/ir_emitter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
18 
19 #include <stddef.h>
20 
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <ostream>
25 #include <string>
26 #include <vector>
27 
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/span.h"
31 #include "llvm/ADT/Triple.h"
32 #include "llvm/IR/Function.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/Value.h"
36 #include "llvm/Target/TargetMachine.h"
37 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
38 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
39 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
40 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
41 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
46 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
47 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
48 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
49 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
50 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
51 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
52 #include "tensorflow/compiler/xla/service/name_uniquer.h"
53 #include "tensorflow/compiler/xla/statusor.h"
54 #include "tensorflow/compiler/xla/types.h"
55 #include "tensorflow/compiler/xla/xla_data.pb.h"
56 
57 namespace xla {
58 namespace cpu {
59 // This class is the top-level API for the XLA HLO --> LLVM IR compiler.  It
60 // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
61 // functions.
62 class IrEmitter : public DfsHloVisitorWithDefault,
63                   public IrBuilderMixin<IrEmitter> {
64   friend class CpuElementalIrEmitter;
65 
66  public:
67   using GeneratorForOperandIrArrays =
68       std::function<std::vector<llvm_ir::IrArray>()>;
69 
70   // Create a new LLVM IR emitter.
71   //
72   // hlo_module: the HLO module we are emitting IR for.
73   // assignment: a BufferAssignment from which we know which buffers are used by
74   //             the HLO nodes.
75   // mlir_context: the MLIR context used for IR emission.
76   // llvm_module: the LLVM module to emit IR into. It's built using the LLVM
77   //              context inside of mlir_context.
78   // instruction_to_profile_idx: the mapping from HLO instructions to their
79   //              index in the profiling array.
80   // computation_to_profile_idx: the mapping from HLO computations to their
81   //              index in the profiling array.
82   // computation_transitively_contains_custom_call: the mapping from HLO
83   //   computations to whether or not they transitively contain a custom-call
84   //   instruction. All computations in the module must have a key in this
85   //   map.
86   // emit_code_for_msan: whether emitted code should be compatible with msan.
87   IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
88             const BufferAssignment& assignment, llvm::Module* llvm_module,
89             absl::flat_hash_map<const HloInstruction*, int64_t>
90                 instruction_to_profile_idx,
91             absl::flat_hash_map<const HloComputation*, int64_t>
92                 computation_to_profile_idx,
93             absl::flat_hash_map<const HloComputation*, bool>
94                 computation_transitively_contains_custom_call,
95             const TargetMachineFeatures* target_machine,
96             bool emit_code_for_msan);
97   ~IrEmitter() override;
98 
99   // Emit and return the given HLO computation as an LLVM IR
100   // function.
101   //
102   // function_name_prefix is the desired name of the function. If the name is
103   // not unique among already emitted functions then a suffix is appended to
104   // make the name unique.
105   //
106   // 'is_top_level_computation' has the following meanings for each CPU backend:
107   // *) sequential: indicates that this is the entry computation of the HLO
108   //    module.
109   // *) parallel: indices that this is the callee of a kCall HLO in the entry
110   //    computation of the HLO module.
111   //
112   // If 'instruction_order' is not NULL, then the HLO instructions are emitted
113   // in the given order.  In this case, 'instruction_order' must be a
114   // topological sort of the set of nodes accessible from the root of the
115   // computation.
116   //
117   // If 'allow_reassociation' is true, the fast-math reassociation flag will
118   // be enabled in the function's body. This is used when emitting reducers.
119   StatusOr<llvm::Function*> EmitComputation(
120       HloComputation* computation, const std::string& function_name_prefix,
121       bool is_top_level_computation,
122       absl::Span<HloInstruction* const> instruction_order,
123       bool allow_reassociation);
124 
b()125   llvm::IRBuilder<>* b() { return &b_; }
126 
127   // builder() is for IrBuilderMixin.
builder()128   llvm::IRBuilder<>* builder() { return &b_; }
129 
130   // Emit an LLVM global variable for every constant buffer allocation.
131   Status EmitConstantGlobals();
132 
133  protected:
134   //
135   // The following methods implement the DfsHloVisitor interface.
136   //
137   // Default action which emits code for most operations. Operations which are
138   // special in some way are handled explicitly in HandleFoo methods.
139   Status DefaultAction(HloInstruction* hlo) override;
140 
141   Status HandleAllToAll(HloInstruction* instruction) override;
142   Status HandleBitcast(HloInstruction* bitcast) override;
143   Status HandleConstant(HloInstruction* constant) override;
144   Status HandleCopy(HloInstruction* copy) override;
145   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
146   Status HandleSelect(HloInstruction* select) override;
147   Status HandleDot(HloInstruction* dot) override;
148   Status HandleConvolution(HloInstruction* convolution) override;
149   Status HandleFft(HloInstruction* fft) override;
150   Status HandleAllReduce(HloInstruction* crs) override;
151   Status HandleCollectivePermute(HloInstruction* crs) override;
152   Status HandleInfeed(HloInstruction* infeed) override;
153   Status HandleOutfeed(HloInstruction* outfeed) override;
154   Status HandleSort(HloInstruction* hlo) override;
155   Status HandleParameter(HloInstruction* parameter) override;
156   Status HandleReduce(HloInstruction* reduce) override;
157   Status HandleReduceWindow(HloInstruction* reduce_window) override;
158   Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
159   Status HandleSend(HloInstruction* send) override;
160   Status HandleSendDone(HloInstruction* send_done) override;
161   Status HandleSlice(HloInstruction* slice) override;
162   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
163   Status HandleDynamicUpdateSlice(
164       HloInstruction* dynamic_update_slice) override;
165   Status HandleRecv(HloInstruction* recv) override;
166   Status HandleRecvDone(HloInstruction* recv_done) override;
167   Status HandlePad(HloInstruction* pad) override;
168   Status HandleTuple(HloInstruction* tuple) override;
169   Status HandleFusion(HloInstruction* fusion) override;
170   Status HandleCall(HloInstruction* call) override;
171   Status HandleCustomCall(HloInstruction* custom_call) override;
172   Status HandleWhile(HloInstruction* xla_while) override;
173   Status HandleConcatenate(HloInstruction* concatenate) override;
174   Status HandleConditional(HloInstruction* conditional) override;
175   Status HandleScatter(HloInstruction* scatter) override;
176   Status HandleAfterAll(HloInstruction* after_all) override;
177   Status HandleAddDependency(HloInstruction* add_dependency) override;
178   Status HandlePartitionId(HloInstruction* hlo) override;
179   Status HandleReplicaId(HloInstruction* hlo) override;
180   Status HandleRng(HloInstruction* rng) override;
181   Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override;
182   Status FinishVisit(HloInstruction* root) override;
183 
184   Status Preprocess(HloInstruction* hlo) override;
185   Status Postprocess(HloInstruction* hlo) override;
186 
187   // A convenient helper for calling BufferAssignment::GetUniqueSlice.
188   BufferAllocation::Slice GetAllocationSlice(
189       const HloInstruction& hlo, const ShapeIndex& index = {}) const {
190     return assignment_.GetUniqueSlice(&hlo, index).value();
191   }
192 
193  private:
194   Status HandleSliceToDynamic(HloInstruction* hlo);
195   Status HandlePadToStatic(HloInstruction* hlo);
196   Status HandleTopK(HloInstruction* hlo);
197   Status HandleAllReduceSingleReplica(HloInstruction* crs);
198   Status HandleAllReduceMultipleReplica(HloInstruction* crs);
199 
200   // Private helper to initialize an IR function for the computation.
201   void InitializeIrFunction(const std::string& function_name);
202 
203   // Emits the copying epilogue for the function,
204   // where it copies the returned value to the reserved alloca.
205   // This is only necessary for thread-local functions.
206   // Note that since the call graph is flattened, if the same function is
207   // called in both thread-local and non-thread-local it would be codegen'd
208   // twice, and we would know whether it's thread-local at codegen time.
209   void EmitThreadLocalFunctionEpilogue(HloComputation* computation);
210 
211   // Convenience functions to generate a GEP into the profile counter parameter
212   // which would correspond to the index for a given HLO instruction or
213   // computation.
214   llvm::Value* GetProfileCounterFor(const HloInstruction& instruction);
215   llvm::Value* GetProfileCounterFor(const HloComputation& computation);
216 
217   // Helper function template for the implementation of the above two functions.
218   template <typename T>
219   llvm::Value* GetProfileCounterCommon(
220       const T& hlo,
221       const absl::flat_hash_map<const T*, int64_t>& profile_index_map);
222 
223   // Gets the IR Value emitted previously for the given hlo.
224   //
225   // Prefer calling GetIrArrayFor if the value you're reading is a buffer,
226   // because GetIrArrayFor annotates buffer's loads/stores with noalias
227   // metadata.
228   //
229   // Make sure to call this only when you're certain a value *was* emitted - if
230   // not found, this will log a fatal error.
231   llvm::Value* GetEmittedValueFor(const HloInstruction* hlo);
232 
233   // Gets an IrArray representing the given hlo.
234   llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo);
235 
236   // Gets a list of IrArrays, one for each of hlo's operands.
237   std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
238       const HloInstruction* hlo);
239 
240   // Bind all argument IrArrays of `fusion` to `fused_emitter`.
241   void BindFusionArguments(const HloInstruction* fusion,
242                            FusedIrEmitter* fused_emitter);
243 
244   // Augments IrArray with aliasing information.
AddAliasingInformationToIrArray(const HloInstruction & hlo,llvm_ir::IrArray * array)245   void AddAliasingInformationToIrArray(const HloInstruction& hlo,
246                                        llvm_ir::IrArray* array) {
247     alias_analysis_.AddAliasingInformationToIrArray(hlo, array);
248   }
249 
250   // Convenience function to get the IR type matching the given shape.
251   llvm::Type* IrShapeType(const Shape& shape);
252 
253   // Get the llvm::Value* that represents the "prof_counters" argument of the
254   // computation function being emitted by this emitter.
255   llvm::Value* GetProfileCountersArgument();
256 
257   // Get the llvm::Value* that represents the "status" argument of the
258   // computation function being emitted by this emitter.
259   llvm::Value* GetStatusArgument();
260 
261   // Get the xla::ExecutableRunOptions that represents the "run_options"
262   // argument of the computation function being emitted by this emitter.
263   llvm::Value* GetExecutableRunOptionsArgument();
264 
265   // Get the llvm::Value* that represents the "buffer_table" argument of the
266   // computation function being emitted by this emitter.
267   llvm::Value* GetBufferTableArgument();
268 
269   // Get the llvm::BasicBlock that contains the return instruction.
270   llvm::BasicBlock* GetReturnBlock();
271 
272   // Emits code to check the state of the status object being threaded through
273   // each computation and return early if it's in an error state.
274   void EmitEarlyReturnIfErrorStatus();
275 
276   // Helper for EmitBufferPointer.
277   llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice,
278                                        const Shape& target_shape);
279 
280   // Helper for EmitBufferPointer.
281   llvm::Value* EmitThreadLocalBufferPointer(
282       const BufferAllocation::Slice& slice, const Shape& target_shape);
283 
284   // Emits code that computes the address of the given buffer allocation slice.
285   llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice,
286                                  const Shape& target_shape);
287 
288   // Emits a call to a thread local function (e.g. to the computation nested
289   // within a reduce or a map).  Thread local callees (by definition) only write
290   // to and read from thread local allocations.
291   // Supports only functions returning scalars or tuples of scalars.
292   //
293   // `parameters` holds the *scalar values* that need to be passed to the
294   // callee.  The return value is the scalar returned by the callee.
295   std::vector<llvm::Value*> EmitThreadLocalCall(
296       const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
297       absl::string_view name, bool is_reducer);
298 
299   // Similar to EmitThreadLocal, yet assumes that the function returns a scalar.
300   llvm::Value* EmitScalarReturningThreadLocalCall(
301       const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
302       absl::string_view name);
303 
304   // Emits a call to a "global" function (e.g. to the computation nested within
305   // a kWhile or a kCall).  Buffer assignment unabiguously assigns buffers to
306   // the parameters and return values for these computations so there is no need
307   // to explicitly pass parameters or return results.
308   void EmitGlobalCall(const HloComputation& callee, absl::string_view name);
309 
310   // Returns the buffer to which a global call to `callee` would have written
311   // its result.
312   llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee);
313 
314   // Verifies that the element types of all of the given operand instructions
315   // match and are of one of the given supported types.
316   Status ElementTypesSameAndSupported(
317       const HloInstruction& instruction,
318       absl::Span<const HloInstruction* const> operands,
319       absl::Span<const PrimitiveType> supported_types);
320 
321   // Emit IR to perform a computation for every element in the given target op.
322   // This produces a series of nested loops (one for each dimension of the op's
323   // shape). The body of the inner-most loop is provided by the body_emitter
324   // function.
325   //
326   // desc is an optional human-readable string that's added to the loop name in
327   // IR.  Regardless of whether desc is provided, target_op->name() is included
328   // in the loop name.
329   //
330   // TODO(jingyue): target_op should be a `const HloInstruction*`.
331   Status EmitTargetElementLoop(
332       HloInstruction* target_op,
333       const llvm_ir::ElementGenerator& element_generator);
334   Status EmitTargetElementLoop(
335       HloInstruction* target_op, absl::string_view desc,
336       const llvm_ir::ElementGenerator& element_generator);
337 
338   // Emits a memcpy from the source instruction's result value to the
339   // destination's.  Both source and destination must have an entry in the
340   // emitted_value_ table.
341   Status EmitMemcpy(const HloInstruction& source,
342                     const HloInstruction& destination);
343 
344   // Emits IR to compute the target address of the buffer for the given op.
345   // After calling this function, you can get a pointer to this buffer by
346   // calling GetIrArrayForOp or GetEmittedValueFor.
347   Status EmitTargetAddressForOp(const HloInstruction* op);
348 
349   // Structurizes "array_elements" into an MD array that represents "shape".
350   // This is a recursive function, and "dimension_index" indicates the index of
351   // the current dimension that the function is considering (0 means the
352   // most-minor dimension).
353   llvm::Constant* CreateInitializerForConstantArray(
354       const std::vector<llvm::Constant*>& array_elements, const Shape& shape,
355       int64_t dimension_index);
356 
357   // Tries to codegen a reduction operation using vectorized instructions.
358   // Returns true if successful, and false on failure.  On failure, sets
359   // "failure_reason" to a string describing why it could not vectorize the
360   // reduction.
361   //
362   // TODO(sanjoy): Some of the things we do here can be abstracted out into
363   // concepts that generalize over other vectorizable operations.  We should
364   // consider pulling out these abstractions into a VectorizingIrEmitter or
365   // something similar.
366   StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce,
367                                       HloInstruction* arg,
368                                       HloInstruction* init_value,
369                                       absl::Span<const int64_t> dimensions,
370                                       HloComputation* function,
371                                       std::string* failure_reason);
372 
373   // We'd like to keep one or two one cache-line's worth of data in registers
374   // without generating IR with illegal (e.g. excessively large or
375   // non-power-of-two) vector types.  We do this by introducing a layer of
376   // abstraction: we introduce a high level vector-like concept called a
377   // "sharded vector" that models data parallelism, and is mapped to a sequence
378   // scalar and vector llvm::Value s.
379   //
380   // For example, we can represent 29 f32 elements by a sharded vector mapped to
381   // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32].
382   // Note that the last element is scalar.
383   //
384   // There is no requirement on the ordering or the uniqueness of the elements
385   // mapped to sharded vectors -- we allow repeated elements, and we allow
386   // elements to appear in any order.
387   using ShardedVector = std::vector<llvm::Value*>;
388 
389   // A sharded vector type is the element-wise llvm::Type's of some
390   // ShardedVector.
391   using ShardedVectorType = std::vector<llvm::Type*>;
392 
393   // Create a sharded vector type corresponding to a "element_count" long
394   // sequence of "element_type" values.
395   ShardedVectorType CreateShardedVectorType(PrimitiveType element_type,
396                                             unsigned element_count);
397 
398   // Emit LLVM IR to store the sharded vector "value_to_store" to
399   // "store_address".
400   void EmitShardedVectorStore(llvm::Value* store_address,
401                               const ShardedVector& value_to_store,
402                               llvm::Align alignment,
403                               const llvm_ir::IrArray& containing_array);
404 
405   using ReductionGenerator = std ::function<llvm::Value*(
406       llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>;
407 
408   // Tries to match the reduction function "function" to a known reduction
409   // pattern.  Returns a non-null ReductionGenerator on a successful match,
410   // which can be used to generate the LLVM IR corresponding to said reduction.
411   // On failure, this stores a reason string into "failure_reason".
412   ReductionGenerator MatchReductionGenerator(HloComputation* function,
413                                              std::string* failure_reason) const;
414 
415   // Emits the inner loop nest that runs the reduction.  Helper function for
416   // EmitVectorizedReduce.
417   StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction(
418       const ReductionGenerator& reduction_generator,
419       const llvm_ir::IrArray::Index& output_index,
420       const ShardedVectorType& accumulator_type, HloInstruction* init_value,
421       HloInstruction* arg, absl::Span<const int64_t> dimensions,
422       llvm::Align element_alignment);
423 
424   // Tries to emit a fast concatenate operation using memcpy.  Returns true if
425   // successful, and false on failure.  On failure, sets "failure_reason" to a
426   // string describing why it could not emit a fast concatenate.
427   StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate,
428                                      absl::Span<HloInstruction* const> operands,
429                                      std::string* failure_reason);
430 
431   // Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
432   // from the address "source" to the address "target".
433   void EmitTransferElements(llvm::Value* target, llvm::Value* source,
434                             int64_t element_count, PrimitiveType primitive_type,
435                             const llvm_ir::IrArray& target_array,
436                             const llvm_ir::IrArray& source_array);
437 
438   // Emits printing during the execution.
439   llvm::Value* EmitPrintf(absl::string_view fmt,
440                           absl::Span<llvm::Value* const> arguments);
441   llvm::Value* EmitPrintfToStderr(absl::string_view fmt,
442                                   absl::Span<llvm::Value* const> arguments);
443 
444   // Emits a call to a non-variadic function `func_name` with arguments
445   // `arguments` assuming C calling convention.
446   llvm::Value* EmitCallToFunc(
447       std::string func_name, const std::vector<llvm::Value*>& arguments,
448       llvm::Type* return_type, bool does_not_throw = true,
449       bool only_accesses_arg_memory = false,
450       bool only_accesses_inaccessible_mem_or_arg_mem = false);
451 
452   // Assignment of the buffers needed by the computation and their shape
453   // information.
454   const BufferAssignment& assignment_;
455 
456   // The LLVM module into which IR will be emitted.
457   llvm::Module* module_;
458 
459   // The target architecture.
460   llvm::Triple::ArchType arch_type_;
461 
462   // Used to produce unique names for generated functions.
463   NameUniquer name_uniquer_;
464 
465   struct ComputationToEmit {
466     const HloComputation* computation;
467     bool allow_reassociation;
468 
469     bool operator==(const ComputationToEmit& other) const {
470       return computation == other.computation &&
471              allow_reassociation == other.allow_reassociation;
472     }
473 
474     template <typename H>
AbslHashValueComputationToEmit475     friend H AbslHashValue(H h, const ComputationToEmit& c) {
476       return H::combine(std::move(h), c.computation, c.allow_reassociation);
477     }
478     friend std::ostream& operator<<(std::ostream& os,
479                                     const ComputationToEmit& c) {
480       return os << c.computation->name() << ", " << c.allow_reassociation;
481     }
482   };
483 
484   // Map containing all previously emitted computations.
485   absl::flat_hash_map<ComputationToEmit, llvm::Function*> emitted_functions_;
486 
487   // Map containing all previously emitted thread-local temporary buffers.
488   std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*>
489       thread_local_buffers_;
490 
491   // The following fields track the IR emission state. According to LLVM memory
492   // management rules, their memory is owned by the module (Note that IrFunction
493   // creates the encapsulated llvm::Function s.t. it is added to the llvm
494   // module's function list).
495   std::unique_ptr<IrFunction> compute_function_;
496   llvm::IRBuilder<> b_;
497   mlir::MLIRContext* mlir_context_;
498   bool allow_reassociation_;
499 
500   // The buffer allocation slice for the root of the computation being compiled.
501   // Only relevant for thread local computations.
502   BufferAllocation::Slice computation_root_allocation_;
503 
504   // Maps the buffer allocation slices for the parameters to the computation
505   // being compiled to their parameter numbers.  Only relevant for thread local
506   // computations.
507   absl::flat_hash_map<BufferAllocation::Index, int64_t>
508       computation_parameter_allocations_;
509 
510   // Maps HLO instructions to their index into the profile counter array.
511   const absl::flat_hash_map<const HloInstruction*, int64_t>
512       instruction_to_profile_idx_;
513 
514   // Maps HLO computations to their index into the profile counter array.
515   const absl::flat_hash_map<const HloComputation*, int64_t>
516       computation_to_profile_idx_;
517 
518   // Maps HLO computations to whether they contain a custom-call instruction
519   // (either directly, or transitively by e.g. calling another computation that
520   // does).
521   const absl::flat_hash_map<const HloComputation*, bool>
522       computation_transitively_contains_custom_call_;
523 
524   // Accessor for the custom-call mapping that enforces the precondition that
525   // all computations must have a key in the map.
ComputationTransitivelyContainsCustomCall(const HloComputation * computation)526   bool ComputationTransitivelyContainsCustomCall(
527       const HloComputation* computation) const {
528     auto it = computation_transitively_contains_custom_call_.find(computation);
529     CHECK(it != computation_transitively_contains_custom_call_.cend())
530         << "Must provide 'contains CustomCall' annotation for all computations "
531            "in the module";
532     return it->second;
533   }
534 
535   // Maps HLOs to Values emitted for them.
536   absl::flat_hash_map<const HloInstruction*, llvm::Value*> emitted_value_;
537 
538   llvm_ir::AliasAnalysis alias_analysis_;
539 
540   // The number of outer dimensions of the root instruction's shape that
541   // will be partitioned when emitting parallel loops. (See
542   // ParallelLoopEmitter).
543   int64_t num_dynamic_loop_bounds_ = 0;
544 
545   // Returns whether the given instruction should be emitted as a parallel loop.
ShouldEmitParallelLoopFor(const HloInstruction & op)546   bool ShouldEmitParallelLoopFor(const HloInstruction& op) const {
547     // Emit parallel loop for root instruction if dynamic outer-dimension loop
548     // bounds were specified.
549     return num_dynamic_loop_bounds_ > 0 &&
550            op.parent()->root_instruction() == &op;
551   }
552 
553   // This struct contains all the state needed to emit instructions for
554   // profiling a computation.
555   class ProfilingState {
556    public:
ProfilingState()557     ProfilingState() : use_rdtscp_(false) {}
ProfilingState(bool use_rdtscp)558     explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {}
559 
560     // Record the cycle counter before an HLO executes.
561     void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo);
562     // Record the number of cycles it took for an HLO to execute.
563     void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo,
564                           llvm::Value* prof_counter);
565     // Record the number of cycles it took for the entire computation to
566     // execute.
567     void RecordCompleteComputation(llvm::IRBuilder<>* b,
568                                    llvm::Value* prof_counter);
569 
570     // Convenience function to generate a call to an intrinsic which reads the
571     // CPU cycle counter.
572     llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b);
573 
574     // Store the cycle counter delta to the per-HLO profile counter.
575     void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter,
576                               llvm::Value* cycle_end, llvm::Value* cycle_start);
577 
578    private:
579     // Should we use the x86-specific rdtscp or the generic readcyclecounter
580     // intrinsic?
581     bool use_rdtscp_;
582 
583     // The first read cycle counter in the program.
584     llvm::Value* first_read_cycle_start_ = nullptr;
585 
586     // The last read cycle counter in the program.
587     llvm::Value* last_read_cycle_end_ = nullptr;
588 
589     // Maps HLOs to the value the cycle counter contained right before the HLO
590     // began to execute.
591     absl::flat_hash_map<const HloInstruction*, llvm::Value*> cycle_starts_;
592   };
593 
594   ProfilingState profiling_state_;
595 
596   class TracingState {
597    public:
TracingState()598     TracingState() : enabled_(false) {}
set_enabled(bool value)599     void set_enabled(bool value) { enabled_ = value; }
600     void EmitTracingStart(llvm::IRBuilder<>* b, HloInstruction* hlo,
601                           llvm::Value* run_options);
602     void EmitTracingEnd(llvm::IRBuilder<>* b, HloInstruction* hlo,
603                         llvm::Value* run_options);
604 
605    private:
606     bool enabled_;
607     // Maps from HLO to the activity id returned by xprof::TraceMe.
608     absl::flat_hash_map<const HloInstruction*, llvm::Value*> activity_ids_;
609   };
610   TracingState tracing_state_;
611 
612   // Given a load instruction and a shape or buffer size, annotate the load's
613   // result with the alignment required by the shape or size.
614   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape);
615   void AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
616                                       int64_t buffer_size);
617 
618   // Given a load instruction and a shape or buffer size, annotate the load's
619   // result with the dereferenceable bytes required by the shape / buffer size.
620   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
621                                             const Shape& shape);
622   void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
623                                             int64_t buffer_size);
624 
625   // Calculate the alignment of a buffer allocated for a given shape.
626   int MinimumAlignmentForShape(const Shape& shape);
627 
628   // Calculate the alignment of a buffer allocated for a given primitive type.
629   int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type);
630 
631   // Returns the number of bytes within the shape.
632   int64_t ByteSizeOf(const Shape& shape) const;
633 
634   enum class XfeedKind {
635     kInfeed,
636     kOutfeed,
637   };
638 
639   // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program
640   // address.
641   Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
642                            llvm::Value* program_buffer_address);
643 
644   // Returns a ConstExpr bitcast.
645   llvm::Constant* EmitGlobalForLiteral(const Literal& literal);
646 
647   const HloModuleConfig& hlo_module_config_;
648 
649   bool is_top_level_computation_;
650 
651   const TargetMachineFeatures& target_machine_features_;
652 
653   struct LiteralPtrHashFunctor {
operatorLiteralPtrHashFunctor654     size_t operator()(const Literal* literal) const {
655       return absl::HashOf(*literal);
656     }
657   };
658 
659   struct LiteralPtrEqualityFunctor {
operatorLiteralPtrEqualityFunctor660     bool operator()(const Literal* lhs, const Literal* rhs) const {
661       return *lhs == *rhs;
662     }
663   };
664 
665   absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor,
666                       LiteralPtrEqualityFunctor>
667       emitted_literals_;
668 
669   absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*>
670       constant_buffer_to_global_;
671 
672   std::vector<const HloComputation*> thread_local_computations_;
673   std::vector<const HloComputation*> global_computations_;
674 
675   bool emit_code_for_msan_;
676 
677   IrEmitter(const IrEmitter&) = delete;
678   IrEmitter& operator=(const IrEmitter&) = delete;
679 };
680 
681 }  // namespace cpu
682 }  // namespace xla
683 
684 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
685