1 #pragma once 2 3 #include <ATen/core/TensorBase.h> 4 #include <c10/core/WrapDimMinimal.h> 5 6 namespace at { 7 8 // Return if the tensor geometry represented by `sizes` and `strides` is 9 // contiguous Although we cache is_contiguous in tensor now, this is till useful 10 // because it allows checking if a particular geometry is contiguous without 11 // explicitly constructing a tensor, e.g., when you want to choose a kernel 12 // strategy based on whether a subgeometry is contiguous. 13 TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides); 14 15 struct TORCH_API TensorGeometry { 16 TensorGeometry() = default; 17 TensorGeometryTensorGeometry18 explicit TensorGeometry(c10::SymIntArrayRef sizes) 19 : sizes_(sizes.vec()), 20 strides_(sizes.size()), 21 has_symbolic_sizes_strides_( 22 !c10::asIntArrayRefSlowOpt(sizes).has_value()) { 23 int64_t dim = static_cast<int64_t>(sizes.size()); 24 c10::SymInt expected_stride = 1; 25 for (int64_t i = dim - 1; i >= 0; i--) { 26 strides_[i] = expected_stride; 27 expected_stride *= sizes_[i]; 28 } 29 numel_ = expected_stride; 30 } 31 TensorGeometryTensorGeometry32 explicit TensorGeometry(const TensorBase& t) 33 : sizes_(t.sym_sizes().vec()), 34 strides_(t.sym_strides().vec()), 35 storage_offset_(t.sym_storage_offset()), 36 numel_(t.sym_numel()), 37 has_symbolic_sizes_strides_( 38 t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} 39 40 // true if the tensor is contiguous 41 bool is_contiguous() const; 42 dimTensorGeometry43 int64_t dim() const { 44 return static_cast<int64_t>(sizes_.size()); 45 } 46 sizeTensorGeometry47 int64_t size(int64_t dim) const { 48 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); 49 dim = c10::maybe_wrap_dim(dim, this->dim()); 50 return sizes_.at(static_cast<size_t>(dim)).as_int_unchecked(); 51 } sizesTensorGeometry52 c10::IntArrayRef sizes() const { 53 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); 54 return c10::asIntArrayRefUnchecked(sizes_); 55 } strideTensorGeometry56 int64_t stride(int64_t dim) const { 57 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); 58 dim = c10::maybe_wrap_dim(dim, this->dim()); 59 return strides_.at(static_cast<size_t>(dim)).as_int_unchecked(); 60 } stridesTensorGeometry61 c10::IntArrayRef strides() const { 62 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); 63 return c10::asIntArrayRefUnchecked(strides_); 64 } storage_offsetTensorGeometry65 int64_t storage_offset() const { 66 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); 67 return storage_offset_.as_int_unchecked(); 68 } numelTensorGeometry69 int64_t numel() const { 70 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); 71 return numel_.as_int_unchecked(); 72 } 73 sym_sizeTensorGeometry74 c10::SymInt sym_size(int64_t dim) const { 75 dim = c10::maybe_wrap_dim(dim, this->dim()); 76 return sizes_.at(static_cast<size_t>(dim)); 77 } sym_sizesTensorGeometry78 c10::SymIntArrayRef sym_sizes() const { 79 return sizes_; 80 } sym_strideTensorGeometry81 c10::SymInt sym_stride(int64_t dim) const { 82 dim = c10::maybe_wrap_dim(dim, this->dim()); 83 return strides_.at(static_cast<size_t>(dim)); 84 } sym_stridesTensorGeometry85 c10::SymIntArrayRef sym_strides() const { 86 return strides_; 87 } sym_storage_offsetTensorGeometry88 c10::SymInt sym_storage_offset() const { 89 return storage_offset_; 90 } sym_numelTensorGeometry91 c10::SymInt sym_numel() const { 92 return numel_; 93 } 94 transposeTensorGeometry95 TensorGeometry transpose(int64_t dim0, int64_t dim1) { 96 TensorGeometry r = *this; // copy 97 TORCH_CHECK( 98 dim0 < dim(), 99 "transpose: dim0=", 100 dim0, 101 " out of range (dim=", 102 dim(), 103 ")") 104 TORCH_CHECK( 105 dim1 < dim(), 106 "transpose: dim1=", 107 dim1, 108 " out of range (dim=", 109 dim(), 110 ")") 111 std::swap(r.sizes_[dim0], r.sizes_[dim1]); 112 std::swap(r.strides_[dim0], r.strides_[dim1]); 113 return r; 114 } 115 mutable_sizesTensorGeometry116 std::vector<c10::SymInt>& mutable_sizes() { 117 return sizes_; 118 } mutable_stridesTensorGeometry119 std::vector<c10::SymInt>& mutable_strides() { 120 return strides_; 121 } mutable_storage_offsetTensorGeometry122 c10::SymInt& mutable_storage_offset() { 123 return storage_offset_; 124 } recomputeTensorGeometry125 void recompute() { 126 // recalculate numel after a change 127 c10::SymInt numel = 1; 128 for (const auto& i : sizes_) { 129 numel = numel * i; 130 } 131 numel_ = std::move(numel); 132 has_symbolic_sizes_strides_ = 133 !c10::asIntArrayRefSlowOpt(sizes_).has_value(); 134 } 135 136 private: 137 std::vector<c10::SymInt> sizes_; 138 std::vector<c10::SymInt> strides_; 139 c10::SymInt storage_offset_; 140 c10::SymInt numel_; 141 bool has_symbolic_sizes_strides_{false}; 142 }; 143 144 } // namespace at 145