1 #include <torch/csrc/jit/tensorexpr/hash_provider.h>
2
3 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
4
5 #include <c10/util/irange.h>
6
7 namespace torch::jit::tensorexpr {
8
operator ==(const SimplifierHashType & other) const9 bool SimplifierHashType::operator==(const SimplifierHashType& other) const {
10 return _h == other._h;
11 }
12
operator !=(const SimplifierHashType & other) const13 bool SimplifierHashType::operator!=(const SimplifierHashType& other) const {
14 return _h != other._h;
15 }
16
operator <(const SimplifierHashType & other) const17 bool SimplifierHashType::operator<(const SimplifierHashType& other) const {
18 return _h < other._h;
19 }
20
operator ==(const size_t other) const21 bool SimplifierHashType::operator==(const size_t other) const {
22 return _h == other;
23 }
24
operator !=(const size_t other) const25 bool SimplifierHashType::operator!=(const size_t other) const {
26 return _h != other;
27 }
28
visit(const AddPtr & v)29 void HashProvider::visit(const AddPtr& v) {
30 CACHE_GUARD();
31 v->lhs()->accept(this);
32 v->rhs()->accept(this);
33 putHash(v, hash_combine(hashOf(v->lhs()), "+", hashOf(v->rhs())));
34 }
35
visit(const SubPtr & v)36 void HashProvider::visit(const SubPtr& v) {
37 CACHE_GUARD();
38 v->lhs()->accept(this);
39 v->rhs()->accept(this);
40 putHash(v, hash_combine(hashOf(v->lhs()), "-", hashOf(v->rhs())));
41 }
42
visit(const MulPtr & v)43 void HashProvider::visit(const MulPtr& v) {
44 CACHE_GUARD();
45 v->lhs()->accept(this);
46 v->rhs()->accept(this);
47 putHash(v, hash_combine(hashOf(v->lhs()), "*", hashOf(v->rhs())));
48 }
49
visit(const DivPtr & v)50 void HashProvider::visit(const DivPtr& v) {
51 CACHE_GUARD();
52 v->lhs()->accept(this);
53 v->rhs()->accept(this);
54 putHash(v, hash_combine(hashOf(v->lhs()), "/", hashOf(v->rhs())));
55 }
56
visit(const ModPtr & v)57 void HashProvider::visit(const ModPtr& v) {
58 CACHE_GUARD();
59 v->lhs()->accept(this);
60 v->rhs()->accept(this);
61 putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs())));
62 }
63
visit(const RoundOffPtr & v)64 void HashProvider::visit(const RoundOffPtr& v) {
65 CACHE_GUARD();
66 v->lhs()->accept(this);
67 v->rhs()->accept(this);
68 putHash(v, hash_combine(hashOf(v->lhs()), "rof", hashOf(v->rhs())));
69 }
70
visit(const MaxPtr & v)71 void HashProvider::visit(const MaxPtr& v) {
72 CACHE_GUARD();
73 v->lhs()->accept(this);
74 v->rhs()->accept(this);
75 putHash(v, hash_combine(hashOf(v->lhs()), "Mx", hashOf(v->rhs())));
76 }
77
visit(const MinPtr & v)78 void HashProvider::visit(const MinPtr& v) {
79 CACHE_GUARD();
80 v->lhs()->accept(this);
81 v->rhs()->accept(this);
82 putHash(v, hash_combine(hashOf(v->lhs()), "Mn", hashOf(v->rhs())));
83 }
84
visit(const AndPtr & v)85 void HashProvider::visit(const AndPtr& v) {
86 CACHE_GUARD();
87 v->lhs()->accept(this);
88 v->rhs()->accept(this);
89 putHash(v, hash_combine(hashOf(v->lhs()), "&", hashOf(v->rhs())));
90 }
91
visit(const OrPtr & v)92 void HashProvider::visit(const OrPtr& v) {
93 CACHE_GUARD();
94 v->lhs()->accept(this);
95 v->rhs()->accept(this);
96 putHash(v, hash_combine(hashOf(v->lhs()), "|", hashOf(v->rhs())));
97 }
98
visit(const XorPtr & v)99 void HashProvider::visit(const XorPtr& v) {
100 CACHE_GUARD();
101 v->lhs()->accept(this);
102 v->rhs()->accept(this);
103 putHash(v, hash_combine(hashOf(v->lhs()), "^", hashOf(v->rhs())));
104 }
105
visit(const LshiftPtr & v)106 void HashProvider::visit(const LshiftPtr& v) {
107 CACHE_GUARD();
108 v->lhs()->accept(this);
109 v->rhs()->accept(this);
110 putHash(v, hash_combine(hashOf(v->lhs()), "<<", hashOf(v->rhs())));
111 }
112
visit(const RshiftPtr & v)113 void HashProvider::visit(const RshiftPtr& v) {
114 CACHE_GUARD();
115 v->lhs()->accept(this);
116 v->rhs()->accept(this);
117 putHash(v, hash_combine(hashOf(v->lhs()), ">>", hashOf(v->rhs())));
118 }
119
visit(const CompareSelectPtr & v)120 void HashProvider::visit(const CompareSelectPtr& v) {
121 CACHE_GUARD();
122 v->lhs()->accept(this);
123 v->rhs()->accept(this);
124 v->ret_val1()->accept(this);
125 v->ret_val2()->accept(this);
126 putHash(
127 v,
128 hash_combine(
129 hashOf(v->lhs()),
130 (int)v->compare_select_op(),
131 hashOf(v->rhs()),
132 hashOf(v->ret_val1()),
133 hashOf(v->ret_val2())));
134 }
135
visit(const CastPtr & v)136 void HashProvider::visit(const CastPtr& v) {
137 CACHE_GUARD();
138 v->src_value()->accept(this);
139 putHash(v, hash_combine("cast", v->dtype(), hashOf(v->src_value())));
140 }
141
visit(const VarPtr & v)142 void HashProvider::visit(const VarPtr& v) {
143 CACHE_GUARD();
144 putHash(v, hash_combine("var", name_manager_.get_unique_name(v)));
145 }
146
visit(const RampPtr & v)147 void HashProvider::visit(const RampPtr& v) {
148 CACHE_GUARD();
149 v->base()->accept(this);
150 v->stride()->accept(this);
151 putHash(
152 v,
153 hash_combine("ramp", hashOf(v->base()), hashOf(v->stride()), v->lanes()));
154 }
155
visit(const LoadPtr & v)156 void HashProvider::visit(const LoadPtr& v) {
157 CACHE_GUARD();
158 v->base_handle()->accept(this);
159 SimplifierHashType indices_hash;
160 for (const ExprPtr& ind : v->indices()) {
161 ind->accept(this);
162 indices_hash = hash_combine(indices_hash, hashOf(ind));
163 }
164 putHash(v, hash_combine("load", hashOf(v->base_handle()), indices_hash));
165 }
166
visit(const StorePtr & v)167 void HashProvider::visit(const StorePtr& v) {
168 CACHE_GUARD();
169 v->base_handle()->accept(this);
170 SimplifierHashType indices_hash;
171 for (const ExprPtr& ind : v->indices()) {
172 ind->accept(this);
173 indices_hash = hash_combine(indices_hash, hashOf(ind));
174 }
175 v->value()->accept(this);
176 putHash(
177 v,
178 hash_combine(
179 "store", hashOf(v->base_handle()), indices_hash, hashOf(v->value())));
180 }
181
visit(const BlockPtr & v)182 void HashProvider::visit(const BlockPtr& v) {
183 CACHE_GUARD();
184 SimplifierHashType hash;
185
186 for (const StmtPtr& s : *v) {
187 s->accept(this);
188 hash = hash_combine(hash, hashOf(s));
189 }
190 putHash(v, hash);
191 }
192
visit(const ForPtr & v)193 void HashProvider::visit(const ForPtr& v) {
194 CACHE_GUARD();
195 v->var()->accept(this);
196 v->start()->accept(this);
197 v->stop()->accept(this);
198
199 SimplifierHashType hash = hash_combine(
200 "for", hashOf(v->var()), hashOf(v->start()), hashOf(v->stop()));
201 hash = hash_combine(hash, v->loop_options().ToString());
202 if (v->body()) {
203 v->body()->accept(this);
204 hash = hash_combine(hash, hashOf(v->body()));
205 }
206
207 putHash(v, hash);
208 }
209
visit(const BroadcastPtr & v)210 void HashProvider::visit(const BroadcastPtr& v) {
211 CACHE_GUARD();
212 v->value()->accept(this);
213 putHash(v, hash_combine("broadcast", hashOf(v->value()), v->lanes()));
214 }
215
visit(const IfThenElsePtr & v)216 void HashProvider::visit(const IfThenElsePtr& v) {
217 CACHE_GUARD();
218 v->condition()->accept(this);
219 v->true_value()->accept(this);
220 v->false_value()->accept(this);
221
222 putHash(
223 v,
224 hash_combine(
225 "ifthenelse",
226 hashOf(v->condition()),
227 hashOf(v->true_value()),
228 hashOf(v->false_value())));
229 }
230
visit(const IntrinsicsPtr & v)231 void HashProvider::visit(const IntrinsicsPtr& v) {
232 CACHE_GUARD();
233 // calls to rand are not symbolic and have a different value each time, they
234 // should not hash to anything and this is the best we can do.
235 if (v->op_type() == kRand) {
236 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
237 putHash(v, (SimplifierHashType)rand());
238 return;
239 }
240
241 SimplifierHashType hash(te_hash(v->func_name()));
242 for (const auto i : c10::irange(v->nparams())) {
243 v->param(i)->accept(this);
244 hash = hash_combine(hash, hashOf(v->param(i)));
245 }
246
247 putHash(v, hash);
248 }
249
visit(const AllocatePtr & v)250 void HashProvider::visit(const AllocatePtr& v) {
251 CACHE_GUARD();
252 VarPtr buffer_var = v->buffer_var();
253 buffer_var->accept(this);
254
255 SimplifierHashType hash =
256 hash_combine("allocate", hashOf(buffer_var), v->dtype());
257
258 std::vector<ExprPtr> dims = v->dims();
259 for (const ExprPtr& dim : dims) {
260 dim->accept(this);
261 hash = hash_combine(hash, hashOf(dim));
262 }
263 putHash(v, hash);
264 }
265
visit(const FreePtr & v)266 void HashProvider::visit(const FreePtr& v) {
267 CACHE_GUARD();
268 VarPtr buffer_var = v->buffer_var();
269 buffer_var->accept(this);
270
271 putHash(v, hash_combine("free", hashOf(buffer_var)));
272 }
273
visit(const CondPtr & v)274 void HashProvider::visit(const CondPtr& v) {
275 CACHE_GUARD();
276 ExprPtr condition = v->condition();
277 StmtPtr true_stmt = v->true_stmt();
278 StmtPtr false_stmt = v->false_stmt();
279 condition->accept(this);
280
281 SimplifierHashType hash = hash_combine("cond", hashOf(condition));
282 if (true_stmt) {
283 true_stmt->accept(this);
284 hash = hash_combine(hash, hashOf(true_stmt));
285 }
286 if (false_stmt) {
287 false_stmt->accept(this);
288 hash = hash_combine(hash, hashOf(false_stmt));
289 }
290
291 putHash(v, hash);
292 }
293
visit(const TermPtr & v)294 void HashProvider::visit(const TermPtr& v) {
295 CACHE_GUARD();
296 v->scalar()->accept(this);
297
298 SimplifierHashType hash = hash_combine("term", hashOf(v->scalar()));
299 for (const auto& c : v->variables()) {
300 c->accept(this);
301 hash = hash_combine(hash, hashOf(c));
302 }
303
304 putHash(v, hash);
305 }
306
visit(const PolynomialPtr & v)307 void HashProvider::visit(const PolynomialPtr& v) {
308 CACHE_GUARD();
309 v->scalar()->accept(this);
310
311 SimplifierHashType hash = hash_combine("term", hashOf(v->scalar()));
312 for (const auto& c : v->variables()) {
313 c->accept(this);
314 hash = hash_combine(hash, hashOf(c));
315 }
316
317 putHash(v, hash);
318 }
319
visit(const MaxTermPtr & v)320 void HashProvider::visit(const MaxTermPtr& v) {
321 CACHE_GUARD();
322 SimplifierHashType hash = hash_combine("maxterm");
323 if (v->scalar()) {
324 v->scalar()->accept(this);
325 hash = hash_combine(hash, hashOf(v->scalar()));
326 }
327
328 for (const auto& c : v->variables()) {
329 c->accept(this);
330 hash = hash_combine(hash, hashOf(c));
331 }
332
333 putHash(v, hash);
334 }
335
visit(const MinTermPtr & v)336 void HashProvider::visit(const MinTermPtr& v) {
337 CACHE_GUARD();
338 SimplifierHashType hash = hash_combine("minterm");
339 if (v->scalar()) {
340 v->scalar()->accept(this);
341 hash = hash_combine(hash, hashOf(v->scalar()));
342 }
343
344 for (const auto& c : v->variables()) {
345 c->accept(this);
346 hash = hash_combine(hash, hashOf(c));
347 }
348
349 putHash(v, hash);
350 }
351
352 } // namespace torch::jit::tensorexpr
353