xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/transfer_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
17 #define TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
18 
19 #include "tensorflow/compiler/jit/xla_device.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/util/stream_executor_util.h"
22 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
23 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
24 
25 namespace tensorflow {
26 
27 class TpuTransferOpInterface {
28  public:
~TpuTransferOpInterface()29   virtual ~TpuTransferOpInterface() {}
30   virtual void Cancel() = 0;
31   virtual StatusOr<int> GetDeviceOrdinal(OpKernelContext* ctx) = 0;
32 
33   virtual Status TransferBuffersToInfeed(
34       int device_ordinal,
35       const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) = 0;
36   virtual Status TransferLiteralToInfeed(int device_ordinal,
37                                          const xla::LiteralSlice& literal) = 0;
38   virtual Status TransferLiteralFromOutfeed(
39       int device_ordinal, xla::MutableBorrowingLiteral literal) = 0;
40 };
41 
42 // Base class providing common functionality for async ops that transfer from
43 // host to TPU.
44 class TpuTransferAsyncOpKernelBase : public AsyncOpKernel {
45  public:
46   explicit TpuTransferAsyncOpKernelBase(
47       OpKernelConstruction* ctx, const string& transfer_type,
48       int number_of_threads,
49       std::unique_ptr<TpuTransferOpInterface> transfer_op);
50 
51   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
52 
53  protected:
54   virtual Status DoWork(OpKernelContext* context, int device_ordinal) = 0;
55 
56   Status RunTransferWithOrdinal(OpKernelContext* ctx, int device_ordinal);
57   std::string transfer_type_;
58   std::unique_ptr<TpuTransferOpInterface> transfer_op_;
59 
60  private:
61   virtual Status RunTransfer(OpKernelContext* ctx) = 0;
62 
63   std::unique_ptr<thread::ThreadPool> thread_pool_;
64   mutex mu_;
65 
66   // TpuTransferAsyncOpKernelBase is neither copyable nor movable.
67   TpuTransferAsyncOpKernelBase(const TpuTransferAsyncOpKernelBase&) = delete;
68   TpuTransferAsyncOpKernelBase& operator=(const TpuTransferAsyncOpKernelBase&) =
69       delete;
70 };
71 
72 class TpuTransferAsyncOpKernel : public TpuTransferAsyncOpKernelBase {
73  public:
74   explicit TpuTransferAsyncOpKernel(
75       OpKernelConstruction* ctx, const string& transfer_type,
76       int number_of_threads,
77       std::unique_ptr<TpuTransferOpInterface> transfer_op);
78 
79  private:
80   Status RunTransfer(OpKernelContext* ctx) override;
81   int device_ordinal_;
82 
83   // TpuTransferAsyncOpKernel is neither copyable nor movable.
84   TpuTransferAsyncOpKernel(const TpuTransferAsyncOpKernel&) = delete;
85   TpuTransferAsyncOpKernel& operator=(const TpuTransferAsyncOpKernel&) = delete;
86 };
87 
88 class TpuTransferAsyncDynamicOrdinalOpKernel
89     : public TpuTransferAsyncOpKernelBase {
90  public:
91   explicit TpuTransferAsyncDynamicOrdinalOpKernel(
92       OpKernelConstruction* ctx, const string& transfer_type,
93       int number_of_threads,
94       std::unique_ptr<TpuTransferOpInterface> transfer_op);
95 
96  private:
97   Status RunTransfer(OpKernelContext* ctx) override;
98 
99   // TpuTransferAsyncDynamicOpKernel is neither copyable nor movable.
100   TpuTransferAsyncDynamicOrdinalOpKernel(
101       const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete;
102   TpuTransferAsyncDynamicOrdinalOpKernel& operator=(
103       const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete;
104 };
105 
106 class StreamExecutorTransferOpImpl : public TpuTransferOpInterface {
107  public:
108   explicit StreamExecutorTransferOpImpl();
109   ~StreamExecutorTransferOpImpl() override = default;
110   void Cancel() override;
111   StatusOr<int> GetDeviceOrdinal(OpKernelContext* ctx) override;
112 
113   Status TransferBuffersToInfeed(
114       int device_ordinal,
115       const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) override;
116   Status TransferLiteralToInfeed(int device_ordinal,
117                                  const xla::LiteralSlice& literal) override;
118 
119   Status TransferLiteralFromOutfeed(
120       int device_ordinal, xla::MutableBorrowingLiteral literal) override;
121 
122  private:
123   StatusOr<stream_executor::StreamExecutor*> GetStreamExecutor(
124       int device_ordinal);
125   xla::TpuTransferManagerInterface* transfer_manager_;
126   tpu::TpuPlatformInterface* tpu_platform_;
127 };
128 
129 }  // namespace tensorflow
130 
131 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
132