1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker #include <ATen/Utils.h> 3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ArrayRef.h> 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker namespace at { 6*da0073e9SAndroid Build Coastguard Worker /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that 7*da0073e9SAndroid Build Coastguard Worker /// we can easily view it as a multidimensional array. 8*da0073e9SAndroid Build Coastguard Worker /// 9*da0073e9SAndroid Build Coastguard Worker /// Like ArrayRef, this class does not own the underlying data, it is expected 10*da0073e9SAndroid Build Coastguard Worker /// to be used in situations where the data resides in some other buffer. 11*da0073e9SAndroid Build Coastguard Worker /// 12*da0073e9SAndroid Build Coastguard Worker /// This is intended to be trivially copyable, so it should be passed by 13*da0073e9SAndroid Build Coastguard Worker /// value. 14*da0073e9SAndroid Build Coastguard Worker /// 15*da0073e9SAndroid Build Coastguard Worker /// For now, 2D only (so the copies are actually cheap, without having 16*da0073e9SAndroid Build Coastguard Worker /// to write a SmallVector class) and contiguous only (so we can 17*da0073e9SAndroid Build Coastguard Worker /// return non-strided ArrayRef on index). 18*da0073e9SAndroid Build Coastguard Worker /// 19*da0073e9SAndroid Build Coastguard Worker /// P.S. dimension 0 indexes rows, dimension 1 indexes columns 20*da0073e9SAndroid Build Coastguard Worker template <typename T> 21*da0073e9SAndroid Build Coastguard Worker class MatrixRef { 22*da0073e9SAndroid Build Coastguard Worker public: 23*da0073e9SAndroid Build Coastguard Worker typedef size_t size_type; 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker private: 26*da0073e9SAndroid Build Coastguard Worker /// Underlying ArrayRef 27*da0073e9SAndroid Build Coastguard Worker ArrayRef<T> arr; 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker /// Stride of dim 0 (outer dimension) 30*da0073e9SAndroid Build Coastguard Worker size_type stride0; 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker // Stride of dim 1 is assumed to be 1 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker public: 35*da0073e9SAndroid Build Coastguard Worker /// Construct an empty Matrixref. MatrixRef()36*da0073e9SAndroid Build Coastguard Worker /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {} 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker /// Construct an MatrixRef from an ArrayRef and outer stride. MatrixRef(ArrayRef<T> arr,size_type stride0)39*da0073e9SAndroid Build Coastguard Worker /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0) 40*da0073e9SAndroid Build Coastguard Worker : arr(arr), stride0(stride0) { 41*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK( 42*da0073e9SAndroid Build Coastguard Worker arr.size() % stride0 == 0, 43*da0073e9SAndroid Build Coastguard Worker "MatrixRef: ArrayRef size ", 44*da0073e9SAndroid Build Coastguard Worker arr.size(), 45*da0073e9SAndroid Build Coastguard Worker " not divisible by stride ", 46*da0073e9SAndroid Build Coastguard Worker stride0) 47*da0073e9SAndroid Build Coastguard Worker } 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker /// @} 50*da0073e9SAndroid Build Coastguard Worker /// @name Simple Operations 51*da0073e9SAndroid Build Coastguard Worker /// @{ 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker /// empty - Check if the matrix is empty. empty()54*da0073e9SAndroid Build Coastguard Worker bool empty() const { 55*da0073e9SAndroid Build Coastguard Worker return arr.empty(); 56*da0073e9SAndroid Build Coastguard Worker } 57*da0073e9SAndroid Build Coastguard Worker data()58*da0073e9SAndroid Build Coastguard Worker const T* data() const { 59*da0073e9SAndroid Build Coastguard Worker return arr.data(); 60*da0073e9SAndroid Build Coastguard Worker } 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker /// size - Get size a dimension size(size_t dim)63*da0073e9SAndroid Build Coastguard Worker size_t size(size_t dim) const { 64*da0073e9SAndroid Build Coastguard Worker if (dim == 0) { 65*da0073e9SAndroid Build Coastguard Worker return arr.size() / stride0; 66*da0073e9SAndroid Build Coastguard Worker } else if (dim == 1) { 67*da0073e9SAndroid Build Coastguard Worker return stride0; 68*da0073e9SAndroid Build Coastguard Worker } else { 69*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK( 70*da0073e9SAndroid Build Coastguard Worker 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1"); 71*da0073e9SAndroid Build Coastguard Worker } 72*da0073e9SAndroid Build Coastguard Worker } 73*da0073e9SAndroid Build Coastguard Worker numel()74*da0073e9SAndroid Build Coastguard Worker size_t numel() const { 75*da0073e9SAndroid Build Coastguard Worker return arr.size(); 76*da0073e9SAndroid Build Coastguard Worker } 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker /// equals - Check for element-wise equality. equals(MatrixRef RHS)79*da0073e9SAndroid Build Coastguard Worker bool equals(MatrixRef RHS) const { 80*da0073e9SAndroid Build Coastguard Worker return stride0 == RHS.stride0 && arr.equals(RHS.arr); 81*da0073e9SAndroid Build Coastguard Worker } 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker /// @} 84*da0073e9SAndroid Build Coastguard Worker /// @name Operator Overloads 85*da0073e9SAndroid Build Coastguard Worker /// @{ 86*da0073e9SAndroid Build Coastguard Worker ArrayRef<T> operator[](size_t Index) const { 87*da0073e9SAndroid Build Coastguard Worker return arr.slice(Index * stride0, stride0); 88*da0073e9SAndroid Build Coastguard Worker } 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker /// Disallow accidental assignment from a temporary. 91*da0073e9SAndroid Build Coastguard Worker /// 92*da0073e9SAndroid Build Coastguard Worker /// The declaration here is extra complicated so that "arrayRef = {}" 93*da0073e9SAndroid Build Coastguard Worker /// continues to select the move assignment operator. 94*da0073e9SAndroid Build Coastguard Worker template <typename U> 95*da0073e9SAndroid Build Coastguard Worker std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=( 96*da0073e9SAndroid Build Coastguard Worker U&& Temporary) = delete; 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker /// Disallow accidental assignment from a temporary. 99*da0073e9SAndroid Build Coastguard Worker /// 100*da0073e9SAndroid Build Coastguard Worker /// The declaration here is extra complicated so that "arrayRef = {}" 101*da0073e9SAndroid Build Coastguard Worker /// continues to select the move assignment operator. 102*da0073e9SAndroid Build Coastguard Worker template <typename U> 103*da0073e9SAndroid Build Coastguard Worker std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=( 104*da0073e9SAndroid Build Coastguard Worker std::initializer_list<U>) = delete; 105*da0073e9SAndroid Build Coastguard Worker }; 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker } // end namespace at 108