xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/tensor_cord.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_CORD_H_
17 #define TENSORFLOW_CORE_KERNELS_TENSOR_CORD_H_
18 
19 #include <array>
20 #include <numeric>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/core/framework/variant_tensor_data.h"
27 
28 namespace tensorflow {
29 
30 typedef void (*CordRepReleaser)(void*);
31 
32 class TensorCord {
33   // A TensorCord keeps a view into some data, and a cleanup method to clean up
34   // that data when the TensorCord destructor is called.  Copying a TensorCord
35   // increments a reference count to the cleanup method, and so the cleanup
36   // method is only called when all copies of the original TensorCord are
37   // cleared.
38   //
39   // Example:
40   //
41   // const string& s = t.scalar<string>()();
42   // TensorCord tc(s, &t);
43   // ASSERT_EQ(s, tc.view());
44   // TensorCord copy(tc);
45   // tc = TensorCord();  // cleanup not called; the reference is held by `copy`.
46   // copy = TensorCord();  // cleanup happens now, the reference is destroyed.
47   //
48   // Another example:
49   //
50   // void TensorProtoDeleter(void* ptr) {
51   //   delete static_cast<TensorProto*>(ptr);
52   // }
53   //
54   // auto p = absl::MakeUnique<TensorProto>(...);
55   // absl::string_view content(p->tensor_content());
56   // TensorCord tc(content, TensorProtoDeleter, p.release());
57   //
58 
59  public:
60   static constexpr const char kTypeName[] = "tensorflow::TensorCord";
61 
TensorCord()62   TensorCord() : chunks_() {}
63 
64   ~TensorCord();
65 
66   // Args:
67   //   `view`: should point to a location in memory that is guaranteed to remain
68   //           valid until `releaser` is called.
69   //   `releaser`: A callback that will be executed when there are no references
70   //               left on `view`.  It will be called via `releaser(memory)`.
71   //   `memory`: The argument passed to `releaser` when it is called.
72   //
73   // You are STRONGLY advised to provide a non-null `releaser`, and a pointer
74   // to the underlying data (while ensuring that the data will not be deleted
75   // until `releaser(memory)` is called).  Otherwise the TensorCord may
76   // outlive the data backing `view`.
77   TensorCord(absl::string_view view, CordRepReleaser releaser,
78              void* memory = nullptr)
79       : chunks_({new CordRep(view, releaser, memory)}) {}
80 
81   // Args:
82   //   `view`: should point to a location in memory backed by `tensor`,
83   //      e.g., `view` is a string_view on a tstring which is an element
84   //      of `tensor`.  Furthermore, the associated tstring is not expected
85   //      to be modified in such a way that the underlying memory will
86   //      be changed after this TensorCord is created.
TensorCord(absl::string_view view,Tensor * tensor)87   TensorCord(absl::string_view view, Tensor* tensor)
88       : chunks_({NewCordRepFromTensor(view, tensor)}) {}
89 
90   // Disallow construction with empty callback or empty tensor.
91   TensorCord(absl::string_view view, std::nullptr_t, void* memory) = delete;
92   TensorCord(absl::string_view view, std::nullptr_t) = delete;
93 
94   TensorCord(const TensorCord& other);
95 
96   TensorCord(TensorCord&& other) noexcept;
97 
98   TensorCord& operator=(const TensorCord& other);
99 
100   TensorCord& operator=(TensorCord&& other) noexcept;
101 
102   void Append(const TensorCord& other);
103 
104   void Append(absl::string_view view, CordRepReleaser releaser,
105               void* memory = nullptr);
106 
107   void Append(absl::string_view view, Tensor* tensor);
108 
109   // Disallow Appends with empty callbacks or empty tensors.
110   void Append(absl::string_view view, std::nullptr_t, void* memory) = delete;
111   void Append(absl::string_view view, std::nullptr_t) = delete;
112 
113   size_t size() const;
empty()114   bool empty() const { return size() == 0; }
115 
116   // NOTE: This performs an expensive copy of the underlying data.
117   explicit operator string() const;
118 
119   class ChunkIterator {
120    public:
121     using iterator_category = std::input_iterator_tag;
122     using value_type = absl::string_view;
123     using difference_type = ptrdiff_t;
124     using pointer = const value_type*;
125     using reference = value_type;
126 
127     ChunkIterator& operator++();
128 
129     ChunkIterator operator++(int) {
130       ChunkIterator tmp(*this);
131       operator++();
132       return tmp;
133     }
134 
135     bool operator==(const ChunkIterator& other) const {
136       return (cord_ == other.cord_ && chunk_index_ == other.chunk_index_);
137     }
138 
139     bool operator!=(const ChunkIterator& other) const {
140       return !(*this == other);
141     }
142     reference operator*() const {
143       assert(cord_ != nullptr);
144       return view_;
145     }
146     pointer operator->() const {
147       assert(cord_ != nullptr);
148       return &view_;
149     }
150 
151     friend class TensorCord;
152 
153    private:
154     // Constructs a `begin()` iterator from `cord`.
155     explicit ChunkIterator(const TensorCord* cord, int chunk_index);
156 
157     const TensorCord* const cord_;
158     int chunk_index_;
159     absl::string_view view_;
160   };
161 
162   class ChunkRange {
163    public:
ChunkRange(const TensorCord * cord)164     explicit ChunkRange(const TensorCord* cord) : cord_(cord) {}
165 
begin()166     ChunkIterator begin() const { return ChunkIterator(cord_, 0); }
167 
end()168     ChunkIterator end() const {
169       return ChunkIterator(cord_, cord_->chunks_.size());
170     }
171 
172    private:
173     const TensorCord* cord_;
174   };
175 
176   // Note that the ordinary caveats of temporary lifetime extension apply:
177   //
178   //   void Process() {
179   //     for (absl::string_view chunk : CordFactory().Chunks()) {
180   //       // The temporary Cord returned by CordFactory has been destroyed!
181   //     }
182   //   }
Chunks()183   ChunkRange Chunks() const { return ChunkRange(this); }
184 
chunk_begin()185   ChunkIterator chunk_begin() const { return ChunkIterator(this, 0); }
186 
chunk_end()187   ChunkIterator chunk_end() const {
188     return ChunkIterator(this, chunks_.size());
189   }
190 
TypeName()191   static string TypeName() { return kTypeName; }
192 
DebugString()193   string DebugString() const {
194     return absl::StrCat("<TensorCord size=", size(), ">");
195   }
196 
197   void Encode(VariantTensorData* data) const;
198 
199   bool Decode(VariantTensorData data);
200 
201  private:
202   void Cleanup();
203 
204   class CordRep : public core::RefCounted {
205    public:
206     CordRep(absl::string_view view, CordRepReleaser releaser,
207             void* arg = nullptr)
is_inline_(false)208         : is_inline_(false), rep_(view, releaser, arg) {}
209 
210     // **WARNING** Only use this constructor if
211     //    view.size() < CordRep::kMaxInlineSize.
CordRep(absl::string_view view)212     explicit CordRep(absl::string_view view) : is_inline_(true), rep_(view) {}
213 
214     ~CordRep() override;
215 
view()216     absl::string_view view() const {
217       if (is_inline_) {
218         return absl::string_view(
219             rep_.internal.data() + 1,
220             *reinterpret_cast<const uint8*>(rep_.internal.data()));
221       } else {
222         return rep_.external.view;
223       }
224     }
225 
226    private:
227     friend class TensorCord;
228 
229     struct ExternalRep {
230       absl::string_view view;
231       CordRepReleaser releaser;
232       void* arg;
233 
ExternalRepExternalRep234       ExternalRep(absl::string_view view_, CordRepReleaser releaser_,
235                   void* arg_)
236           : view(view_), releaser(releaser_), arg(arg_) {}
237     };
238 
239     // We save the size in the first byte, so subtract 1.
240     static constexpr int kMaxInlineSize = sizeof(ExternalRep) - 1;
241     static_assert(kMaxInlineSize < 255,
242                   "Cannot store size of InlineRep in a single byte.");
243 
244     // The first byte stores the size as a uint8.  The rest of the bytes are the
245     // string itself.
246     using InlineRep = std::array<char, sizeof(ExternalRep)>;
247 
248     // Member variables.
249     const bool is_inline_;
250     const union _rep_union {
251       InlineRep internal;
252       ExternalRep external;
253 
_rep_union(absl::string_view view,CordRepReleaser releaser,void * arg)254       _rep_union(absl::string_view view, CordRepReleaser releaser, void* arg)
255           : external(view, releaser, arg) {}
256 
_rep_union(absl::string_view view)257       explicit _rep_union(absl::string_view view) {
258         DCHECK_LT(view.size(), kMaxInlineSize);
259         *reinterpret_cast<uint8*>(internal.data()) = view.size();
260         std::memcpy(static_cast<char*>(internal.data() + 1), view.data(),
261                     view.size());
262       }
263     } rep_;
264   };
265 
266   static TensorBuffer* TensorBufWithRef(Tensor* tensor);
267   static void TensorBufReleaser(void* tensor_buffer);
268   static void StringReleaser(void* str_ptr);
269   static CordRep* NewCordRepFromTensor(absl::string_view view, Tensor* tensor);
270 
271   absl::InlinedVector<CordRep*, 2> chunks_;
272 };
273 
TensorCord(const TensorCord & other)274 inline TensorCord::TensorCord(const TensorCord& other)
275     : chunks_(other.chunks_) {
276   for (auto* rep : chunks_) {
277     rep->Ref();
278   }
279 }
280 
TensorCord(TensorCord && other)281 inline TensorCord::TensorCord(TensorCord&& other) noexcept
282     : chunks_(std::move(other.chunks_)) {
283   other.chunks_.clear();
284 }
285 
286 inline TensorCord& TensorCord::operator=(const TensorCord& other) {
287   Cleanup();
288   chunks_ = other.chunks_;
289   for (auto* rep : chunks_) {
290     rep->Ref();
291   }
292   return *this;
293 }
294 
295 inline TensorCord& TensorCord::operator=(TensorCord&& other) noexcept {
296   Cleanup();
297   std::swap(chunks_, other.chunks_);
298   return *this;
299 }
300 
Append(const TensorCord & other)301 inline void TensorCord::Append(const TensorCord& other) {
302   for (auto* rep : other.chunks_) {
303     chunks_.push_back(rep);
304     rep->Ref();
305   }
306 }
307 
Append(absl::string_view view,CordRepReleaser releaser,void * memory)308 inline void TensorCord::Append(absl::string_view view, CordRepReleaser releaser,
309                                void* memory) {
310   chunks_.push_back(new CordRep(view, releaser, memory));
311 }
312 
Append(absl::string_view view,Tensor * tensor)313 inline void TensorCord::Append(absl::string_view view, Tensor* tensor) {
314   chunks_.push_back(NewCordRepFromTensor(view, tensor));
315 }
316 
size()317 inline size_t TensorCord::size() const {
318   return (chunks_.empty())
319              ? 0
320              : std::accumulate(chunk_begin(), chunk_end(), 0,
321                                [](size_t acc, absl::string_view b) {
322                                  return acc + b.size();
323                                });
324 }
325 
326 inline TensorCord::ChunkIterator& TensorCord::ChunkIterator::operator++() {
327   assert(cord_ != nullptr);
328   assert(chunk_index_ < cord_->chunks_.size());
329   chunk_index_ += 1;
330   if (chunk_index_ != cord_->chunks_.size()) {
331     view_ = cord_->chunks_[chunk_index_]->view();
332   }
333   return *this;
334 }
335 
ChunkIterator(const TensorCord * cord,int index)336 inline TensorCord::ChunkIterator::ChunkIterator(const TensorCord* cord,
337                                                 int index)
338     : cord_(cord), chunk_index_(index) {
339   if (index < cord_->chunks_.size()) {
340     view_ = cord_->chunks_[index]->view();
341   }
342 }
343 
NewCordRepFromTensor(absl::string_view view,Tensor * tensor)344 inline TensorCord::CordRep* TensorCord::NewCordRepFromTensor(
345     absl::string_view view, Tensor* tensor) {
346   if (view.size() <= TensorCord::CordRep::kMaxInlineSize) {
347     return new CordRep(view);
348   } else {
349     return new CordRep(view, &TensorBufReleaser, TensorBufWithRef(tensor));
350   }
351 }
352 
Cleanup()353 inline void TensorCord::Cleanup() {
354   if (chunks_.empty()) return;
355   for (auto* rep : chunks_) {
356     rep->Unref();
357   }
358   chunks_.clear();
359 }
360 
361 }  // namespace tensorflow
362 
363 #endif  // TENSORFLOW_CORE_KERNELS_TENSOR_CORD_H_
364