xref: /aosp_15_r20/external/pytorch/c10/util/accumulate.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker // Copyright 2004-present Facebook. All Rights Reserved.
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #pragma once
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
6*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
7*da0073e9SAndroid Build Coastguard Worker #include <functional>
8*da0073e9SAndroid Build Coastguard Worker #include <iterator>
9*da0073e9SAndroid Build Coastguard Worker #include <numeric>
10*da0073e9SAndroid Build Coastguard Worker #include <type_traits>
11*da0073e9SAndroid Build Coastguard Worker #include <utility>
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker namespace c10 {
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker /// Sum of a list of integers; accumulates into the int64_t datatype
16*da0073e9SAndroid Build Coastguard Worker template <
17*da0073e9SAndroid Build Coastguard Worker     typename C,
18*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
sum_integers(const C & container)19*da0073e9SAndroid Build Coastguard Worker inline int64_t sum_integers(const C& container) {
20*da0073e9SAndroid Build Coastguard Worker   // std::accumulate infers return type from `init` type, so if the `init` type
21*da0073e9SAndroid Build Coastguard Worker   // is not large enough to hold the result, computation can overflow. We use
22*da0073e9SAndroid Build Coastguard Worker   // `int64_t` here to avoid this.
23*da0073e9SAndroid Build Coastguard Worker   return std::accumulate(
24*da0073e9SAndroid Build Coastguard Worker       container.begin(), container.end(), static_cast<int64_t>(0));
25*da0073e9SAndroid Build Coastguard Worker }
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker /// Sum of integer elements referred to by iterators; accumulates into the
28*da0073e9SAndroid Build Coastguard Worker /// int64_t datatype
29*da0073e9SAndroid Build Coastguard Worker template <
30*da0073e9SAndroid Build Coastguard Worker     typename Iter,
31*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<
32*da0073e9SAndroid Build Coastguard Worker         std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
33*da0073e9SAndroid Build Coastguard Worker         int> = 0>
sum_integers(Iter begin,Iter end)34*da0073e9SAndroid Build Coastguard Worker inline int64_t sum_integers(Iter begin, Iter end) {
35*da0073e9SAndroid Build Coastguard Worker   // std::accumulate infers return type from `init` type, so if the `init` type
36*da0073e9SAndroid Build Coastguard Worker   // is not large enough to hold the result, computation can overflow. We use
37*da0073e9SAndroid Build Coastguard Worker   // `int64_t` here to avoid this.
38*da0073e9SAndroid Build Coastguard Worker   return std::accumulate(begin, end, static_cast<int64_t>(0));
39*da0073e9SAndroid Build Coastguard Worker }
40*da0073e9SAndroid Build Coastguard Worker 
41*da0073e9SAndroid Build Coastguard Worker /// Product of a list of integers; accumulates into the int64_t datatype
42*da0073e9SAndroid Build Coastguard Worker template <
43*da0073e9SAndroid Build Coastguard Worker     typename C,
44*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
multiply_integers(const C & container)45*da0073e9SAndroid Build Coastguard Worker inline int64_t multiply_integers(const C& container) {
46*da0073e9SAndroid Build Coastguard Worker   // std::accumulate infers return type from `init` type, so if the `init` type
47*da0073e9SAndroid Build Coastguard Worker   // is not large enough to hold the result, computation can overflow. We use
48*da0073e9SAndroid Build Coastguard Worker   // `int64_t` here to avoid this.
49*da0073e9SAndroid Build Coastguard Worker   return std::accumulate(
50*da0073e9SAndroid Build Coastguard Worker       container.begin(),
51*da0073e9SAndroid Build Coastguard Worker       container.end(),
52*da0073e9SAndroid Build Coastguard Worker       static_cast<int64_t>(1),
53*da0073e9SAndroid Build Coastguard Worker       std::multiplies<>());
54*da0073e9SAndroid Build Coastguard Worker }
55*da0073e9SAndroid Build Coastguard Worker 
56*da0073e9SAndroid Build Coastguard Worker /// Product of integer elements referred to by iterators; accumulates into the
57*da0073e9SAndroid Build Coastguard Worker /// int64_t datatype
58*da0073e9SAndroid Build Coastguard Worker template <
59*da0073e9SAndroid Build Coastguard Worker     typename Iter,
60*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<
61*da0073e9SAndroid Build Coastguard Worker         std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
62*da0073e9SAndroid Build Coastguard Worker         int> = 0>
multiply_integers(Iter begin,Iter end)63*da0073e9SAndroid Build Coastguard Worker inline int64_t multiply_integers(Iter begin, Iter end) {
64*da0073e9SAndroid Build Coastguard Worker   // std::accumulate infers return type from `init` type, so if the `init` type
65*da0073e9SAndroid Build Coastguard Worker   // is not large enough to hold the result, computation can overflow. We use
66*da0073e9SAndroid Build Coastguard Worker   // `int64_t` here to avoid this.
67*da0073e9SAndroid Build Coastguard Worker   return std::accumulate(
68*da0073e9SAndroid Build Coastguard Worker       begin, end, static_cast<int64_t>(1), std::multiplies<>());
69*da0073e9SAndroid Build Coastguard Worker }
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker /// Return product of all dimensions starting from k
72*da0073e9SAndroid Build Coastguard Worker /// Returns 1 if k>=dims.size()
73*da0073e9SAndroid Build Coastguard Worker template <
74*da0073e9SAndroid Build Coastguard Worker     typename C,
75*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
numelements_from_dim(const int k,const C & dims)76*da0073e9SAndroid Build Coastguard Worker inline int64_t numelements_from_dim(const int k, const C& dims) {
77*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0);
78*da0073e9SAndroid Build Coastguard Worker 
79*da0073e9SAndroid Build Coastguard Worker   if (k > static_cast<int>(dims.size())) {
80*da0073e9SAndroid Build Coastguard Worker     return 1;
81*da0073e9SAndroid Build Coastguard Worker   } else {
82*da0073e9SAndroid Build Coastguard Worker     auto cbegin = dims.cbegin();
83*da0073e9SAndroid Build Coastguard Worker     std::advance(cbegin, k);
84*da0073e9SAndroid Build Coastguard Worker     return multiply_integers(cbegin, dims.cend());
85*da0073e9SAndroid Build Coastguard Worker   }
86*da0073e9SAndroid Build Coastguard Worker }
87*da0073e9SAndroid Build Coastguard Worker 
88*da0073e9SAndroid Build Coastguard Worker /// Product of all dims up to k (not including dims[k])
89*da0073e9SAndroid Build Coastguard Worker /// Throws an error if k>dims.size()
90*da0073e9SAndroid Build Coastguard Worker template <
91*da0073e9SAndroid Build Coastguard Worker     typename C,
92*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
numelements_to_dim(const int k,const C & dims)93*da0073e9SAndroid Build Coastguard Worker inline int64_t numelements_to_dim(const int k, const C& dims) {
94*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(0 <= k);
95*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size());
96*da0073e9SAndroid Build Coastguard Worker 
97*da0073e9SAndroid Build Coastguard Worker   auto cend = dims.cbegin();
98*da0073e9SAndroid Build Coastguard Worker   std::advance(cend, k);
99*da0073e9SAndroid Build Coastguard Worker   return multiply_integers(dims.cbegin(), cend);
100*da0073e9SAndroid Build Coastguard Worker }
101*da0073e9SAndroid Build Coastguard Worker 
102*da0073e9SAndroid Build Coastguard Worker /// Product of all dims between k and l (including dims[k] and excluding
103*da0073e9SAndroid Build Coastguard Worker /// dims[l]) k and l may be supplied in either order
104*da0073e9SAndroid Build Coastguard Worker template <
105*da0073e9SAndroid Build Coastguard Worker     typename C,
106*da0073e9SAndroid Build Coastguard Worker     std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
numelements_between_dim(int k,int l,const C & dims)107*da0073e9SAndroid Build Coastguard Worker inline int64_t numelements_between_dim(int k, int l, const C& dims) {
108*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(0 <= k);
109*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(0 <= l);
110*da0073e9SAndroid Build Coastguard Worker 
111*da0073e9SAndroid Build Coastguard Worker   if (k > l) {
112*da0073e9SAndroid Build Coastguard Worker     std::swap(k, l);
113*da0073e9SAndroid Build Coastguard Worker   }
114*da0073e9SAndroid Build Coastguard Worker 
115*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT((unsigned)l < dims.size());
116*da0073e9SAndroid Build Coastguard Worker 
117*da0073e9SAndroid Build Coastguard Worker   auto cbegin = dims.cbegin();
118*da0073e9SAndroid Build Coastguard Worker   auto cend = dims.cbegin();
119*da0073e9SAndroid Build Coastguard Worker   std::advance(cbegin, k);
120*da0073e9SAndroid Build Coastguard Worker   std::advance(cend, l);
121*da0073e9SAndroid Build Coastguard Worker   return multiply_integers(cbegin, cend);
122*da0073e9SAndroid Build Coastguard Worker }
123*da0073e9SAndroid Build Coastguard Worker 
124*da0073e9SAndroid Build Coastguard Worker } // namespace c10
125