xref: /aosp_15_r20/external/pytorch/aten/src/ATen/MatrixRef.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/Utils.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ArrayRef.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker namespace at {
6*da0073e9SAndroid Build Coastguard Worker /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
7*da0073e9SAndroid Build Coastguard Worker /// we can easily view it as a multidimensional array.
8*da0073e9SAndroid Build Coastguard Worker ///
9*da0073e9SAndroid Build Coastguard Worker /// Like ArrayRef, this class does not own the underlying data, it is expected
10*da0073e9SAndroid Build Coastguard Worker /// to be used in situations where the data resides in some other buffer.
11*da0073e9SAndroid Build Coastguard Worker ///
12*da0073e9SAndroid Build Coastguard Worker /// This is intended to be trivially copyable, so it should be passed by
13*da0073e9SAndroid Build Coastguard Worker /// value.
14*da0073e9SAndroid Build Coastguard Worker ///
15*da0073e9SAndroid Build Coastguard Worker /// For now, 2D only (so the copies are actually cheap, without having
16*da0073e9SAndroid Build Coastguard Worker /// to write a SmallVector class) and contiguous only (so we can
17*da0073e9SAndroid Build Coastguard Worker /// return non-strided ArrayRef on index).
18*da0073e9SAndroid Build Coastguard Worker ///
19*da0073e9SAndroid Build Coastguard Worker /// P.S. dimension 0 indexes rows, dimension 1 indexes columns
20*da0073e9SAndroid Build Coastguard Worker template <typename T>
21*da0073e9SAndroid Build Coastguard Worker class MatrixRef {
22*da0073e9SAndroid Build Coastguard Worker  public:
23*da0073e9SAndroid Build Coastguard Worker   typedef size_t size_type;
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker  private:
26*da0073e9SAndroid Build Coastguard Worker   /// Underlying ArrayRef
27*da0073e9SAndroid Build Coastguard Worker   ArrayRef<T> arr;
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker   /// Stride of dim 0 (outer dimension)
30*da0073e9SAndroid Build Coastguard Worker   size_type stride0;
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker   // Stride of dim 1 is assumed to be 1
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker  public:
35*da0073e9SAndroid Build Coastguard Worker   /// Construct an empty Matrixref.
MatrixRef()36*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker   /// Construct an MatrixRef from an ArrayRef and outer stride.
MatrixRef(ArrayRef<T> arr,size_type stride0)39*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
40*da0073e9SAndroid Build Coastguard Worker       : arr(arr), stride0(stride0) {
41*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
42*da0073e9SAndroid Build Coastguard Worker         arr.size() % stride0 == 0,
43*da0073e9SAndroid Build Coastguard Worker         "MatrixRef: ArrayRef size ",
44*da0073e9SAndroid Build Coastguard Worker         arr.size(),
45*da0073e9SAndroid Build Coastguard Worker         " not divisible by stride ",
46*da0073e9SAndroid Build Coastguard Worker         stride0)
47*da0073e9SAndroid Build Coastguard Worker   }
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   /// @}
50*da0073e9SAndroid Build Coastguard Worker   /// @name Simple Operations
51*da0073e9SAndroid Build Coastguard Worker   /// @{
52*da0073e9SAndroid Build Coastguard Worker 
53*da0073e9SAndroid Build Coastguard Worker   /// empty - Check if the matrix is empty.
empty()54*da0073e9SAndroid Build Coastguard Worker   bool empty() const {
55*da0073e9SAndroid Build Coastguard Worker     return arr.empty();
56*da0073e9SAndroid Build Coastguard Worker   }
57*da0073e9SAndroid Build Coastguard Worker 
data()58*da0073e9SAndroid Build Coastguard Worker   const T* data() const {
59*da0073e9SAndroid Build Coastguard Worker     return arr.data();
60*da0073e9SAndroid Build Coastguard Worker   }
61*da0073e9SAndroid Build Coastguard Worker 
62*da0073e9SAndroid Build Coastguard Worker   /// size - Get size a dimension
size(size_t dim)63*da0073e9SAndroid Build Coastguard Worker   size_t size(size_t dim) const {
64*da0073e9SAndroid Build Coastguard Worker     if (dim == 0) {
65*da0073e9SAndroid Build Coastguard Worker       return arr.size() / stride0;
66*da0073e9SAndroid Build Coastguard Worker     } else if (dim == 1) {
67*da0073e9SAndroid Build Coastguard Worker       return stride0;
68*da0073e9SAndroid Build Coastguard Worker     } else {
69*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
70*da0073e9SAndroid Build Coastguard Worker           0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
71*da0073e9SAndroid Build Coastguard Worker     }
72*da0073e9SAndroid Build Coastguard Worker   }
73*da0073e9SAndroid Build Coastguard Worker 
numel()74*da0073e9SAndroid Build Coastguard Worker   size_t numel() const {
75*da0073e9SAndroid Build Coastguard Worker     return arr.size();
76*da0073e9SAndroid Build Coastguard Worker   }
77*da0073e9SAndroid Build Coastguard Worker 
78*da0073e9SAndroid Build Coastguard Worker   /// equals - Check for element-wise equality.
equals(MatrixRef RHS)79*da0073e9SAndroid Build Coastguard Worker   bool equals(MatrixRef RHS) const {
80*da0073e9SAndroid Build Coastguard Worker     return stride0 == RHS.stride0 && arr.equals(RHS.arr);
81*da0073e9SAndroid Build Coastguard Worker   }
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker   /// @}
84*da0073e9SAndroid Build Coastguard Worker   /// @name Operator Overloads
85*da0073e9SAndroid Build Coastguard Worker   /// @{
86*da0073e9SAndroid Build Coastguard Worker   ArrayRef<T> operator[](size_t Index) const {
87*da0073e9SAndroid Build Coastguard Worker     return arr.slice(Index * stride0, stride0);
88*da0073e9SAndroid Build Coastguard Worker   }
89*da0073e9SAndroid Build Coastguard Worker 
90*da0073e9SAndroid Build Coastguard Worker   /// Disallow accidental assignment from a temporary.
91*da0073e9SAndroid Build Coastguard Worker   ///
92*da0073e9SAndroid Build Coastguard Worker   /// The declaration here is extra complicated so that "arrayRef = {}"
93*da0073e9SAndroid Build Coastguard Worker   /// continues to select the move assignment operator.
94*da0073e9SAndroid Build Coastguard Worker   template <typename U>
95*da0073e9SAndroid Build Coastguard Worker   std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
96*da0073e9SAndroid Build Coastguard Worker       U&& Temporary) = delete;
97*da0073e9SAndroid Build Coastguard Worker 
98*da0073e9SAndroid Build Coastguard Worker   /// Disallow accidental assignment from a temporary.
99*da0073e9SAndroid Build Coastguard Worker   ///
100*da0073e9SAndroid Build Coastguard Worker   /// The declaration here is extra complicated so that "arrayRef = {}"
101*da0073e9SAndroid Build Coastguard Worker   /// continues to select the move assignment operator.
102*da0073e9SAndroid Build Coastguard Worker   template <typename U>
103*da0073e9SAndroid Build Coastguard Worker   std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
104*da0073e9SAndroid Build Coastguard Worker       std::initializer_list<U>) = delete;
105*da0073e9SAndroid Build Coastguard Worker };
106*da0073e9SAndroid Build Coastguard Worker 
107*da0073e9SAndroid Build Coastguard Worker } // end namespace at
108