1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_platform_info.h"
17
18 #include <utility>
19
20 #include "tensorflow/compiler/jit/flags.h"
21 #include "tensorflow/compiler/xla/client/client_library.h"
22
23 namespace tensorflow {
24
ParseVisibleDeviceList(absl::string_view visible_device_list)25 xla::StatusOr<std::optional<std::set<int>>> ParseVisibleDeviceList(
26 absl::string_view visible_device_list) {
27 std::set<int> gpu_ids;
28 if (visible_device_list.empty()) {
29 return {{std::nullopt}};
30 }
31 const std::vector<string> visible_devices =
32 absl::StrSplit(visible_device_list, ',');
33 for (const string& platform_device_id_str : visible_devices) {
34 int32_t platform_device_id;
35 if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
36 return errors::InvalidArgument(
37 "Could not parse entry in 'visible_device_list': '",
38 platform_device_id_str,
39 "'. visible_device_list = ", visible_device_list);
40 }
41 gpu_ids.insert(platform_device_id);
42 }
43 return {{gpu_ids}};
44 }
45
BuildXlaCompilationCache(DeviceBase * device,FunctionLibraryRuntime * flr,const XlaPlatformInfo & platform_info,XlaCompilationCache ** cache)46 Status BuildXlaCompilationCache(DeviceBase* device, FunctionLibraryRuntime* flr,
47 const XlaPlatformInfo& platform_info,
48 XlaCompilationCache** cache) {
49 XlaCompilationCache::Config cache_config(
50 GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory,
51 GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks,
52 GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix);
53
54 if (platform_info.xla_device_metadata()) {
55 *cache = new XlaCompilationCache(
56 std::move(cache_config), platform_info.xla_device_metadata()->client(),
57 platform_info.xla_device_metadata()->jit_device_type());
58 return OkStatus();
59 }
60
61 auto platform =
62 se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
63 if (!platform.ok()) {
64 return platform.status();
65 }
66
67 StatusOr<xla::Compiler*> compiler_for_platform =
68 xla::Compiler::GetForPlatform(platform.ValueOrDie());
69 if (!compiler_for_platform.ok()) {
70 // In some rare cases (usually in unit tests with very small clusters) we
71 // may end up transforming an XLA cluster with at least one GPU operation
72 // (which would normally force the cluster to be compiled using XLA:GPU)
73 // into an XLA cluster with no GPU operations (i.e. containing only CPU
74 // operations). Such a cluster can fail compilation (in way that
75 // MarkForCompilation could not have detected) if the CPU JIT is not linked
76 // in.
77 //
78 // So bail out of _XlaCompile in this case, and let the executor handle the
79 // situation for us.
80 const Status& status = compiler_for_platform.status();
81 if (status.code() == error::NOT_FOUND) {
82 return errors::Unimplemented("Could not find compiler for platform ",
83 platform.ValueOrDie()->Name(), ": ",
84 status.ToString());
85 }
86 }
87
88 xla::LocalClientOptions client_options;
89 client_options.set_platform(platform.ValueOrDie());
90 client_options.set_intra_op_parallelism_threads(
91 device->tensorflow_cpu_worker_threads()->num_threads);
92
93 if (flr->config_proto()) {
94 string allowed_gpus =
95 flr->config_proto()->gpu_options().visible_device_list();
96 TF_ASSIGN_OR_RETURN(std::optional<std::set<int>> gpu_ids,
97 ParseVisibleDeviceList(allowed_gpus));
98 client_options.set_allowed_devices(gpu_ids);
99 }
100
101 auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
102 if (!client.ok()) {
103 return client.status();
104 }
105 const XlaOpRegistry::DeviceRegistration* registration;
106 if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
107 ®istration)) {
108 return errors::InvalidArgument("No JIT device registered for ",
109 platform_info.device_type().type());
110 }
111 *cache = new XlaCompilationCache(
112 std::move(cache_config), client.ValueOrDie(),
113 DeviceType(registration->compilation_device_name));
114 return OkStatus();
115 }
116
XlaPlatformInfoFromDevice(DeviceBase * device_base)117 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
118 auto device = static_cast<Device*>(device_base);
119 se::Platform::Id platform_id = nullptr;
120 const XlaDevice::Metadata* xla_device_metadata = nullptr;
121 std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator;
122
123 if (device->device_type() == DEVICE_CPU) {
124 platform_id = se::host::kHostPlatformId;
125 } else if (device->device_type() == DEVICE_GPU) {
126 platform_id = device->tensorflow_accelerator_device_info()
127 ->stream->parent()
128 ->platform()
129 ->id();
130 } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
131 .ok()) {
132 // If we are on an XlaDevice, use the underlying XLA platform's allocator
133 // directly. We could use the StreamExecutor's allocator which may
134 // theoretically be more correct, but XLA returns a nice OOM message in a
135 // Status and StreamExecutor does not.
136 //
137 // Importantly we can't use ctx->device()->GetAllocator() as the allocator
138 // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
139 // allocator that returns XlaTensor objects. The XlaCompiler needs a real
140 // allocator to allocate real buffers.
141 platform_id = xla_device_metadata->platform()->id();
142 custom_allocator =
143 xla_device_metadata->client()->backend().shared_memory_allocator();
144 }
145
146 return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
147 xla_device_metadata, custom_allocator);
148 }
149
GetAllocator(DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info)150 std::shared_ptr<se::DeviceMemoryAllocator> GetAllocator(
151 DeviceBase* device, se::Stream* stream,
152 const XlaPlatformInfo& platform_info) {
153 if (platform_info.custom_allocator()) {
154 return platform_info.custom_allocator();
155 }
156 auto* alloc = device->GetAllocator({});
157 if (!stream) {
158 // Stream is not set for the host platform.
159 se::Platform* platform =
160 se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
161 .ValueOrDie();
162 return std::make_shared<se::TfAllocatorAdapter>(alloc, platform);
163 }
164 return std::make_shared<se::TfAllocatorAdapter>(alloc, stream);
165 }
166
GenerateCompilerOptions(const XlaCompilationCache & cache,const FunctionLibraryRuntime & function_library,DeviceBase * device,se::Stream * stream,const XlaPlatformInfo & platform_info,bool has_ref_vars)167 XlaCompiler::Options GenerateCompilerOptions(
168 const XlaCompilationCache& cache,
169 const FunctionLibraryRuntime& function_library, DeviceBase* device,
170 se::Stream* stream, const XlaPlatformInfo& platform_info,
171 bool has_ref_vars) {
172 XlaCompiler::Options options;
173 options.client = static_cast<xla::LocalClient*>(cache.client());
174 if (stream != nullptr) {
175 options.device_ordinal = stream->parent()->device_ordinal();
176 }
177 options.device_type = cache.device_type();
178 options.flib_def = function_library.GetFunctionLibraryDefinition();
179 options.graph_def_version = function_library.graph_def_version();
180 options.allow_cpu_custom_calls =
181 (platform_info.platform_id() == se::host::kHostPlatformId);
182 options.device_allocator = GetAllocator(device, stream, platform_info);
183 if (platform_info.xla_device_metadata()) {
184 options.shape_determination_fns =
185 platform_info.xla_device_metadata()->default_shape_determination_fns();
186 }
187 // If reference variables are not present in the graph, we can safely alias
188 // passthrough parameters without performing a copy.
189 options.alias_passthrough_params =
190 !has_ref_vars && !platform_info.is_on_xla_device();
191 return options;
192 }
193
194 } // namespace tensorflow
195