xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_platform_info.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_platform_info.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/compiler/jit/flags.h"
21 #include "tensorflow/compiler/xla/client/client_library.h"
22 
23 namespace tensorflow {
24 
ParseVisibleDeviceList(absl::string_view visible_device_list)25 xla::StatusOr<std::optional<std::set<int>>> ParseVisibleDeviceList(
26     absl::string_view visible_device_list) {
27   std::set<int> gpu_ids;
28   if (visible_device_list.empty()) {
29     return {{std::nullopt}};
30   }
31   const std::vector<string> visible_devices =
32       absl::StrSplit(visible_device_list, ',');
33   for (const string& platform_device_id_str : visible_devices) {
34     int32_t platform_device_id;
35     if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
36       return errors::InvalidArgument(
37           "Could not parse entry in 'visible_device_list': '",
38           platform_device_id_str,
39           "'. visible_device_list = ", visible_device_list);
40     }
41     gpu_ids.insert(platform_device_id);
42   }
43   return {{gpu_ids}};
44 }
45 
BuildXlaCompilationCache(DeviceBase * device,FunctionLibraryRuntime * flr,const XlaPlatformInfo & platform_info,XlaCompilationCache ** cache)46 Status BuildXlaCompilationCache(DeviceBase* device, FunctionLibraryRuntime* flr,
47                                 const XlaPlatformInfo& platform_info,
48                                 XlaCompilationCache** cache) {
49   XlaCompilationCache::Config cache_config(
50       GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory,
51       GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks,
52       GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix);
53 
54   if (platform_info.xla_device_metadata()) {
55     *cache = new XlaCompilationCache(
56         std::move(cache_config), platform_info.xla_device_metadata()->client(),
57         platform_info.xla_device_metadata()->jit_device_type());
58     return OkStatus();
59   }
60 
61   auto platform =
62       se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
63   if (!platform.ok()) {
64     return platform.status();
65   }
66 
67   StatusOr<xla::Compiler*> compiler_for_platform =
68       xla::Compiler::GetForPlatform(platform.ValueOrDie());
69   if (!compiler_for_platform.ok()) {
70     // In some rare cases (usually in unit tests with very small clusters) we
71     // may end up transforming an XLA cluster with at least one GPU operation
72     // (which would normally force the cluster to be compiled using XLA:GPU)
73     // into an XLA cluster with no GPU operations (i.e. containing only CPU
74     // operations).  Such a cluster can fail compilation (in way that
75     // MarkForCompilation could not have detected) if the CPU JIT is not linked
76     // in.
77     //
78     // So bail out of _XlaCompile in this case, and let the executor handle the
79     // situation for us.
80     const Status& status = compiler_for_platform.status();
81     if (status.code() == error::NOT_FOUND) {
82       return errors::Unimplemented("Could not find compiler for platform ",
83                                    platform.ValueOrDie()->Name(), ": ",
84                                    status.ToString());
85     }
86   }
87 
88   xla::LocalClientOptions client_options;
89   client_options.set_platform(platform.ValueOrDie());
90   client_options.set_intra_op_parallelism_threads(
91       device->tensorflow_cpu_worker_threads()->num_threads);
92 
93   if (flr->config_proto()) {
94     string allowed_gpus =
95         flr->config_proto()->gpu_options().visible_device_list();
96     TF_ASSIGN_OR_RETURN(std::optional<std::set<int>> gpu_ids,
97                         ParseVisibleDeviceList(allowed_gpus));
98     client_options.set_allowed_devices(gpu_ids);
99   }
100 
101   auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
102   if (!client.ok()) {
103     return client.status();
104   }
105   const XlaOpRegistry::DeviceRegistration* registration;
106   if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
107                                            &registration)) {
108     return errors::InvalidArgument("No JIT device registered for ",
109                                    platform_info.device_type().type());
110   }
111   *cache = new XlaCompilationCache(
112       std::move(cache_config), client.ValueOrDie(),
113       DeviceType(registration->compilation_device_name));
114   return OkStatus();
115 }
116 
XlaPlatformInfoFromDevice(DeviceBase * device_base)117 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
118   auto device = static_cast<Device*>(device_base);
119   se::Platform::Id platform_id = nullptr;
120   const XlaDevice::Metadata* xla_device_metadata = nullptr;
121   std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator;
122 
123   if (device->device_type() == DEVICE_CPU) {
124     platform_id = se::host::kHostPlatformId;
125   } else if (device->device_type() == DEVICE_GPU) {
126     platform_id = device->tensorflow_accelerator_device_info()
127                       ->stream->parent()
128                       ->platform()
129                       ->id();
130   } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
131                  .ok()) {
132     // If we are on an XlaDevice, use the underlying XLA platform's allocator
133     // directly. We could use the StreamExecutor's allocator which may
134     // theoretically be more correct, but XLA returns a nice OOM message in a
135     // Status and StreamExecutor does not.
136     //
137     // Importantly we can't use ctx->device()->GetAllocator() as the allocator
138     // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
139     // allocator that returns XlaTensor objects. The XlaCompiler needs a real
140     // allocator to allocate real buffers.
141     platform_id = xla_device_metadata->platform()->id();
142     custom_allocator =
143         xla_device_metadata->client()->backend().shared_memory_allocator();
144   }
145 
146   return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
147                          xla_device_metadata, custom_allocator);
148 }
149 
GetAllocator(DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info)150 std::shared_ptr<se::DeviceMemoryAllocator> GetAllocator(
151     DeviceBase* device, se::Stream* stream,
152     const XlaPlatformInfo& platform_info) {
153   if (platform_info.custom_allocator()) {
154     return platform_info.custom_allocator();
155   }
156   auto* alloc = device->GetAllocator({});
157   if (!stream) {
158     // Stream is not set for the host platform.
159     se::Platform* platform =
160         se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
161             .ValueOrDie();
162     return std::make_shared<se::TfAllocatorAdapter>(alloc, platform);
163   }
164   return std::make_shared<se::TfAllocatorAdapter>(alloc, stream);
165 }
166 
GenerateCompilerOptions(const XlaCompilationCache & cache,const FunctionLibraryRuntime & function_library,DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info,bool has_ref_vars)167 XlaCompiler::Options GenerateCompilerOptions(
168     const XlaCompilationCache& cache,
169     const FunctionLibraryRuntime& function_library, DeviceBase* device,
170     se::Stream* stream, const XlaPlatformInfo& platform_info,
171     bool has_ref_vars) {
172   XlaCompiler::Options options;
173   options.client = static_cast<xla::LocalClient*>(cache.client());
174   if (stream != nullptr) {
175     options.device_ordinal = stream->parent()->device_ordinal();
176   }
177   options.device_type = cache.device_type();
178   options.flib_def = function_library.GetFunctionLibraryDefinition();
179   options.graph_def_version = function_library.graph_def_version();
180   options.allow_cpu_custom_calls =
181       (platform_info.platform_id() == se::host::kHostPlatformId);
182   options.device_allocator = GetAllocator(device, stream, platform_info);
183   if (platform_info.xla_device_metadata()) {
184     options.shape_determination_fns =
185         platform_info.xla_device_metadata()->default_shape_determination_fns();
186   }
187   // If reference variables are not present in the graph, we can safely alias
188   // passthrough parameters without performing a copy.
189   options.alias_passthrough_params =
190       !has_ref_vars && !platform_info.is_on_xla_device();
191   return options;
192 }
193 
194 }  // namespace tensorflow
195