1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <ATen/core/jit_type.h> 5 #include <c10/util/Exception.h> 6 #include <c10/util/hash.h> 7 #include <torch/csrc/Export.h> 8 9 #include <algorithm> 10 #include <ostream> 11 #include <vector> 12 13 namespace torch::jit::fuser { 14 15 // type information needed by the compiler for input/outputs 16 // contiguity[i] is true if the dim i is contiguous with dim i + 1. 17 // contiguity.back() == true means strides.back() == 1. 18 struct TORCH_API TensorDesc { 19 at::ScalarType scalar_type; 20 std::vector<bool> contiguity; 21 TensorDescTensorDesc22 TensorDesc(const at::ScalarType& type, const std::vector<bool>& contiguity) 23 : scalar_type{type}, contiguity{contiguity} { 24 if (contiguity.empty()) { 25 nDim_ = 0; 26 } else { 27 nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + 28 (lastIsContiguous() ? 1 : 0); 29 } 30 } 31 32 // Delegating constructors TensorDescTensorDesc33 TensorDesc( 34 const at::ScalarType& type, 35 const at::IntArrayRef& sizes, 36 const at::IntArrayRef& strides) 37 : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} 38 TensorDescTensorDesc39 TensorDesc(const at::Tensor& t) 40 : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {} 41 TensorDescTensorDesc42 TensorDesc(const c10::TensorTypePtr& type) 43 : TensorDesc( 44 type->scalarType().value(), 45 type->sizes().concrete_sizes().value(), 46 type->strides().concrete_sizes().value()) {} 47 48 // number of dimensions after contiguity compression nDimTensorDesc49 size_t nDim() const { 50 return nDim_; 51 } 52 53 // True iff innermost stride is 1 lastIsContiguousTensorDesc54 bool lastIsContiguous() const { 55 return (contiguity.empty() || contiguity.back()); 56 } 57 findContiguousTensorDesc58 static std::vector<bool> findContiguous( 59 const at::IntArrayRef& sizes, 60 const at::IntArrayRef& strides) { 61 AT_ASSERT(sizes.size() == strides.size()); 62 std::vector<bool> cont(sizes.size()); 63 for (size_t i = 0; i < sizes.size(); ++i) { 64 const auto expected_stride = 65 (i + 1 < sizes.size()) ? sizes[i + 1] * strides[i + 1] : 1; 66 cont[i] = (strides[i] == expected_stride); 67 } 68 return cont; 69 } 70 71 bool operator==(const TensorDesc& desc) const { 72 return scalar_type == desc.scalar_type && contiguity == desc.contiguity; 73 } 74 75 bool operator!=(const TensorDesc& desc) const { 76 return !(*this == desc); 77 } 78 hashTensorDesc79 static size_t hash(const TensorDesc& spec) { 80 return c10::get_hash( 81 spec.scalar_type, 82 spec.nDim_, 83 std::hash<std::vector<bool>>{}(spec.contiguity)); 84 } 85 86 private: 87 size_t nDim_; 88 }; 89 90 inline std::ostream& operator<<(std::ostream& out, const TensorDesc& d) { 91 out << d.scalar_type << "["; 92 for (const auto b : d.contiguity) 93 out << b << ";"; 94 out << "]"; 95 return out; 96 } 97 98 } // namespace torch::jit::fuser 99