xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/target_machine_features.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "llvm/Analysis/TargetTransformInfo.h"
21 #include "llvm/Target/TargetMachine.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 
24 namespace xla {
25 namespace cpu {
26 
27 // Abstract interface for classes providing information about the target we're
28 // compiling for.
29 class TargetMachineFeatures {
30  public:
31   static constexpr int kX86AvxVectorByteSize = 32;
32 
33   // Input and output tensor buffers must be aligned to this many bytes if we
34   // want to call an Eigen backed GEMM or Convolution.
35   static constexpr int kEigenExpectedTensorAlignment = 16;
36 
37   // Return the vectorization factor, which is the number of bytes of data
38   // explicitly vectorized routines will try to process at once.
39   virtual int vectorization_factor_in_bytes() const = 0;
40 
41   // Return the size of the largest vector size in bytes.  We need to pass in
42   // "function" since llvm functions can contain annotations for specializing
43   // them to specific micro-architectures (though currently XLA does not use
44   // this functionality).
45   virtual int vector_register_byte_size(
46       const llvm::Function& function) const = 0;
47 
48   // Return the number of elements of type `type` that can fit into the largest
49   // vector register available.  We need to pass in "function" since llvm
50   // functions can contain annotations for specializing them to specific
51   // micro-architectures (though currently XLA does not use this functionality).
52   virtual int vector_register_num_elements(const llvm::Function& function,
53                                            PrimitiveType type) const = 0;
54 
55   // Return the number of vector registers.  We need to pass in
56   // "function" since llvm functions can contain annotations for specializing
57   // them to specific micro-architectures (though currently XLA does not use
58   // this functionality).
59   virtual int vector_register_count(const llvm::Function& function) const = 0;
60 
61   // Returns the minimum alignment for a buffer of size size_bytes.
62   virtual int64_t minimum_alignment_for_allocation(
63       int64_t size_bytes) const = 0;
64 
65   virtual ~TargetMachineFeatures() = default;
66 };
67 
68 // Implements the TargetMachineFeatures interface using an llvm::TargetMachine.
69 class LLVMTargetMachineFeatures : public TargetMachineFeatures {
70  public:
71   static constexpr int kX86AvxVectorByteSize = 32;
72 
LLVMTargetMachineFeatures(llvm::TargetMachine * target_machine)73   LLVMTargetMachineFeatures(llvm::TargetMachine* target_machine)
74       : target_machine_(target_machine) {}
75 
vectorization_factor_in_bytes()76   int vectorization_factor_in_bytes() const override {
77     // Ideally this should be a function of the cache line size (which we can
78     // get from llvm::TargetTransformInfo::getCacheLineSize) of the target
79     // machine.  Guess a value of 128 bytes for now.
80     return 128;
81   }
82 
vector_register_byte_size(const llvm::Function & function)83   int vector_register_byte_size(const llvm::Function& function) const override {
84     llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function);
85     return tti->getRegisterBitWidth(
86                llvm::TargetTransformInfo::RGK_FixedWidthVector) /
87            8;
88   }
89 
vector_register_num_elements(const llvm::Function & function,PrimitiveType type)90   int vector_register_num_elements(const llvm::Function& function,
91                                    PrimitiveType type) const override {
92     return vector_register_byte_size(function) /
93            (primitive_util::BitWidth(type) / 8);
94   }
95 
vector_register_count(const llvm::Function & function)96   int vector_register_count(const llvm::Function& function) const override {
97     llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function);
98     return static_cast<int>(tti->getNumberOfRegisters(
99         tti->getRegisterClassForType(/*Vector=*/true)));
100   }
101 
102   int64_t minimum_alignment_for_allocation(int64_t size_bytes) const override;
103 
104  private:
105   llvm::TargetTransformInfo* GetTargetTransformInfoFor(
106       const llvm::Function& function) const;
107 
108   // This cache saves us from having to create a llvm::TargetTransformInfo for
109   // every call to GetTargetTransformInfoFor (creating a TargetTransformInfo
110   // costs one heap allocation on X86).
111   //
112   // This is mutated from within `GetTargetTransformInfoFor` which is
113   // semantically a getter (and thus `const`); and is therefore declared
114   // mutable.  Making this mutable is okay because it has cache semantics.
115   mutable absl::flat_hash_map<const llvm::Function*, llvm::TargetTransformInfo>
116       target_transform_info_cache_;
117   llvm::TargetMachine* target_machine_;
118 };
119 
120 }  // namespace cpu
121 }  // namespace xla
122 
123 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
124