xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/CompositeRandomAccessorCommon.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <utility>
2 
3 #pragma once
4 
5 namespace at::native {
6 
7 namespace {
8 
9 // operator_brackets_proxy is used in
10 // CompositeRandomAccessor in place of operator[].
11 // For some iterators, references returned by operator[]
12 // could become invalid, operator_brackets_proxy tries to
13 // resolve that by making accessor[n] to be equivalent to
14 // *(accessor + n).
15 template <typename Accessor>
16 class operator_brackets_proxy {
17   using reference = typename std::iterator_traits<Accessor>::reference;
18   using value_type = typename std::iterator_traits<Accessor>::value_type;
19 
20 public:
21   C10_HOST_DEVICE
operator_brackets_proxy(Accessor const & accessor)22   operator_brackets_proxy(Accessor const& accessor)
23     : accessor(accessor)
24   {}
25 
26   C10_HOST_DEVICE
reference()27   operator reference() {
28     return *accessor;
29   }
30 
31   C10_HOST_DEVICE
32   reference operator*() {
33     return *accessor;
34   }
35 
36   C10_HOST_DEVICE
37   operator_brackets_proxy& operator=(value_type const& val) {
38     *accessor = val;
39     return *this;
40   }
41 
42 private:
43   Accessor accessor;
44 };
45 
46 }
47 
48 // references_holder is used as a surrogate for the
49 // references type from std::iterator_traits in CompositeRandomAccessor.
50 // It is assumed in CompositeRandomAccessor that
51 // References = tuple<Types&...>,
52 // Values = tuple<Types...> by default,
53 // but they could be anything as long as References could be
54 // cast to Values.
55 // If you plan to use it with STL, for example, you will need to
56 // define 'swap` and `get`(aka std::get) methods.
57 template <typename Values, typename References>
58 class references_holder {
59 public:
60   using values = Values;
61   using references = References;
62 
63   C10_HOST_DEVICE
references_holder(references refs)64   references_holder(references refs)
65     : refs{std::move(refs)}
66   {}
67 
68   C10_HOST_DEVICE
references()69   operator references() {
70     return refs;
71   }
72 
73   C10_HOST_DEVICE
values()74   operator values() {
75     return refs;
76   }
77 
78   C10_HOST_DEVICE
79   references_holder& operator=(values vals) {
80     refs = vals;
81     return *this;
82   }
83 
84   C10_HOST_DEVICE
data()85   references& data() {
86     return refs;
87   }
88 
89 protected:
90   references refs;
91 };
92 
93 // CompositeRandomAccessor is essentially a simplified version of
94 // a random access iterator over two random access iterators.
95 // TupleInfo should contain a variadic type `tuple`, and a method `tie`,
96 // which constructs a tuple of references from a variadic list of arguments.
97 template <typename KeyAccessor, typename ValueAccessor, typename TupleInfo>
98 class CompositeRandomAccessor {
99   using self_type = CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfo>;
100 
101   using key_accessor_value_type =
102     typename std::iterator_traits<KeyAccessor>::value_type;
103   using value_accessor_value_type =
104     typename std::iterator_traits<ValueAccessor>::value_type;
105   using key_accessor_reference_type =
106     typename std::iterator_traits<KeyAccessor>::reference;
107   using value_accessor_reference_type =
108     typename std::iterator_traits<ValueAccessor>::reference;
109 
110   using composite_value_type = typename TupleInfo::template tuple<
111     key_accessor_value_type,
112     value_accessor_value_type>;
113   using composite_reference = typename TupleInfo::template tuple<
114     key_accessor_reference_type,
115     value_accessor_reference_type>;
116 
117 public:
118   using value_type = composite_value_type;
119   using reference = references_holder<composite_value_type, composite_reference>;
120   // Note that CompositeRandomAccessor does not hold key and values
121   // in a specific datastructure, which means that a pointer to a (key, value)
122   // is not defined. Hence we just use a pointer type of the KeyAccessor.
123   using pointer = typename std::iterator_traits<KeyAccessor>::pointer;
124   using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type;
125   using iterator_category = std::random_access_iterator_tag;
126 
127   C10_HOST_DEVICE
128   CompositeRandomAccessor() = default;
129 
130   C10_HOST_DEVICE
CompositeRandomAccessor(KeyAccessor keys,ValueAccessor values)131   CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
132     : keys(keys), values(values)
133   {}
134 
135   // Pointer-like operations {
136   C10_HOST_DEVICE
137   reference operator*() const {
138     return TupleInfo::tie(*keys, *values);
139   }
140 
141   // operator->() is supposed to return a pointer type.
142   // Since CompositeRandomAccessor does not hold pointers to pairs,
143   // we just return a pointer to a key.
144   C10_HOST_DEVICE
145   auto* operator->() const {
146     return keys.operator->();
147   }
148 
149   C10_HOST_DEVICE
150   reference operator[](difference_type idx) {
151     return operator_brackets_proxy<self_type>(
152       CompositeRandomAccessor(keys + idx, values + idx)
153     );
154   }
155   // }
156 
157   // Prefix/postfix increment/decrement {
158   C10_HOST_DEVICE
159   CompositeRandomAccessor& operator++() {
160     ++keys;
161     ++values;
162     return *this;
163   }
164 
165   C10_HOST_DEVICE
166   CompositeRandomAccessor operator++(int) {
167     CompositeRandomAccessor copy(*this);
168     ++*this;
169     return copy;
170   }
171 
172   C10_HOST_DEVICE
173   CompositeRandomAccessor& operator--() {
174     --keys;
175     --values;
176     return *this;
177   }
178 
179   C10_HOST_DEVICE
180   CompositeRandomAccessor operator--(int) {
181     CompositeRandomAccessor copy(*this);
182     --*this;
183     return copy;
184   }
185   // }
186 
187   // Arithmetic operations {
188   C10_HOST_DEVICE
189   CompositeRandomAccessor& operator+=(difference_type offset) {
190     keys += offset;
191     values += offset;
192     return *this;
193   }
194 
195   C10_HOST_DEVICE
196   CompositeRandomAccessor operator+(difference_type offset) const {
197     return CompositeRandomAccessor(keys + offset, values + offset);
198   }
199 
200   C10_HOST_DEVICE
201   friend CompositeRandomAccessor operator+(
202     difference_type offset,
203     const CompositeRandomAccessor& accessor
204   ) {
205     return accessor + offset;
206   }
207 
208   C10_HOST_DEVICE
209   CompositeRandomAccessor& operator-=(difference_type offset) {
210     keys -= offset;
211     values -= offset;
212     return *this;
213   }
214 
215   C10_HOST_DEVICE
216   CompositeRandomAccessor operator-(difference_type offset) const {
217     return CompositeRandomAccessor(keys - offset, values - offset);
218   }
219 
220   C10_HOST_DEVICE
221   difference_type operator-(const CompositeRandomAccessor& other) const {
222     return keys - other.keys;
223   }
224   // }
225 
226   // Comparison operators {
227   C10_HOST_DEVICE
228   bool operator==(const CompositeRandomAccessor& other) const {
229     return keys == other.keys;
230   }
231 
232   C10_HOST_DEVICE
233   bool operator!=(const CompositeRandomAccessor& other) const {
234     return keys != other.keys;
235   }
236 
237   C10_HOST_DEVICE
238   bool operator<(const CompositeRandomAccessor& other) const {
239     return keys < other.keys;
240   }
241 
242   C10_HOST_DEVICE
243   bool operator<=(const CompositeRandomAccessor& other) const {
244     return keys <= other.keys;
245   }
246 
247   C10_HOST_DEVICE
248   bool operator>(const CompositeRandomAccessor& other) const {
249     return keys > other.keys;
250   }
251 
252   C10_HOST_DEVICE
253   bool operator>=(const CompositeRandomAccessor& other) const {
254     return keys >= other.keys;
255   }
256   // }
257 
258 protected:
259   KeyAccessor keys;
260   ValueAccessor values;
261 };
262 
263 } // namespace at::native
264