xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_configuration_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
17 
18 #include <stdint.h>
19 
20 #include <vector>
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
24 #include "tensorflow/stream_executor/lib/statusor.h"
25 
26 namespace tensorflow {
27 
28 Status CreateTpuCompilationCache(
29     ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache);
30 
31 xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
32     OpKernelContext* ctx);
33 
34 // The ConfigureDistributedTpu op is used to start an TPUDriver from
35 // TensorFlow. It should be run on a TPU_SYSTEM device and returns the
36 // connection host:port for the CompilationCacheServer. The
37 // CompilationCacheServer will remain live until the device's Resource Manager
38 // is cleared or a ShutdownDistributedTpuOp is run on the same device.
39 class ConfigureDistributedTpuOp : public OpKernel {
40  public:
ConfigureDistributedTpuOp(OpKernelConstruction * ctx)41   explicit ConfigureDistributedTpuOp(OpKernelConstruction* ctx)
42       : OpKernel(ctx) {
43     OP_REQUIRES(
44         ctx, ctx->num_inputs() > 0,
45         errors::Internal("_ConfigureDistributedTPU needs at least one input"));
46   }
47   void Compute(OpKernelContext* ctx) override;
~ConfigureDistributedTpuOp()48   ~ConfigureDistributedTpuOp() override {}
49 
50  private:
51   // ConfigureDistributedTpuOp is neither copyable nor movable.
52   ConfigureDistributedTpuOp(const ConfigureDistributedTpuOp&) = delete;
53   ConfigureDistributedTpuOp& operator=(const ConfigureDistributedTpuOp&) =
54       delete;
55 };
56 
57 // The WaitForDistributedTpuOp op is used to block execution until
58 // the distributed Tpu system has started up. It must be run on
59 // the same TPU_SYSTEM device that ConfigureDistributedTpuOp was run
60 // on, after all of the InitializeHostForDistributedTpuOp Ops have
61 // completed.
62 class WaitForDistributedTpuOp : public OpKernel {
63  public:
WaitForDistributedTpuOp(OpKernelConstruction * ctx)64   explicit WaitForDistributedTpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
65     OP_REQUIRES_OK(ctx,
66                    ctx->GetAttr("startup_timeout_sec", &startup_timeout_sec_));
67     OP_REQUIRES(ctx, startup_timeout_sec_ > 0,
68                 errors::InvalidArgument("startup_timeout_sec ",
69                                         startup_timeout_sec_, " must be >0"));
70   }
71   void Compute(OpKernelContext* ctx) override;
~WaitForDistributedTpuOp()72   ~WaitForDistributedTpuOp() override {}
73 
74  private:
75   // The time to wait for all hosts to start up.
76   int startup_timeout_sec_;
77 
78   // WaitForDistributedTpuOp is neither copyable nor movable.
79   WaitForDistributedTpuOp(const WaitForDistributedTpuOp&) = delete;
80   WaitForDistributedTpuOp& operator=(const WaitForDistributedTpuOp&) = delete;
81 };
82 
83 // The ShutdownDistributedTpu op is used to stop a running TPUDriver from
84 // TensorFlow. It should be run on the TPU_SYSTEM device where
85 // ConfigureDistributedTpuOp was run.
86 class ShutdownDistributedTpuOp : public OpKernel {
87  public:
ShutdownDistributedTpuOp(OpKernelConstruction * ctx)88   explicit ShutdownDistributedTpuOp(OpKernelConstruction* ctx)
89       : OpKernel(ctx) {}
90 
91   void Compute(OpKernelContext* ctx) override;
92 
~ShutdownDistributedTpuOp()93   ~ShutdownDistributedTpuOp() override {}
94 
95  private:
96   // ShutdownDistributedTpuOp is neither copyable nor movable.
97   ShutdownDistributedTpuOp(const ShutdownDistributedTpuOp&) = delete;
98   ShutdownDistributedTpuOp& operator=(const ShutdownDistributedTpuOp&) = delete;
99 };
100 
101 // The InitializeHostForDistributedTpu op is used to initialize the
102 // TPUPlatform on a host in a distributed TPU system. It should be
103 // run on every host containing TPU devices before any other Ops that use
104 // TPU are run.
105 class InitializeHostForDistributedTpuOp : public OpKernel {
106  public:
InitializeHostForDistributedTpuOp(OpKernelConstruction * ctx)107   explicit InitializeHostForDistributedTpuOp(OpKernelConstruction* ctx)
108       : OpKernel(ctx) {
109     ctx->GetAttr("enable_whole_mesh_compilations",
110                  &enable_whole_mesh_compilations_)
111         .IgnoreError();
112     ctx->GetAttr("tpu_cancellation_closes_chips",
113                  &tpu_cancellation_closes_chips_)
114         .IgnoreError();
115   }
116 
117   void Compute(OpKernelContext* ctx) override;
118 
~InitializeHostForDistributedTpuOp()119   ~InitializeHostForDistributedTpuOp() override {}
120 
121  private:
122   // InitializeHostForDistributedTpuOp is neither copyable nor movable.
123   InitializeHostForDistributedTpuOp(const InitializeHostForDistributedTpuOp&) =
124       delete;
125   InitializeHostForDistributedTpuOp& operator=(
126       const InitializeHostForDistributedTpuOp&) = delete;
127 
128   bool enable_whole_mesh_compilations_ = false;
129   int tpu_cancellation_closes_chips_ = 0;
130 };
131 
132 // The SetGlobalTPUArray op is used to initialize the TPUPlatform on a
133 // host in a distributed TPU system. It should be run on every host
134 // containing TPU devices before any other Ops that use TPU are run.
135 class SetGlobalTPUArrayOp : public OpKernel {
136  public:
SetGlobalTPUArrayOp(OpKernelConstruction * ctx)137   explicit SetGlobalTPUArrayOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
138 
139   void Compute(OpKernelContext* ctx) override;
140 
~SetGlobalTPUArrayOp()141   ~SetGlobalTPUArrayOp() override {}
142 
143  private:
144   // SetGlobalTPUArrayOp is neither copyable nor movable.
145   SetGlobalTPUArrayOp(const SetGlobalTPUArrayOp&) = delete;
146   SetGlobalTPUArrayOp& operator=(const SetGlobalTPUArrayOp&) = delete;
147 };
148 
149 // The DisconnectDistributedTpuChips op is used to disconnect all the chips on a
150 // host from a running TPUDriver instance. It should be run on every host
151 // containing TPU devices before the ShutdownDistributedTpuOp is run on
152 // the TPU_SYSTEM.
153 class DisconnectDistributedTpuChipsOp : public OpKernel {
154  public:
DisconnectDistributedTpuChipsOp(OpKernelConstruction * ctx)155   explicit DisconnectDistributedTpuChipsOp(OpKernelConstruction* ctx)
156       : OpKernel(ctx) {}
157 
158   void Compute(OpKernelContext* ctx) override;
159 
~DisconnectDistributedTpuChipsOp()160   ~DisconnectDistributedTpuChipsOp() override {}
161 
162  private:
163   // DisconnectDistributedTpuChipsOp is neither copyable nor movable.
164   DisconnectDistributedTpuChipsOp(const DisconnectDistributedTpuChipsOp&) =
165       delete;
166   DisconnectDistributedTpuChipsOp& operator=(
167       const DisconnectDistributedTpuChipsOp&) = delete;
168 };
169 
170 }  // namespace tensorflow
171 
172 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
173