xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir_simplifier.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
4 #include <torch/csrc/jit/tensorexpr/eval.h>
5 #include <torch/csrc/jit/tensorexpr/hash_provider.h>
6 #include <torch/csrc/jit/tensorexpr/ir.h>
7 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
8 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
9 #include <torch/csrc/jit/tensorexpr/types.h>
10 
11 #include <utility>
12 
13 /* IR Simplification
14  *
15  * Simplifies expressions in two stages:
16  *  1. Recursively traverse the map combining similar operations into Terms
17  * (interacted via Multiplication) and Polynomials (interacted via Addition). We
18  * reorder the components of each Term or Polynomial into a consistent order to
19  * allow combination or cancelling of like terms.
20  *  2. Once the format of the tree is minimal, expand each Term into a sequence
21  * of Muls, and each Polynomial into a sequence of Ads.
22  */
23 
24 namespace torch::jit::tensorexpr {
25 
26 // A bunch of helpers for determine the Dtype of the output of a multi argument
27 // Term or Polynomial.
28 template <class ExprType>
promoteTypesVec(const ExprPtr & s,const std::vector<ExprType> & v)29 Dtype promoteTypesVec(const ExprPtr& s, const std::vector<ExprType>& v) {
30   Dtype t = s->dtype();
31   bool first = true;
32 
33   for (const auto& e : v) {
34     if (first) {
35       t = Dtype(t.scalar_type(), e->dtype().lanes());
36       first = false;
37     }
38     t = promoteTypes(t, e->dtype());
39   }
40   return t;
41 }
42 
43 template <class ExprType>
promoteTypesVec(const std::vector<ExprType> & v)44 Dtype promoteTypesVec(const std::vector<ExprType>& v) {
45   if (v.empty()) {
46     throw malformed_input("empty list of types");
47   }
48 
49   Dtype t = v[0]->dtype();
50   for (const auto& e : v) {
51     t = promoteTypes(t, e->dtype());
52   }
53   return t;
54 }
55 
56 template <class ExprType>
promoteTypesMap(const ExprPtr & s,std::unordered_map<SimplifierHashType,ExprType> & m)57 Dtype promoteTypesMap(
58     const ExprPtr& s,
59     std::unordered_map<SimplifierHashType, ExprType>& m) {
60   Dtype t = s->dtype();
61   bool first = true;
62   for (auto& e : m) {
63     if (first) {
64       t = Dtype(t.scalar_type(), e.second->dtype().lanes());
65       first = false;
66     }
67     t = promoteTypes(t, e.second->dtype());
68   }
69   return t;
70 }
71 
72 template <class ExprType>
promoteTypesVar(ExprType e)73 Dtype promoteTypesVar(ExprType e) {
74   return e->dtype();
75 }
76 
77 template <class ExprType, class... Args>
promoteTypesVar(ExprType e,Args...es)78 Dtype promoteTypesVar(ExprType e, Args... es) {
79   Dtype lhs = e->dtype();
80   Dtype rhs = promoteTypesVar(es...);
81   if (e->isConstant()) {
82     lhs = Dtype(lhs.scalar_type(), rhs.lanes());
83   }
84 
85   return promoteTypes(lhs, rhs);
86 }
87 
88 // Uses the evaluator to fold an Expression with constant terms.
89 // E.g. evaluateOp(Add(3, 4)) => 7.
90 // Expr v must not have any unbound Vars.
evaluateOp(const ExprPtr & v)91 inline ExprPtr evaluateOp(const ExprPtr& v) {
92   ExprHandle handle(v);
93   ExprEval<SimpleIREvaluator> eval(handle);
94 
95   switch (v->dtype().scalar_type()) {
96 #define TYPE_CASE(Type, Name)                                 \
97   case ScalarType::Name: {                                    \
98     Type val = eval.value<Type>();                            \
99     return getImmediateByType(v->dtype().scalar_type(), val); \
100   }
101     AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
102 #undef TYPE_CASE
103     default:
104       LOG(FATAL) << "Unsupported datatype: " << v->dtype();
105       return nullptr;
106   }
107   return nullptr;
108 }
109 
110 // A Term represents a grouping of Exprs through multiplication.
111 // E.g. product(scalar, *variables).
112 class Term : public ExprNode<Term> {
113  public:
114   template <class... Args>
Term(HashProvider & hasher,ExprPtr s,Args...ts)115   Term(HashProvider& hasher, ExprPtr s, Args... ts)
116       : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
117     CHECK(s->isConstant());
118     addComponent(ts...);
119     sort();
120   }
121 
Term(HashProvider & hasher,ExprPtr s,std::vector<ExprPtr> v)122   Term(HashProvider& hasher, ExprPtr s, std::vector<ExprPtr> v)
123       : ExprNodeBase(promoteTypesVec(s, v)),
124         variables_(std::move(v)),
125         scalar_(std::move(s)),
126         hasher_(hasher) {
127     sort();
128   }
129 
130   // Convenience constructor from a map of hash -> var, used when merging Terms.
Term(HashProvider & hasher,const ExprPtr & s,std::unordered_map<SimplifierHashType,ExprPtr> varmap)131   Term(
132       HashProvider& hasher,
133       const ExprPtr& s,
134       std::unordered_map<SimplifierHashType, ExprPtr> varmap)
135       : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
136     for (auto& p : varmap) {
137       addComponent(p.second);
138     }
139     sort();
140   }
141 
scalar()142   ExprPtr scalar() const {
143     return scalar_;
144   }
variables()145   const std::vector<ExprPtr>& variables() const {
146     return variables_;
147   }
hasher()148   HashProvider& hasher() const {
149     return hasher_;
150   }
151 
152   // Produce a hash of just the variable components of this term, to determine
153   // if it can be combined with another term.
154   SimplifierHashType hashVars() const;
155 
156  private:
157   std::vector<ExprPtr> variables_;
158   ExprPtr scalar_;
159   HashProvider& hasher_;
160 
addComponent()161   void addComponent() {}
addComponent(ExprPtr e)162   void addComponent(ExprPtr e) {
163     variables_.push_back(std::move(e));
164   }
165   template <class... Es>
addComponent(ExprPtr e,Es &&...es)166   void addComponent(ExprPtr e, Es&&... es) {
167     addComponent(std::move(e));
168     addComponent(std::forward<Es>(es)...);
169   }
170 
171   // Sort by hash to normalize order of components.
172   void sort();
173 };
174 
175 // Polynomial represents a grouping of Exprs by addition.
176 // E.g. sum(*variables, scalar).
177 // This would better be called Expression, but, naming conflict...
178 class Polynomial : public ExprNode<Polynomial> {
179  public:
180   template <class... Args>
Polynomial(HashProvider & hasher,ExprPtr s,Args...ts)181   Polynomial(HashProvider& hasher, ExprPtr s, Args... ts)
182       : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
183     CHECK(s->isConstant());
184     addTerm(ts...);
185     sort();
186   }
187 
Polynomial(HashProvider & hasher,const ExprPtr & s,std::vector<TermPtr> v)188   Polynomial(HashProvider& hasher, const ExprPtr& s, std::vector<TermPtr> v)
189       : ExprNodeBase(promoteTypesVec(s, v)),
190         variables_(std::move(v)),
191         scalar_(s),
192         hasher_(hasher) {
193     sort();
194   }
195 
196   // Helper constructor for list of terms with no scalar component.
Polynomial(HashProvider & hasher,std::vector<TermPtr> terms)197   Polynomial(HashProvider& hasher, std::vector<TermPtr> terms)
198       : ExprNodeBase(promoteTypesVec(terms)),
199         variables_(std::move(terms)),
200         scalar_(getImmediateByType(dtype(), 0)),
201         hasher_(hasher) {
202     sort();
203   }
204 
205   // Convenience constructor for map of hash -> var, used when merging
206   // Polynomials.
Polynomial(HashProvider & hasher,const ExprPtr & s,std::unordered_map<SimplifierHashType,TermPtr> varmap)207   Polynomial(
208       HashProvider& hasher,
209       const ExprPtr& s,
210       std::unordered_map<SimplifierHashType, TermPtr> varmap)
211       : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
212     for (auto& p : varmap) {
213       addTerm(p.second);
214     }
215     sort();
216   }
217 
scalar()218   ExprPtr scalar() const {
219     return scalar_;
220   }
variables()221   const std::vector<TermPtr>& variables() const {
222     return variables_;
223   }
hasher()224   HashProvider& hasher() const {
225     return hasher_;
226   }
227 
228   SimplifierHashType hashVars() const;
229 
230  private:
231   std::vector<TermPtr> variables_;
232   ExprPtr scalar_;
233   HashProvider& hasher_;
234 
addTerm(TermPtr t)235   void addTerm(TermPtr t) {
236     variables_.push_back(std::move(t));
237   }
238   template <class... Ts>
addTerm(TermPtr t,Ts &&...ts)239   void addTerm(TermPtr t, Ts&&... ts) {
240     addTerm(std::move(t));
241     addTerm(std::forward<Ts>(ts)...);
242   }
243 
244   // Sort by hash to normalize order of terms.
245   void sort();
246 };
247 
248 class RoundOff : public BinaryOpNode<RoundOff> {
249  public:
RoundOff(ExprPtr lhs,ExprPtr rhs)250   RoundOff(ExprPtr lhs, ExprPtr rhs)
251       : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kOther) {}
252 };
253 
254 class MaxTerm : public ExprNode<MaxTerm> {
255  public:
256   template <class... Args>
MaxTerm(HashProvider & hasher,ExprPtr s,bool p,Args...ts)257   MaxTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts)
258       : ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
259         scalar_(s),
260         hasher_(hasher),
261         propagate_nans_(p) {
262     addComponent(ts...);
263     uniquefy();
264   }
265 
MaxTerm(HashProvider & hasher,const ExprPtr & s,bool p,std::vector<ExprPtr> v)266   MaxTerm(
267       HashProvider& hasher,
268       const ExprPtr& s,
269       bool p,
270       std::vector<ExprPtr> v)
271       : ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
272         variables_(std::move(v)),
273         scalar_(s),
274         hasher_(hasher),
275         propagate_nans_(p) {
276     uniquefy();
277   }
278 
propagate_nans()279   bool propagate_nans() const {
280     return propagate_nans_;
281   }
282 
scalar()283   ExprPtr scalar() const {
284     return scalar_;
285   }
variables()286   const std::vector<ExprPtr>& variables() const {
287     return variables_;
288   }
hasher()289   HashProvider& hasher() const {
290     return hasher_;
291   }
292 
293  private:
294   std::vector<ExprPtr> variables_;
295   ExprPtr scalar_;
296   HashProvider& hasher_;
297   bool propagate_nans_;
298 
addComponent()299   void addComponent() {}
addComponent(ExprPtr e)300   void addComponent(ExprPtr e) {
301     variables_.push_back(std::move(e));
302   }
303   template <class... Es>
addComponent(ExprPtr e,Es &&...es)304   void addComponent(ExprPtr e, Es&&... es) {
305     addComponent(std::move(e));
306     addComponent(std::forward<Es>(es)...);
307   }
308 
309   // Uniquefy the terms using their hash.
310   void uniquefy();
311 };
312 
313 class MinTerm : public ExprNode<MinTerm> {
314  public:
315   template <class... Args>
MinTerm(HashProvider & hasher,ExprPtr s,bool p,Args...ts)316   MinTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts)
317       : ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
318         scalar_(s),
319         hasher_(hasher),
320         propagate_nans_(p) {
321     addComponent(ts...);
322     uniquefy();
323   }
324 
MinTerm(HashProvider & hasher,const ExprPtr & s,bool p,std::vector<ExprPtr> v)325   MinTerm(
326       HashProvider& hasher,
327       const ExprPtr& s,
328       bool p,
329       std::vector<ExprPtr> v)
330       : ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
331         variables_(std::move(v)),
332         scalar_(s),
333         hasher_(hasher),
334         propagate_nans_(p) {
335     uniquefy();
336   }
337 
propagate_nans()338   bool propagate_nans() const {
339     return propagate_nans_;
340   }
341 
scalar()342   ExprPtr scalar() const {
343     return scalar_;
344   }
variables()345   const std::vector<ExprPtr>& variables() const {
346     return variables_;
347   }
hasher()348   HashProvider& hasher() const {
349     return hasher_;
350   }
351 
352  private:
353   std::vector<ExprPtr> variables_;
354   ExprPtr scalar_;
355   HashProvider& hasher_;
356   bool propagate_nans_;
357 
addComponent()358   void addComponent() {}
addComponent(ExprPtr e)359   void addComponent(ExprPtr e) {
360     variables_.push_back(std::move(e));
361   }
362   template <class... Es>
addComponent(ExprPtr e,Es &&...es)363   void addComponent(ExprPtr e, Es&&... es) {
364     addComponent(std::move(e));
365     addComponent(std::forward<Es>(es)...);
366   }
367 
368   // Uniquefy the terms using their hash.
369   void uniquefy();
370 };
371 
372 // Context-sensitive IR simplification
373 using VarBoundInfo = std::unordered_map<VarPtr, analysis::Bound>;
374 
375 class TORCH_API SimplifierUnderContext : public IRMutator {
376  public:
377   ~SimplifierUnderContext() override = default;
378   // Add boundary info for index variables in for-loops
379   StmtPtr mutate(const ForPtr& v) override;
380 
381   ExprPtr mutate(const DivPtr& v) override;
382   ExprPtr mutate(const ModPtr& v) override;
383   ExprPtr mutate(const CompareSelectPtr& v) override;
384   ExprPtr mutate(const IfThenElsePtr& v) override;
385 
386  protected:
387   bool getLoopBoundInfo(const ExprPtr& expr, analysis::Bound* loop_bound_info);
388 
389  protected:
390   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
391   HashProvider hasher_;
392   VarBoundInfo var_bound_info_;
393 };
394 
395 // Stmt simplification should occur in both modes.
396 class TORCH_API PolynomialBase : public IRMutator {
397  public:
398   ~PolynomialBase() override = default;
399 
400   StmtPtr mutate(const BlockPtr& v) override;
401 
402   StmtPtr mutate(const CondPtr& v) override;
403 
404   StmtPtr mutate(const ForPtr& v) override;
405 
406   // Trivially factorize terms by GCD of scalar components.
407   TermPtr factorizePolynomial(const PolynomialPtr& poly);
408 
hasher()409   HashProvider& hasher() {
410     return hasher_;
411   }
412 
413  protected:
414   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
415   HashProvider hasher_;
416 };
417 
418 // Simplify the IR by combining arithmetic expressions over common terms.
419 class TORCH_API PolynomialTransformer : public PolynomialBase {
420  public:
421   using PolynomialBase::mutate;
422   // Inserts term into the provided map, in the case of a hash collision
423   // combines the term with the existing and updates the map.
424   void addOrUpdateTerm(
425       std::unordered_map<SimplifierHashType, TermPtr>& varmap,
426       const TermPtr& term);
427 
428   // Add Polynomial expressions, combining Terms representing the same
429   // variables.
430   ExprPtr addPolynomials(const PolynomialPtr& lhs, const PolynomialPtr& rhs);
431 
432   // Insert a new Term into the provided polynomial. If the new term has
433   // common variables to an existing term it is combined.
434   ExprPtr insertTerm(const PolynomialPtr& poly, const TermPtr& term);
435 
436   // Merge and simplify addition.
437   ExprPtr mutate(const AddPtr& v) override;
438 
439   // Subtract one term from another, cancelling if necessary.
440   ExprPtr subTerms(const TermPtr& lhs, TermPtr rhs, bool negated);
441 
442   // Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
443   // possible.
444   ExprPtr subPolynomials(const PolynomialPtr& lhs, const PolynomialPtr& rhs);
445 
446   // Merge and simplify subtraction.
447   ExprPtr mutate(const SubPtr& v) override;
448 
449   // Multiply two terms together, usually creating a new term with the variable
450   // lists concatenated.
451   TermPtr mulTerms(const TermPtr& lhs, const TermPtr& rhs);
452 
453   // Multiply a Polynomial by a Term.
454   ExprPtr polyByTerm(const PolynomialPtr& poly, const TermPtr& term);
455 
456   // Match a rounding pattern and create a RoundOff if found.
457   ExprPtr isRoundOff(const ExprPtr& lhs, const ExprPtr& rhs);
458 
459   // Inserts a new component into a term, simplifying if possible.
460   ExprPtr insertIntoTerm(const TermPtr& term, const ExprPtr& expr);
461 
462   // Merge and simplify multiplication.
463   ExprPtr mutate(const MulPtr& v) override;
464 
465   ExprPtr mutate(const DivPtr& v) override;
466 
467   ExprPtr mutate(const ModPtr& v) override;
468 
469   ExprPtr mutate(const AndPtr& v) override;
470 
471   ExprPtr mutate(const XorPtr& v) override;
472 
473   ExprPtr mutate(const LshiftPtr& v) override;
474 
475   ExprPtr mutate(const RshiftPtr& v) override;
476 
477   ExprPtr mutate(const MaxPtr& v) override;
478 
479   ExprPtr mutate(const MinPtr& v) override;
480 
481   ExprPtr mutate(const CompareSelectPtr& v) override;
482 
483   ExprPtr mutate(const IntrinsicsPtr& v) override;
484 
485   ExprPtr mutate(const CastPtr& v) override;
486 
487   ExprPtr mutate(const IfThenElsePtr& v) override;
488 
489   static ExprPtr simplify(ExprPtr e);
490   static ExprHandle simplify(const ExprHandle& e);
491   static StmtPtr simplify(StmtPtr e);
492 };
493 
494 // Expands Terms and Polynomial expressions into primitive operations.
495 // Does some simple factorization and reordering.
496 class TORCH_API TermExpander : public PolynomialBase {
497   PolynomialTransformer* simplifier_;
498   std::set<VarPtr> eliminated_allocations_;
499 
500  public:
501   using PolynomialBase::mutate;
TermExpander(PolynomialTransformer * simplifier)502   TermExpander(PolynomialTransformer* simplifier) : simplifier_(simplifier) {}
check_safe()503   bool check_safe() {
504     return eliminated_allocations_.empty();
505   }
506 
507   // Expand Terms out to a series of Muls.
508   ExprPtr mutate(const TermPtr& v) override;
509 
510   // Expand Polynomials out to a series of Adds.
511   ExprPtr mutate(const PolynomialPtr& v) override;
512 
513   // Expand MaxTerms to a series of Max ops.
514   ExprPtr mutate(const MaxTermPtr& v) override;
515 
516   // Expand MinTerms to a series of Min ops.
517   ExprPtr mutate(const MinTermPtr& v) override;
518 
519   // Expand RoundOff to it's component: Mul(Div(lhs, rhs), rhs).
520   ExprPtr mutate(const RoundOffPtr& v) override;
521 
522   // Eliminate zero length allocations.
523   StmtPtr mutate(const AllocatePtr& v) override;
524   StmtPtr mutate(const FreePtr& v) override;
525 
526   // Override to enable condition fusing.
527   BlockPtr fuseConditions(BlockPtr v);
528   StmtPtr fuseSyncThreads(BlockPtr block);
529   StmtPtr mutate(const BlockPtr& v) override;
530 };
531 
532 class TORCH_API IRSimplifier {
533  public:
534   static StmtPtr simplify(StmtPtr s);
535   static ExprPtr simplify(ExprPtr e);
simplify(const ExprHandle & e)536   static ExprHandle simplify(const ExprHandle& e) {
537     return ExprHandle(simplify(e.node()));
538   }
539 };
540 
541 // Flattens the buf and performs the simplifier on the flattened dims.
542 ExprPtr buf_flat_size(const BufPtr& v);
543 // Returns true if expressions A and B can be simplified to an equal expression.
544 TORCH_API bool exprEquals(const ExprPtr& A, const ExprPtr& B);
545 
546 } // namespace torch::jit::tensorexpr
547