xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/hash_provider.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/hash_provider.h>
2 
3 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
4 
5 #include <c10/util/irange.h>
6 
7 namespace torch::jit::tensorexpr {
8 
operator ==(const SimplifierHashType & other) const9 bool SimplifierHashType::operator==(const SimplifierHashType& other) const {
10   return _h == other._h;
11 }
12 
operator !=(const SimplifierHashType & other) const13 bool SimplifierHashType::operator!=(const SimplifierHashType& other) const {
14   return _h != other._h;
15 }
16 
operator <(const SimplifierHashType & other) const17 bool SimplifierHashType::operator<(const SimplifierHashType& other) const {
18   return _h < other._h;
19 }
20 
operator ==(const size_t other) const21 bool SimplifierHashType::operator==(const size_t other) const {
22   return _h == other;
23 }
24 
operator !=(const size_t other) const25 bool SimplifierHashType::operator!=(const size_t other) const {
26   return _h != other;
27 }
28 
visit(const AddPtr & v)29 void HashProvider::visit(const AddPtr& v) {
30   CACHE_GUARD();
31   v->lhs()->accept(this);
32   v->rhs()->accept(this);
33   putHash(v, hash_combine(hashOf(v->lhs()), "+", hashOf(v->rhs())));
34 }
35 
visit(const SubPtr & v)36 void HashProvider::visit(const SubPtr& v) {
37   CACHE_GUARD();
38   v->lhs()->accept(this);
39   v->rhs()->accept(this);
40   putHash(v, hash_combine(hashOf(v->lhs()), "-", hashOf(v->rhs())));
41 }
42 
visit(const MulPtr & v)43 void HashProvider::visit(const MulPtr& v) {
44   CACHE_GUARD();
45   v->lhs()->accept(this);
46   v->rhs()->accept(this);
47   putHash(v, hash_combine(hashOf(v->lhs()), "*", hashOf(v->rhs())));
48 }
49 
visit(const DivPtr & v)50 void HashProvider::visit(const DivPtr& v) {
51   CACHE_GUARD();
52   v->lhs()->accept(this);
53   v->rhs()->accept(this);
54   putHash(v, hash_combine(hashOf(v->lhs()), "/", hashOf(v->rhs())));
55 }
56 
visit(const ModPtr & v)57 void HashProvider::visit(const ModPtr& v) {
58   CACHE_GUARD();
59   v->lhs()->accept(this);
60   v->rhs()->accept(this);
61   putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs())));
62 }
63 
visit(const RoundOffPtr & v)64 void HashProvider::visit(const RoundOffPtr& v) {
65   CACHE_GUARD();
66   v->lhs()->accept(this);
67   v->rhs()->accept(this);
68   putHash(v, hash_combine(hashOf(v->lhs()), "rof", hashOf(v->rhs())));
69 }
70 
visit(const MaxPtr & v)71 void HashProvider::visit(const MaxPtr& v) {
72   CACHE_GUARD();
73   v->lhs()->accept(this);
74   v->rhs()->accept(this);
75   putHash(v, hash_combine(hashOf(v->lhs()), "Mx", hashOf(v->rhs())));
76 }
77 
visit(const MinPtr & v)78 void HashProvider::visit(const MinPtr& v) {
79   CACHE_GUARD();
80   v->lhs()->accept(this);
81   v->rhs()->accept(this);
82   putHash(v, hash_combine(hashOf(v->lhs()), "Mn", hashOf(v->rhs())));
83 }
84 
visit(const AndPtr & v)85 void HashProvider::visit(const AndPtr& v) {
86   CACHE_GUARD();
87   v->lhs()->accept(this);
88   v->rhs()->accept(this);
89   putHash(v, hash_combine(hashOf(v->lhs()), "&", hashOf(v->rhs())));
90 }
91 
visit(const OrPtr & v)92 void HashProvider::visit(const OrPtr& v) {
93   CACHE_GUARD();
94   v->lhs()->accept(this);
95   v->rhs()->accept(this);
96   putHash(v, hash_combine(hashOf(v->lhs()), "|", hashOf(v->rhs())));
97 }
98 
visit(const XorPtr & v)99 void HashProvider::visit(const XorPtr& v) {
100   CACHE_GUARD();
101   v->lhs()->accept(this);
102   v->rhs()->accept(this);
103   putHash(v, hash_combine(hashOf(v->lhs()), "^", hashOf(v->rhs())));
104 }
105 
visit(const LshiftPtr & v)106 void HashProvider::visit(const LshiftPtr& v) {
107   CACHE_GUARD();
108   v->lhs()->accept(this);
109   v->rhs()->accept(this);
110   putHash(v, hash_combine(hashOf(v->lhs()), "<<", hashOf(v->rhs())));
111 }
112 
visit(const RshiftPtr & v)113 void HashProvider::visit(const RshiftPtr& v) {
114   CACHE_GUARD();
115   v->lhs()->accept(this);
116   v->rhs()->accept(this);
117   putHash(v, hash_combine(hashOf(v->lhs()), ">>", hashOf(v->rhs())));
118 }
119 
visit(const CompareSelectPtr & v)120 void HashProvider::visit(const CompareSelectPtr& v) {
121   CACHE_GUARD();
122   v->lhs()->accept(this);
123   v->rhs()->accept(this);
124   v->ret_val1()->accept(this);
125   v->ret_val2()->accept(this);
126   putHash(
127       v,
128       hash_combine(
129           hashOf(v->lhs()),
130           (int)v->compare_select_op(),
131           hashOf(v->rhs()),
132           hashOf(v->ret_val1()),
133           hashOf(v->ret_val2())));
134 }
135 
visit(const CastPtr & v)136 void HashProvider::visit(const CastPtr& v) {
137   CACHE_GUARD();
138   v->src_value()->accept(this);
139   putHash(v, hash_combine("cast", v->dtype(), hashOf(v->src_value())));
140 }
141 
visit(const VarPtr & v)142 void HashProvider::visit(const VarPtr& v) {
143   CACHE_GUARD();
144   putHash(v, hash_combine("var", name_manager_.get_unique_name(v)));
145 }
146 
visit(const RampPtr & v)147 void HashProvider::visit(const RampPtr& v) {
148   CACHE_GUARD();
149   v->base()->accept(this);
150   v->stride()->accept(this);
151   putHash(
152       v,
153       hash_combine("ramp", hashOf(v->base()), hashOf(v->stride()), v->lanes()));
154 }
155 
visit(const LoadPtr & v)156 void HashProvider::visit(const LoadPtr& v) {
157   CACHE_GUARD();
158   v->base_handle()->accept(this);
159   SimplifierHashType indices_hash;
160   for (const ExprPtr& ind : v->indices()) {
161     ind->accept(this);
162     indices_hash = hash_combine(indices_hash, hashOf(ind));
163   }
164   putHash(v, hash_combine("load", hashOf(v->base_handle()), indices_hash));
165 }
166 
visit(const StorePtr & v)167 void HashProvider::visit(const StorePtr& v) {
168   CACHE_GUARD();
169   v->base_handle()->accept(this);
170   SimplifierHashType indices_hash;
171   for (const ExprPtr& ind : v->indices()) {
172     ind->accept(this);
173     indices_hash = hash_combine(indices_hash, hashOf(ind));
174   }
175   v->value()->accept(this);
176   putHash(
177       v,
178       hash_combine(
179           "store", hashOf(v->base_handle()), indices_hash, hashOf(v->value())));
180 }
181 
visit(const BlockPtr & v)182 void HashProvider::visit(const BlockPtr& v) {
183   CACHE_GUARD();
184   SimplifierHashType hash;
185 
186   for (const StmtPtr& s : *v) {
187     s->accept(this);
188     hash = hash_combine(hash, hashOf(s));
189   }
190   putHash(v, hash);
191 }
192 
visit(const ForPtr & v)193 void HashProvider::visit(const ForPtr& v) {
194   CACHE_GUARD();
195   v->var()->accept(this);
196   v->start()->accept(this);
197   v->stop()->accept(this);
198 
199   SimplifierHashType hash = hash_combine(
200       "for", hashOf(v->var()), hashOf(v->start()), hashOf(v->stop()));
201   hash = hash_combine(hash, v->loop_options().ToString());
202   if (v->body()) {
203     v->body()->accept(this);
204     hash = hash_combine(hash, hashOf(v->body()));
205   }
206 
207   putHash(v, hash);
208 }
209 
visit(const BroadcastPtr & v)210 void HashProvider::visit(const BroadcastPtr& v) {
211   CACHE_GUARD();
212   v->value()->accept(this);
213   putHash(v, hash_combine("broadcast", hashOf(v->value()), v->lanes()));
214 }
215 
visit(const IfThenElsePtr & v)216 void HashProvider::visit(const IfThenElsePtr& v) {
217   CACHE_GUARD();
218   v->condition()->accept(this);
219   v->true_value()->accept(this);
220   v->false_value()->accept(this);
221 
222   putHash(
223       v,
224       hash_combine(
225           "ifthenelse",
226           hashOf(v->condition()),
227           hashOf(v->true_value()),
228           hashOf(v->false_value())));
229 }
230 
visit(const IntrinsicsPtr & v)231 void HashProvider::visit(const IntrinsicsPtr& v) {
232   CACHE_GUARD();
233   // calls to rand are not symbolic and have a different value each time, they
234   // should not hash to anything and this is the best we can do.
235   if (v->op_type() == kRand) {
236     // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
237     putHash(v, (SimplifierHashType)rand());
238     return;
239   }
240 
241   SimplifierHashType hash(te_hash(v->func_name()));
242   for (const auto i : c10::irange(v->nparams())) {
243     v->param(i)->accept(this);
244     hash = hash_combine(hash, hashOf(v->param(i)));
245   }
246 
247   putHash(v, hash);
248 }
249 
visit(const AllocatePtr & v)250 void HashProvider::visit(const AllocatePtr& v) {
251   CACHE_GUARD();
252   VarPtr buffer_var = v->buffer_var();
253   buffer_var->accept(this);
254 
255   SimplifierHashType hash =
256       hash_combine("allocate", hashOf(buffer_var), v->dtype());
257 
258   std::vector<ExprPtr> dims = v->dims();
259   for (const ExprPtr& dim : dims) {
260     dim->accept(this);
261     hash = hash_combine(hash, hashOf(dim));
262   }
263   putHash(v, hash);
264 }
265 
visit(const FreePtr & v)266 void HashProvider::visit(const FreePtr& v) {
267   CACHE_GUARD();
268   VarPtr buffer_var = v->buffer_var();
269   buffer_var->accept(this);
270 
271   putHash(v, hash_combine("free", hashOf(buffer_var)));
272 }
273 
visit(const CondPtr & v)274 void HashProvider::visit(const CondPtr& v) {
275   CACHE_GUARD();
276   ExprPtr condition = v->condition();
277   StmtPtr true_stmt = v->true_stmt();
278   StmtPtr false_stmt = v->false_stmt();
279   condition->accept(this);
280 
281   SimplifierHashType hash = hash_combine("cond", hashOf(condition));
282   if (true_stmt) {
283     true_stmt->accept(this);
284     hash = hash_combine(hash, hashOf(true_stmt));
285   }
286   if (false_stmt) {
287     false_stmt->accept(this);
288     hash = hash_combine(hash, hashOf(false_stmt));
289   }
290 
291   putHash(v, hash);
292 }
293 
visit(const TermPtr & v)294 void HashProvider::visit(const TermPtr& v) {
295   CACHE_GUARD();
296   v->scalar()->accept(this);
297 
298   SimplifierHashType hash = hash_combine("term", hashOf(v->scalar()));
299   for (const auto& c : v->variables()) {
300     c->accept(this);
301     hash = hash_combine(hash, hashOf(c));
302   }
303 
304   putHash(v, hash);
305 }
306 
visit(const PolynomialPtr & v)307 void HashProvider::visit(const PolynomialPtr& v) {
308   CACHE_GUARD();
309   v->scalar()->accept(this);
310 
311   SimplifierHashType hash = hash_combine("term", hashOf(v->scalar()));
312   for (const auto& c : v->variables()) {
313     c->accept(this);
314     hash = hash_combine(hash, hashOf(c));
315   }
316 
317   putHash(v, hash);
318 }
319 
visit(const MaxTermPtr & v)320 void HashProvider::visit(const MaxTermPtr& v) {
321   CACHE_GUARD();
322   SimplifierHashType hash = hash_combine("maxterm");
323   if (v->scalar()) {
324     v->scalar()->accept(this);
325     hash = hash_combine(hash, hashOf(v->scalar()));
326   }
327 
328   for (const auto& c : v->variables()) {
329     c->accept(this);
330     hash = hash_combine(hash, hashOf(c));
331   }
332 
333   putHash(v, hash);
334 }
335 
visit(const MinTermPtr & v)336 void HashProvider::visit(const MinTermPtr& v) {
337   CACHE_GUARD();
338   SimplifierHashType hash = hash_combine("minterm");
339   if (v->scalar()) {
340     v->scalar()->accept(this);
341     hash = hash_combine(hash, hashOf(v->scalar()));
342   }
343 
344   for (const auto& c : v->variables()) {
345     c->accept(this);
346     hash = hash_combine(hash, hashOf(c));
347   }
348 
349   putHash(v, hash);
350 }
351 
352 } // namespace torch::jit::tensorexpr
353