xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/typed_queue.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
17 #define TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
18 
19 #include <deque>
20 #include <queue>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/kernels/queue_base.h"
25 #include "tensorflow/core/platform/mutex.h"
26 
27 namespace tensorflow {
28 
29 // TypedQueue builds on QueueBase, with backing class (SubQueue)
30 // known and stored within.  Shared methods that need to have access
31 // to the backed data sit in this class.
32 template <typename SubQueue>
33 class TypedQueue : public QueueBase {
34  public:
35   TypedQueue(const int32_t capacity, const DataTypeVector& component_dtypes,
36              const std::vector<TensorShape>& component_shapes,
37              const string& name);
38 
39   virtual Status Initialize();  // Must be called before any other method.
40 
41   int64_t MemoryUsed() const override;
42 
43  protected:
44   std::vector<SubQueue> queues_ TF_GUARDED_BY(mu_);
45 };  // class TypedQueue
46 
47 template <typename SubQueue>
TypedQueue(int32_t capacity,const DataTypeVector & component_dtypes,const std::vector<TensorShape> & component_shapes,const string & name)48 TypedQueue<SubQueue>::TypedQueue(
49     int32_t capacity, const DataTypeVector& component_dtypes,
50     const std::vector<TensorShape>& component_shapes, const string& name)
51     : QueueBase(capacity, component_dtypes, component_shapes, name) {}
52 
53 template <typename SubQueue>
Initialize()54 Status TypedQueue<SubQueue>::Initialize() {
55   if (component_dtypes_.empty()) {
56     return errors::InvalidArgument("Empty component types for queue ", name_);
57   }
58   if (!component_shapes_.empty() &&
59       component_dtypes_.size() != component_shapes_.size()) {
60     return errors::InvalidArgument(
61         "Different number of component types.  ",
62         "Types: ", DataTypeSliceString(component_dtypes_),
63         ", Shapes: ", ShapeListString(component_shapes_));
64   }
65 
66   mutex_lock lock(mu_);
67   queues_.reserve(num_components());
68   for (int i = 0; i < num_components(); ++i) {
69     queues_.push_back(SubQueue());
70   }
71   return OkStatus();
72 }
73 
74 template <typename SubQueue>
SizeOf(const SubQueue & sq)75 inline int64_t SizeOf(const SubQueue& sq) {
76   static_assert(sizeof(SubQueue) != sizeof(SubQueue), "SubQueue size unknown.");
77   return 0;
78 }
79 
80 template <>
SizeOf(const std::deque<Tensor> & sq)81 inline int64_t SizeOf(const std::deque<Tensor>& sq) {
82   if (sq.empty()) {
83     return 0;
84   }
85   return sq.size() * sq.front().AllocatedBytes();
86 }
87 
88 template <>
SizeOf(const std::vector<Tensor> & sq)89 inline int64_t SizeOf(const std::vector<Tensor>& sq) {
90   if (sq.empty()) {
91     return 0;
92   }
93   return sq.size() * sq.front().AllocatedBytes();
94 }
95 
96 using TensorPair = std::pair<int64_t, Tensor>;
97 
98 template <typename U, typename V>
SizeOf(const std::priority_queue<TensorPair,U,V> & sq)99 int64_t SizeOf(const std::priority_queue<TensorPair, U, V>& sq) {
100   if (sq.empty()) {
101     return 0;
102   }
103   return sq.size() * (sizeof(TensorPair) + sq.top().second.AllocatedBytes());
104 }
105 
106 template <typename SubQueue>
MemoryUsed()107 inline int64_t TypedQueue<SubQueue>::MemoryUsed() const {
108   int memory_size = 0;
109   mutex_lock l(mu_);
110   for (const auto& sq : queues_) {
111     memory_size += SizeOf(sq);
112   }
113   return memory_size;
114 }
115 
116 }  // namespace tensorflow
117 
118 #endif  // TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
119