xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/tpu/noncopyable_buffer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 
23 #include "absl/base/casts.h"
24 #include "absl/functional/function_ref.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mem.h"
29 
30 namespace tensorflow {
31 namespace tpu {
32 
33 using BufferDeallocator = std::function<void(void*)>;
34 using OwnedDataPtr = std::unique_ptr<uint8_t[], BufferDeallocator>;
35 using BufferAllocator = absl::FunctionRef<OwnedDataPtr(size_t)>;
36 
DefaultAllocator(size_t size)37 inline OwnedDataPtr DefaultAllocator(size_t size) {
38   return {static_cast<uint8_t*>(malloc(size)), free};
39 }
40 
41 // Uncopyable buffer type with optional ownership of the underlying data. If
42 // data is not owned then ensuring lifetime of the data exceeds the lifetime of
43 // the buffer is the responsibility of the user.
44 class NoncopyableBuffer {
45  public:
46   NoncopyableBuffer() = default;
47 
48   // Allocate an owning buffer without initializing the data. Useful when it
49   // will be filled by a subsequent function and want to avoid initialization
50   // cost. Size is specified in number of bytes.
51   explicit NoncopyableBuffer(size_t size,
52                              BufferAllocator allocator = DefaultAllocator)
data_(allocator (size))53       : data_(allocator(size)), buf_(data_.get()), size_(size) {}
54 
55   // Allocates an owning buffer and initializes it with the specified data. Size
56   // is specified in number of uint32's.
57   NoncopyableBuffer(size_t size_in_u32s, std::optional<uint32_t> value,
58                     BufferAllocator allocator = DefaultAllocator)
NoncopyableBuffer(size_in_u32s * sizeof (uint32_t),allocator)59       : NoncopyableBuffer(size_in_u32s * sizeof(uint32_t), allocator) {
60 #ifndef MEMORY_SANITIZER
61     if (!value.has_value()) {
62       return;
63     }
64 #endif
65     uint32_t* data_u32 = reinterpret_cast<uint32_t*>(data_.get());
66     uint32_t v = value.value_or(0);
67     for (uint32_t *p = data_u32, *e = data_u32 + size_in_u32s; p < e; ++p) {
68       *p = v;
69     }
70   }
71 
72   // Directly use buf pointer without copying it to owning data_. This delays
73   // the memcpy until mutable access is requested. "buf" is not owned by this
74   // data structure, so it is the user's duty to ensure the live range of "buf"
75   // is longer than this data structure.
NoncopyableBuffer(const uint8_t * buf,size_t size)76   NoncopyableBuffer(const uint8_t* buf, size_t size)  // Size is in uint8's.
77       : buf_(buf), size_(size) {}
NoncopyableBuffer(const uint32_t * buf,size_t size_in_u32s)78   NoncopyableBuffer(const uint32_t* buf,
79                     size_t size_in_u32s)  // Size is in uint32_t's.
80       : buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
81 
82   NoncopyableBuffer(const NoncopyableBuffer&) = delete;
83   NoncopyableBuffer(NoncopyableBuffer&&) = default;
84 
85   NoncopyableBuffer& operator=(const NoncopyableBuffer&) = delete;
86   NoncopyableBuffer& operator=(NoncopyableBuffer&&) = default;
87 
88   // Ensure that the buffer owns the data and returns a mutable view into the
89   // owned data for modification.
90   template <typename T>
mutable_data()91   absl::Span<T> mutable_data() {
92     static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
93     EnsureDataOwned();
94     DCHECK_EQ(size_ % sizeof(T), 0);
95     return absl::Span<T>(reinterpret_cast<T*>(data_.get()), size_ / sizeof(T));
96   }
97 
98   template <typename T>
const_data()99   absl::Span<const T> const_data() const {
100     static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
101     DCHECK_EQ(size_ % sizeof(T), 0);
102     return absl::Span<const T>(static_cast<const T*>(buf_), size_ / sizeof(T));
103   }
104   // Clone the content to a given buffer.
CloneTo(void * buf)105   void CloneTo(void* buf) { memcpy(buf, buf_, size_); }
106 
107   // Return true if data is owned by this buffer (have been copied to `data_`).
owns_data()108   bool owns_data() const { return data_ != nullptr; }
109 
110   // Returns a copy of the object that owns its buffer.
111   NoncopyableBuffer Clone(size_t alignment = 1) const {
112     auto clone = alignment <= 1
113                      ? NoncopyableBuffer(size_)
114                      : NoncopyableBuffer(AlignedAlloc(size_, alignment), size_);
115     memcpy(clone.data_.get(), buf_, size_);
116     return clone;
117   }
118   // Returns a copy of the object that owns its buffer. It uses `allocator` to
119   // allocate the new buffer, which can have custom properties like special
120   // alignment.
Clone(BufferAllocator allocator)121   NoncopyableBuffer Clone(BufferAllocator allocator) const {
122     NoncopyableBuffer clone(size_, allocator);
123     memcpy(clone.data_.get(), buf_, size_);
124     return clone;
125   }
126 
127   // Ensure that the buffer owns the data.
128   void EnsureDataOwned(BufferAllocator allocator = DefaultAllocator) {
129     if (data_ == nullptr) {
130       data_ = allocator(size_);
131       memcpy(data_.get(), buf_, size_);
132       buf_ = data_.get();
133     }
134   }
135 
AlignedAlloc(size_t size,size_t alignment)136   static OwnedDataPtr AlignedAlloc(size_t size, size_t alignment) {
137     return OwnedDataPtr(
138         static_cast<uint8_t*>(port::AlignedMalloc(size, alignment)),
139         port::AlignedFree);
140   }
141 
142  private:
NoncopyableBuffer(OwnedDataPtr data,size_t size)143   NoncopyableBuffer(OwnedDataPtr data, size_t size)
144       : data_(std::move(data)), buf_(data_.get()), size_(size) {}
145 
146   // If data_ != nullptr then buf_ == data_.get()
147   OwnedDataPtr data_{nullptr, free};  // Owning data pointer.
148   const void* buf_;                   // Non-owning data pointer.
149   size_t size_;                       // Size in number of bytes.
150 };
151 
152 }  // namespace tpu
153 }  // namespace tensorflow
154 
155 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
156