xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir_verifier.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/ir_verifier.h>
2 
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5 #include <torch/csrc/jit/tensorexpr/reduction.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7 
8 namespace torch::jit::tensorexpr {
9 
10 namespace detail {
11 template <typename T>
12 void deducer(BinaryOpNode<T>);
13 
14 bool deducer(...);
15 } // namespace detail
16 
17 template <
18     typename D,
19     std::enable_if_t<
20         std::is_same_v<decltype(detail::deducer(std::declval<D>())), void>>* =
21         nullptr>
verifyBitwiseOp(NodePtr<D> v,IRVerifier * verifier)22 void verifyBitwiseOp(NodePtr<D> v, IRVerifier* verifier) {
23   if (!v->lhs()->dtype().is_integral()) {
24     throw unsupported_dtype();
25   }
26   if (v->lhs()->dtype() != v->rhs()->dtype()) {
27     throw malformed_ir("lhs/rhs dtype mismatch");
28   }
29 }
30 
visit(const AndPtr & v)31 void IRVerifier::visit(const AndPtr& v) {
32   verifyBitwiseOp(v, this);
33   IRVisitor::visit(v);
34 }
35 
visit(const OrPtr & v)36 void IRVerifier::visit(const OrPtr& v) {
37   verifyBitwiseOp(v, this);
38   IRVisitor::visit(v);
39 }
40 
visit(const XorPtr & v)41 void IRVerifier::visit(const XorPtr& v) {
42   verifyBitwiseOp(v, this);
43   IRVisitor::visit(v);
44 }
45 
visit(const LshiftPtr & v)46 void IRVerifier::visit(const LshiftPtr& v) {
47   verifyBitwiseOp(v, this);
48   IRVisitor::visit(v);
49 }
50 
visit(const RshiftPtr & v)51 void IRVerifier::visit(const RshiftPtr& v) {
52   verifyBitwiseOp(v, this);
53   IRVisitor::visit(v);
54 }
55 
visit(const ModPtr & v)56 void IRVerifier::visit(const ModPtr& v) {
57   if (!v->dtype().is_integral() && !v->dtype().is_floating_point()) {
58     throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype()));
59   }
60   IRVisitor::visit(v);
61 }
62 
visit(const CompareSelectPtr & v)63 void IRVerifier::visit(const CompareSelectPtr& v) {
64   if (v->ret_val1()->dtype() != v->ret_val2()->dtype()) {
65     throw malformed_ir("bad dtype in CompareSelect");
66   }
67   if (v->lhs()->dtype() != v->rhs()->dtype()) {
68     throw malformed_ir("bad dtype in CompareSelect");
69   }
70   IRVisitor::visit(v);
71 }
72 
visit(const RampPtr & v)73 void IRVerifier::visit(const RampPtr& v) {
74   if (v->stride()->dtype() != v->base()->dtype()) {
75     throw malformed_ir("Bad stride in Ramp");
76   }
77   IRVisitor::visit(v);
78 }
79 
visit(const LoadPtr & v)80 void IRVerifier::visit(const LoadPtr& v) {
81   auto indices = v->indices();
82   if (!indices.empty() && v->buf()->base_handle()->dtype() != kHandle) {
83     throw malformed_ir(
84         "Load base handle dtype must be Handle", v->buf()->base_handle());
85   }
86 
87   Dtype index_dtype = !indices.empty() ? indices.at(0)->dtype() : kInt;
88   if (indices.size() > 1) {
89     for (size_t i = 1; i < indices.size(); ++i) {
90       if (indices.at(i)->dtype() != index_dtype) {
91         throw malformed_ir("dtype mismatch in Load indices");
92       }
93     }
94   }
95   if (indices.size() > 1 && index_dtype.lanes() > 1) {
96     throw malformed_ir("Multilane is only allowed in a flattened index");
97   }
98   if (index_dtype.scalar_type() != ScalarType::Int &&
99       index_dtype.scalar_type() != ScalarType::Long) {
100     throw malformed_ir("Index scalar dtype is not Int or Long!");
101   }
102 
103   IRVisitor::visit(v);
104 }
105 
visit(const IfThenElsePtr & v)106 void IRVerifier::visit(const IfThenElsePtr& v) {
107   if (!v->condition()->dtype().is_integral()) {
108     throw unsupported_dtype();
109   }
110   if (v->condition()->dtype().lanes() != 1) {
111     throw unsupported_dtype();
112   }
113   if (v->true_value()->dtype() != v->false_value()->dtype()) {
114     throw malformed_ir("Bad dtype in IfThenElse");
115   }
116   IRVisitor::visit(v);
117 }
118 
visit(const IntrinsicsPtr & v)119 void IRVerifier::visit(const IntrinsicsPtr& v) {
120   if (v->op_type() == kIsNan) {
121     if (v->dtype().scalar_type() != c10::kInt) {
122       throw malformed_ir("bad dtype in intrinsic arg");
123     }
124     IRVisitor::visit(v);
125     return;
126   }
127   // TODO: add a check for OpArgCount and op_type
128   for (auto const& param : v->params()) {
129     if (param->dtype() != v->dtype()) {
130       throw malformed_ir("bad dtype in intrinsic arg");
131     }
132   }
133   IRVisitor::visit(v);
134 }
135 
visit(const StorePtr & v)136 void IRVerifier::visit(const StorePtr& v) {
137   auto indices = v->indices();
138   if (!indices.empty() && v->buf()->base_handle()->dtype() != kHandle) {
139     throw malformed_ir(
140         "Store base handle dtype must be Handle", v->buf()->base_handle());
141   }
142 
143   Dtype index_dtype = !indices.empty() ? indices.at(0)->dtype() : kInt;
144   if (indices.size() > 1) {
145     for (size_t i = 1; i < indices.size(); ++i) {
146       if (indices.at(i)->dtype() != index_dtype) {
147         throw malformed_ir("dtype mismatch in Store indices");
148       }
149     }
150   }
151   if (indices.size() > 1 && index_dtype.lanes() > 1) {
152     throw malformed_ir("Multilane is only allowed in a flattened index");
153   }
154   if (index_dtype.scalar_type() != ScalarType::Int &&
155       index_dtype.scalar_type() != ScalarType::Long) {
156     throw malformed_ir("Index scalar dtype is not Int or Long!");
157   }
158   if (v->buf()->dtype() != v->value()->dtype()) {
159     throw malformed_ir("buf and value dtype mismatch in Store");
160   }
161 
162   IRVisitor::visit(v);
163 }
164 
visit(const ForPtr & v)165 void IRVerifier::visit(const ForPtr& v) {
166   if (!v->var()) {
167     throw malformed_ir("nullptr Var in For loop");
168   } else if (!v->start()) {
169     throw malformed_ir("nullptr Start in For loop");
170   } else if (!v->stop()) {
171     throw malformed_ir("nullptr Stop in For loop");
172   } else if (!v->body()) {
173     throw malformed_ir("invalid Body in For loop");
174   }
175   IRVisitor::visit(v);
176 }
177 
visit(const BlockPtr & v)178 void IRVerifier::visit(const BlockPtr& v) {
179   for (const StmtPtr& s : v->stmts()) {
180     if (s->get_parent() != v) {
181       throw malformed_ir("Broken child-parent link inside a Block");
182     }
183   }
184   IRVisitor::visit(v);
185 }
186 
visit(const ExternalCallPtr & v)187 void IRVerifier::visit(const ExternalCallPtr& v) {
188   IRVisitor::visit(v);
189 }
190 
verify(const StmtPtr & s)191 void verify(const StmtPtr& s) {
192   IRVerifier verifier;
193   s->accept(&verifier);
194 }
195 
verify(const ExprPtr & e)196 void verify(const ExprPtr& e) {
197   IRVerifier verifier;
198   e->accept(&verifier);
199 }
200 
verify(const ExprHandle & e)201 void verify(const ExprHandle& e) {
202   verify(e.node());
203 }
204 
205 } // namespace torch::jit::tensorexpr
206