xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorGeometry.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/TensorBase.h>
4 #include <c10/core/WrapDimMinimal.h>
5 
6 namespace at {
7 
8 // Return if the tensor geometry represented by `sizes` and `strides` is
9 // contiguous Although we cache is_contiguous in tensor now, this is till useful
10 // because it allows checking if a particular geometry is contiguous without
11 // explicitly constructing a tensor, e.g., when you want to choose a kernel
12 // strategy based on whether a subgeometry is contiguous.
13 TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
14 
15 struct TORCH_API TensorGeometry {
16   TensorGeometry() = default;
17 
TensorGeometryTensorGeometry18   explicit TensorGeometry(c10::SymIntArrayRef sizes)
19       : sizes_(sizes.vec()),
20         strides_(sizes.size()),
21         has_symbolic_sizes_strides_(
22             !c10::asIntArrayRefSlowOpt(sizes).has_value()) {
23     int64_t dim = static_cast<int64_t>(sizes.size());
24     c10::SymInt expected_stride = 1;
25     for (int64_t i = dim - 1; i >= 0; i--) {
26       strides_[i] = expected_stride;
27       expected_stride *= sizes_[i];
28     }
29     numel_ = expected_stride;
30   }
31 
TensorGeometryTensorGeometry32   explicit TensorGeometry(const TensorBase& t)
33       : sizes_(t.sym_sizes().vec()),
34         strides_(t.sym_strides().vec()),
35         storage_offset_(t.sym_storage_offset()),
36         numel_(t.sym_numel()),
37         has_symbolic_sizes_strides_(
38             t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
39 
40   // true if the tensor is contiguous
41   bool is_contiguous() const;
42 
dimTensorGeometry43   int64_t dim() const {
44     return static_cast<int64_t>(sizes_.size());
45   }
46 
sizeTensorGeometry47   int64_t size(int64_t dim) const {
48     TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
49     dim = c10::maybe_wrap_dim(dim, this->dim());
50     return sizes_.at(static_cast<size_t>(dim)).as_int_unchecked();
51   }
sizesTensorGeometry52   c10::IntArrayRef sizes() const {
53     TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
54     return c10::asIntArrayRefUnchecked(sizes_);
55   }
strideTensorGeometry56   int64_t stride(int64_t dim) const {
57     TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
58     dim = c10::maybe_wrap_dim(dim, this->dim());
59     return strides_.at(static_cast<size_t>(dim)).as_int_unchecked();
60   }
stridesTensorGeometry61   c10::IntArrayRef strides() const {
62     TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
63     return c10::asIntArrayRefUnchecked(strides_);
64   }
storage_offsetTensorGeometry65   int64_t storage_offset() const {
66     TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
67     return storage_offset_.as_int_unchecked();
68   }
numelTensorGeometry69   int64_t numel() const {
70     TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
71     return numel_.as_int_unchecked();
72   }
73 
sym_sizeTensorGeometry74   c10::SymInt sym_size(int64_t dim) const {
75     dim = c10::maybe_wrap_dim(dim, this->dim());
76     return sizes_.at(static_cast<size_t>(dim));
77   }
sym_sizesTensorGeometry78   c10::SymIntArrayRef sym_sizes() const {
79     return sizes_;
80   }
sym_strideTensorGeometry81   c10::SymInt sym_stride(int64_t dim) const {
82     dim = c10::maybe_wrap_dim(dim, this->dim());
83     return strides_.at(static_cast<size_t>(dim));
84   }
sym_stridesTensorGeometry85   c10::SymIntArrayRef sym_strides() const {
86     return strides_;
87   }
sym_storage_offsetTensorGeometry88   c10::SymInt sym_storage_offset() const {
89     return storage_offset_;
90   }
sym_numelTensorGeometry91   c10::SymInt sym_numel() const {
92     return numel_;
93   }
94 
transposeTensorGeometry95   TensorGeometry transpose(int64_t dim0, int64_t dim1) {
96     TensorGeometry r = *this; // copy
97     TORCH_CHECK(
98         dim0 < dim(),
99         "transpose: dim0=",
100         dim0,
101         " out of range (dim=",
102         dim(),
103         ")")
104     TORCH_CHECK(
105         dim1 < dim(),
106         "transpose: dim1=",
107         dim1,
108         " out of range (dim=",
109         dim(),
110         ")")
111     std::swap(r.sizes_[dim0], r.sizes_[dim1]);
112     std::swap(r.strides_[dim0], r.strides_[dim1]);
113     return r;
114   }
115 
mutable_sizesTensorGeometry116   std::vector<c10::SymInt>& mutable_sizes() {
117     return sizes_;
118   }
mutable_stridesTensorGeometry119   std::vector<c10::SymInt>& mutable_strides() {
120     return strides_;
121   }
mutable_storage_offsetTensorGeometry122   c10::SymInt& mutable_storage_offset() {
123     return storage_offset_;
124   }
recomputeTensorGeometry125   void recompute() {
126     // recalculate numel after a change
127     c10::SymInt numel = 1;
128     for (const auto& i : sizes_) {
129       numel = numel * i;
130     }
131     numel_ = std::move(numel);
132     has_symbolic_sizes_strides_ =
133         !c10::asIntArrayRefSlowOpt(sizes_).has_value();
134   }
135 
136  private:
137   std::vector<c10::SymInt> sizes_;
138   std::vector<c10::SymInt> strides_;
139   c10::SymInt storage_offset_;
140   c10::SymInt numel_;
141   bool has_symbolic_sizes_strides_{false};
142 };
143 
144 } // namespace at
145