1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_ 17 #define TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_ 18 19 #include "tensorflow/compiler/jit/xla_device.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/util/stream_executor_util.h" 22 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" 23 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h" 24 25 namespace tensorflow { 26 27 class TpuTransferOpInterface { 28 public: ~TpuTransferOpInterface()29 virtual ~TpuTransferOpInterface() {} 30 virtual void Cancel() = 0; 31 virtual StatusOr<int> GetDeviceOrdinal(OpKernelContext* ctx) = 0; 32 33 virtual Status TransferBuffersToInfeed( 34 int device_ordinal, 35 const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) = 0; 36 virtual Status TransferLiteralToInfeed(int device_ordinal, 37 const xla::LiteralSlice& literal) = 0; 38 virtual Status TransferLiteralFromOutfeed( 39 int device_ordinal, xla::MutableBorrowingLiteral literal) = 0; 40 }; 41 42 // Base class providing common functionality for async ops that transfer from 43 // host to TPU. 44 class TpuTransferAsyncOpKernelBase : public AsyncOpKernel { 45 public: 46 explicit TpuTransferAsyncOpKernelBase( 47 OpKernelConstruction* ctx, const string& transfer_type, 48 int number_of_threads, 49 std::unique_ptr<TpuTransferOpInterface> transfer_op); 50 51 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; 52 53 protected: 54 virtual Status DoWork(OpKernelContext* context, int device_ordinal) = 0; 55 56 Status RunTransferWithOrdinal(OpKernelContext* ctx, int device_ordinal); 57 std::string transfer_type_; 58 std::unique_ptr<TpuTransferOpInterface> transfer_op_; 59 60 private: 61 virtual Status RunTransfer(OpKernelContext* ctx) = 0; 62 63 std::unique_ptr<thread::ThreadPool> thread_pool_; 64 mutex mu_; 65 66 // TpuTransferAsyncOpKernelBase is neither copyable nor movable. 67 TpuTransferAsyncOpKernelBase(const TpuTransferAsyncOpKernelBase&) = delete; 68 TpuTransferAsyncOpKernelBase& operator=(const TpuTransferAsyncOpKernelBase&) = 69 delete; 70 }; 71 72 class TpuTransferAsyncOpKernel : public TpuTransferAsyncOpKernelBase { 73 public: 74 explicit TpuTransferAsyncOpKernel( 75 OpKernelConstruction* ctx, const string& transfer_type, 76 int number_of_threads, 77 std::unique_ptr<TpuTransferOpInterface> transfer_op); 78 79 private: 80 Status RunTransfer(OpKernelContext* ctx) override; 81 int device_ordinal_; 82 83 // TpuTransferAsyncOpKernel is neither copyable nor movable. 84 TpuTransferAsyncOpKernel(const TpuTransferAsyncOpKernel&) = delete; 85 TpuTransferAsyncOpKernel& operator=(const TpuTransferAsyncOpKernel&) = delete; 86 }; 87 88 class TpuTransferAsyncDynamicOrdinalOpKernel 89 : public TpuTransferAsyncOpKernelBase { 90 public: 91 explicit TpuTransferAsyncDynamicOrdinalOpKernel( 92 OpKernelConstruction* ctx, const string& transfer_type, 93 int number_of_threads, 94 std::unique_ptr<TpuTransferOpInterface> transfer_op); 95 96 private: 97 Status RunTransfer(OpKernelContext* ctx) override; 98 99 // TpuTransferAsyncDynamicOpKernel is neither copyable nor movable. 100 TpuTransferAsyncDynamicOrdinalOpKernel( 101 const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete; 102 TpuTransferAsyncDynamicOrdinalOpKernel& operator=( 103 const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete; 104 }; 105 106 class StreamExecutorTransferOpImpl : public TpuTransferOpInterface { 107 public: 108 explicit StreamExecutorTransferOpImpl(); 109 ~StreamExecutorTransferOpImpl() override = default; 110 void Cancel() override; 111 StatusOr<int> GetDeviceOrdinal(OpKernelContext* ctx) override; 112 113 Status TransferBuffersToInfeed( 114 int device_ordinal, 115 const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) override; 116 Status TransferLiteralToInfeed(int device_ordinal, 117 const xla::LiteralSlice& literal) override; 118 119 Status TransferLiteralFromOutfeed( 120 int device_ordinal, xla::MutableBorrowingLiteral literal) override; 121 122 private: 123 StatusOr<stream_executor::StreamExecutor*> GetStreamExecutor( 124 int device_ordinal); 125 xla::TpuTransferManagerInterface* transfer_manager_; 126 tpu::TpuPlatformInterface* tpu_platform_; 127 }; 128 129 } // namespace tensorflow 130 131 #endif // TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_ 132