xref: /aosp_15_r20/external/pytorch/c10/util/irange.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 
3 #pragma once
4 
5 #include <c10/util/Exception.h>
6 #include <c10/util/TypeSafeSignMath.h>
7 
8 #include <algorithm>
9 #include <cstddef>
10 #include <iterator>
11 #include <type_traits>
12 
13 namespace c10 {
14 
15 namespace detail {
16 
17 template <
18     typename I,
19     bool one_sided = false,
20     std::enable_if_t<std::is_integral_v<I>, int> = 0>
21 struct integer_iterator {
22   using iterator_category = std::input_iterator_tag;
23   using value_type = I;
24   using difference_type = std::ptrdiff_t;
25   using pointer = I*;
26   using reference = I&;
27 
integer_iteratorinteger_iterator28   explicit integer_iterator(I value) : value(value) {}
29 
30   I operator*() const {
31     return value;
32   }
33 
34   I const* operator->() const {
35     return &value;
36   }
37 
38   integer_iterator& operator++() {
39     ++value;
40     return *this;
41   }
42 
43   integer_iterator operator++(int) {
44     const auto copy = *this;
45     ++*this;
46     return copy;
47   }
48 
49   bool operator==(const integer_iterator& other) const {
50     if constexpr (one_sided) {
51       // Range-for loops' end test is `begin != end`, not `begin <
52       // end`. To handle `c10::irange(n)` where n < 0 (which should be
53       // empty), we just make `begin != end` fail whenever `end` is
54       // negative.
55       return is_negative(other.value) || value == other.value;
56     } else {
57       return value == other.value;
58     }
59     // Suppress "warning: missing return statement at end of non-void function"
60     // which Nvidia's Robert Crovella confirms is an NVCC compiler error
61     // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27
62     // `__builtin_unreachable();` would be best here, but it's not
63     // available with all compilers. So we instead return an arbitrary
64     // value trusting that this line will, in fact, never be reached.
65     return false; // Horrible hack
66   }
67 
68   bool operator!=(const integer_iterator& other) const {
69     return !(*this == other);
70   }
71 
72  protected:
73   I value;
74 };
75 
76 } // namespace detail
77 
78 template <
79     typename I,
80     bool one_sided = false,
81     std::enable_if_t<std::is_integral_v<I>, bool> = true>
82 struct integer_range {
83  public:
integer_rangeinteger_range84   integer_range(I begin, I end) : begin_(begin), end_(end) {}
85   using iterator = detail::integer_iterator<I, one_sided>;
begininteger_range86   iterator begin() const {
87     return begin_;
88   }
endinteger_range89   iterator end() const {
90     return end_;
91   }
92 
93  private:
94   iterator begin_;
95   iterator end_;
96 };
97 
98 /// Creates an integer range for the half-open interval [begin, end)
99 /// If end<=begin, then the range is empty.
100 /// The range has the type of the `end` integer; `begin` integer is
101 /// cast to this type.
102 template <
103     typename Integer1,
104     typename Integer2,
105     std::enable_if_t<std::is_integral_v<Integer1>, bool> = true,
106     std::enable_if_t<std::is_integral_v<Integer2>, bool> = true>
irange(Integer1 begin,Integer2 end)107 integer_range<Integer2> irange(Integer1 begin, Integer2 end) {
108   // If end<=begin then the range is empty; we can achieve this effect by
109   // choosing the larger of {begin, end} as the loop terminator
110   return {
111       static_cast<Integer2>(begin),
112       std::max(static_cast<Integer2>(begin), end)};
113 }
114 
115 /// Creates an integer range for the half-open interval [0, end)
116 /// If end<=begin, then the range is empty
117 template <
118     typename Integer,
119     std::enable_if_t<std::is_integral_v<Integer>, bool> = true>
irange(Integer end)120 integer_range<Integer, true> irange(Integer end) {
121   return {Integer(), end};
122 }
123 
124 } // namespace c10
125