xref: /aosp_15_r20/external/pytorch/c10/util/irange.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker // Copyright 2004-present Facebook. All Rights Reserved.
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #pragma once
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/TypeSafeSignMath.h>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
9*da0073e9SAndroid Build Coastguard Worker #include <cstddef>
10*da0073e9SAndroid Build Coastguard Worker #include <iterator>
11*da0073e9SAndroid Build Coastguard Worker #include <type_traits>
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker namespace c10 {
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker namespace detail {
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker template <
18*da0073e9SAndroid Build Coastguard Worker     typename I,
19*da0073e9SAndroid Build Coastguard Worker     bool one_sided = false,
20*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<I>, int> = 0>
21*da0073e9SAndroid Build Coastguard Worker struct integer_iterator {
22*da0073e9SAndroid Build Coastguard Worker   using iterator_category = std::input_iterator_tag;
23*da0073e9SAndroid Build Coastguard Worker   using value_type = I;
24*da0073e9SAndroid Build Coastguard Worker   using difference_type = std::ptrdiff_t;
25*da0073e9SAndroid Build Coastguard Worker   using pointer = I*;
26*da0073e9SAndroid Build Coastguard Worker   using reference = I&;
27*da0073e9SAndroid Build Coastguard Worker 
integer_iteratorinteger_iterator28*da0073e9SAndroid Build Coastguard Worker   explicit integer_iterator(I value) : value(value) {}
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker   I operator*() const {
31*da0073e9SAndroid Build Coastguard Worker     return value;
32*da0073e9SAndroid Build Coastguard Worker   }
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker   I const* operator->() const {
35*da0073e9SAndroid Build Coastguard Worker     return &value;
36*da0073e9SAndroid Build Coastguard Worker   }
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker   integer_iterator& operator++() {
39*da0073e9SAndroid Build Coastguard Worker     ++value;
40*da0073e9SAndroid Build Coastguard Worker     return *this;
41*da0073e9SAndroid Build Coastguard Worker   }
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker   integer_iterator operator++(int) {
44*da0073e9SAndroid Build Coastguard Worker     const auto copy = *this;
45*da0073e9SAndroid Build Coastguard Worker     ++*this;
46*da0073e9SAndroid Build Coastguard Worker     return copy;
47*da0073e9SAndroid Build Coastguard Worker   }
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   bool operator==(const integer_iterator& other) const {
50*da0073e9SAndroid Build Coastguard Worker     if constexpr (one_sided) {
51*da0073e9SAndroid Build Coastguard Worker       // Range-for loops' end test is `begin != end`, not `begin <
52*da0073e9SAndroid Build Coastguard Worker       // end`. To handle `c10::irange(n)` where n < 0 (which should be
53*da0073e9SAndroid Build Coastguard Worker       // empty), we just make `begin != end` fail whenever `end` is
54*da0073e9SAndroid Build Coastguard Worker       // negative.
55*da0073e9SAndroid Build Coastguard Worker       return is_negative(other.value) || value == other.value;
56*da0073e9SAndroid Build Coastguard Worker     } else {
57*da0073e9SAndroid Build Coastguard Worker       return value == other.value;
58*da0073e9SAndroid Build Coastguard Worker     }
59*da0073e9SAndroid Build Coastguard Worker     // Suppress "warning: missing return statement at end of non-void function"
60*da0073e9SAndroid Build Coastguard Worker     // which Nvidia's Robert Crovella confirms is an NVCC compiler error
61*da0073e9SAndroid Build Coastguard Worker     // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27
62*da0073e9SAndroid Build Coastguard Worker     // `__builtin_unreachable();` would be best here, but it's not
63*da0073e9SAndroid Build Coastguard Worker     // available with all compilers. So we instead return an arbitrary
64*da0073e9SAndroid Build Coastguard Worker     // value trusting that this line will, in fact, never be reached.
65*da0073e9SAndroid Build Coastguard Worker     return false; // Horrible hack
66*da0073e9SAndroid Build Coastguard Worker   }
67*da0073e9SAndroid Build Coastguard Worker 
68*da0073e9SAndroid Build Coastguard Worker   bool operator!=(const integer_iterator& other) const {
69*da0073e9SAndroid Build Coastguard Worker     return !(*this == other);
70*da0073e9SAndroid Build Coastguard Worker   }
71*da0073e9SAndroid Build Coastguard Worker 
72*da0073e9SAndroid Build Coastguard Worker  protected:
73*da0073e9SAndroid Build Coastguard Worker   I value;
74*da0073e9SAndroid Build Coastguard Worker };
75*da0073e9SAndroid Build Coastguard Worker 
76*da0073e9SAndroid Build Coastguard Worker } // namespace detail
77*da0073e9SAndroid Build Coastguard Worker 
78*da0073e9SAndroid Build Coastguard Worker template <
79*da0073e9SAndroid Build Coastguard Worker     typename I,
80*da0073e9SAndroid Build Coastguard Worker     bool one_sided = false,
81*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<I>, bool> = true>
82*da0073e9SAndroid Build Coastguard Worker struct integer_range {
83*da0073e9SAndroid Build Coastguard Worker  public:
integer_rangeinteger_range84*da0073e9SAndroid Build Coastguard Worker   integer_range(I begin, I end) : begin_(begin), end_(end) {}
85*da0073e9SAndroid Build Coastguard Worker   using iterator = detail::integer_iterator<I, one_sided>;
begininteger_range86*da0073e9SAndroid Build Coastguard Worker   iterator begin() const {
87*da0073e9SAndroid Build Coastguard Worker     return begin_;
88*da0073e9SAndroid Build Coastguard Worker   }
endinteger_range89*da0073e9SAndroid Build Coastguard Worker   iterator end() const {
90*da0073e9SAndroid Build Coastguard Worker     return end_;
91*da0073e9SAndroid Build Coastguard Worker   }
92*da0073e9SAndroid Build Coastguard Worker 
93*da0073e9SAndroid Build Coastguard Worker  private:
94*da0073e9SAndroid Build Coastguard Worker   iterator begin_;
95*da0073e9SAndroid Build Coastguard Worker   iterator end_;
96*da0073e9SAndroid Build Coastguard Worker };
97*da0073e9SAndroid Build Coastguard Worker 
98*da0073e9SAndroid Build Coastguard Worker /// Creates an integer range for the half-open interval [begin, end)
99*da0073e9SAndroid Build Coastguard Worker /// If end<=begin, then the range is empty.
100*da0073e9SAndroid Build Coastguard Worker /// The range has the type of the `end` integer; `begin` integer is
101*da0073e9SAndroid Build Coastguard Worker /// cast to this type.
102*da0073e9SAndroid Build Coastguard Worker template <
103*da0073e9SAndroid Build Coastguard Worker     typename Integer1,
104*da0073e9SAndroid Build Coastguard Worker     typename Integer2,
105*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<Integer1>, bool> = true,
106*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<Integer2>, bool> = true>
irange(Integer1 begin,Integer2 end)107*da0073e9SAndroid Build Coastguard Worker integer_range<Integer2> irange(Integer1 begin, Integer2 end) {
108*da0073e9SAndroid Build Coastguard Worker   // If end<=begin then the range is empty; we can achieve this effect by
109*da0073e9SAndroid Build Coastguard Worker   // choosing the larger of {begin, end} as the loop terminator
110*da0073e9SAndroid Build Coastguard Worker   return {
111*da0073e9SAndroid Build Coastguard Worker       static_cast<Integer2>(begin),
112*da0073e9SAndroid Build Coastguard Worker       std::max(static_cast<Integer2>(begin), end)};
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker 
115*da0073e9SAndroid Build Coastguard Worker /// Creates an integer range for the half-open interval [0, end)
116*da0073e9SAndroid Build Coastguard Worker /// If end<=begin, then the range is empty
117*da0073e9SAndroid Build Coastguard Worker template <
118*da0073e9SAndroid Build Coastguard Worker     typename Integer,
119*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<Integer>, bool> = true>
irange(Integer end)120*da0073e9SAndroid Build Coastguard Worker integer_range<Integer, true> irange(Integer end) {
121*da0073e9SAndroid Build Coastguard Worker   return {Integer(), end};
122*da0073e9SAndroid Build Coastguard Worker }
123*da0073e9SAndroid Build Coastguard Worker 
124*da0073e9SAndroid Build Coastguard Worker } // namespace c10
125