1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
16
17 #include <cstdint>
18
19 #include "absl/cleanup/cleanup.h"
20 #include "tensorflow/c/tf_status.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/platform/refcount.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
30 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
31 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
32 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
33 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
34 #include "tensorflow/core/tpu/kernels/tpu_embedding_engine_state_interface.h"
35 #include "tensorflow/core/tpu/kernels/tpu_execute_op_options.h"
36 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
37 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
38 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
39 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
40 #include "tensorflow/core/tpu/tpu_api.h"
41 #include "tensorflow/core/tpu/tpu_configuration.h"
42 #include "tensorflow/core/tpu/tpu_defs.h"
43 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
44 #include "tensorflow/stream_executor/stream.h"
45 #include "tensorflow/stream_executor/tpu/proto_helper.h"
46
47 namespace tensorflow {
48 namespace {
GetTpuMeshStateInterface(const ResourceMgr * rmgr,tpu::TpuMeshStateInterface ** state)49 Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
50 tpu::TpuMeshStateInterface** state) {
51 if (!rmgr->Lookup(rmgr->default_container(),
52 tpu::kTpuMeshStateInterfaceResourceName, state)
53 .ok()) {
54 return errors::FailedPrecondition(
55 "GetTpuMeshStateInterface: The TPU system has not been initialized.");
56 }
57 return OkStatus();
58 }
59
CreateTpuFingerprintLookup(ResourceMgr * rmgr)60 Status CreateTpuFingerprintLookup(ResourceMgr* rmgr) {
61 VLOG(1) << "CreateTpuFingerprintLookup";
62 tpu::TpuFingerprintLookup* fingerprint_lookup;
63 TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<tpu::TpuFingerprintLookup>(
64 rmgr->default_container(), tpu::kFingerprintLookupResourceName,
65 &fingerprint_lookup, [&](tpu::TpuFingerprintLookup** new_lookup) {
66 *new_lookup = tpu::TpuFingerprintLookup::Create();
67 return OkStatus();
68 }));
69
70 core::ScopedUnref fingerprint_lookup_ref(fingerprint_lookup);
71 return OkStatus();
72 }
73
74 // Attempt to delete resource_name from resource_manager's default_container.
75 // Returns OK if the deletion succeeded, or if the resource was not found. Else
76 // return the deletion error.
77 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)78 Status DeleteIfExists(ResourceMgr* resource_manager,
79 const char* resource_name) {
80 VLOG(1) << "Removing resource " << resource_name << " if it exists";
81 Status status = resource_manager->Delete<ResourceT>(
82 resource_manager->default_container(), resource_name);
83 if (status.ok()) {
84 VLOG(1) << "Removed existing resource " << resource_name;
85 return OkStatus();
86 }
87 if (status.code() == error::NOT_FOUND) {
88 VLOG(1) << "No resource " << resource_name << " to remove";
89 return OkStatus();
90 }
91 VLOG(1) << "Error removing resource " << resource_name << " : " << status;
92 return status;
93 }
94 } // namespace
95
CreateTpuCompilationCache(ResourceMgr * rmgr,tpu::TpuCompilationCacheInterface ** compilation_cache)96 Status CreateTpuCompilationCache(
97 ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) {
98 return rmgr->LookupOrCreate<tpu::TpuCompilationCacheInterface>(
99 rmgr->default_container(), tpu::kCompilationCacheResourceName,
100 compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) {
101 *new_cache = tpu::GetCompilationCacheCreateFn()();
102 return OkStatus();
103 });
104 }
105
ConstructDevicesPerHost(OpKernelContext * ctx)106 xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
107 OpKernelContext* ctx) {
108 std::vector<int32_t> num_devices_per_host;
109 int chips_per_host = -1;
110 for (int i = 0; i < ctx->num_inputs(); ++i) {
111 const Tensor& input_tensor = ctx->input(i);
112 if (!TensorShapeUtils::IsScalar(input_tensor.shape())) {
113 return errors::InvalidArgument("Input ", i,
114 " should be a scalar but has ",
115 input_tensor.dims(), " dimensions");
116 }
117 if (chips_per_host == -1) {
118 chips_per_host = input_tensor.scalar<int32_t>()();
119 } else {
120 if (chips_per_host != input_tensor.scalar<int32>()()) {
121 return errors::Internal("Host ", i, " has ",
122 input_tensor.scalar<int32>()(),
123 " TPU chips but host 0 has ", chips_per_host);
124 }
125 }
126 num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
127 }
128 return num_devices_per_host;
129 }
130
Compute(OpKernelContext * ctx)131 void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
132 VLOG(1) << "ConfigureDistributedTpuOp";
133 XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
134
135 xla::StatusOr<std::vector<int32_t>> num_devices_per_host =
136 ConstructDevicesPerHost(ctx);
137 OP_REQUIRES_OK(ctx, num_devices_per_host.status());
138 ResourceMgr* rmgr = GetTPUConfigResourceMgr();
139
140 // Create the subgraph compilation cache and put it in the local resource
141 // manager.
142 tpu::TpuCompilationCacheInterface* compilation_cache;
143 OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
144 core::ScopedUnref compilation_cache_ref(compilation_cache);
145
146 std::string host_config_output;
147 OP_REQUIRES_OK(
148 ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache,
149 &host_config_output));
150
151 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
152 rmgr, tpu::kTpuMeshStateInterfaceResourceName));
153
154 auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
155 OP_REQUIRES_OK(
156 ctx, rmgr->Create(rmgr->default_container(),
157 tpu::kTpuMeshStateInterfaceResourceName, tpu_mesh));
158
159 Tensor* ctx_output;
160 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
161 ctx_output->scalar<tstring>()() = std::move(host_config_output);
162
163 OP_REQUIRES_OK(ctx, CreateTpuFingerprintLookup(rmgr));
164 VLOG(1) << "ConfigureDistributedTpuOp done";
165 }
166
Compute(OpKernelContext * ctx)167 void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
168 VLOG(1) << "WaitForDistributedTpuOp";
169 XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
170
171 size_t num_devices_per_host = -1;
172 size_t num_hosts = ctx->num_inputs();
173
174 for (int i = 0; i < ctx->num_inputs(); ++i) {
175 const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
176 OP_REQUIRES(
177 ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
178 errors::InvalidArgument("Input ", i, " should be a vector but has ",
179 host_ordinal_to_global_device_id_tensor.dims(),
180 " dimensions"));
181 }
182
183 std::vector<std::vector<int32_t>> mapping;
184 std::vector<int32_t*> mapping_arg;
185
186 mapping.resize(ctx->num_inputs());
187
188 for (int i = 0; i < ctx->num_inputs(); ++i) {
189 const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
190 const auto host_ordinal_to_global_device_id =
191 host_ordinal_to_global_device_id_tensor.flat<int>();
192 if (num_devices_per_host == -1) {
193 num_devices_per_host =
194 host_ordinal_to_global_device_id_tensor.dim_size(0);
195 } else {
196 OP_REQUIRES(ctx,
197 num_devices_per_host ==
198 host_ordinal_to_global_device_id_tensor.dim_size(0),
199 errors::Internal(
200 "Host ", i, " has ",
201 host_ordinal_to_global_device_id_tensor.dim_size(0),
202 " TPU devices but host 0 has ", num_devices_per_host));
203 }
204 for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
205 ++j) {
206 int32_t global_device_id = host_ordinal_to_global_device_id(j);
207 mapping[i].push_back(global_device_id);
208 }
209 mapping_arg.push_back(mapping[i].data());
210 }
211
212 tpu::TpuMeshStateInterface* mesh_state;
213 auto* rmgr = GetTPUConfigResourceMgr();
214 OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
215 core::ScopedUnref mesh_state_unref(mesh_state);
216
217 // TODO(b/166858751): this code to check if `TpuPodState` exists is ported
218 // from a legacy library that may have staled. A candidate for cleanup.
219 TpuPodState* pod_state;
220 OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state));
221 core::ScopedUnref pod_state_unref(pod_state);
222
223 size_t tpu_topology_output_size;
224 char* tpu_topology_output = nullptr;
225 TF_Status* status = TF_NewStatus();
226 auto cleanup = absl::MakeCleanup([&status, &tpu_topology_output]() {
227 TF_DeleteStatus(status);
228 tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
229 });
230
231 auto* mesh_common_state = mesh_state->mesh_common_state();
232
233 WaitForDistributedTpuOp_DoWork_Params params;
234 params.struct_size = WaitForDistributedTpuOp_DoWork_Params_SIZE;
235 params.priv = nullptr;
236 params.num_hosts = num_hosts;
237 params.num_cores_per_host = num_devices_per_host;
238 params.host_ordinal_to_global_core_id_map =
239 const_cast<const int32_t**>(mapping_arg.data());
240 params.tpu_mesh_common_state = mesh_common_state;
241 params.tpu_topology_output_size = &tpu_topology_output_size;
242 params.tpu_topology_output = &tpu_topology_output;
243 params.status = status;
244
245 tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(¶ms);
246
247 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
248
249 Tensor* ctx_output;
250 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
251 ctx_output->scalar<tstring>()() =
252 std::string(tpu_topology_output, tpu_topology_output_size);
253
254 VLOG(1) << "WaitForDistributedTpuOp done";
255 }
256
Compute(OpKernelContext * ctx)257 void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
258 VLOG(1) << "ShutdownDistributedTpuOp";
259 XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
260
261 auto* rmgr = GetTPUConfigResourceMgr();
262 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
263 rmgr, tpu::kTpuMeshStateInterfaceResourceName));
264
265 OP_REQUIRES_OK(ctx,
266 DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
267 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
268 rmgr, tpu::kCompilationCacheResourceName));
269
270 VLOG(1) << "ShutdownDistributedTpuOp done";
271 }
272
Compute(OpKernelContext * ctx)273 void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
274 VLOG(1) << "InitializeHostForDistributedTpuOp";
275 XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
276
277 auto* rmgr = GetTPUConfigResourceMgr();
278 auto tpu_host_config = ctx->input(0).scalar<tstring>()();
279
280 // Reset the TPU embedding engine interface if we are not the master.
281 // We need to reset the interface before initializing the host because the
282 // resetting process reset the TPU platform.
283 OP_REQUIRES_OK(ctx,
284 DeleteIfExists<tpu::TpuEmbeddingEngineStateInterface>(
285 rmgr, tpu::kTpuEmbeddingEngineStateInterfaceResourceName));
286
287 bool is_master_worker =
288 tpu::OpsApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
289 if (!is_master_worker) {
290 // Reset the mesh interface if we are not the master.
291 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
292 rmgr, tpu::kTpuMeshStateInterfaceResourceName));
293 auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create();
294 OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
295 tpu::kTpuMeshStateInterfaceResourceName,
296 mesh_state_interface));
297 }
298
299 VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
300 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheLookup>(
301 rmgr, tpu::kCompiledProtoCacheResourceName));
302
303 if (enable_whole_mesh_compilations_) {
304 // If this is a whole mesh compilation mode, create the compilation cache,
305 // if missing.
306 tpu::TpuCompilationCacheInterface* compilation_cache;
307 OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
308 compilation_cache->Unref();
309 }
310
311 OP_REQUIRES_OK(ctx, internal::SetTpuCancellationClosesChips(
312 tpu_cancellation_closes_chips_));
313
314 tpu::TpuCompilationCacheInterface* local_compilation_cache;
315 Status s = rmgr->Lookup(rmgr->default_container(),
316 tpu::kCompilationCacheResourceName,
317 &local_compilation_cache);
318 if (!s.ok()) {
319 local_compilation_cache = nullptr;
320 }
321
322 TF_Status* status = TF_NewStatus();
323 size_t device_id_output_size;
324 int32_t* device_id_output = nullptr;
325 auto cleanup = absl::MakeCleanup([&status, &device_id_output]() {
326 TF_DeleteStatus(status);
327 tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
328 });
329
330 InitializeHostForDistributedTpuOp_DoWork_Params params;
331 params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
332 params.priv = nullptr;
333 params.tpu_host_config_size = tpu_host_config.size();
334 params.tpu_host_config = tpu_host_config.data();
335 params.enable_whole_mesh_compilations = enable_whole_mesh_compilations_;
336 params.is_master_worker = is_master_worker;
337 params.core_id_output_size = &device_id_output_size;
338 params.core_id_output = &device_id_output;
339 params.status = status;
340
341 tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(¶ms);
342 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
343
344 if (local_compilation_cache != nullptr) {
345 local_compilation_cache->Unref();
346
347 tpu::TpuCompilationCacheLookup* proto_lookup;
348 proto_lookup =
349 new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
350 OP_REQUIRES_OK(
351 ctx, rmgr->Create(rmgr->default_container(),
352 tpu::kCompiledProtoCacheResourceName, proto_lookup));
353 } else {
354 int64_t cache_size_bytes;
355 tpu::OpsApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
356 &cache_size_bytes);
357
358 char* server_address_output = nullptr;
359 auto cleanup_server_address = absl::MakeCleanup([&server_address_output]() {
360 tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(
361 server_address_output);
362 });
363 size_t server_address_output_size;
364
365 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params params;
366 params.struct_size =
367 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE;
368 params.priv = nullptr;
369 params.tpu_host_config_size = tpu_host_config.size();
370 params.tpu_host_config = tpu_host_config.data();
371 params.server_address_output_size = &server_address_output_size;
372 params.server_address_output = &server_address_output;
373 params.status = status;
374
375 tpu::OpsApiFn()
376 ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
377 ¶ms);
378 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
379
380 std::string server_address(server_address_output,
381 server_address_output_size);
382 tpu::TpuCompilationCacheLookup* proto_lookup =
383 new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes);
384 OP_REQUIRES_OK(
385 ctx, rmgr->Create(rmgr->default_container(),
386 tpu::kCompiledProtoCacheResourceName, proto_lookup));
387 }
388
389 auto* engine_state_interface =
390 tpu::TpuEmbeddingEngineStateInterface::Create();
391 OP_REQUIRES_OK(
392 ctx, rmgr->Create(rmgr->default_container(),
393 tpu::kTpuEmbeddingEngineStateInterfaceResourceName,
394 engine_state_interface));
395
396 Tensor* ctx_output;
397 OP_REQUIRES_OK(
398 ctx, ctx->allocate_output(
399 0, TensorShape({static_cast<long long>(device_id_output_size)}),
400 &ctx_output));
401
402 for (size_t i = 0; i < device_id_output_size; ++i) {
403 ctx_output->flat<int32>()(i) = device_id_output[i];
404 }
405 if (ctx->function_library() != nullptr &&
406 ctx->function_library()->device_mgr() != nullptr) {
407 // If a DeviceMgr is available, set global IDs for TPU devices from the
408 // topology.
409 DeviceBase* tpu_system_device = ctx->device();
410 const DeviceNameUtils::ParsedName& tpu_system_name =
411 tpu_system_device->parsed_name();
412 for (DeviceBase* device :
413 ctx->function_library()->device_mgr()->ListDevices()) {
414 const DeviceNameUtils::ParsedName& device_parsed_name =
415 device->parsed_name();
416 if (device_parsed_name.type == "TPU" &&
417 DeviceNameUtils::IsSameAddressSpace(tpu_system_name,
418 device_parsed_name)) {
419 const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
420 device->tensorflow_accelerator_device_info();
421 if (accelerator_device_info && accelerator_device_info->stream) {
422 int device_ordinal =
423 accelerator_device_info->stream->parent()->device_ordinal();
424 if (device_ordinal >= device_id_output_size) {
425 OP_REQUIRES_OK(ctx,
426 errors::Internal(absl::StrCat(
427 "TPU core with ordinal ", device_ordinal,
428 " out of range for device ", device->name(),
429 ". Expected ordinals in range [0, ",
430 device_id_output_size, ") from topology.")));
431 }
432 int64_t global_id = device_id_output[device_ordinal];
433 VLOG(1) << "Setting global/physical id for " << device->name()
434 << " to " << global_id;
435 device->set_xla_global_id(global_id);
436 }
437 }
438 }
439 }
440 VLOG(1) << "InitializeHostForDistributedTpuOp done";
441 }
442
Compute(OpKernelContext * ctx)443 void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
444 VLOG(1) << "SetGlobalTPUArrayOp";
445 XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
446
447 auto tpu_topology = ctx->input(0).scalar<tstring>()();
448 TF_Status* status = TF_NewStatus();
449
450 tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
451 tpu_topology.data(), status);
452
453 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
454 TF_DeleteStatus(status);
455
456 VLOG(1) << "SetGlobalTPUArrayOp done";
457 }
458
Compute(OpKernelContext * ctx)459 void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
460 VLOG(1) << "DisconnectDistributedTpuChipsOp";
461 XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
462
463 TF_Status* status = TF_NewStatus();
464 int32_t number_of_chips_output = 0;
465
466 tpu::OpsApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
467 &number_of_chips_output, status);
468
469 Tensor* ctx_output;
470 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
471 ctx_output->scalar<int32_t>()() = number_of_chips_output;
472
473 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
474 TF_DeleteStatus(status);
475
476 VLOG(1) << "DisconnectDistributedTpuChipsOp done";
477 }
478
479 // These ops execute on the TPU_SYSTEM device only.
480 REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
481 .Device(DEVICE_TPU_SYSTEM)
482 .HostMemory("output"),
483 ConfigureDistributedTpuOp);
484 REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
485 .Device(DEVICE_TPU_SYSTEM)
486 .HostMemory("inputs")
487 .HostMemory("topology"),
488 WaitForDistributedTpuOp);
489 REGISTER_KERNEL_BUILDER(
490 Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
491 ShutdownDistributedTpuOp);
492 REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
493 .Device(DEVICE_TPU_SYSTEM)
494 .HostMemory("input")
495 .HostMemory("tpu_ids"),
496 InitializeHostForDistributedTpuOp);
497 REGISTER_KERNEL_BUILDER(
498 Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
499 SetGlobalTPUArrayOp);
500 REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
501 .Device(DEVICE_TPU_SYSTEM)
502 .HostMemory("number_of_tpu_chips"),
503 DisconnectDistributedTpuChipsOp);
504
505 } // namespace tensorflow
506