1 #include <torch/csrc/jit/tensorexpr/ir_verifier.h>
2
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5 #include <torch/csrc/jit/tensorexpr/reduction.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7
8 namespace torch::jit::tensorexpr {
9
10 namespace detail {
11 template <typename T>
12 void deducer(BinaryOpNode<T>);
13
14 bool deducer(...);
15 } // namespace detail
16
17 template <
18 typename D,
19 std::enable_if_t<
20 std::is_same_v<decltype(detail::deducer(std::declval<D>())), void>>* =
21 nullptr>
verifyBitwiseOp(NodePtr<D> v,IRVerifier * verifier)22 void verifyBitwiseOp(NodePtr<D> v, IRVerifier* verifier) {
23 if (!v->lhs()->dtype().is_integral()) {
24 throw unsupported_dtype();
25 }
26 if (v->lhs()->dtype() != v->rhs()->dtype()) {
27 throw malformed_ir("lhs/rhs dtype mismatch");
28 }
29 }
30
visit(const AndPtr & v)31 void IRVerifier::visit(const AndPtr& v) {
32 verifyBitwiseOp(v, this);
33 IRVisitor::visit(v);
34 }
35
visit(const OrPtr & v)36 void IRVerifier::visit(const OrPtr& v) {
37 verifyBitwiseOp(v, this);
38 IRVisitor::visit(v);
39 }
40
visit(const XorPtr & v)41 void IRVerifier::visit(const XorPtr& v) {
42 verifyBitwiseOp(v, this);
43 IRVisitor::visit(v);
44 }
45
visit(const LshiftPtr & v)46 void IRVerifier::visit(const LshiftPtr& v) {
47 verifyBitwiseOp(v, this);
48 IRVisitor::visit(v);
49 }
50
visit(const RshiftPtr & v)51 void IRVerifier::visit(const RshiftPtr& v) {
52 verifyBitwiseOp(v, this);
53 IRVisitor::visit(v);
54 }
55
visit(const ModPtr & v)56 void IRVerifier::visit(const ModPtr& v) {
57 if (!v->dtype().is_integral() && !v->dtype().is_floating_point()) {
58 throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype()));
59 }
60 IRVisitor::visit(v);
61 }
62
visit(const CompareSelectPtr & v)63 void IRVerifier::visit(const CompareSelectPtr& v) {
64 if (v->ret_val1()->dtype() != v->ret_val2()->dtype()) {
65 throw malformed_ir("bad dtype in CompareSelect");
66 }
67 if (v->lhs()->dtype() != v->rhs()->dtype()) {
68 throw malformed_ir("bad dtype in CompareSelect");
69 }
70 IRVisitor::visit(v);
71 }
72
visit(const RampPtr & v)73 void IRVerifier::visit(const RampPtr& v) {
74 if (v->stride()->dtype() != v->base()->dtype()) {
75 throw malformed_ir("Bad stride in Ramp");
76 }
77 IRVisitor::visit(v);
78 }
79
visit(const LoadPtr & v)80 void IRVerifier::visit(const LoadPtr& v) {
81 auto indices = v->indices();
82 if (!indices.empty() && v->buf()->base_handle()->dtype() != kHandle) {
83 throw malformed_ir(
84 "Load base handle dtype must be Handle", v->buf()->base_handle());
85 }
86
87 Dtype index_dtype = !indices.empty() ? indices.at(0)->dtype() : kInt;
88 if (indices.size() > 1) {
89 for (size_t i = 1; i < indices.size(); ++i) {
90 if (indices.at(i)->dtype() != index_dtype) {
91 throw malformed_ir("dtype mismatch in Load indices");
92 }
93 }
94 }
95 if (indices.size() > 1 && index_dtype.lanes() > 1) {
96 throw malformed_ir("Multilane is only allowed in a flattened index");
97 }
98 if (index_dtype.scalar_type() != ScalarType::Int &&
99 index_dtype.scalar_type() != ScalarType::Long) {
100 throw malformed_ir("Index scalar dtype is not Int or Long!");
101 }
102
103 IRVisitor::visit(v);
104 }
105
visit(const IfThenElsePtr & v)106 void IRVerifier::visit(const IfThenElsePtr& v) {
107 if (!v->condition()->dtype().is_integral()) {
108 throw unsupported_dtype();
109 }
110 if (v->condition()->dtype().lanes() != 1) {
111 throw unsupported_dtype();
112 }
113 if (v->true_value()->dtype() != v->false_value()->dtype()) {
114 throw malformed_ir("Bad dtype in IfThenElse");
115 }
116 IRVisitor::visit(v);
117 }
118
visit(const IntrinsicsPtr & v)119 void IRVerifier::visit(const IntrinsicsPtr& v) {
120 if (v->op_type() == kIsNan) {
121 if (v->dtype().scalar_type() != c10::kInt) {
122 throw malformed_ir("bad dtype in intrinsic arg");
123 }
124 IRVisitor::visit(v);
125 return;
126 }
127 // TODO: add a check for OpArgCount and op_type
128 for (auto const& param : v->params()) {
129 if (param->dtype() != v->dtype()) {
130 throw malformed_ir("bad dtype in intrinsic arg");
131 }
132 }
133 IRVisitor::visit(v);
134 }
135
visit(const StorePtr & v)136 void IRVerifier::visit(const StorePtr& v) {
137 auto indices = v->indices();
138 if (!indices.empty() && v->buf()->base_handle()->dtype() != kHandle) {
139 throw malformed_ir(
140 "Store base handle dtype must be Handle", v->buf()->base_handle());
141 }
142
143 Dtype index_dtype = !indices.empty() ? indices.at(0)->dtype() : kInt;
144 if (indices.size() > 1) {
145 for (size_t i = 1; i < indices.size(); ++i) {
146 if (indices.at(i)->dtype() != index_dtype) {
147 throw malformed_ir("dtype mismatch in Store indices");
148 }
149 }
150 }
151 if (indices.size() > 1 && index_dtype.lanes() > 1) {
152 throw malformed_ir("Multilane is only allowed in a flattened index");
153 }
154 if (index_dtype.scalar_type() != ScalarType::Int &&
155 index_dtype.scalar_type() != ScalarType::Long) {
156 throw malformed_ir("Index scalar dtype is not Int or Long!");
157 }
158 if (v->buf()->dtype() != v->value()->dtype()) {
159 throw malformed_ir("buf and value dtype mismatch in Store");
160 }
161
162 IRVisitor::visit(v);
163 }
164
visit(const ForPtr & v)165 void IRVerifier::visit(const ForPtr& v) {
166 if (!v->var()) {
167 throw malformed_ir("nullptr Var in For loop");
168 } else if (!v->start()) {
169 throw malformed_ir("nullptr Start in For loop");
170 } else if (!v->stop()) {
171 throw malformed_ir("nullptr Stop in For loop");
172 } else if (!v->body()) {
173 throw malformed_ir("invalid Body in For loop");
174 }
175 IRVisitor::visit(v);
176 }
177
visit(const BlockPtr & v)178 void IRVerifier::visit(const BlockPtr& v) {
179 for (const StmtPtr& s : v->stmts()) {
180 if (s->get_parent() != v) {
181 throw malformed_ir("Broken child-parent link inside a Block");
182 }
183 }
184 IRVisitor::visit(v);
185 }
186
visit(const ExternalCallPtr & v)187 void IRVerifier::visit(const ExternalCallPtr& v) {
188 IRVisitor::visit(v);
189 }
190
verify(const StmtPtr & s)191 void verify(const StmtPtr& s) {
192 IRVerifier verifier;
193 s->accept(&verifier);
194 }
195
verify(const ExprPtr & e)196 void verify(const ExprPtr& e) {
197 IRVerifier verifier;
198 e->accept(&verifier);
199 }
200
verify(const ExprHandle & e)201 void verify(const ExprHandle& e) {
202 verify(e.node());
203 }
204
205 } // namespace torch::jit::tensorexpr
206