1 #pragma once 2 3 #include <cstddef> 4 #include <cstdint> 5 6 #include <memory> 7 8 #include <c10/macros/Macros.h> 9 #include <c10/util/Logging.h> 10 11 /** 12 * Packed representation of input indices for ProcessedNode. 13 */ 14 class ProcessedNodeInputs { 15 private: 16 // This keeps the size usage for inputs + outputs down to 16 bytes; 17 // we use 12 bytes, and then two 2-byte integers are used to store 18 // the outputs. 19 static constexpr size_t kMaxInlineInputs = 5; 20 21 public: ProcessedNodeInputs()22 ProcessedNodeInputs() : ProcessedNodeInputs(0) {} 23 ProcessedNodeInputs(size_t size)24 explicit ProcessedNodeInputs(size_t size) { 25 TORCH_DCHECK_LT(size, (1 << 16)); 26 if (size <= kMaxInlineInputs) { 27 repr_.inline_repr_.size = size; 28 } else { 29 new (&repr_.outline_repr_) HeapArrayPtr(size); 30 } 31 } 32 33 uint16_t operator[](uint16_t idx) const { 34 return (*const_cast<ProcessedNodeInputs*>(this))[idx]; 35 } 36 37 uint16_t& operator[](uint16_t idx) { 38 if (C10_LIKELY(repr_.is_inline())) { 39 TORCH_DCHECK_LT(idx, repr_.inline_repr_.size); 40 return repr_.inline_repr_.inputs[idx]; 41 } else { 42 return repr_.outline_repr_[idx]; 43 } 44 } 45 size()46 C10_NODISCARD uint16_t size() const { 47 if (C10_LIKELY(repr_.is_inline())) { 48 return repr_.inline_repr_.size; 49 } else { 50 return repr_.outline_repr_.size(); 51 } 52 } 53 empty()54 C10_NODISCARD bool empty() const { 55 return size() == 0; 56 } 57 58 private: 59 class HeapArrayPtr { 60 public: 61 HeapArrayPtr() = default; 62 ~HeapArrayPtr() = default; 63 HeapArrayPtr(uint16_t size)64 explicit HeapArrayPtr(uint16_t size) : array_(alloc(size)) {} 65 HeapArrayPtr(const HeapArrayPtr & rhs)66 HeapArrayPtr(const HeapArrayPtr& rhs) : array_(alloc(rhs.size())) { 67 if (rhs.array_) { 68 std::memcpy( 69 array_.get(), 70 rhs.array_.get(), 71 (rhs.size() + 1) * sizeof(uint16_t)); 72 } 73 } 74 75 HeapArrayPtr& operator=(const HeapArrayPtr& rhs) { 76 if (&rhs == this) { 77 return *this; 78 } 79 80 if (size() != rhs.size()) { 81 array_ = alloc(rhs.size()); 82 } 83 84 if (rhs.array_) { 85 std::memcpy( 86 array_.get(), 87 rhs.array_.get(), 88 (rhs.size() + 1) * sizeof(uint16_t)); 89 } 90 return *this; 91 } 92 93 HeapArrayPtr(HeapArrayPtr&&) noexcept = default; 94 HeapArrayPtr& operator=(HeapArrayPtr&&) noexcept = default; 95 empty()96 C10_NODISCARD bool empty() const { 97 return size() != 0; 98 } 99 size()100 C10_NODISCARD uint16_t size() const { 101 return array_ ? array_[0] : 0; 102 } 103 104 uint16_t operator[](uint16_t idx) const { 105 TORCH_DCHECK_LT(idx, size()); 106 return array_[idx + 1]; 107 } 108 109 uint16_t& operator[](uint16_t idx) { 110 TORCH_DCHECK_LT(idx, size()); 111 return array_[idx + 1]; 112 } 113 114 private: 115 // NOLINTNEXTLINE(modernize-avoid-c-arrays) 116 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) 117 std::unique_ptr<uint16_t[]> array_; 118 119 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) 120 // NOLINTNEXTLINE(modernize-avoid-c-arrays) alloc(uint16_t num_elts)121 static std::unique_ptr<uint16_t[]> alloc(uint16_t num_elts) { 122 if (num_elts) { 123 auto result = std::make_unique<uint16_t[]>(num_elts + 1); 124 result[0] = num_elts; 125 return result; 126 } else { 127 return nullptr; 128 } 129 } 130 }; 131 132 // We want ProcessedNode to be able to pack two more `uint16_t` 133 // fields after its ProcessedNodeInputs, and we'll end up being 134 // aligned to an 8-byte boundary anyway. We could avoid this pragma 135 // at the cost of having to move ProcessedNode::outputs_offset_ and 136 // ProcessedNode::num_outputs_ into this class, which would be 137 // awkward. 138 #pragma pack(push, 2) 139 union Repr { is_inline()140 C10_NODISCARD bool is_inline() const { 141 uint8_t tag = 0; 142 // Use of reinterpret_cast to pointer to char or unsigned char 143 // is defined behavior; see 144 // https://en.cppreference.com/w/cpp/language/reinterpret_cast . 145 std::memcpy(&tag, reinterpret_cast<const uint8_t*>(this), 1); 146 // HeapArrayPtr will be represented as a plain old pointer, 147 // which will have alignment to at least a 2-byte boundary 148 // (because it's uint16_t*) and more likely an 8- or 16-byte 149 // boundary because malloc will tend to just align everything to 150 // one of those. So, we just set tag to 1 when inline_repr_ is 151 // active so as to be able to differentiate the two. 152 return (tag & 1) != 0; 153 } 154 155 // NOLINTNEXTLINE(modernize-use-equals-default) Repr()156 Repr() {} 157 ~Repr()158 ~Repr() { 159 destroyIfOutline(); 160 } 161 Repr(const Repr & rhs)162 Repr(const Repr& rhs) { 163 if (rhs.is_inline()) { 164 std::memcpy(&inline_repr_, &rhs.inline_repr_, sizeof(inline_repr_)); 165 } else { 166 new (&outline_repr_) OutlineRepr(rhs.outline_repr_); 167 } 168 } 169 170 Repr& operator=(const Repr& rhs) { 171 if (&rhs == this) { 172 return *this; 173 } 174 if (rhs.is_inline()) { 175 destroyIfOutline(); 176 new (&inline_repr_) InlineRepr(); 177 std::memcpy(&inline_repr_, &rhs.inline_repr_, sizeof(inline_repr_)); 178 } else { 179 if (is_inline()) { 180 new (&outline_repr_) OutlineRepr(rhs.outline_repr_); 181 } else { 182 outline_repr_ = rhs.outline_repr_; 183 } 184 } 185 return *this; 186 } 187 Repr(Repr && rhs)188 Repr(Repr&& rhs) noexcept { 189 if (rhs.is_inline()) { 190 std::memcpy(&inline_repr_, &rhs.inline_repr_, sizeof(inline_repr_)); 191 } else { 192 new (&outline_repr_) OutlineRepr(std::move(rhs.outline_repr_)); 193 } 194 } 195 196 Repr& operator=(Repr&& rhs) noexcept { 197 if (&rhs == this) { 198 return *this; 199 } 200 201 if (rhs.is_inline()) { 202 destroyIfOutline(); 203 new (&inline_repr_) InlineRepr(); 204 std::memcpy(&inline_repr_, &rhs.inline_repr_, sizeof(inline_repr_)); 205 } else { 206 if (is_inline()) { 207 new (&outline_repr_) OutlineRepr(std::move(rhs.outline_repr_)); 208 } else { 209 outline_repr_ = std::move(rhs.outline_repr_); 210 } 211 } 212 213 return *this; 214 } 215 216 struct InlineRepr { 217 uint8_t tag = 0x1; 218 uint8_t size{}; 219 uint16_t inputs[kMaxInlineInputs]{}; 220 }; 221 222 using OutlineRepr = HeapArrayPtr; 223 224 InlineRepr inline_repr_{}; 225 OutlineRepr outline_repr_; 226 227 private: destroyIfOutline()228 void destroyIfOutline() { 229 if (!is_inline()) { 230 outline_repr_.~OutlineRepr(); 231 } 232 } 233 } repr_; 234 #pragma pack(pop) 235 }; 236 237 static_assert( 238 sizeof(ProcessedNodeInputs) == 12, 239 "ProcessedNodeInputs has the wrong size!"); 240