xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
17 
18 #include <algorithm>
19 #include <ostream>
20 #include <string>
21 
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace xla {
26 namespace gpu {
27 
operator <<(std::ostream & out,const LaunchDimensions & launch_dims)28 std::ostream& operator<<(std::ostream& out,
29                          const LaunchDimensions& launch_dims) {
30   LaunchDimensions::Dim3D block_counts = launch_dims.block_counts();
31   LaunchDimensions::Dim3D thread_counts = launch_dims.thread_counts_per_block();
32   out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]",
33                          block_counts.x, block_counts.y, block_counts.z,
34                          thread_counts.x, thread_counts.y, thread_counts.z);
35   return out;
36 }
37 
ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info)38 static int64_t ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) {
39   int64_t threads_per_block = gpu_device_info.threads_per_block_limit;
40   if (threads_per_block <= 0) {
41     static std::atomic<int64_t> log_count{0};
42     if (log_count.fetch_add(1) < 8) {
43       LOG(WARNING) << "Attempting to calculate launch dimensions for GPU "
44                       "without full information about its capabilities.  "
45                       "StreamExecutor's PopulateDeviceDescription should be "
46                       "updated for this device.";
47     }
48     threads_per_block = gpu_device_info.threads_per_warp;
49     if (threads_per_block == 0) {
50       // Fall back to *something* if we can't even get num threads per warp.
51       threads_per_block = 32;
52     }
53   }
54   return threads_per_block;
55 }
56 
ThreadsPerBlockRowVectorized(const Shape & shape,GpuDeviceInfo gpu_device_info,LaunchDimensionsConfig dim_config)57 int64_t ThreadsPerBlockRowVectorized(const Shape& shape,
58                                      GpuDeviceInfo gpu_device_info,
59                                      LaunchDimensionsConfig dim_config) {
60   if (shape.dimensions().empty()) {
61     return -1;
62   }
63   int64_t threads_per_block_row_vectorized =
64       shape.dimensions().back() / dim_config.unroll_factor;
65   if (dim_config.row_vectorized &&
66       shape.dimensions().back() % dim_config.unroll_factor == 0 &&
67       // If the row size is a multiple of 256, then use the old code
68       // path that use a block size of 256. This give small speed up on V100.
69       // Vectorization of the row load was already happening.
70       (shape.dimensions().back() % 256) != 0 &&
71       // We do not support row that do not fit in one block.
72       threads_per_block_row_vectorized <=
73           gpu_device_info.threads_per_block_limit) {
74     return threads_per_block_row_vectorized;
75   }
76   return -1;
77 }
78 
CalculateLaunchDimensions(const Shape & shape,GpuDeviceInfo gpu_device_info,LaunchDimensionsConfig dim_config)79 StatusOr<LaunchDimensions> CalculateLaunchDimensions(
80     const Shape& shape, GpuDeviceInfo gpu_device_info,
81     LaunchDimensionsConfig dim_config) {
82   int64_t num_elements = ShapeUtil::ElementsIn(shape);
83   if (num_elements <= 1) {
84     return LaunchDimensions();
85   }
86 
87   CHECK_EQ(num_elements % dim_config.unroll_factor, 0);
88   num_elements = num_elements / dim_config.unroll_factor;
89 
90   // Since we don't do any inter-warp communication, we're free to choose any
91   // block size we want, subject to hardware constraints.  We choose the largest
92   // block size allowed, as empirically, this is a performance win on almost
93   // (but not all) benchmarks.
94   //
95   // My guess is that using a larger block size encourages ptxas to decrease
96   // per-thread register usage, thus allowing for higher occupancy, but I
97   // haven't verified this.
98   //
99   // TODO(jlebar): Investigate this further, and tune this heuristic so we can
100   // run faster on the few benchmarks where smaller block size helps.
101   int64_t threads_per_block_row_vectorized =
102       ThreadsPerBlockRowVectorized(shape, gpu_device_info, dim_config);
103   // If row vectorized, threads_per_block_x is the vectorized size.
104   // Otherwise, we unroll kernels to make use of vectorized
105   // loads/stores. This means we need more registers to hold
106   // intermediate values. Reduce the number of threads per block to
107   // increase the number of registers available to ptxas.  Make sure
108   // we still have a multiple of 32.
109   int64_t threads_per_block_x = [&]() {
110     int64_t max_threads_per_block_x =
111         threads_per_block_row_vectorized > 0
112             ? threads_per_block_row_vectorized
113             : RoundUpTo(ThreadsPerBlockLimit(gpu_device_info) /
114                             dim_config.unroll_factor,
115                         int64_t{32});
116     if (num_elements < max_threads_per_block_x) {
117       return num_elements;
118     }
119     return max_threads_per_block_x;
120   }();
121   // threads_per_block_y > 1 when we row vectorize and have small row size.
122   int64_t threads_per_block_y =
123       threads_per_block_row_vectorized > 0 &&
124               threads_per_block_row_vectorized < 128 && num_elements > 128
125           ? CeilOfRatio(static_cast<int64_t>(128),
126                         threads_per_block_row_vectorized)
127           : 1;
128   VLOG(2) << "Set # of threads per block to (.x=" << threads_per_block_x
129           << ", .y=" << threads_per_block_y << ")";
130 
131   int64_t block_count =
132       CeilOfRatio(num_elements, threads_per_block_x * threads_per_block_y);
133   if (dim_config.few_waves && !dim_config.row_vectorized) {
134     int64_t capped_threads_per_block_x =
135         std::min<int64_t>(threads_per_block_x, 128);
136     int64_t capped_block_count =
137         gpu_device_info.core_count *
138         (gpu_device_info.threads_per_core_limit /
139          (capped_threads_per_block_x * threads_per_block_y));
140     if (capped_block_count < block_count) {
141       threads_per_block_x = capped_threads_per_block_x;
142       block_count = capped_block_count;
143       VLOG(2) << "Update the # of blocks to " << block_count
144               << " and the # of threads per blocks to " << threads_per_block_x
145               << " as the few waves mode is enabled.";
146     }
147   } else if (dim_config.few_waves && dim_config.row_vectorized) {
148     int64_t min_block_count = gpu_device_info.core_count *
149                               (gpu_device_info.threads_per_core_limit /
150                                (threads_per_block_x * threads_per_block_y));
151     int64_t capped_block_count = block_count;
152     // This multiple of 32 was tuned to not cause regression on multiple
153     // benchmarks.  It isn't a value that is optimal for all
154     // kernels. Maybe looking at the arithmetic intensity of the
155     // kernels can specialize the multiple per kernel.
156     while (capped_block_count > (32 * min_block_count)) {
157       capped_block_count /= 2;
158     }
159     // Do not increase the number of blocks. This can happens for
160     // small num_elements.
161     if (capped_block_count < block_count) {
162       VLOG(2) << "Update # of blocks to block_count as few_waves is enabled.";
163       block_count = capped_block_count;
164     }
165   }
166   if (gpu_device_info.block_dim_limit_x > 0 &&
167       block_count >= gpu_device_info.block_dim_limit_x) {
168     return tensorflow::errors::Unimplemented(
169         "Kernel launch needs more blocks (", block_count,
170         ") than allowed by hardware (", gpu_device_info.block_dim_limit_x,
171         ").");
172   }
173 
174   VLOG(2) << absl::StrFormat(
175       "Initialized the block count to %d, the block size .x=%d and .y=%d"
176       " for %d elements in the tensor.",
177       block_count, threads_per_block_x, threads_per_block_y, num_elements);
178   return LaunchDimensions({block_count, 1, 1},
179                           {threads_per_block_x, threads_per_block_y, 1});
180 }
181 
182 }  // namespace gpu
183 }  // namespace xla
184