xref: /aosp_15_r20/external/pytorch/c10/util/strides.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/util/ArrayRef.h>
3 #include <c10/util/DimVector.h>
4 #include <algorithm>
5 
6 namespace c10 {
7 
8 // Computes the contiguous strides of a tensor, given its sizes.
contiguous_strides(const IntArrayRef sizes)9 inline DimVector contiguous_strides(const IntArrayRef sizes) {
10   using Int = IntArrayRef::value_type;
11   const Int dims = static_cast<Int>(sizes.size());
12 
13   // With this initialisation we get the case dim == 0 or 1 right
14   DimVector strides(dims, 1);
15 
16   for (auto i = dims - 2; i >= 0; --i) {
17     // Strides can't be 0 even if sizes are 0.
18     strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1});
19   }
20 
21   return strides;
22 }
23 
24 } // namespace c10
25