xref: /aosp_15_r20/external/pytorch/c10/util/accumulate.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 
3 #pragma once
4 
5 #include <c10/util/Exception.h>
6 #include <cstdint>
7 #include <functional>
8 #include <iterator>
9 #include <numeric>
10 #include <type_traits>
11 #include <utility>
12 
13 namespace c10 {
14 
15 /// Sum of a list of integers; accumulates into the int64_t datatype
16 template <
17     typename C,
18     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
sum_integers(const C & container)19 inline int64_t sum_integers(const C& container) {
20   // std::accumulate infers return type from `init` type, so if the `init` type
21   // is not large enough to hold the result, computation can overflow. We use
22   // `int64_t` here to avoid this.
23   return std::accumulate(
24       container.begin(), container.end(), static_cast<int64_t>(0));
25 }
26 
27 /// Sum of integer elements referred to by iterators; accumulates into the
28 /// int64_t datatype
29 template <
30     typename Iter,
31     std::enable_if_t<
32         std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
33         int> = 0>
sum_integers(Iter begin,Iter end)34 inline int64_t sum_integers(Iter begin, Iter end) {
35   // std::accumulate infers return type from `init` type, so if the `init` type
36   // is not large enough to hold the result, computation can overflow. We use
37   // `int64_t` here to avoid this.
38   return std::accumulate(begin, end, static_cast<int64_t>(0));
39 }
40 
41 /// Product of a list of integers; accumulates into the int64_t datatype
42 template <
43     typename C,
44     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
multiply_integers(const C & container)45 inline int64_t multiply_integers(const C& container) {
46   // std::accumulate infers return type from `init` type, so if the `init` type
47   // is not large enough to hold the result, computation can overflow. We use
48   // `int64_t` here to avoid this.
49   return std::accumulate(
50       container.begin(),
51       container.end(),
52       static_cast<int64_t>(1),
53       std::multiplies<>());
54 }
55 
56 /// Product of integer elements referred to by iterators; accumulates into the
57 /// int64_t datatype
58 template <
59     typename Iter,
60     std::enable_if_t<
61         std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
62         int> = 0>
multiply_integers(Iter begin,Iter end)63 inline int64_t multiply_integers(Iter begin, Iter end) {
64   // std::accumulate infers return type from `init` type, so if the `init` type
65   // is not large enough to hold the result, computation can overflow. We use
66   // `int64_t` here to avoid this.
67   return std::accumulate(
68       begin, end, static_cast<int64_t>(1), std::multiplies<>());
69 }
70 
71 /// Return product of all dimensions starting from k
72 /// Returns 1 if k>=dims.size()
73 template <
74     typename C,
75     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
numelements_from_dim(const int k,const C & dims)76 inline int64_t numelements_from_dim(const int k, const C& dims) {
77   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0);
78 
79   if (k > static_cast<int>(dims.size())) {
80     return 1;
81   } else {
82     auto cbegin = dims.cbegin();
83     std::advance(cbegin, k);
84     return multiply_integers(cbegin, dims.cend());
85   }
86 }
87 
88 /// Product of all dims up to k (not including dims[k])
89 /// Throws an error if k>dims.size()
90 template <
91     typename C,
92     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
numelements_to_dim(const int k,const C & dims)93 inline int64_t numelements_to_dim(const int k, const C& dims) {
94   TORCH_INTERNAL_ASSERT(0 <= k);
95   TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size());
96 
97   auto cend = dims.cbegin();
98   std::advance(cend, k);
99   return multiply_integers(dims.cbegin(), cend);
100 }
101 
102 /// Product of all dims between k and l (including dims[k] and excluding
103 /// dims[l]) k and l may be supplied in either order
104 template <
105     typename C,
106     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
numelements_between_dim(int k,int l,const C & dims)107 inline int64_t numelements_between_dim(int k, int l, const C& dims) {
108   TORCH_INTERNAL_ASSERT(0 <= k);
109   TORCH_INTERNAL_ASSERT(0 <= l);
110 
111   if (k > l) {
112     std::swap(k, l);
113   }
114 
115   TORCH_INTERNAL_ASSERT((unsigned)l < dims.size());
116 
117   auto cbegin = dims.cbegin();
118   auto cend = dims.cbegin();
119   std::advance(cbegin, k);
120   std::advance(cend, l);
121   return multiply_integers(cbegin, cend);
122 }
123 
124 } // namespace c10
125