1 #include <utility> 2 3 #pragma once 4 5 namespace at::native { 6 7 namespace { 8 9 // operator_brackets_proxy is used in 10 // CompositeRandomAccessor in place of operator[]. 11 // For some iterators, references returned by operator[] 12 // could become invalid, operator_brackets_proxy tries to 13 // resolve that by making accessor[n] to be equivalent to 14 // *(accessor + n). 15 template <typename Accessor> 16 class operator_brackets_proxy { 17 using reference = typename std::iterator_traits<Accessor>::reference; 18 using value_type = typename std::iterator_traits<Accessor>::value_type; 19 20 public: 21 C10_HOST_DEVICE operator_brackets_proxy(Accessor const & accessor)22 operator_brackets_proxy(Accessor const& accessor) 23 : accessor(accessor) 24 {} 25 26 C10_HOST_DEVICE reference()27 operator reference() { 28 return *accessor; 29 } 30 31 C10_HOST_DEVICE 32 reference operator*() { 33 return *accessor; 34 } 35 36 C10_HOST_DEVICE 37 operator_brackets_proxy& operator=(value_type const& val) { 38 *accessor = val; 39 return *this; 40 } 41 42 private: 43 Accessor accessor; 44 }; 45 46 } 47 48 // references_holder is used as a surrogate for the 49 // references type from std::iterator_traits in CompositeRandomAccessor. 50 // It is assumed in CompositeRandomAccessor that 51 // References = tuple<Types&...>, 52 // Values = tuple<Types...> by default, 53 // but they could be anything as long as References could be 54 // cast to Values. 55 // If you plan to use it with STL, for example, you will need to 56 // define 'swap` and `get`(aka std::get) methods. 57 template <typename Values, typename References> 58 class references_holder { 59 public: 60 using values = Values; 61 using references = References; 62 63 C10_HOST_DEVICE references_holder(references refs)64 references_holder(references refs) 65 : refs{std::move(refs)} 66 {} 67 68 C10_HOST_DEVICE references()69 operator references() { 70 return refs; 71 } 72 73 C10_HOST_DEVICE values()74 operator values() { 75 return refs; 76 } 77 78 C10_HOST_DEVICE 79 references_holder& operator=(values vals) { 80 refs = vals; 81 return *this; 82 } 83 84 C10_HOST_DEVICE data()85 references& data() { 86 return refs; 87 } 88 89 protected: 90 references refs; 91 }; 92 93 // CompositeRandomAccessor is essentially a simplified version of 94 // a random access iterator over two random access iterators. 95 // TupleInfo should contain a variadic type `tuple`, and a method `tie`, 96 // which constructs a tuple of references from a variadic list of arguments. 97 template <typename KeyAccessor, typename ValueAccessor, typename TupleInfo> 98 class CompositeRandomAccessor { 99 using self_type = CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfo>; 100 101 using key_accessor_value_type = 102 typename std::iterator_traits<KeyAccessor>::value_type; 103 using value_accessor_value_type = 104 typename std::iterator_traits<ValueAccessor>::value_type; 105 using key_accessor_reference_type = 106 typename std::iterator_traits<KeyAccessor>::reference; 107 using value_accessor_reference_type = 108 typename std::iterator_traits<ValueAccessor>::reference; 109 110 using composite_value_type = typename TupleInfo::template tuple< 111 key_accessor_value_type, 112 value_accessor_value_type>; 113 using composite_reference = typename TupleInfo::template tuple< 114 key_accessor_reference_type, 115 value_accessor_reference_type>; 116 117 public: 118 using value_type = composite_value_type; 119 using reference = references_holder<composite_value_type, composite_reference>; 120 // Note that CompositeRandomAccessor does not hold key and values 121 // in a specific datastructure, which means that a pointer to a (key, value) 122 // is not defined. Hence we just use a pointer type of the KeyAccessor. 123 using pointer = typename std::iterator_traits<KeyAccessor>::pointer; 124 using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type; 125 using iterator_category = std::random_access_iterator_tag; 126 127 C10_HOST_DEVICE 128 CompositeRandomAccessor() = default; 129 130 C10_HOST_DEVICE CompositeRandomAccessor(KeyAccessor keys,ValueAccessor values)131 CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values) 132 : keys(keys), values(values) 133 {} 134 135 // Pointer-like operations { 136 C10_HOST_DEVICE 137 reference operator*() const { 138 return TupleInfo::tie(*keys, *values); 139 } 140 141 // operator->() is supposed to return a pointer type. 142 // Since CompositeRandomAccessor does not hold pointers to pairs, 143 // we just return a pointer to a key. 144 C10_HOST_DEVICE 145 auto* operator->() const { 146 return keys.operator->(); 147 } 148 149 C10_HOST_DEVICE 150 reference operator[](difference_type idx) { 151 return operator_brackets_proxy<self_type>( 152 CompositeRandomAccessor(keys + idx, values + idx) 153 ); 154 } 155 // } 156 157 // Prefix/postfix increment/decrement { 158 C10_HOST_DEVICE 159 CompositeRandomAccessor& operator++() { 160 ++keys; 161 ++values; 162 return *this; 163 } 164 165 C10_HOST_DEVICE 166 CompositeRandomAccessor operator++(int) { 167 CompositeRandomAccessor copy(*this); 168 ++*this; 169 return copy; 170 } 171 172 C10_HOST_DEVICE 173 CompositeRandomAccessor& operator--() { 174 --keys; 175 --values; 176 return *this; 177 } 178 179 C10_HOST_DEVICE 180 CompositeRandomAccessor operator--(int) { 181 CompositeRandomAccessor copy(*this); 182 --*this; 183 return copy; 184 } 185 // } 186 187 // Arithmetic operations { 188 C10_HOST_DEVICE 189 CompositeRandomAccessor& operator+=(difference_type offset) { 190 keys += offset; 191 values += offset; 192 return *this; 193 } 194 195 C10_HOST_DEVICE 196 CompositeRandomAccessor operator+(difference_type offset) const { 197 return CompositeRandomAccessor(keys + offset, values + offset); 198 } 199 200 C10_HOST_DEVICE 201 friend CompositeRandomAccessor operator+( 202 difference_type offset, 203 const CompositeRandomAccessor& accessor 204 ) { 205 return accessor + offset; 206 } 207 208 C10_HOST_DEVICE 209 CompositeRandomAccessor& operator-=(difference_type offset) { 210 keys -= offset; 211 values -= offset; 212 return *this; 213 } 214 215 C10_HOST_DEVICE 216 CompositeRandomAccessor operator-(difference_type offset) const { 217 return CompositeRandomAccessor(keys - offset, values - offset); 218 } 219 220 C10_HOST_DEVICE 221 difference_type operator-(const CompositeRandomAccessor& other) const { 222 return keys - other.keys; 223 } 224 // } 225 226 // Comparison operators { 227 C10_HOST_DEVICE 228 bool operator==(const CompositeRandomAccessor& other) const { 229 return keys == other.keys; 230 } 231 232 C10_HOST_DEVICE 233 bool operator!=(const CompositeRandomAccessor& other) const { 234 return keys != other.keys; 235 } 236 237 C10_HOST_DEVICE 238 bool operator<(const CompositeRandomAccessor& other) const { 239 return keys < other.keys; 240 } 241 242 C10_HOST_DEVICE 243 bool operator<=(const CompositeRandomAccessor& other) const { 244 return keys <= other.keys; 245 } 246 247 C10_HOST_DEVICE 248 bool operator>(const CompositeRandomAccessor& other) const { 249 return keys > other.keys; 250 } 251 252 C10_HOST_DEVICE 253 bool operator>=(const CompositeRandomAccessor& other) const { 254 return keys >= other.keys; 255 } 256 // } 257 258 protected: 259 KeyAccessor keys; 260 ValueAccessor values; 261 }; 262 263 } // namespace at::native 264