xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstddef>
4 #include <cstdint>
5 
6 #include <memory>
7 
8 #include <c10/macros/Macros.h>
9 #include <c10/util/Logging.h>
10 
11 /**
12  * Packed representation of input indices for ProcessedNode.
13  */
14 class ProcessedNodeInputs {
15  private:
16   // This keeps the size usage for inputs + outputs down to 16 bytes;
17   // we use 12 bytes, and then two 2-byte integers are used to store
18   // the outputs.
19   static constexpr size_t kMaxInlineInputs = 5;
20 
21  public:
ProcessedNodeInputs()22   ProcessedNodeInputs() : ProcessedNodeInputs(0) {}
23 
ProcessedNodeInputs(size_t size)24   explicit ProcessedNodeInputs(size_t size) {
25     TORCH_DCHECK_LT(size, (1 << 16));
26     if (size <= kMaxInlineInputs) {
27       repr_.inline_repr_.size = size;
28     } else {
29       new (&repr_.outline_repr_) HeapArrayPtr(size);
30     }
31   }
32 
33   uint16_t operator[](uint16_t idx) const {
34     return (*const_cast<ProcessedNodeInputs*>(this))[idx];
35   }
36 
37   uint16_t& operator[](uint16_t idx) {
38     if (C10_LIKELY(repr_.is_inline())) {
39       TORCH_DCHECK_LT(idx, repr_.inline_repr_.size);
40       return repr_.inline_repr_.inputs[idx];
41     } else {
42       return repr_.outline_repr_[idx];
43     }
44   }
45 
size()46   C10_NODISCARD uint16_t size() const {
47     if (C10_LIKELY(repr_.is_inline())) {
48       return repr_.inline_repr_.size;
49     } else {
50       return repr_.outline_repr_.size();
51     }
52   }
53 
empty()54   C10_NODISCARD bool empty() const {
55     return size() == 0;
56   }
57 
58  private:
59   class HeapArrayPtr {
60    public:
61     HeapArrayPtr() = default;
62     ~HeapArrayPtr() = default;
63 
HeapArrayPtr(uint16_t size)64     explicit HeapArrayPtr(uint16_t size) : array_(alloc(size)) {}
65 
HeapArrayPtr(const HeapArrayPtr & rhs)66     HeapArrayPtr(const HeapArrayPtr& rhs) : array_(alloc(rhs.size())) {
67       if (rhs.array_) {
68         std::memcpy(
69             array_.get(),
70             rhs.array_.get(),
71             (rhs.size() + 1) * sizeof(uint16_t));
72       }
73     }
74 
75     HeapArrayPtr& operator=(const HeapArrayPtr& rhs) {
76       if (&rhs == this) {
77         return *this;
78       }
79 
80       if (size() != rhs.size()) {
81         array_ = alloc(rhs.size());
82       }
83 
84       if (rhs.array_) {
85         std::memcpy(
86             array_.get(),
87             rhs.array_.get(),
88             (rhs.size() + 1) * sizeof(uint16_t));
89       }
90       return *this;
91     }
92 
93     HeapArrayPtr(HeapArrayPtr&&) noexcept = default;
94     HeapArrayPtr& operator=(HeapArrayPtr&&) noexcept = default;
95 
empty()96     C10_NODISCARD bool empty() const {
97       return size() != 0;
98     }
99 
size()100     C10_NODISCARD uint16_t size() const {
101       return array_ ? array_[0] : 0;
102     }
103 
104     uint16_t operator[](uint16_t idx) const {
105       TORCH_DCHECK_LT(idx, size());
106       return array_[idx + 1];
107     }
108 
109     uint16_t& operator[](uint16_t idx) {
110       TORCH_DCHECK_LT(idx, size());
111       return array_[idx + 1];
112     }
113 
114    private:
115     // NOLINTNEXTLINE(modernize-avoid-c-arrays)
116     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
117     std::unique_ptr<uint16_t[]> array_;
118 
119     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
120     // NOLINTNEXTLINE(modernize-avoid-c-arrays)
alloc(uint16_t num_elts)121     static std::unique_ptr<uint16_t[]> alloc(uint16_t num_elts) {
122       if (num_elts) {
123         auto result = std::make_unique<uint16_t[]>(num_elts + 1);
124         result[0] = num_elts;
125         return result;
126       } else {
127         return nullptr;
128       }
129     }
130   };
131 
132   // We want ProcessedNode to be able to pack two more `uint16_t`
133   // fields after its ProcessedNodeInputs, and we'll end up being
134   // aligned to an 8-byte boundary anyway. We could avoid this pragma
135   // at the cost of having to move ProcessedNode::outputs_offset_ and
136   // ProcessedNode::num_outputs_ into this class, which would be
137   // awkward.
138 #pragma pack(push, 2)
139   union Repr {
is_inline()140     C10_NODISCARD bool is_inline() const {
141       uint8_t tag = 0;
142       // Use of reinterpret_cast to pointer to char or unsigned char
143       // is defined behavior; see
144       // https://en.cppreference.com/w/cpp/language/reinterpret_cast .
145       std::memcpy(&tag, reinterpret_cast<const uint8_t*>(this), 1);
146       // HeapArrayPtr will be represented as a plain old pointer,
147       // which will have alignment to at least a 2-byte boundary
148       // (because it's uint16_t*) and more likely an 8- or 16-byte
149       // boundary because malloc will tend to just align everything to
150       // one of those. So, we just set tag to 1 when inline_repr_ is
151       // active so as to be able to differentiate the two.
152       return (tag & 1) != 0;
153     }
154 
155     // NOLINTNEXTLINE(modernize-use-equals-default)
Repr()156     Repr() {}
157 
~Repr()158     ~Repr() {
159       destroyIfOutline();
160     }
161 
Repr(const Repr & rhs)162     Repr(const Repr& rhs) {
163       if (rhs.is_inline()) {
164         std::memcpy(&inline_repr_, &rhs.inline_repr_, sizeof(inline_repr_));
165       } else {
166         new (&outline_repr_) OutlineRepr(rhs.outline_repr_);
167       }
168     }
169 
170     Repr& operator=(const Repr& rhs) {
171       if (&rhs == this) {
172         return *this;
173       }
174       if (rhs.is_inline()) {
175         destroyIfOutline();
176         new (&inline_repr_) InlineRepr();
177         std::memcpy(&inline_repr_, &rhs.inline_repr_, sizeof(inline_repr_));
178       } else {
179         if (is_inline()) {
180           new (&outline_repr_) OutlineRepr(rhs.outline_repr_);
181         } else {
182           outline_repr_ = rhs.outline_repr_;
183         }
184       }
185       return *this;
186     }
187 
Repr(Repr && rhs)188     Repr(Repr&& rhs) noexcept {
189       if (rhs.is_inline()) {
190         std::memcpy(&inline_repr_, &rhs.inline_repr_, sizeof(inline_repr_));
191       } else {
192         new (&outline_repr_) OutlineRepr(std::move(rhs.outline_repr_));
193       }
194     }
195 
196     Repr& operator=(Repr&& rhs) noexcept {
197       if (&rhs == this) {
198         return *this;
199       }
200 
201       if (rhs.is_inline()) {
202         destroyIfOutline();
203         new (&inline_repr_) InlineRepr();
204         std::memcpy(&inline_repr_, &rhs.inline_repr_, sizeof(inline_repr_));
205       } else {
206         if (is_inline()) {
207           new (&outline_repr_) OutlineRepr(std::move(rhs.outline_repr_));
208         } else {
209           outline_repr_ = std::move(rhs.outline_repr_);
210         }
211       }
212 
213       return *this;
214     }
215 
216     struct InlineRepr {
217       uint8_t tag = 0x1;
218       uint8_t size{};
219       uint16_t inputs[kMaxInlineInputs]{};
220     };
221 
222     using OutlineRepr = HeapArrayPtr;
223 
224     InlineRepr inline_repr_{};
225     OutlineRepr outline_repr_;
226 
227    private:
destroyIfOutline()228     void destroyIfOutline() {
229       if (!is_inline()) {
230         outline_repr_.~OutlineRepr();
231       }
232     }
233   } repr_;
234 #pragma pack(pop)
235 };
236 
237 static_assert(
238     sizeof(ProcessedNodeInputs) == 12,
239     "ProcessedNodeInputs has the wrong size!");
240