xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/transfer_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/kernels/transfer_ops.h"
17 
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/kernels/ops_util.h"
21 #include "tensorflow/core/platform/tracing.h"
22 #include "tensorflow/core/profiler/lib/connected_traceme.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24 #include "tensorflow/core/profiler/lib/traceme_encode.h"
25 #include "tensorflow/stream_executor/multi_platform_manager.h"
26 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
27 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
28 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
29 
30 namespace tensorflow {
31 
TpuTransferAsyncOpKernelBase(OpKernelConstruction * ctx,const string & transfer_type,int number_of_threads,std::unique_ptr<TpuTransferOpInterface> transfer_op)32 TpuTransferAsyncOpKernelBase::TpuTransferAsyncOpKernelBase(
33     OpKernelConstruction* ctx, const string& transfer_type,
34     int number_of_threads, std::unique_ptr<TpuTransferOpInterface> transfer_op)
35     : AsyncOpKernel(ctx),
36       transfer_type_(transfer_type),
37       transfer_op_(std::move(transfer_op)),
38       thread_pool_(new thread::ThreadPool(
39           ctx->env(),
40           strings::StrCat(transfer_type, "_thread_",
41                           SanitizeThreadSuffix(def().name())),
42           /*num_threads=*/8)) {}
43 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)44 void TpuTransferAsyncOpKernelBase::ComputeAsync(OpKernelContext* ctx,
45                                                 DoneCallback done) {
46   profiler::TraceMeProducer schedule_activity(
47       "TpuTransferAsyncOpKernelBase::ComputeAsync");
48   CancellationToken token =
49       ctx->cancellation_manager()->get_cancellation_token();
50   bool already_cancelled;
51   {
52     // Only protect registering the cancellation callback as mu_ cannot be held
53     // at a point where `done` could be called.
54     mutex_lock lock(mu_);
55     already_cancelled =
56         !ctx->cancellation_manager()->RegisterCallback(token, [this]() {
57           mutex_lock lock(mu_);
58           transfer_op_->Cancel();
59         });
60   }
61   OP_REQUIRES_ASYNC(ctx, !already_cancelled,
62                     errors::Cancelled("Infeed was cancelled."), done);
63   thread_pool_->Schedule(
64       [this, ctx, done, token,
65        traceme_context_id = schedule_activity.GetContextId()]() {
66         profiler::TraceMeConsumer compute_activity(
67             [this] { return profiler::TraceMeOp(name(), type_string()); },
68             traceme_context_id);
69         Status s = RunTransfer(ctx);
70         ctx->cancellation_manager()->DeregisterCallback(token);
71         OP_REQUIRES_OK_ASYNC(ctx, s, done);
72         done();
73       });
74 }
75 
RunTransferWithOrdinal(OpKernelContext * ctx,int device_ordinal)76 Status TpuTransferAsyncOpKernelBase::RunTransferWithOrdinal(
77     OpKernelContext* ctx, int device_ordinal) {
78 
79   int real_device_ordinal = device_ordinal;
80   if (real_device_ordinal < 0) {
81     TF_ASSIGN_OR_RETURN(real_device_ordinal,
82                         transfer_op_->GetDeviceOrdinal(ctx));
83   }
84 
85   profiler::TraceMe activity(
86       [real_device_ordinal] {
87         return profiler::TraceMeEncode(
88             "RunTransferWithOrdinal",
89             {{"device_ordinal", real_device_ordinal}});
90       },
91       profiler::kInfo);
92   return DoWork(ctx, real_device_ordinal);
93 }
94 
TpuTransferAsyncOpKernel(OpKernelConstruction * ctx,const string & transfer_type,int number_of_threads,std::unique_ptr<TpuTransferOpInterface> transfer_op)95 TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(
96     OpKernelConstruction* ctx, const string& transfer_type,
97     int number_of_threads, std::unique_ptr<TpuTransferOpInterface> transfer_op)
98     : TpuTransferAsyncOpKernelBase(ctx, transfer_type, number_of_threads,
99                                    std::move(transfer_op)) {
100   OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
101   if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
102     OP_REQUIRES(
103         ctx, device_ordinal_ >= 0,
104         errors::InvalidArgument(transfer_type,
105                                 " ops must specify a device_ordinal when "
106                                 "placed on CPU."));
107   }
108 }
109 
RunTransfer(OpKernelContext * ctx)110 Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
111   return RunTransferWithOrdinal(ctx, device_ordinal_);
112 }
113 
TpuTransferAsyncDynamicOrdinalOpKernel(OpKernelConstruction * ctx,const string & transfer_type,int number_of_threads,std::unique_ptr<TpuTransferOpInterface> transfer_op)114 TpuTransferAsyncDynamicOrdinalOpKernel::TpuTransferAsyncDynamicOrdinalOpKernel(
115     OpKernelConstruction* ctx, const string& transfer_type,
116     int number_of_threads, std::unique_ptr<TpuTransferOpInterface> transfer_op)
117     : TpuTransferAsyncOpKernelBase(ctx, transfer_type, number_of_threads,
118                                    std::move(transfer_op)) {}
119 
RunTransfer(OpKernelContext * ctx)120 Status TpuTransferAsyncDynamicOrdinalOpKernel::RunTransfer(
121     OpKernelContext* ctx) {
122   const Tensor& device_ordinal_tensor = ctx->input(0);
123   const int device_ordinal = device_ordinal_tensor.scalar<int32>()();
124   XlaDevice* xla_device =
125       dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
126   if (((xla_device == nullptr) || (xla_device->device_type() == DEVICE_CPU)) &&
127       (device_ordinal < 0)) {
128     return errors::InvalidArgument(transfer_type_,
129                                    " ops must specify a device_ordinal when "
130                                    "placed on CPU.");
131   }
132   return RunTransferWithOrdinal(ctx, device_ordinal);
133 }
134 
StreamExecutorTransferOpImpl()135 StreamExecutorTransferOpImpl::StreamExecutorTransferOpImpl()
136     : transfer_manager_(
137           xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager()),
138       tpu_platform_(tpu::TpuPlatformInterface::GetRegisteredPlatform(
139           /*initialize_platform=*/false)) {}
140 
Cancel()141 void StreamExecutorTransferOpImpl::Cancel() {
142   TF_CHECK_OK(tpu::TpuNodeContext::CloseTpuHost());
143 }
144 
GetDeviceOrdinal(OpKernelContext * ctx)145 StatusOr<int> StreamExecutorTransferOpImpl::GetDeviceOrdinal(
146     OpKernelContext* ctx) {
147   const XlaDevice::Metadata* metadata;
148   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
149   return metadata->device_ordinal();
150 }
151 
TransferBuffersToInfeed(int device_ordinal,const std::deque<tensorflow::tpu::NoncopyableBuffer> & buffers)152 Status StreamExecutorTransferOpImpl::TransferBuffersToInfeed(
153     int device_ordinal,
154     const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) {
155   TF_ASSIGN_OR_RETURN(auto* executor, GetStreamExecutor(device_ordinal));
156   return transfer_manager_->TransferBuffersToInfeed(executor, buffers);
157 }
158 
TransferLiteralToInfeed(int device_ordinal,const xla::LiteralSlice & literal)159 Status StreamExecutorTransferOpImpl::TransferLiteralToInfeed(
160     int device_ordinal, const xla::LiteralSlice& literal) {
161   TF_ASSIGN_OR_RETURN(auto* executor, GetStreamExecutor(device_ordinal));
162   return transfer_manager_->TransferLiteralToInfeed(executor, literal);
163 }
164 
TransferLiteralFromOutfeed(int device_ordinal,xla::MutableBorrowingLiteral literal)165 Status StreamExecutorTransferOpImpl::TransferLiteralFromOutfeed(
166     int device_ordinal, xla::MutableBorrowingLiteral literal) {
167   TF_ASSIGN_OR_RETURN(auto* executor, GetStreamExecutor(device_ordinal));
168   return transfer_manager_->TransferLiteralFromOutfeed(executor, literal);
169 }
170 
171 StatusOr<stream_executor::StreamExecutor*>
GetStreamExecutor(int device_ordinal)172 StreamExecutorTransferOpImpl::GetStreamExecutor(int device_ordinal) {
173   return tpu_platform_->ExecutorForDevice(device_ordinal);
174 }
175 
176 }  // namespace tensorflow
177