1 // Copyright 2004-present Facebook. All Rights Reserved.
2
3 #pragma once
4
5 #include <c10/util/Exception.h>
6 #include <cstdint>
7 #include <functional>
8 #include <iterator>
9 #include <numeric>
10 #include <type_traits>
11 #include <utility>
12
13 namespace c10 {
14
15 /// Sum of a list of integers; accumulates into the int64_t datatype
16 template <
17 typename C,
18 std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
sum_integers(const C & container)19 inline int64_t sum_integers(const C& container) {
20 // std::accumulate infers return type from `init` type, so if the `init` type
21 // is not large enough to hold the result, computation can overflow. We use
22 // `int64_t` here to avoid this.
23 return std::accumulate(
24 container.begin(), container.end(), static_cast<int64_t>(0));
25 }
26
27 /// Sum of integer elements referred to by iterators; accumulates into the
28 /// int64_t datatype
29 template <
30 typename Iter,
31 std::enable_if_t<
32 std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
33 int> = 0>
sum_integers(Iter begin,Iter end)34 inline int64_t sum_integers(Iter begin, Iter end) {
35 // std::accumulate infers return type from `init` type, so if the `init` type
36 // is not large enough to hold the result, computation can overflow. We use
37 // `int64_t` here to avoid this.
38 return std::accumulate(begin, end, static_cast<int64_t>(0));
39 }
40
41 /// Product of a list of integers; accumulates into the int64_t datatype
42 template <
43 typename C,
44 std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
multiply_integers(const C & container)45 inline int64_t multiply_integers(const C& container) {
46 // std::accumulate infers return type from `init` type, so if the `init` type
47 // is not large enough to hold the result, computation can overflow. We use
48 // `int64_t` here to avoid this.
49 return std::accumulate(
50 container.begin(),
51 container.end(),
52 static_cast<int64_t>(1),
53 std::multiplies<>());
54 }
55
56 /// Product of integer elements referred to by iterators; accumulates into the
57 /// int64_t datatype
58 template <
59 typename Iter,
60 std::enable_if_t<
61 std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
62 int> = 0>
multiply_integers(Iter begin,Iter end)63 inline int64_t multiply_integers(Iter begin, Iter end) {
64 // std::accumulate infers return type from `init` type, so if the `init` type
65 // is not large enough to hold the result, computation can overflow. We use
66 // `int64_t` here to avoid this.
67 return std::accumulate(
68 begin, end, static_cast<int64_t>(1), std::multiplies<>());
69 }
70
71 /// Return product of all dimensions starting from k
72 /// Returns 1 if k>=dims.size()
73 template <
74 typename C,
75 std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
numelements_from_dim(const int k,const C & dims)76 inline int64_t numelements_from_dim(const int k, const C& dims) {
77 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0);
78
79 if (k > static_cast<int>(dims.size())) {
80 return 1;
81 } else {
82 auto cbegin = dims.cbegin();
83 std::advance(cbegin, k);
84 return multiply_integers(cbegin, dims.cend());
85 }
86 }
87
88 /// Product of all dims up to k (not including dims[k])
89 /// Throws an error if k>dims.size()
90 template <
91 typename C,
92 std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
numelements_to_dim(const int k,const C & dims)93 inline int64_t numelements_to_dim(const int k, const C& dims) {
94 TORCH_INTERNAL_ASSERT(0 <= k);
95 TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size());
96
97 auto cend = dims.cbegin();
98 std::advance(cend, k);
99 return multiply_integers(dims.cbegin(), cend);
100 }
101
102 /// Product of all dims between k and l (including dims[k] and excluding
103 /// dims[l]) k and l may be supplied in either order
104 template <
105 typename C,
106 std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
numelements_between_dim(int k,int l,const C & dims)107 inline int64_t numelements_between_dim(int k, int l, const C& dims) {
108 TORCH_INTERNAL_ASSERT(0 <= k);
109 TORCH_INTERNAL_ASSERT(0 <= l);
110
111 if (k > l) {
112 std::swap(k, l);
113 }
114
115 TORCH_INTERNAL_ASSERT((unsigned)l < dims.size());
116
117 auto cbegin = dims.cbegin();
118 auto cend = dims.cbegin();
119 std::advance(cbegin, k);
120 std::advance(cend, l);
121 return multiply_integers(cbegin, cend);
122 }
123
124 } // namespace c10
125