xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/device_description.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Describes the underlying platform for a StreamExecutor; e.g. OpenCL or CUDA
17 // device and platform properties. Also contains convenience functions for
18 // checking/calculating launch dimensionality based on device properties.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
21 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
22 
23 #include <map>
24 #include <memory>
25 #include <set>
26 #include <vector>
27 
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/str_split.h"
30 #include "tensorflow/compiler/xla/stream_executor/launch_dim.h"
31 
32 namespace stream_executor {
33 namespace internal {
34 class DeviceDescriptionBuilder;
35 }  // namespace internal
36 
37 // CUDA compute capability, as reported by the device description.
38 struct CudaComputeCapability {
39   int major = 0;
40   int minor = 0;
41 
42   // MSVC does not like "PASCAL" symbol.
43   enum CudaComputeCapabilities { PASCAL_ = 6, VOLTA = 7, AMPERE = 8 };
44 
CudaComputeCapabilityCudaComputeCapability45   CudaComputeCapability() {}
CudaComputeCapabilityCudaComputeCapability46   CudaComputeCapability(int major, int minor) {
47     this->major = major;
48     this->minor = minor;
49   }
50 
51   bool IsAtLeast(int other_major, int other_minor = 0) const {
52     return !(*this < CudaComputeCapability{other_major, other_minor});
53   }
54 
55   bool operator<(const CudaComputeCapability &other) const {
56     return ToPair() < other.ToPair();
57   }
58 
59   bool operator==(const CudaComputeCapability &other) const {
60     return ToPair() == other.ToPair();
61   }
62 
63   bool operator!=(const CudaComputeCapability &other) const {
64     return !(*this == other);
65   }
66 
67   // Maximum resident blocks per multiprocessor, values taken from
68   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities.
GetMaxResidentBlocksPerSMCudaComputeCapability69   int GetMaxResidentBlocksPerSM() const {
70     if (IsAtLeast(8, 6)) {
71       return 16;
72     } else if (IsAtLeast(8)) {
73       return 32;
74     } else if (IsAtLeast(7, 5)) {
75       return 16;
76     }
77     return 32;
78   }
79 
80   // Maximum resident warps per multiprocessor, values taken from
81   // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities.
GetMaxResidentWarpsPerSMCudaComputeCapability82   int GetMaxResidentWarpsPerSM() const {
83     if (IsAtLeast(8, 6)) {
84       return 48;
85     } else if (IsAtLeast(8)) {
86       return 64;
87     } else if (IsAtLeast(7, 5)) {
88       return 32;
89     }
90     return 64;
91   }
92 
ToStringCudaComputeCapability93   std::string ToString() const { return absl::StrCat(major, ".", minor); }
94 
ToPairCudaComputeCapability95   std::pair<int, int> ToPair() const { return std::make_pair(major, minor); }
96 };
97 
98 // ROCm compute capability, as reported by the device description.
99 class RocmComputeCapability {
100  public:
101   // gcn_arch_name example --  gfx90a:sramecc+:xnack-
102   // gfx_version is the "gfx90a" part of the gcn_arch_name
RocmComputeCapability(const std::string & gcn_arch_name)103   explicit RocmComputeCapability(const std::string &gcn_arch_name)
104       : gcn_arch_name_(gcn_arch_name) {}
105 
~RocmComputeCapability()106   ~RocmComputeCapability() {}
107 
gcn_arch_name()108   std::string gcn_arch_name() { return gcn_arch_name_; }
109 
gfx_version()110   std::string gfx_version() {
111     std::vector<std::string> tokens = absl::StrSplit(gcn_arch_name_, ':');
112     return tokens[0];
113   }
114 
is_supported_gfx_version()115   bool is_supported_gfx_version() {
116     return supported_gfx_versions().count(gfx_version()) != 0;
117   }
118 
supported_gfx_versions_str()119   std::string supported_gfx_versions_str() {
120     return absl::StrJoin(supported_gfx_versions(), ", ");
121   }
122 
has_nhwc_layout_support()123   bool has_nhwc_layout_support() {
124     return gfx_versions_with_nhwc_layout_support().count(gfx_version()) != 0;
125   }
126 
has_bf16_dtype_support()127   bool has_bf16_dtype_support() {
128     return gfx_versions_with_fast_bf16_support().count(gfx_version()) != 0;
129   }
130 
has_fast_fp16_support()131   bool has_fast_fp16_support() {
132     return gfx_versions_with_fast_fp16_support().count(gfx_version()) != 0;
133   }
134 
has_mfma_instr_support()135   bool has_mfma_instr_support() {
136     return gfx_versions_with_mfma_instr_support().count(gfx_version()) != 0;
137   }
138 
has_fp16_atomics_support()139   bool has_fp16_atomics_support() {
140     return gfx_versions_with_fp16_atomics_support().count(gfx_version()) != 0;
141   }
142 
143  private:
144   std::string gcn_arch_name_;
supported_gfx_versions()145   std::set<std::string> supported_gfx_versions() {
146     return {
147         "gfx900",  // MI25
148         "gfx906",  // MI50 / MI60
149         "gfx908",  // MI100
150         "gfx90a",  // MI200
151         "gfx1030"  // Navi21
152     };
153   }
gfx_versions_with_nhwc_layout_support()154   std::set<std::string> gfx_versions_with_nhwc_layout_support() {
155     return {"gfx908", "gfx90a"};
156   }
gfx_versions_with_fast_bf16_support()157   std::set<std::string> gfx_versions_with_fast_bf16_support() {
158     return {"gfx908", "gfx90a"};
159   }
gfx_versions_with_fast_fp16_support()160   std::set<std::string> gfx_versions_with_fast_fp16_support() {
161     return {"gfx906", "gfx908", "gfx90a", "gfx1030"};
162   }
gfx_versions_with_mfma_instr_support()163   std::set<std::string> gfx_versions_with_mfma_instr_support() {
164     return {"gfx908", "gfx90a"};
165   }
gfx_versions_with_fp16_atomics_support()166   std::set<std::string> gfx_versions_with_fp16_atomics_support() {
167     return {"gfx90a"};
168   }
169 };
170 
171 // Data that describes the execution target of the StreamExecutor, in terms of
172 // important logical parameters. These include dimensionality limits and
173 // physical parameters of interest, such as number of cores present on the
174 // device.
175 //
176 // Thread-safe: immutable post-initialization.
177 class DeviceDescription {
178  public:
179   // Returns the platform being run on; this value is primarily intended for
180   // printing, and comes out something like "OpenCL 1.2" or "Compute Capability
181   // 3.5".
platform_version()182   const std::string &platform_version() const { return platform_version_; }
183 
184   // Returns the driver version interfacing with the underlying platform. Vendor
185   // dependent format.
driver_version()186   const std::string &driver_version() const { return driver_version_; }
187 
188   // Return the runtime version, if one is provided by the underlying platform.
189   // Vendor dependent format / usefulness.
runtime_version()190   const std::string &runtime_version() const { return runtime_version_; }
191 
192   // Returns the name that the device reports. Vendor dependent.
name()193   const std::string &name() const { return name_; }
194 
195   // Returns the PCI bus identifier for this device, of the form
196   // [domain]:[bus]:[device].[function]
pci_bus_id()197   const std::string &pci_bus_id() const { return pci_bus_id_; }
198 
199   // Returns the NUMA node associated with this device, for use in
200   // determining socket locality. If the NUMA node could not be determined, -1
201   // is returned.
numa_node()202   int numa_node() const { return numa_node_; }
203 
204   // Number of cores (traditional notion of core; i.e. an SM on an NVIDIA device
205   // or an AMD Compute Unit.
core_count()206   int core_count() const { return core_count_; }
207 
208   // Returns the limit on the thread dimensionality values in each of the
209   // respective dimensions. These limits affect what constitutes a legitimate
210   // kernel launch request.
thread_dim_limit()211   const ThreadDim &thread_dim_limit() const { return thread_dim_limit_; }
212 
213   // Returns the limit on the block dimensionality values in each of the
214   // respective dimensions. These limits may affect what constitutes a
215   // legitimate kernel launch request.
block_dim_limit()216   const BlockDim &block_dim_limit() const { return block_dim_limit_; }
217 
218   // Returns the limit on the total number of threads that can be launched in a
219   // single block; i.e. the limit on x * y * z dimensions of a ThreadDim.
220   // This limit affects what constitutes a legitimate kernel launch request.
threads_per_block_limit()221   const int64_t &threads_per_block_limit() const {
222     return threads_per_block_limit_;
223   }
224 
225   // Returns the limit on the total number of threads that can be simultaneously
226   // launched on a given multiprocessor.
threads_per_core_limit()227   const int64_t &threads_per_core_limit() const {
228     return threads_per_core_limit_;
229   }
230 
231   // Returns the number of threads per warp/wavefront.
threads_per_warp()232   const int64_t &threads_per_warp() const { return threads_per_warp_; }
233 
234   // Returns the limit on the total number of registers per core.
registers_per_core_limit()235   const int64_t &registers_per_core_limit() const {
236     return registers_per_core_limit_;
237   }
238 
239   // Returns the limit on the total number of registers that can be
240   // simultaneously used by a block.
registers_per_block_limit()241   const int64_t &registers_per_block_limit() const {
242     return registers_per_block_limit_;
243   }
244 
245   // Returns the number of address bits available to kernel code running on the
246   // platform. This affects things like the maximum allocation size and perhaps
247   // types used in kernel code such as size_t.
device_address_bits()248   const int64_t &device_address_bits() const { return device_address_bits_; }
249 
250   // Returns the device memory size in bytes.
device_memory_size()251   int64_t device_memory_size() const { return device_memory_size_; }
252 
253   // Returns the device's memory bandwidth in bytes/sec.  (This is for
254   // reads/writes to/from the device's own memory, not for transfers between the
255   // host and device.)
memory_bandwidth()256   int64_t memory_bandwidth() const { return memory_bandwidth_; }
257 
258   // Returns the device's core clock rate in GHz.
clock_rate_ghz()259   float clock_rate_ghz() const { return clock_rate_ghz_; }
260 
261   // Returns whether ECC is enabled.
ecc_enabled()262   bool ecc_enabled() const { return ecc_enabled_; }
263 
264   // Returns the device vendor string, e.g., "NVIDIA Corporation", "Advanced
265   // Micro Devices, Inc.", or "GenuineIntel".
device_vendor()266   const std::string &device_vendor() const { return device_vendor_; }
267 
268   // Returns the CUDA compute capability if we're running on the CUDA platform.
269   // If a CUDA compute capability is not available, the major version will be
270   // zero.
271   CudaComputeCapability cuda_compute_capability() const;
272 
273   // Returns the ROCm compute capability if we're running on the ROCm platform.
274   // If a ROCm compute capability is not available, the default gfx_arch will
275   // be "gfx000" (which is an invalid gfx arch).
276   RocmComputeCapability rocm_compute_capability() const;
277 
278   // Returns the maximum amount of shared memory present on a single core
279   // (i.e. Streaming Multiprocessor on NVIDIA GPUs; Compute Unit for OpenCL
280   // devices). Note that some devices, such as NVIDIA's have a configurable
281   // partitioning between shared memory and L1 cache.
shared_memory_per_core()282   int64_t shared_memory_per_core() const { return shared_memory_per_core_; }
283 
284   // Returns the maximum amount of shared memory available for a single block.
shared_memory_per_block()285   int64_t shared_memory_per_block() const { return shared_memory_per_block_; }
286 
287   // TODO(leary): resident blocks per core will be useful.
288 
289   // Convenience typedef for the string-based DeviceDescription mapping.
290   typedef std::map<std::string, std::string> Map;
291 
292   // Returns a mapping from readable names to readable values that describe the
293   // device. This is useful for things like printing.
294   std::unique_ptr<Map> ToMap() const;
295 
296   // For string values that are not available via the underlying platform, this
297   // value will be provided.
298   static const char *kUndefinedString;
299 
300  private:
301   friend class internal::DeviceDescriptionBuilder;
302 
303   DeviceDescription();
304 
305   // For description of the following members, see the corresponding accessor
306   // above.
307   //
308   // N.B. If another field is added, update ToMap() above.
309   std::string device_vendor_;
310   std::string platform_version_;
311   std::string driver_version_;
312   std::string runtime_version_;
313   std::string pci_bus_id_;
314   std::string name_;
315 
316   ThreadDim thread_dim_limit_;
317   BlockDim block_dim_limit_;
318 
319   int64_t threads_per_core_limit_;
320   int64_t threads_per_block_limit_;
321   int64_t threads_per_warp_;
322 
323   int64_t registers_per_core_limit_;
324   int64_t registers_per_block_limit_;
325 
326   int64_t device_address_bits_;
327   int64_t device_memory_size_;
328   int64_t memory_bandwidth_;
329 
330   // Shared memory limits on a given device.
331   int64_t shared_memory_per_core_;
332   int64_t shared_memory_per_block_;
333 
334   float clock_rate_ghz_;
335 
336   // CUDA "CC" major value, -1 if not available.
337   CudaComputeCapability cuda_compute_capability_{-1, -1};
338 
339   // ROCm gfx arch,  "gfx000" if not available.
340   RocmComputeCapability rocm_compute_capability_{"gfx000"};
341 
342   int numa_node_;
343   int core_count_;
344   bool ecc_enabled_;
345 
346   SE_DISALLOW_COPY_AND_ASSIGN(DeviceDescription);
347 };
348 
349 namespace internal {
350 
351 // Helper class the builds a device description, given that it has a large
352 // number of fields that would be easily confused in constructor form.
353 class DeviceDescriptionBuilder {
354  public:
355   DeviceDescriptionBuilder();
356 
357   // For descriptions of the following fields, see comments on the corresponding
358   // DeviceDescription::* accessors above.
359 
set_device_vendor(const std::string & value)360   void set_device_vendor(const std::string &value) {
361     device_description_->device_vendor_ = value;
362   }
set_platform_version(const std::string & value)363   void set_platform_version(const std::string &value) {
364     device_description_->platform_version_ = value;
365   }
set_driver_version(const std::string & value)366   void set_driver_version(const std::string &value) {
367     device_description_->driver_version_ = value;
368   }
set_runtime_version(const std::string & value)369   void set_runtime_version(const std::string &value) {
370     device_description_->runtime_version_ = value;
371   }
set_pci_bus_id(const std::string & value)372   void set_pci_bus_id(const std::string &value) {
373     device_description_->pci_bus_id_ = value;
374   }
set_name(const std::string & value)375   void set_name(const std::string &value) {
376     device_description_->name_ = value;
377   }
378 
set_thread_dim_limit(const ThreadDim & value)379   void set_thread_dim_limit(const ThreadDim &value) {
380     device_description_->thread_dim_limit_ = value;
381   }
set_block_dim_limit(const BlockDim & value)382   void set_block_dim_limit(const BlockDim &value) {
383     device_description_->block_dim_limit_ = value;
384   }
385 
set_threads_per_core_limit(int64_t value)386   void set_threads_per_core_limit(int64_t value) {
387     device_description_->threads_per_core_limit_ = value;
388   }
set_threads_per_block_limit(int64_t value)389   void set_threads_per_block_limit(int64_t value) {
390     device_description_->threads_per_block_limit_ = value;
391   }
set_threads_per_warp(int64_t value)392   void set_threads_per_warp(int64_t value) {
393     device_description_->threads_per_warp_ = value;
394   }
395 
set_registers_per_core_limit(int64_t value)396   void set_registers_per_core_limit(int64_t value) {
397     device_description_->registers_per_core_limit_ = value;
398   }
set_registers_per_block_limit(int64_t value)399   void set_registers_per_block_limit(int64_t value) {
400     device_description_->registers_per_block_limit_ = value;
401   }
402 
set_device_address_bits(int64_t value)403   void set_device_address_bits(int64_t value) {
404     device_description_->device_address_bits_ = value;
405   }
set_device_memory_size(int64_t value)406   void set_device_memory_size(int64_t value) {
407     device_description_->device_memory_size_ = value;
408   }
set_memory_bandwidth(int64_t value)409   void set_memory_bandwidth(int64_t value) {
410     device_description_->memory_bandwidth_ = value;
411   }
412 
set_shared_memory_per_core(int64_t value)413   void set_shared_memory_per_core(int64_t value) {
414     device_description_->shared_memory_per_core_ = value;
415   }
set_shared_memory_per_block(int64_t value)416   void set_shared_memory_per_block(int64_t value) {
417     device_description_->shared_memory_per_block_ = value;
418   }
419 
set_clock_rate_ghz(float value)420   void set_clock_rate_ghz(float value) {
421     device_description_->clock_rate_ghz_ = value;
422   }
423 
set_cuda_compute_capability(int major,int minor)424   void set_cuda_compute_capability(int major, int minor) {
425     device_description_->cuda_compute_capability_ =
426         CudaComputeCapability{major, minor};
427   }
428 
set_rocm_compute_capability(std::string gcn_arch_name)429   void set_rocm_compute_capability(std::string gcn_arch_name) {
430     device_description_->rocm_compute_capability_ =
431         RocmComputeCapability(gcn_arch_name);
432   }
433 
set_numa_node(int value)434   void set_numa_node(int value) { device_description_->numa_node_ = value; }
set_core_count(int value)435   void set_core_count(int value) { device_description_->core_count_ = value; }
set_ecc_enabled(bool value)436   void set_ecc_enabled(bool value) {
437     device_description_->ecc_enabled_ = value;
438   }
439 
440   // Returns a built DeviceDescription with ownership transferred to the
441   // caller. There are currently no restrictions on which fields must be set in
442   // order to build the descriptor.
443   //
444   // Once the description is built, this builder object should be discarded.
Build()445   std::unique_ptr<DeviceDescription> Build() {
446     return std::move(device_description_);
447   }
448 
449  private:
450   std::unique_ptr<DeviceDescription> device_description_;
451 
452   SE_DISALLOW_COPY_AND_ASSIGN(DeviceDescriptionBuilder);
453 };
454 
455 }  // namespace internal
456 
457 // Returns whether the given thread_dim is acceptable given the limits described
458 // in device_description. For detailed reasons for failing the predicate, enable
459 // VLOG(2) for this module.
460 bool ThreadDimOk(const DeviceDescription &device_description,
461                  const ThreadDim &thread_dim);
462 
463 // Equivalent to ceil(double(element_count) / threads_per_block).
464 ABSL_DEPRECATED("Use MathUtil::CeilOfRatio directly instead.")
465 int64_t DivideCeil(int64_t x, int64_t y);
466 
467 // Calculate the number of threads/blocks required to process element_count
468 // elements. Note that you can still end up with more threads than
469 // element_count due to rounding, so kernels often start with an "is this
470 // thread id in the element_count range?" test.
471 void CalculateDimensionality(const DeviceDescription &device_description,
472                              int64_t element_count, int64_t *threads_per_block,
473                              int64_t *block_count);
474 
475 }  // namespace stream_executor
476 
477 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
478