1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/kernels/transfer_ops.h"
17
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/kernels/ops_util.h"
21 #include "tensorflow/core/platform/tracing.h"
22 #include "tensorflow/core/profiler/lib/connected_traceme.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24 #include "tensorflow/core/profiler/lib/traceme_encode.h"
25 #include "tensorflow/stream_executor/multi_platform_manager.h"
26 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
27 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
28 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
29
30 namespace tensorflow {
31
TpuTransferAsyncOpKernelBase(OpKernelConstruction * ctx,const string & transfer_type,int number_of_threads,std::unique_ptr<TpuTransferOpInterface> transfer_op)32 TpuTransferAsyncOpKernelBase::TpuTransferAsyncOpKernelBase(
33 OpKernelConstruction* ctx, const string& transfer_type,
34 int number_of_threads, std::unique_ptr<TpuTransferOpInterface> transfer_op)
35 : AsyncOpKernel(ctx),
36 transfer_type_(transfer_type),
37 transfer_op_(std::move(transfer_op)),
38 thread_pool_(new thread::ThreadPool(
39 ctx->env(),
40 strings::StrCat(transfer_type, "_thread_",
41 SanitizeThreadSuffix(def().name())),
42 /*num_threads=*/8)) {}
43
ComputeAsync(OpKernelContext * ctx,DoneCallback done)44 void TpuTransferAsyncOpKernelBase::ComputeAsync(OpKernelContext* ctx,
45 DoneCallback done) {
46 profiler::TraceMeProducer schedule_activity(
47 "TpuTransferAsyncOpKernelBase::ComputeAsync");
48 CancellationToken token =
49 ctx->cancellation_manager()->get_cancellation_token();
50 bool already_cancelled;
51 {
52 // Only protect registering the cancellation callback as mu_ cannot be held
53 // at a point where `done` could be called.
54 mutex_lock lock(mu_);
55 already_cancelled =
56 !ctx->cancellation_manager()->RegisterCallback(token, [this]() {
57 mutex_lock lock(mu_);
58 transfer_op_->Cancel();
59 });
60 }
61 OP_REQUIRES_ASYNC(ctx, !already_cancelled,
62 errors::Cancelled("Infeed was cancelled."), done);
63 thread_pool_->Schedule(
64 [this, ctx, done, token,
65 traceme_context_id = schedule_activity.GetContextId()]() {
66 profiler::TraceMeConsumer compute_activity(
67 [this] { return profiler::TraceMeOp(name(), type_string()); },
68 traceme_context_id);
69 Status s = RunTransfer(ctx);
70 ctx->cancellation_manager()->DeregisterCallback(token);
71 OP_REQUIRES_OK_ASYNC(ctx, s, done);
72 done();
73 });
74 }
75
RunTransferWithOrdinal(OpKernelContext * ctx,int device_ordinal)76 Status TpuTransferAsyncOpKernelBase::RunTransferWithOrdinal(
77 OpKernelContext* ctx, int device_ordinal) {
78
79 int real_device_ordinal = device_ordinal;
80 if (real_device_ordinal < 0) {
81 TF_ASSIGN_OR_RETURN(real_device_ordinal,
82 transfer_op_->GetDeviceOrdinal(ctx));
83 }
84
85 profiler::TraceMe activity(
86 [real_device_ordinal] {
87 return profiler::TraceMeEncode(
88 "RunTransferWithOrdinal",
89 {{"device_ordinal", real_device_ordinal}});
90 },
91 profiler::kInfo);
92 return DoWork(ctx, real_device_ordinal);
93 }
94
TpuTransferAsyncOpKernel(OpKernelConstruction * ctx,const string & transfer_type,int number_of_threads,std::unique_ptr<TpuTransferOpInterface> transfer_op)95 TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(
96 OpKernelConstruction* ctx, const string& transfer_type,
97 int number_of_threads, std::unique_ptr<TpuTransferOpInterface> transfer_op)
98 : TpuTransferAsyncOpKernelBase(ctx, transfer_type, number_of_threads,
99 std::move(transfer_op)) {
100 OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
101 if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
102 OP_REQUIRES(
103 ctx, device_ordinal_ >= 0,
104 errors::InvalidArgument(transfer_type,
105 " ops must specify a device_ordinal when "
106 "placed on CPU."));
107 }
108 }
109
RunTransfer(OpKernelContext * ctx)110 Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
111 return RunTransferWithOrdinal(ctx, device_ordinal_);
112 }
113
TpuTransferAsyncDynamicOrdinalOpKernel(OpKernelConstruction * ctx,const string & transfer_type,int number_of_threads,std::unique_ptr<TpuTransferOpInterface> transfer_op)114 TpuTransferAsyncDynamicOrdinalOpKernel::TpuTransferAsyncDynamicOrdinalOpKernel(
115 OpKernelConstruction* ctx, const string& transfer_type,
116 int number_of_threads, std::unique_ptr<TpuTransferOpInterface> transfer_op)
117 : TpuTransferAsyncOpKernelBase(ctx, transfer_type, number_of_threads,
118 std::move(transfer_op)) {}
119
RunTransfer(OpKernelContext * ctx)120 Status TpuTransferAsyncDynamicOrdinalOpKernel::RunTransfer(
121 OpKernelContext* ctx) {
122 const Tensor& device_ordinal_tensor = ctx->input(0);
123 const int device_ordinal = device_ordinal_tensor.scalar<int32>()();
124 XlaDevice* xla_device =
125 dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
126 if (((xla_device == nullptr) || (xla_device->device_type() == DEVICE_CPU)) &&
127 (device_ordinal < 0)) {
128 return errors::InvalidArgument(transfer_type_,
129 " ops must specify a device_ordinal when "
130 "placed on CPU.");
131 }
132 return RunTransferWithOrdinal(ctx, device_ordinal);
133 }
134
StreamExecutorTransferOpImpl()135 StreamExecutorTransferOpImpl::StreamExecutorTransferOpImpl()
136 : transfer_manager_(
137 xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager()),
138 tpu_platform_(tpu::TpuPlatformInterface::GetRegisteredPlatform(
139 /*initialize_platform=*/false)) {}
140
Cancel()141 void StreamExecutorTransferOpImpl::Cancel() {
142 TF_CHECK_OK(tpu::TpuNodeContext::CloseTpuHost());
143 }
144
GetDeviceOrdinal(OpKernelContext * ctx)145 StatusOr<int> StreamExecutorTransferOpImpl::GetDeviceOrdinal(
146 OpKernelContext* ctx) {
147 const XlaDevice::Metadata* metadata;
148 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
149 return metadata->device_ordinal();
150 }
151
TransferBuffersToInfeed(int device_ordinal,const std::deque<tensorflow::tpu::NoncopyableBuffer> & buffers)152 Status StreamExecutorTransferOpImpl::TransferBuffersToInfeed(
153 int device_ordinal,
154 const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) {
155 TF_ASSIGN_OR_RETURN(auto* executor, GetStreamExecutor(device_ordinal));
156 return transfer_manager_->TransferBuffersToInfeed(executor, buffers);
157 }
158
TransferLiteralToInfeed(int device_ordinal,const xla::LiteralSlice & literal)159 Status StreamExecutorTransferOpImpl::TransferLiteralToInfeed(
160 int device_ordinal, const xla::LiteralSlice& literal) {
161 TF_ASSIGN_OR_RETURN(auto* executor, GetStreamExecutor(device_ordinal));
162 return transfer_manager_->TransferLiteralToInfeed(executor, literal);
163 }
164
TransferLiteralFromOutfeed(int device_ordinal,xla::MutableBorrowingLiteral literal)165 Status StreamExecutorTransferOpImpl::TransferLiteralFromOutfeed(
166 int device_ordinal, xla::MutableBorrowingLiteral literal) {
167 TF_ASSIGN_OR_RETURN(auto* executor, GetStreamExecutor(device_ordinal));
168 return transfer_manager_->TransferLiteralFromOutfeed(executor, literal);
169 }
170
171 StatusOr<stream_executor::StreamExecutor*>
GetStreamExecutor(int device_ordinal)172 StreamExecutorTransferOpImpl::GetStreamExecutor(int device_ordinal) {
173 return tpu_platform_->ExecutorForDevice(device_ordinal);
174 }
175
176 } // namespace tensorflow
177