xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/gpu_device.h"
17 
18 #include <map>
19 #include <optional>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/base/attributes.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
28 #include "tensorflow/stream_executor/device_memory.h"
29 
30 #ifdef GOOGLE_CUDA
31 #include "third_party/gpus/cuda/include/cuda.h"
32 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
33 #include "tensorflow/compiler/xla/pjrt/nccl_id_store.h"
34 #include "tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h"
35 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
36 #endif  // GOOGLE_CUDA
37 
38 #ifdef TENSORFLOW_USE_ROCM
39 #include "rocm/rocm_config.h"
40 #endif  // TENSORFLOW_USE_ROCM
41 
42 #include "tensorflow/compiler/xla/client/client_library.h"
43 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
44 #include "tensorflow/compiler/xla/service/platform_util.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/common_runtime/device/device_host_allocator.h"
48 #include "tensorflow/core/common_runtime/device/device_id.h"
49 #include "tensorflow/core/common_runtime/device/device_mem_allocator.h"
50 #include "tensorflow/core/util/env_var.h"
51 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
52 
53 namespace xla {
54 namespace {
55 
56 #if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020
57 
CreateCudaAsyncAllocator(se::Platform * platform,const std::map<int,std::unique_ptr<LocalDeviceState>> & addressable_devices,double memory_fraction,bool preallocate)58 StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateCudaAsyncAllocator(
59     se::Platform* platform,
60     const std::map<int, std::unique_ptr<LocalDeviceState>>& addressable_devices,
61     double memory_fraction, bool preallocate) {
62   CHECK_GT(addressable_devices.size(), 0);
63   std::vector<se::MultiDeviceAdapter::AllocatorWithStream> allocators;
64 
65   for (auto& ordinal_and_device : addressable_devices) {
66     se::StreamExecutor* executor = ordinal_and_device.second->executor();
67     int device_ordinal = executor->device_ordinal();
68 
69     int64_t free_memory;
70     int64_t total_memory;
71     if (!executor->DeviceMemoryUsage(&free_memory, &total_memory)) {
72       return Unavailable("Failed to query available memory from device %i",
73                          device_ordinal);
74     }
75     // To allow full GPU memory to be visible to the BFC allocator if using
76     // unified memory.
77     // When unified memory is enabled, allow GPU memory oversubscription by
78     // setting memory_fraction > 1.
79     size_t allocator_memory = free_memory * memory_fraction;
80     if (preallocate) {
81       LOG(INFO) << "XLA backend allocating " << allocator_memory
82                 << " bytes on device " << device_ordinal
83                 << " for BFCAllocator.";
84     } else {
85       LOG(INFO) << "XLA backend will use up to " << allocator_memory
86                 << " bytes on device " << device_ordinal
87                 << " for BFCAllocator.";
88     }
89 
90     auto allocator = std::make_unique<tensorflow::GpuCudaMallocAsyncAllocator>(
91         tensorflow::PlatformDeviceId(device_ordinal), allocator_memory,
92         preallocate);
93     allocator->SetStreamAndPreallocateMemory(
94         ordinal_and_device.second->compute_stream()
95             ->implementation()
96             ->GpuStreamMemberHack());
97     allocators.emplace_back(std::move(allocator),
98                             ordinal_and_device.second->compute_stream());
99   }
100   return std::make_unique<se::MultiDeviceAdapter>(platform,
101                                                   std::move(allocators));
102 }
103 
104 #else  // defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020
105 
106 StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateCudaAsyncAllocator(
107     se::Platform* platform,
108     const std::map<int, std::unique_ptr<LocalDeviceState>>& addressable_devices,
109     double memory_fraction, bool preallocate) {
110   return FailedPrecondition("CUDA async allocator requires CUDA >= 11.2");
111 }
112 
113 #endif  // defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020
114 
115 // A custom PjRtClient that overrides the device assignment method.
116 class GpuClient : public xla::PjRtStreamExecutorClient {
117  public:
118   using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
119 
120   xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
121       int num_replicas, int num_partitions) const override;
122 
platform_version() const123   absl::string_view platform_version() const override {
124 #define STRINGIFY2(X) #X
125 #define STRINGIFY(X) STRINGIFY2(X)
126 #if TENSORFLOW_USE_ROCM && defined(TF_ROCM_VERSION)  // rocm
127     // TF_ROCM_VERSION fomrat may change in future. Use it
128     // cautiously
129     return "rocm " STRINGIFY(TF_ROCM_VERSION);
130 #elif GOOGLE_CUDA && defined(CUDART_VERSION)  // cuda
131     return "cuda " STRINGIFY(CUDART_VERSION);
132 #else
133     return "<unknown>";
134 #endif  // TENSORFLOW_USE_ROCM && defined(TF_ROCM_VERSION)
135   }
136 };
137 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const138 xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
139     int num_replicas, int num_partitions) const {
140   if (num_partitions == 1 && num_replicas <= addressable_devices().size()) {
141     xla::DeviceAssignment assignment(num_replicas, 1);
142     for (int i = 0; i < num_replicas; ++i) {
143       assignment(i, 0) = addressable_devices().at(i)->id();
144     }
145     return assignment;
146   }
147   // Fallback to default global device assignment if we can't run locally.
148   return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
149                                                               num_partitions);
150 }
151 
152 // Builds an xla::LocalClient for the GPU platform.
GetGpuXlaClient(const std::optional<std::string> & platform_name,const std::optional<std::set<int>> & allowed_devices)153 StatusOr<LocalClient*> GetGpuXlaClient(
154     const std::optional<std::string>& platform_name,
155     const std::optional<std::set<int>>& allowed_devices) {
156   TF_ASSIGN_OR_RETURN(
157       se::Platform * platform,
158       PlatformUtil::GetPlatform(platform_name ? *platform_name : "gpu"));
159   if (platform->VisibleDeviceCount() <= 0) {
160     return FailedPrecondition("No visible GPU devices.");
161   }
162   LocalClientOptions options;
163   options.set_platform(platform);
164   options.set_allowed_devices(allowed_devices);
165   return ClientLibrary::GetOrCreateLocalClient(options);
166 }
167 
EnablePeerAccess(absl::Span<se::StreamExecutor * const> executors)168 void EnablePeerAccess(absl::Span<se::StreamExecutor* const> executors) {
169   for (int i = 0; i < executors.size(); ++i) {
170     for (int j = 0; j < executors.size(); ++j) {
171       if (i == j) {
172         continue;
173       }
174       se::StreamExecutor* from = executors[i];
175       se::StreamExecutor* to = executors[j];
176       if (from->CanEnablePeerAccessTo(to)) {
177         Status status = from->EnablePeerAccessTo(to);
178         if (!status.ok()) {
179           LOG(WARNING) << "Unable to enable peer access between GPUs " << i
180                        << " and " << j << "; status: " << status;
181         } else {
182           VLOG(2) << "Enabled peer access from GPU " << i << " to GPU " << j;
183         }
184       }
185     }
186   }
187 }
188 
189 // Builds a LocalDeviceState for each GPU present.
190 StatusOr<std::map<int, std::unique_ptr<LocalDeviceState>>>
BuildLocalDeviceStates(LocalClient * xla_client,bool asynchronous)191 BuildLocalDeviceStates(LocalClient* xla_client, bool asynchronous) {
192   std::map<int, std::unique_ptr<LocalDeviceState>> addressable_devices;
193   for (se::StreamExecutor* executor :
194        xla_client->backend().stream_executors()) {
195     addressable_devices.emplace(
196         executor->device_ordinal(),
197         std::make_unique<LocalDeviceState>(
198             executor, xla_client, LocalDeviceState::kComputeSynchronized,
199             /*max_inflight_computations=*/32,
200             /*allow_event_reuse=*/true, /*use_callback_stream=*/true));
201   }
202   return std::move(addressable_devices);
203 }
204 
205 // Builds a BFCAllocator for all local GPUs.
CreateBFCAllocator(const std::map<int,std::unique_ptr<LocalDeviceState>> & addressable_devices,double memory_fraction,bool preallocate)206 StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
207     const std::map<int, std::unique_ptr<LocalDeviceState>>& addressable_devices,
208     double memory_fraction, bool preallocate) {
209   CHECK_GT(addressable_devices.size(), 0);
210   const se::Platform* platform =
211       addressable_devices.begin()->second->executor()->platform();
212   std::vector<se::MultiDeviceAdapter::AllocatorWithStream> allocators;
213   bool enable_unified_memory;
214   Status status = tensorflow::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY",
215                                                  false, &enable_unified_memory);
216   if (!status.ok()) {
217     LOG(ERROR) << "Unable to read TF_FORCE_UNIFIED_MEMORY: "
218                << status.error_message();
219   }
220 
221   for (auto& ordinal_and_device : addressable_devices) {
222     se::StreamExecutor* executor = ordinal_and_device.second->executor();
223     int device_ordinal = executor->device_ordinal();
224     auto sub_allocator = std::make_unique<tensorflow::DeviceMemAllocator>(
225         executor, tensorflow::PlatformDeviceId(device_ordinal),
226         /*use_unified_memory=*/enable_unified_memory,
227         /*alloc_visitors=*/std::vector<tensorflow::SubAllocator::Visitor>(),
228         /*free_visitors=*/std::vector<tensorflow::SubAllocator::Visitor>());
229 
230     int64_t free_memory;
231     int64_t total_memory;
232     if (!executor->DeviceMemoryUsage(&free_memory, &total_memory)) {
233       return Unavailable("Failed to query available memory from device %i",
234                          device_ordinal);
235     }
236     // To allow full GPU memory to be visible to the BFC allocator if using
237     // unified memory.
238     // When unified memory is enabled, allow GPU memory oversubscription by
239     // setting memory_fraction > 1.
240     size_t allocator_memory = enable_unified_memory
241                                   ? total_memory * fmax(1.0, memory_fraction)
242                                   : free_memory * memory_fraction;
243     if (preallocate) {
244       LOG(INFO) << "XLA backend allocating " << allocator_memory
245                 << " bytes on device " << device_ordinal
246                 << " for BFCAllocator.";
247     } else {
248       LOG(INFO) << "XLA backend will use up to " << allocator_memory
249                 << " bytes on device " << device_ordinal
250                 << " for BFCAllocator.";
251     }
252 
253     tensorflow::BFCAllocator::Options opts;
254     opts.allow_growth = !preallocate;
255     auto gpu_bfc_allocator = std::make_unique<tensorflow::BFCAllocator>(
256         std::move(sub_allocator), allocator_memory,
257         absl::StrCat("GPU_", device_ordinal, "_bfc"), opts);
258     allocators.emplace_back(std::move(gpu_bfc_allocator),
259                             ordinal_and_device.second->compute_stream());
260   }
261   return std::make_unique<se::MultiDeviceAdapter>(platform,
262                                                   std::move(allocators));
263 }
264 
265 // Constructs a GPU device memory allocator to use, according to the allocator
266 // configuration the client requested.
GetGpuDeviceAllocator(se::Platform * platform,const GpuAllocatorConfig & allocator_config,const std::map<int,std::unique_ptr<LocalDeviceState>> & addressable_devices)267 StatusOr<std::unique_ptr<se::DeviceMemoryAllocator>> GetGpuDeviceAllocator(
268     se::Platform* platform, const GpuAllocatorConfig& allocator_config,
269     const std::map<int, std::unique_ptr<LocalDeviceState>>&
270         addressable_devices) {
271   std::unique_ptr<se::DeviceMemoryAllocator> allocator;
272   switch (allocator_config.kind) {
273     case GpuAllocatorConfig::Kind::kCudaAsync: {
274       auto allocator_or = CreateCudaAsyncAllocator(
275           platform, addressable_devices, allocator_config.memory_fraction,
276           allocator_config.preallocate);
277       if (allocator_or.ok()) {
278         LOG(INFO) << "Using CUDA async allocator.";
279         allocator = std::move(allocator_or.ValueOrDie());
280         break;
281       }
282       LOG(ERROR) << "Failed to initialize CUDA async allocator: "
283                  << allocator_or.status() << "; falling back to BFC.";
284       [[fallthrough]];
285     }
286 
287     case GpuAllocatorConfig::Kind::kDefault:
288     case GpuAllocatorConfig::Kind::kBFC: {
289       LOG(INFO) << "Using BFC allocator.";
290       TF_ASSIGN_OR_RETURN(allocator,
291                           CreateBFCAllocator(addressable_devices,
292                                              allocator_config.memory_fraction,
293                                              allocator_config.preallocate));
294       break;
295     }
296 
297     case GpuAllocatorConfig::Kind::kPlatform:
298       LOG(INFO) << "Using platform allocator.";
299       break;
300   }
301   return std::move(allocator);
302 }
303 
304 // Returns a GPU pinned host memory allocator to use when staging host->GPU
305 // transfers. We use a fixed 64MB pool of pinned memory.
GetGpuHostAllocator(se::StreamExecutor * executor)306 std::unique_ptr<tensorflow::BFCAllocator> GetGpuHostAllocator(
307     se::StreamExecutor* executor) {
308   std::unique_ptr<tensorflow::SubAllocator> sub_allocator(
309       new tensorflow::DeviceHostAllocator(executor, /*numa_node=*/0,
310                                           /*alloc_visitors=*/{},
311                                           /*free_visitors=*/{}));
312   // TODO(phawkins): allow the user to tune this.
313   const int64_t kGpuHostMemoryLimitBytes = 64 * (1LL << 30);
314 
315   tensorflow::BFCAllocator::Options opts;
316   opts.allow_growth = true;
317   return std::make_unique<tensorflow::BFCAllocator>(
318       std::move(sub_allocator), kGpuHostMemoryLimitBytes,
319       /*name=*/"xla_gpu_host_bfc", opts);
320 }
321 
BuildLocalDevices(std::map<int,std::unique_ptr<LocalDeviceState>> local_device_states)322 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
323     std::map<int, std::unique_ptr<LocalDeviceState>> local_device_states) {
324   std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
325   for (auto& ordinal_and_device : local_device_states) {
326     const se::DeviceDescription& description =
327         ordinal_and_device.second->executor()->GetDeviceDescription();
328     auto device = std::make_unique<GpuDevice>(
329         ordinal_and_device.first, std::move(ordinal_and_device.second),
330         description.name(), description.device_vendor(),
331         /*node_id=*/0);
332     devices.push_back(std::move(device));
333   }
334   return devices;
335 }
336 
BuildDistributedDevices(std::map<int,std::unique_ptr<LocalDeviceState>> local_device_states,std::shared_ptr<DistributedRuntimeClient> distributed_client,int node_id,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> * devices,gpu::GpuExecutableRunOptions * gpu_executable_run_options)337 Status BuildDistributedDevices(
338     std::map<int, std::unique_ptr<LocalDeviceState>> local_device_states,
339     std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
340     std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
341     gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
342   LocalTopologyProto local_topology;
343   local_topology.set_node_id(node_id);
344   for (const auto& ordinal_and_device : local_device_states) {
345     const se::Platform* platform =
346         ordinal_and_device.second->executor()->platform();
347     TF_ASSIGN_OR_RETURN(
348         std::unique_ptr<xla::se::DeviceDescription> desc,
349         platform->DescriptionForDevice(ordinal_and_device.first));
350     DeviceProto* device_proto = local_topology.add_devices();
351     device_proto->set_local_device_ordinal(ordinal_and_device.first);
352     device_proto->set_name(desc->name());
353     device_proto->set_vendor(desc->device_vendor());
354   }
355 
356   GlobalTopologyProto global_topology;
357   TF_RETURN_IF_ERROR(
358       distributed_client->EnumerateDevices(local_topology, &global_topology));
359 
360   std::map<int, GlobalDeviceId> gpu_device_ids;
361   absl::flat_hash_map<GlobalDeviceId, int> device_to_node;
362   for (const LocalTopologyProto& node : global_topology.nodes()) {
363     for (const DeviceProto& device_proto : node.devices()) {
364       GlobalDeviceId global_device_id(device_proto.global_device_id());
365       device_to_node[global_device_id] = node.node_id();
366       std::unique_ptr<LocalDeviceState> local_device;
367       if (node.node_id() == node_id) {
368         auto it = local_device_states.find(device_proto.local_device_ordinal());
369         TF_RET_CHECK(it != local_device_states.end())
370             << device_proto.local_device_ordinal();
371         TF_RET_CHECK(it->second != nullptr);
372         local_device = std::move(it->second);
373         gpu_device_ids[device_proto.local_device_ordinal()] = global_device_id;
374       }
375       auto device = std::make_unique<GpuDevice>(
376           device_proto.global_device_id(), std::move(local_device),
377           device_proto.name(), device_proto.vendor(), node.node_id());
378       devices->push_back(std::move(device));
379     }
380   }
381   for (const auto& device : local_device_states) {
382     TF_RET_CHECK(device.second == nullptr);
383   }
384   std::vector<GlobalDeviceId> sorted_global_device_ids;
385   sorted_global_device_ids.reserve(gpu_device_ids.size());
386   for (const auto& e : gpu_device_ids) {
387     sorted_global_device_ids.push_back(e.second);
388   }
389   gpu_executable_run_options->set_gpu_global_device_ids(
390       std::move(sorted_global_device_ids));
391 #ifdef GOOGLE_CUDA
392   auto nccl_id_store = std::make_shared<NcclIdStore>(
393       node_id, distributed_client, device_to_node);
394   gpu_executable_run_options->set_nccl_unique_id_callback(
395       [nccl_id_store](const gpu::NcclCliqueKey& key) {
396         return nccl_id_store->GetNcclUniqueId(key);
397       });
398 #endif  // GOOGLE_CUDA
399   return OkStatus();
400 }
401 
402 }  // namespace
403 
GpuDevice(int id,std::unique_ptr<LocalDeviceState> local_device_state,std::string device_kind,std::string device_vendor,int node_id)404 GpuDevice::GpuDevice(int id,
405                      std::unique_ptr<LocalDeviceState> local_device_state,
406                      std::string device_kind, std::string device_vendor,
407                      int node_id)
408     : PjRtStreamExecutorDevice(id, std::move(local_device_state),
409                                std::move(device_kind), node_id),
410       device_vendor_(std::move(device_vendor)) {
411   attributes_ = {
412       {"device_vendor", PjRtDeviceAttribute(device_vendor_)},
413   };
414   to_string_ = absl::StrFormat("GpuDevice(id=%i, process_index=%i)", id,
415                                process_index());
416 }
417 
device_vendor()418 absl::string_view GpuDevice::device_vendor() { return device_vendor_; }
419 
ToString() const420 absl::string_view GpuDevice::ToString() const { return to_string_; }
421 
GetGpuClient(bool asynchronous,const GpuAllocatorConfig & allocator_config,std::shared_ptr<DistributedRuntimeClient> distributed_client,int node_id,const std::optional<std::set<int>> & allowed_devices,std::optional<std::string> platform_name)422 StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
423     bool asynchronous, const GpuAllocatorConfig& allocator_config,
424     std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
425     const std::optional<std::set<int>>& allowed_devices,
426     std::optional<std::string> platform_name) {
427   TF_ASSIGN_OR_RETURN(LocalClient * xla_client,
428                       GetGpuXlaClient(platform_name, allowed_devices));
429   std::map<int, std::unique_ptr<LocalDeviceState>> local_device_states;
430   TF_ASSIGN_OR_RETURN(local_device_states,
431                       BuildLocalDeviceStates(xla_client, asynchronous));
432   EnablePeerAccess(xla_client->backend().stream_executors());
433   TF_ASSIGN_OR_RETURN(
434       auto allocator,
435       GetGpuDeviceAllocator(xla_client->platform(), allocator_config,
436                             local_device_states));
437   auto host_memory_allocator =
438       GetGpuHostAllocator(local_device_states.begin()->second->executor());
439 
440   std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
441   auto gpu_run_options = std::make_unique<gpu::GpuExecutableRunOptions>();
442   if (distributed_client) {
443     TF_RETURN_IF_ERROR(BuildDistributedDevices(
444         std::move(local_device_states), std::move(distributed_client), node_id,
445         &devices, gpu_run_options.get()));
446   } else {
447     devices = BuildLocalDevices(std::move(local_device_states));
448   }
449 
450   return std::unique_ptr<PjRtClient>(std::make_unique<GpuClient>(
451       GpuName(), xla_client, std::move(devices),
452       /*node_id=*/node_id, std::move(allocator),
453       std::move(host_memory_allocator),
454       /*should_stage_host_to_device_transfers=*/true,
455       /*gpu_run_options=*/std::move(gpu_run_options)));
456 }
457 
458 }  // namespace xla
459