xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/tensor_desc.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/jit_type.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/hash.h>
7 #include <torch/csrc/Export.h>
8 
9 #include <algorithm>
10 #include <ostream>
11 #include <vector>
12 
13 namespace torch::jit::fuser {
14 
15 // type information needed by the compiler for input/outputs
16 // contiguity[i] is true if the dim i is contiguous with dim i + 1.
17 // contiguity.back() == true means strides.back() == 1.
18 struct TORCH_API TensorDesc {
19   at::ScalarType scalar_type;
20   std::vector<bool> contiguity;
21 
TensorDescTensorDesc22   TensorDesc(const at::ScalarType& type, const std::vector<bool>& contiguity)
23       : scalar_type{type}, contiguity{contiguity} {
24     if (contiguity.empty()) {
25       nDim_ = 0;
26     } else {
27       nDim_ = std::count(contiguity.begin(), contiguity.end(), false) +
28           (lastIsContiguous() ? 1 : 0);
29     }
30   }
31 
32   // Delegating constructors
TensorDescTensorDesc33   TensorDesc(
34       const at::ScalarType& type,
35       const at::IntArrayRef& sizes,
36       const at::IntArrayRef& strides)
37       : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
38 
TensorDescTensorDesc39   TensorDesc(const at::Tensor& t)
40       : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {}
41 
TensorDescTensorDesc42   TensorDesc(const c10::TensorTypePtr& type)
43       : TensorDesc(
44             type->scalarType().value(),
45             type->sizes().concrete_sizes().value(),
46             type->strides().concrete_sizes().value()) {}
47 
48   // number of dimensions after contiguity compression
nDimTensorDesc49   size_t nDim() const {
50     return nDim_;
51   }
52 
53   // True iff innermost stride is 1
lastIsContiguousTensorDesc54   bool lastIsContiguous() const {
55     return (contiguity.empty() || contiguity.back());
56   }
57 
findContiguousTensorDesc58   static std::vector<bool> findContiguous(
59       const at::IntArrayRef& sizes,
60       const at::IntArrayRef& strides) {
61     AT_ASSERT(sizes.size() == strides.size());
62     std::vector<bool> cont(sizes.size());
63     for (size_t i = 0; i < sizes.size(); ++i) {
64       const auto expected_stride =
65           (i + 1 < sizes.size()) ? sizes[i + 1] * strides[i + 1] : 1;
66       cont[i] = (strides[i] == expected_stride);
67     }
68     return cont;
69   }
70 
71   bool operator==(const TensorDesc& desc) const {
72     return scalar_type == desc.scalar_type && contiguity == desc.contiguity;
73   }
74 
75   bool operator!=(const TensorDesc& desc) const {
76     return !(*this == desc);
77   }
78 
hashTensorDesc79   static size_t hash(const TensorDesc& spec) {
80     return c10::get_hash(
81         spec.scalar_type,
82         spec.nDim_,
83         std::hash<std::vector<bool>>{}(spec.contiguity));
84   }
85 
86  private:
87   size_t nDim_;
88 };
89 
90 inline std::ostream& operator<<(std::ostream& out, const TensorDesc& d) {
91   out << d.scalar_type << "[";
92   for (const auto b : d.contiguity)
93     out << b << ";";
94   out << "]";
95   return out;
96 }
97 
98 } // namespace torch::jit::fuser
99