xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
16 
17 #include <cstdint>
18 
19 #include "absl/cleanup/cleanup.h"
20 #include "tensorflow/c/tf_status.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/platform/refcount.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
30 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
31 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
32 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
33 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
34 #include "tensorflow/core/tpu/kernels/tpu_embedding_engine_state_interface.h"
35 #include "tensorflow/core/tpu/kernels/tpu_execute_op_options.h"
36 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
37 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
38 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
39 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
40 #include "tensorflow/core/tpu/tpu_api.h"
41 #include "tensorflow/core/tpu/tpu_configuration.h"
42 #include "tensorflow/core/tpu/tpu_defs.h"
43 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
44 #include "tensorflow/stream_executor/stream.h"
45 #include "tensorflow/stream_executor/tpu/proto_helper.h"
46 
47 namespace tensorflow {
48 namespace {
GetTpuMeshStateInterface(const ResourceMgr * rmgr,tpu::TpuMeshStateInterface ** state)49 Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
50                                 tpu::TpuMeshStateInterface** state) {
51   if (!rmgr->Lookup(rmgr->default_container(),
52                     tpu::kTpuMeshStateInterfaceResourceName, state)
53            .ok()) {
54     return errors::FailedPrecondition(
55         "GetTpuMeshStateInterface: The TPU system has not been initialized.");
56   }
57   return OkStatus();
58 }
59 
CreateTpuFingerprintLookup(ResourceMgr * rmgr)60 Status CreateTpuFingerprintLookup(ResourceMgr* rmgr) {
61   VLOG(1) << "CreateTpuFingerprintLookup";
62   tpu::TpuFingerprintLookup* fingerprint_lookup;
63   TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<tpu::TpuFingerprintLookup>(
64       rmgr->default_container(), tpu::kFingerprintLookupResourceName,
65       &fingerprint_lookup, [&](tpu::TpuFingerprintLookup** new_lookup) {
66         *new_lookup = tpu::TpuFingerprintLookup::Create();
67         return OkStatus();
68       }));
69 
70   core::ScopedUnref fingerprint_lookup_ref(fingerprint_lookup);
71   return OkStatus();
72 }
73 
74 // Attempt to delete resource_name from resource_manager's default_container.
75 // Returns OK if the deletion succeeded, or if the resource was not found. Else
76 // return the deletion error.
77 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)78 Status DeleteIfExists(ResourceMgr* resource_manager,
79                       const char* resource_name) {
80   VLOG(1) << "Removing resource " << resource_name << " if it exists";
81   Status status = resource_manager->Delete<ResourceT>(
82       resource_manager->default_container(), resource_name);
83   if (status.ok()) {
84     VLOG(1) << "Removed existing resource " << resource_name;
85     return OkStatus();
86   }
87   if (status.code() == error::NOT_FOUND) {
88     VLOG(1) << "No resource " << resource_name << " to remove";
89     return OkStatus();
90   }
91   VLOG(1) << "Error removing resource " << resource_name << " : " << status;
92   return status;
93 }
94 }  // namespace
95 
CreateTpuCompilationCache(ResourceMgr * rmgr,tpu::TpuCompilationCacheInterface ** compilation_cache)96 Status CreateTpuCompilationCache(
97     ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) {
98   return rmgr->LookupOrCreate<tpu::TpuCompilationCacheInterface>(
99       rmgr->default_container(), tpu::kCompilationCacheResourceName,
100       compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) {
101         *new_cache = tpu::GetCompilationCacheCreateFn()();
102         return OkStatus();
103       });
104 }
105 
ConstructDevicesPerHost(OpKernelContext * ctx)106 xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
107     OpKernelContext* ctx) {
108   std::vector<int32_t> num_devices_per_host;
109   int chips_per_host = -1;
110   for (int i = 0; i < ctx->num_inputs(); ++i) {
111     const Tensor& input_tensor = ctx->input(i);
112     if (!TensorShapeUtils::IsScalar(input_tensor.shape())) {
113       return errors::InvalidArgument("Input ", i,
114                                      " should be a scalar but has ",
115                                      input_tensor.dims(), " dimensions");
116     }
117     if (chips_per_host == -1) {
118       chips_per_host = input_tensor.scalar<int32_t>()();
119     } else {
120       if (chips_per_host != input_tensor.scalar<int32>()()) {
121         return errors::Internal("Host ", i, " has ",
122                                 input_tensor.scalar<int32>()(),
123                                 " TPU chips but host 0 has ", chips_per_host);
124       }
125     }
126     num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
127   }
128   return num_devices_per_host;
129 }
130 
Compute(OpKernelContext * ctx)131 void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
132   VLOG(1) << "ConfigureDistributedTpuOp";
133   XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
134 
135   xla::StatusOr<std::vector<int32_t>> num_devices_per_host =
136       ConstructDevicesPerHost(ctx);
137   OP_REQUIRES_OK(ctx, num_devices_per_host.status());
138   ResourceMgr* rmgr = GetTPUConfigResourceMgr();
139 
140   // Create the subgraph compilation cache and put it in the local resource
141   // manager.
142   tpu::TpuCompilationCacheInterface* compilation_cache;
143   OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
144   core::ScopedUnref compilation_cache_ref(compilation_cache);
145 
146   std::string host_config_output;
147   OP_REQUIRES_OK(
148       ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache,
149                                 &host_config_output));
150 
151   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
152                           rmgr, tpu::kTpuMeshStateInterfaceResourceName));
153 
154   auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
155   OP_REQUIRES_OK(
156       ctx, rmgr->Create(rmgr->default_container(),
157                         tpu::kTpuMeshStateInterfaceResourceName, tpu_mesh));
158 
159   Tensor* ctx_output;
160   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
161   ctx_output->scalar<tstring>()() = std::move(host_config_output);
162 
163   OP_REQUIRES_OK(ctx, CreateTpuFingerprintLookup(rmgr));
164   VLOG(1) << "ConfigureDistributedTpuOp done";
165 }
166 
Compute(OpKernelContext * ctx)167 void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
168   VLOG(1) << "WaitForDistributedTpuOp";
169   XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
170 
171   size_t num_devices_per_host = -1;
172   size_t num_hosts = ctx->num_inputs();
173 
174   for (int i = 0; i < ctx->num_inputs(); ++i) {
175     const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
176     OP_REQUIRES(
177         ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
178         errors::InvalidArgument("Input ", i, " should be a vector but has ",
179                                 host_ordinal_to_global_device_id_tensor.dims(),
180                                 " dimensions"));
181   }
182 
183   std::vector<std::vector<int32_t>> mapping;
184   std::vector<int32_t*> mapping_arg;
185 
186   mapping.resize(ctx->num_inputs());
187 
188   for (int i = 0; i < ctx->num_inputs(); ++i) {
189     const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
190     const auto host_ordinal_to_global_device_id =
191         host_ordinal_to_global_device_id_tensor.flat<int>();
192     if (num_devices_per_host == -1) {
193       num_devices_per_host =
194           host_ordinal_to_global_device_id_tensor.dim_size(0);
195     } else {
196       OP_REQUIRES(ctx,
197                   num_devices_per_host ==
198                       host_ordinal_to_global_device_id_tensor.dim_size(0),
199                   errors::Internal(
200                       "Host ", i, " has ",
201                       host_ordinal_to_global_device_id_tensor.dim_size(0),
202                       " TPU devices but host 0 has ", num_devices_per_host));
203     }
204     for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
205          ++j) {
206       int32_t global_device_id = host_ordinal_to_global_device_id(j);
207       mapping[i].push_back(global_device_id);
208     }
209     mapping_arg.push_back(mapping[i].data());
210   }
211 
212   tpu::TpuMeshStateInterface* mesh_state;
213   auto* rmgr = GetTPUConfigResourceMgr();
214   OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
215   core::ScopedUnref mesh_state_unref(mesh_state);
216 
217   // TODO(b/166858751): this code to check if `TpuPodState` exists is ported
218   // from a legacy library that may have staled. A candidate for cleanup.
219   TpuPodState* pod_state;
220   OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state));
221   core::ScopedUnref pod_state_unref(pod_state);
222 
223   size_t tpu_topology_output_size;
224   char* tpu_topology_output = nullptr;
225   TF_Status* status = TF_NewStatus();
226   auto cleanup = absl::MakeCleanup([&status, &tpu_topology_output]() {
227     TF_DeleteStatus(status);
228     tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
229   });
230 
231   auto* mesh_common_state = mesh_state->mesh_common_state();
232 
233   WaitForDistributedTpuOp_DoWork_Params params;
234   params.struct_size = WaitForDistributedTpuOp_DoWork_Params_SIZE;
235   params.priv = nullptr;
236   params.num_hosts = num_hosts;
237   params.num_cores_per_host = num_devices_per_host;
238   params.host_ordinal_to_global_core_id_map =
239       const_cast<const int32_t**>(mapping_arg.data());
240   params.tpu_mesh_common_state = mesh_common_state;
241   params.tpu_topology_output_size = &tpu_topology_output_size;
242   params.tpu_topology_output = &tpu_topology_output;
243   params.status = status;
244 
245   tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(&params);
246 
247   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
248 
249   Tensor* ctx_output;
250   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
251   ctx_output->scalar<tstring>()() =
252       std::string(tpu_topology_output, tpu_topology_output_size);
253 
254   VLOG(1) << "WaitForDistributedTpuOp done";
255 }
256 
Compute(OpKernelContext * ctx)257 void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
258   VLOG(1) << "ShutdownDistributedTpuOp";
259   XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
260 
261   auto* rmgr = GetTPUConfigResourceMgr();
262   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
263                           rmgr, tpu::kTpuMeshStateInterfaceResourceName));
264 
265   OP_REQUIRES_OK(ctx,
266                  DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
267   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
268                           rmgr, tpu::kCompilationCacheResourceName));
269 
270   VLOG(1) << "ShutdownDistributedTpuOp done";
271 }
272 
Compute(OpKernelContext * ctx)273 void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
274   VLOG(1) << "InitializeHostForDistributedTpuOp";
275   XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
276 
277   auto* rmgr = GetTPUConfigResourceMgr();
278   auto tpu_host_config = ctx->input(0).scalar<tstring>()();
279 
280   // Reset the TPU embedding engine interface if we are not the master.
281   // We need to reset the interface before initializing the host because the
282   // resetting process reset the TPU platform.
283   OP_REQUIRES_OK(ctx,
284                  DeleteIfExists<tpu::TpuEmbeddingEngineStateInterface>(
285                      rmgr, tpu::kTpuEmbeddingEngineStateInterfaceResourceName));
286 
287   bool is_master_worker =
288       tpu::OpsApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
289   if (!is_master_worker) {
290     // Reset the mesh interface if we are not the master.
291     OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
292                             rmgr, tpu::kTpuMeshStateInterfaceResourceName));
293     auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create();
294     OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
295                                      tpu::kTpuMeshStateInterfaceResourceName,
296                                      mesh_state_interface));
297   }
298 
299   VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
300   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheLookup>(
301                           rmgr, tpu::kCompiledProtoCacheResourceName));
302 
303   if (enable_whole_mesh_compilations_) {
304     // If this is a whole mesh compilation mode, create the compilation cache,
305     // if missing.
306     tpu::TpuCompilationCacheInterface* compilation_cache;
307     OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
308     compilation_cache->Unref();
309   }
310 
311   OP_REQUIRES_OK(ctx, internal::SetTpuCancellationClosesChips(
312                           tpu_cancellation_closes_chips_));
313 
314   tpu::TpuCompilationCacheInterface* local_compilation_cache;
315   Status s = rmgr->Lookup(rmgr->default_container(),
316                           tpu::kCompilationCacheResourceName,
317                           &local_compilation_cache);
318   if (!s.ok()) {
319     local_compilation_cache = nullptr;
320   }
321 
322   TF_Status* status = TF_NewStatus();
323   size_t device_id_output_size;
324   int32_t* device_id_output = nullptr;
325   auto cleanup = absl::MakeCleanup([&status, &device_id_output]() {
326     TF_DeleteStatus(status);
327     tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
328   });
329 
330   InitializeHostForDistributedTpuOp_DoWork_Params params;
331   params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
332   params.priv = nullptr;
333   params.tpu_host_config_size = tpu_host_config.size();
334   params.tpu_host_config = tpu_host_config.data();
335   params.enable_whole_mesh_compilations = enable_whole_mesh_compilations_;
336   params.is_master_worker = is_master_worker;
337   params.core_id_output_size = &device_id_output_size;
338   params.core_id_output = &device_id_output;
339   params.status = status;
340 
341   tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(&params);
342   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
343 
344   if (local_compilation_cache != nullptr) {
345     local_compilation_cache->Unref();
346 
347     tpu::TpuCompilationCacheLookup* proto_lookup;
348     proto_lookup =
349         new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
350     OP_REQUIRES_OK(
351         ctx, rmgr->Create(rmgr->default_container(),
352                           tpu::kCompiledProtoCacheResourceName, proto_lookup));
353   } else {
354     int64_t cache_size_bytes;
355     tpu::OpsApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
356         &cache_size_bytes);
357 
358     char* server_address_output = nullptr;
359     auto cleanup_server_address = absl::MakeCleanup([&server_address_output]() {
360       tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(
361           server_address_output);
362     });
363     size_t server_address_output_size;
364 
365     TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params params;
366     params.struct_size =
367         TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE;
368     params.priv = nullptr;
369     params.tpu_host_config_size = tpu_host_config.size();
370     params.tpu_host_config = tpu_host_config.data();
371     params.server_address_output_size = &server_address_output_size;
372     params.server_address_output = &server_address_output;
373     params.status = status;
374 
375     tpu::OpsApiFn()
376         ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
377             &params);
378     OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
379 
380     std::string server_address(server_address_output,
381                                server_address_output_size);
382     tpu::TpuCompilationCacheLookup* proto_lookup =
383         new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes);
384     OP_REQUIRES_OK(
385         ctx, rmgr->Create(rmgr->default_container(),
386                           tpu::kCompiledProtoCacheResourceName, proto_lookup));
387   }
388 
389   auto* engine_state_interface =
390       tpu::TpuEmbeddingEngineStateInterface::Create();
391   OP_REQUIRES_OK(
392       ctx, rmgr->Create(rmgr->default_container(),
393                         tpu::kTpuEmbeddingEngineStateInterfaceResourceName,
394                         engine_state_interface));
395 
396   Tensor* ctx_output;
397   OP_REQUIRES_OK(
398       ctx, ctx->allocate_output(
399                0, TensorShape({static_cast<long long>(device_id_output_size)}),
400                &ctx_output));
401 
402   for (size_t i = 0; i < device_id_output_size; ++i) {
403     ctx_output->flat<int32>()(i) = device_id_output[i];
404   }
405   if (ctx->function_library() != nullptr &&
406       ctx->function_library()->device_mgr() != nullptr) {
407     // If a DeviceMgr is available, set global IDs for TPU devices from the
408     // topology.
409     DeviceBase* tpu_system_device = ctx->device();
410     const DeviceNameUtils::ParsedName& tpu_system_name =
411         tpu_system_device->parsed_name();
412     for (DeviceBase* device :
413          ctx->function_library()->device_mgr()->ListDevices()) {
414       const DeviceNameUtils::ParsedName& device_parsed_name =
415           device->parsed_name();
416       if (device_parsed_name.type == "TPU" &&
417           DeviceNameUtils::IsSameAddressSpace(tpu_system_name,
418                                               device_parsed_name)) {
419         const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
420             device->tensorflow_accelerator_device_info();
421         if (accelerator_device_info && accelerator_device_info->stream) {
422           int device_ordinal =
423               accelerator_device_info->stream->parent()->device_ordinal();
424           if (device_ordinal >= device_id_output_size) {
425             OP_REQUIRES_OK(ctx,
426                            errors::Internal(absl::StrCat(
427                                "TPU core with ordinal ", device_ordinal,
428                                " out of range for device ", device->name(),
429                                ". Expected ordinals in range [0, ",
430                                device_id_output_size, ") from topology.")));
431           }
432           int64_t global_id = device_id_output[device_ordinal];
433           VLOG(1) << "Setting global/physical id for " << device->name()
434                   << " to " << global_id;
435           device->set_xla_global_id(global_id);
436         }
437       }
438     }
439   }
440   VLOG(1) << "InitializeHostForDistributedTpuOp done";
441 }
442 
Compute(OpKernelContext * ctx)443 void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
444   VLOG(1) << "SetGlobalTPUArrayOp";
445   XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
446 
447   auto tpu_topology = ctx->input(0).scalar<tstring>()();
448   TF_Status* status = TF_NewStatus();
449 
450   tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
451                                                 tpu_topology.data(), status);
452 
453   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
454   TF_DeleteStatus(status);
455 
456   VLOG(1) << "SetGlobalTPUArrayOp done";
457 }
458 
Compute(OpKernelContext * ctx)459 void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
460   VLOG(1) << "DisconnectDistributedTpuChipsOp";
461   XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
462 
463   TF_Status* status = TF_NewStatus();
464   int32_t number_of_chips_output = 0;
465 
466   tpu::OpsApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
467       &number_of_chips_output, status);
468 
469   Tensor* ctx_output;
470   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
471   ctx_output->scalar<int32_t>()() = number_of_chips_output;
472 
473   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
474   TF_DeleteStatus(status);
475 
476   VLOG(1) << "DisconnectDistributedTpuChipsOp done";
477 }
478 
479 // These ops execute on the TPU_SYSTEM device only.
480 REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
481                             .Device(DEVICE_TPU_SYSTEM)
482                             .HostMemory("output"),
483                         ConfigureDistributedTpuOp);
484 REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
485                             .Device(DEVICE_TPU_SYSTEM)
486                             .HostMemory("inputs")
487                             .HostMemory("topology"),
488                         WaitForDistributedTpuOp);
489 REGISTER_KERNEL_BUILDER(
490     Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
491     ShutdownDistributedTpuOp);
492 REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
493                             .Device(DEVICE_TPU_SYSTEM)
494                             .HostMemory("input")
495                             .HostMemory("tpu_ids"),
496                         InitializeHostForDistributedTpuOp);
497 REGISTER_KERNEL_BUILDER(
498     Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
499     SetGlobalTPUArrayOp);
500 REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
501                             .Device(DEVICE_TPU_SYSTEM)
502                             .HostMemory("number_of_tpu_chips"),
503                         DisconnectDistributedTpuChipsOp);
504 
505 }  // namespace tensorflow
506