xref: /aosp_15_r20/external/pytorch/c10/util/strides.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ArrayRef.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/DimVector.h>
4*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace c10 {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker // Computes the contiguous strides of a tensor, given its sizes.
contiguous_strides(const IntArrayRef sizes)9*da0073e9SAndroid Build Coastguard Worker inline DimVector contiguous_strides(const IntArrayRef sizes) {
10*da0073e9SAndroid Build Coastguard Worker   using Int = IntArrayRef::value_type;
11*da0073e9SAndroid Build Coastguard Worker   const Int dims = static_cast<Int>(sizes.size());
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker   // With this initialisation we get the case dim == 0 or 1 right
14*da0073e9SAndroid Build Coastguard Worker   DimVector strides(dims, 1);
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker   for (auto i = dims - 2; i >= 0; --i) {
17*da0073e9SAndroid Build Coastguard Worker     // Strides can't be 0 even if sizes are 0.
18*da0073e9SAndroid Build Coastguard Worker     strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1});
19*da0073e9SAndroid Build Coastguard Worker   }
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker   return strides;
22*da0073e9SAndroid Build Coastguard Worker }
23*da0073e9SAndroid Build Coastguard Worker 
24*da0073e9SAndroid Build Coastguard Worker } // namespace c10
25