1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
18
19 #include <functional>
20 #include <memory>
21 #include <utility>
22
23 #include "absl/base/casts.h"
24 #include "absl/functional/function_ref.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mem.h"
29
30 namespace tensorflow {
31 namespace tpu {
32
33 using BufferDeallocator = std::function<void(void*)>;
34 using OwnedDataPtr = std::unique_ptr<uint8_t[], BufferDeallocator>;
35 using BufferAllocator = absl::FunctionRef<OwnedDataPtr(size_t)>;
36
DefaultAllocator(size_t size)37 inline OwnedDataPtr DefaultAllocator(size_t size) {
38 return {static_cast<uint8_t*>(malloc(size)), free};
39 }
40
41 // Uncopyable buffer type with optional ownership of the underlying data. If
42 // data is not owned then ensuring lifetime of the data exceeds the lifetime of
43 // the buffer is the responsibility of the user.
44 class NoncopyableBuffer {
45 public:
46 NoncopyableBuffer() = default;
47
48 // Allocate an owning buffer without initializing the data. Useful when it
49 // will be filled by a subsequent function and want to avoid initialization
50 // cost. Size is specified in number of bytes.
51 explicit NoncopyableBuffer(size_t size,
52 BufferAllocator allocator = DefaultAllocator)
data_(allocator (size))53 : data_(allocator(size)), buf_(data_.get()), size_(size) {}
54
55 // Allocates an owning buffer and initializes it with the specified data. Size
56 // is specified in number of uint32's.
57 NoncopyableBuffer(size_t size_in_u32s, std::optional<uint32_t> value,
58 BufferAllocator allocator = DefaultAllocator)
NoncopyableBuffer(size_in_u32s * sizeof (uint32_t),allocator)59 : NoncopyableBuffer(size_in_u32s * sizeof(uint32_t), allocator) {
60 #ifndef MEMORY_SANITIZER
61 if (!value.has_value()) {
62 return;
63 }
64 #endif
65 uint32_t* data_u32 = reinterpret_cast<uint32_t*>(data_.get());
66 uint32_t v = value.value_or(0);
67 for (uint32_t *p = data_u32, *e = data_u32 + size_in_u32s; p < e; ++p) {
68 *p = v;
69 }
70 }
71
72 // Directly use buf pointer without copying it to owning data_. This delays
73 // the memcpy until mutable access is requested. "buf" is not owned by this
74 // data structure, so it is the user's duty to ensure the live range of "buf"
75 // is longer than this data structure.
NoncopyableBuffer(const uint8_t * buf,size_t size)76 NoncopyableBuffer(const uint8_t* buf, size_t size) // Size is in uint8's.
77 : buf_(buf), size_(size) {}
NoncopyableBuffer(const uint32_t * buf,size_t size_in_u32s)78 NoncopyableBuffer(const uint32_t* buf,
79 size_t size_in_u32s) // Size is in uint32_t's.
80 : buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
81
82 NoncopyableBuffer(const NoncopyableBuffer&) = delete;
83 NoncopyableBuffer(NoncopyableBuffer&&) = default;
84
85 NoncopyableBuffer& operator=(const NoncopyableBuffer&) = delete;
86 NoncopyableBuffer& operator=(NoncopyableBuffer&&) = default;
87
88 // Ensure that the buffer owns the data and returns a mutable view into the
89 // owned data for modification.
90 template <typename T>
mutable_data()91 absl::Span<T> mutable_data() {
92 static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
93 EnsureDataOwned();
94 DCHECK_EQ(size_ % sizeof(T), 0);
95 return absl::Span<T>(reinterpret_cast<T*>(data_.get()), size_ / sizeof(T));
96 }
97
98 template <typename T>
const_data()99 absl::Span<const T> const_data() const {
100 static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
101 DCHECK_EQ(size_ % sizeof(T), 0);
102 return absl::Span<const T>(static_cast<const T*>(buf_), size_ / sizeof(T));
103 }
104 // Clone the content to a given buffer.
CloneTo(void * buf)105 void CloneTo(void* buf) { memcpy(buf, buf_, size_); }
106
107 // Return true if data is owned by this buffer (have been copied to `data_`).
owns_data()108 bool owns_data() const { return data_ != nullptr; }
109
110 // Returns a copy of the object that owns its buffer.
111 NoncopyableBuffer Clone(size_t alignment = 1) const {
112 auto clone = alignment <= 1
113 ? NoncopyableBuffer(size_)
114 : NoncopyableBuffer(AlignedAlloc(size_, alignment), size_);
115 memcpy(clone.data_.get(), buf_, size_);
116 return clone;
117 }
118 // Returns a copy of the object that owns its buffer. It uses `allocator` to
119 // allocate the new buffer, which can have custom properties like special
120 // alignment.
Clone(BufferAllocator allocator)121 NoncopyableBuffer Clone(BufferAllocator allocator) const {
122 NoncopyableBuffer clone(size_, allocator);
123 memcpy(clone.data_.get(), buf_, size_);
124 return clone;
125 }
126
127 // Ensure that the buffer owns the data.
128 void EnsureDataOwned(BufferAllocator allocator = DefaultAllocator) {
129 if (data_ == nullptr) {
130 data_ = allocator(size_);
131 memcpy(data_.get(), buf_, size_);
132 buf_ = data_.get();
133 }
134 }
135
AlignedAlloc(size_t size,size_t alignment)136 static OwnedDataPtr AlignedAlloc(size_t size, size_t alignment) {
137 return OwnedDataPtr(
138 static_cast<uint8_t*>(port::AlignedMalloc(size, alignment)),
139 port::AlignedFree);
140 }
141
142 private:
NoncopyableBuffer(OwnedDataPtr data,size_t size)143 NoncopyableBuffer(OwnedDataPtr data, size_t size)
144 : data_(std::move(data)), buf_(data_.get()), size_(size) {}
145
146 // If data_ != nullptr then buf_ == data_.get()
147 OwnedDataPtr data_{nullptr, free}; // Owning data pointer.
148 const void* buf_; // Non-owning data pointer.
149 size_t size_; // Size in number of bytes.
150 };
151
152 } // namespace tpu
153 } // namespace tensorflow
154
155 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
156