1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ArrayRef.h> 3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/DimVector.h> 4*da0073e9SAndroid Build Coastguard Worker #include <algorithm> 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker namespace c10 { 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker // Computes the contiguous strides of a tensor, given its sizes. contiguous_strides(const IntArrayRef sizes)9*da0073e9SAndroid Build Coastguard Workerinline DimVector contiguous_strides(const IntArrayRef sizes) { 10*da0073e9SAndroid Build Coastguard Worker using Int = IntArrayRef::value_type; 11*da0073e9SAndroid Build Coastguard Worker const Int dims = static_cast<Int>(sizes.size()); 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker // With this initialisation we get the case dim == 0 or 1 right 14*da0073e9SAndroid Build Coastguard Worker DimVector strides(dims, 1); 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker for (auto i = dims - 2; i >= 0; --i) { 17*da0073e9SAndroid Build Coastguard Worker // Strides can't be 0 even if sizes are 0. 18*da0073e9SAndroid Build Coastguard Worker strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1}); 19*da0073e9SAndroid Build Coastguard Worker } 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker return strides; 22*da0073e9SAndroid Build Coastguard Worker } 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker } // namespace c10 25