xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/stage_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <deque>
18 #include <mutex>
19 #include <numeric>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/mutex.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 class Buffer : public ResourceBase {
34  public:
35   using Tuple = std::vector<Tensor>;
36 
Buffer(std::size_t capacity,std::size_t memory_limit)37   explicit Buffer(std::size_t capacity, std::size_t memory_limit)
38       : capacity_(capacity), memory_limit_(memory_limit), current_bytes_(0) {}
39 
40   // the Buffer takes ownership of the Tuple
Put(Tuple * tuple)41   Status Put(Tuple* tuple) {
42     std::unique_lock<std::mutex> lock(mu_);
43 
44     std::size_t tuple_bytes = GetTupleBytes(*tuple);
45 
46     // Sanity check so that we don't block for ever below
47     if (memory_limit_ > 0 && tuple_bytes > memory_limit_) {
48       return Status(
49           errors::ResourceExhausted("Attempted to insert "
50                                     "tensors with combined size of '",
51                                     tuple_bytes,
52                                     "' bytes into "
53                                     "Staging Area with a memory limit of '",
54                                     memory_limit_, "'."));
55     }
56 
57     // If buffer capacity is bounded wait until elements have been removed
58     if (IsBounded()) {
59       full_cond_var_.wait(lock, [tuple_bytes, this]() {
60         // If there's a memory limit, check if there's space for insertion
61         bool memory_limit_valid =
62             memory_limit_ > 0 ? !WouldExceedMemoryLimit(tuple_bytes) : true;
63         // If we're configured for capacity check if there's space for insertion
64         bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true;
65 
66         // Stop waiting upon success for both conditions
67         return capacity_valid && memory_limit_valid;
68       });
69     }
70 
71     // Update bytes in the Staging Area
72     current_bytes_ += tuple_bytes;
73 
74     // Store tuple
75     buf_.push_back(std::move(*tuple));
76 
77     lock.unlock();
78     // Notify all removers. Removers
79     // may be peeking at a specific element or waiting
80     // for the element at the front of the deque.
81     // As we don't know the appropriate one to wake up
82     // we should wake them all.
83     non_empty_cond_var_.notify_all();
84 
85     return OkStatus();
86   }
87 
88   // Get tuple at front of the buffer
Get(Tuple * tuple)89   void Get(Tuple* tuple) {  // TODO(zhifengc): Support cancellation.
90     std::unique_lock<std::mutex> lock(mu_);
91 
92     // Wait for data if the buffer is empty
93     non_empty_cond_var_.wait(lock, [this]() { return !buf_.empty(); });
94 
95     // Move data into the output tuple
96     *tuple = std::move(buf_.front());
97     buf_.pop_front();
98 
99     // Update bytes in the Staging Area
100     current_bytes_ -= GetTupleBytes(*tuple);
101 
102     notify_inserters_if_bounded(&lock);
103   }
104 
105   // Return tuple at index
Peek(std::size_t index,Tuple * tuple)106   Status Peek(std::size_t index, Tuple* tuple) {
107     std::unique_lock<std::mutex> lock(mu_);
108 
109     // Wait if the requested index is not available
110     non_empty_cond_var_.wait(
111         lock, [index, this]() { return index < this->buf_.size(); });
112 
113     // Place tensors in the output tuple
114     for (const auto& tensor : buf_[index]) {
115       tuple->push_back(tensor);
116     }
117 
118     return OkStatus();
119   }
120 
121   // Buffer size
Size()122   size_t Size() {
123     std::unique_lock<std::mutex> lock(mu_);
124     return buf_.size();
125   }
126 
Clear()127   void Clear() {
128     std::unique_lock<std::mutex> lock(mu_);
129     buf_.clear();
130     current_bytes_ = 0;
131 
132     notify_inserters_if_bounded(&lock);
133   }
134 
DebugString() const135   string DebugString() const override {
136     std::unique_lock<std::mutex> lock(mu_);
137     return strings::StrCat("Staging size: ", buf_.size());
138   }
139 
140  private:
141   // If the buffer is configured for bounded capacity, notify
142   // waiting inserters that space is now available
notify_inserters_if_bounded(std::unique_lock<std::mutex> * lock)143   void notify_inserters_if_bounded(std::unique_lock<std::mutex>* lock) {
144     if (IsBounded()) {
145       lock->unlock();
146       // Notify all inserters. The removal of an element
147       // may make memory available for many inserters
148       // to insert new elements
149       full_cond_var_.notify_all();
150     }
151   }
152 
153   // Are there a limit number of elements or a memory limit
154   // configured on this buffer?
IsBounded() const155   bool IsBounded() const { return capacity_ > 0 || memory_limit_ > 0; }
156 
IsCapacityFull() const157   bool IsCapacityFull() const { return buf_.size() >= capacity_; }
158 
WouldExceedMemoryLimit(std::size_t bytes) const159   bool WouldExceedMemoryLimit(std::size_t bytes) const {
160     return bytes + current_bytes_ > memory_limit_;
161   }
162 
GetTupleBytes(const Tuple & tuple)163   std::size_t GetTupleBytes(const Tuple& tuple) {
164     return std::accumulate(tuple.begin(), tuple.end(), 0,
165                            [](const std::size_t& lhs, const Tensor& rhs) {
166                              return lhs + rhs.TotalBytes();
167                            });
168   }
169 
170   std::size_t capacity_;
171   std::size_t memory_limit_;
172   std::size_t current_bytes_;
173   mutable std::mutex mu_;
174   std::condition_variable non_empty_cond_var_;
175   std::condition_variable full_cond_var_;
176   std::deque<Tuple> buf_;
177 };
178 
GetBuffer(OpKernelContext * ctx,const NodeDef & ndef,Buffer ** buf)179 Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) {
180   auto rm = ctx->resource_manager();
181   ContainerInfo cinfo;
182 
183   // Lambda for creating the Staging Area
184   auto create_fn = [&ndef](Buffer** ret) -> Status {
185     int64_t capacity;
186     int64_t memory_limit;
187     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
188     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
189     *ret = new Buffer(capacity, memory_limit);
190     return OkStatus();
191   };
192 
193   TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
194   TF_RETURN_IF_ERROR(rm->LookupOrCreate<Buffer>(cinfo.container(), cinfo.name(),
195                                                 buf, create_fn));
196   return OkStatus();
197 }
198 
199 }  // namespace
200 
201 class StageOp : public OpKernel {
202  public:
StageOp(OpKernelConstruction * ctx)203   explicit StageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
204 
Compute(OpKernelContext * ctx)205   void Compute(OpKernelContext* ctx) override {
206     Buffer* buf = nullptr;
207     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
208     core::ScopedUnref scope(buf);
209     Buffer::Tuple tuple;
210     tuple.reserve(ctx->num_inputs());
211     for (int i = 0; i < ctx->num_inputs(); ++i) {
212       tuple.push_back(ctx->input(i));
213     }
214     OP_REQUIRES_OK(ctx, buf->Put(&tuple));
215   }
216 };
217 
218 REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp);
219 REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_DEFAULT), StageOp);
220 
221 class UnstageOp : public OpKernel {
222  public:
UnstageOp(OpKernelConstruction * ctx)223   explicit UnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
224 
225   // Using this op in such a way that it blocks forever
226   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)227   void Compute(OpKernelContext* ctx) override {
228     Buffer* buf = nullptr;
229     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
230     core::ScopedUnref scope(buf);
231     Buffer::Tuple tuple;
232 
233     buf->Get(&tuple);
234 
235     OP_REQUIRES(
236         ctx, tuple.size() == (size_t)ctx->num_outputs(),
237         errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
238                                 " vs. ", ctx->num_outputs()));
239 
240     for (size_t i = 0; i < tuple.size(); ++i) {
241       ctx->set_output(i, tuple[i]);
242     }
243   }
244 };
245 
246 REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp);
247 REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_DEFAULT), UnstageOp);
248 
249 class StagePeekOp : public OpKernel {
250  public:
StagePeekOp(OpKernelConstruction * ctx)251   explicit StagePeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
252 
253   // Using this op in such a way that it blocks forever
254   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)255   void Compute(OpKernelContext* ctx) override {
256     Buffer* buf = nullptr;
257     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
258     core::ScopedUnref scope(buf);
259     Buffer::Tuple tuple;
260 
261     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(0).shape()),
262                 errors::InvalidArgument("index must be scalar"));
263     std::size_t index = ctx->input(0).scalar<int>()();
264 
265     OP_REQUIRES_OK(ctx, buf->Peek(index, &tuple));
266 
267     OP_REQUIRES(
268         ctx, tuple.size() == (size_t)ctx->num_outputs(),
269         errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
270                                 " vs. ", ctx->num_outputs()));
271 
272     for (size_t i = 0; i < tuple.size(); ++i) {
273       ctx->set_output(i, tuple[i]);
274     }
275   }
276 };
277 
278 REGISTER_KERNEL_BUILDER(Name("StagePeek").Device(DEVICE_CPU), StagePeekOp);
279 REGISTER_KERNEL_BUILDER(
280     Name("StagePeek").HostMemory("index").Device(DEVICE_DEFAULT), StagePeekOp);
281 
282 class StageSizeOp : public OpKernel {
283  public:
StageSizeOp(OpKernelConstruction * ctx)284   explicit StageSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
285 
286   // Using this op in such a way that it blocks forever
287   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)288   void Compute(OpKernelContext* ctx) override {
289     Buffer* buf = nullptr;
290     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
291     core::ScopedUnref scope(buf);
292 
293     // Allocate size output tensor
294     Tensor* size = nullptr;
295     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
296 
297     // Set it to the actual size
298     size->scalar<int32>().setConstant(buf->Size());
299   }
300 };
301 
302 REGISTER_KERNEL_BUILDER(Name("StageSize").Device(DEVICE_CPU), StageSizeOp);
303 REGISTER_KERNEL_BUILDER(
304     Name("StageSize").HostMemory("size").Device(DEVICE_DEFAULT), StageSizeOp);
305 
306 class StageClearOp : public OpKernel {
307  public:
StageClearOp(OpKernelConstruction * ctx)308   explicit StageClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
309 
310   // Using this op in such a way that it blocks forever
311   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)312   void Compute(OpKernelContext* ctx) override {
313     Buffer* buf = nullptr;
314     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
315     core::ScopedUnref scope(buf);
316 
317     buf->Clear();
318   }
319 };
320 
321 REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_CPU), StageClearOp);
322 REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_DEFAULT),
323                         StageClearOp);
324 
325 }  // namespace tensorflow
326