1 #pragma once
2
3 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
4 #include <torch/csrc/jit/tensorexpr/eval.h>
5 #include <torch/csrc/jit/tensorexpr/hash_provider.h>
6 #include <torch/csrc/jit/tensorexpr/ir.h>
7 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
8 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
9 #include <torch/csrc/jit/tensorexpr/types.h>
10
11 #include <utility>
12
13 /* IR Simplification
14 *
15 * Simplifies expressions in two stages:
16 * 1. Recursively traverse the map combining similar operations into Terms
17 * (interacted via Multiplication) and Polynomials (interacted via Addition). We
18 * reorder the components of each Term or Polynomial into a consistent order to
19 * allow combination or cancelling of like terms.
20 * 2. Once the format of the tree is minimal, expand each Term into a sequence
21 * of Muls, and each Polynomial into a sequence of Ads.
22 */
23
24 namespace torch::jit::tensorexpr {
25
26 // A bunch of helpers for determine the Dtype of the output of a multi argument
27 // Term or Polynomial.
28 template <class ExprType>
promoteTypesVec(const ExprPtr & s,const std::vector<ExprType> & v)29 Dtype promoteTypesVec(const ExprPtr& s, const std::vector<ExprType>& v) {
30 Dtype t = s->dtype();
31 bool first = true;
32
33 for (const auto& e : v) {
34 if (first) {
35 t = Dtype(t.scalar_type(), e->dtype().lanes());
36 first = false;
37 }
38 t = promoteTypes(t, e->dtype());
39 }
40 return t;
41 }
42
43 template <class ExprType>
promoteTypesVec(const std::vector<ExprType> & v)44 Dtype promoteTypesVec(const std::vector<ExprType>& v) {
45 if (v.empty()) {
46 throw malformed_input("empty list of types");
47 }
48
49 Dtype t = v[0]->dtype();
50 for (const auto& e : v) {
51 t = promoteTypes(t, e->dtype());
52 }
53 return t;
54 }
55
56 template <class ExprType>
promoteTypesMap(const ExprPtr & s,std::unordered_map<SimplifierHashType,ExprType> & m)57 Dtype promoteTypesMap(
58 const ExprPtr& s,
59 std::unordered_map<SimplifierHashType, ExprType>& m) {
60 Dtype t = s->dtype();
61 bool first = true;
62 for (auto& e : m) {
63 if (first) {
64 t = Dtype(t.scalar_type(), e.second->dtype().lanes());
65 first = false;
66 }
67 t = promoteTypes(t, e.second->dtype());
68 }
69 return t;
70 }
71
72 template <class ExprType>
promoteTypesVar(ExprType e)73 Dtype promoteTypesVar(ExprType e) {
74 return e->dtype();
75 }
76
77 template <class ExprType, class... Args>
promoteTypesVar(ExprType e,Args...es)78 Dtype promoteTypesVar(ExprType e, Args... es) {
79 Dtype lhs = e->dtype();
80 Dtype rhs = promoteTypesVar(es...);
81 if (e->isConstant()) {
82 lhs = Dtype(lhs.scalar_type(), rhs.lanes());
83 }
84
85 return promoteTypes(lhs, rhs);
86 }
87
88 // Uses the evaluator to fold an Expression with constant terms.
89 // E.g. evaluateOp(Add(3, 4)) => 7.
90 // Expr v must not have any unbound Vars.
evaluateOp(const ExprPtr & v)91 inline ExprPtr evaluateOp(const ExprPtr& v) {
92 ExprHandle handle(v);
93 ExprEval<SimpleIREvaluator> eval(handle);
94
95 switch (v->dtype().scalar_type()) {
96 #define TYPE_CASE(Type, Name) \
97 case ScalarType::Name: { \
98 Type val = eval.value<Type>(); \
99 return getImmediateByType(v->dtype().scalar_type(), val); \
100 }
101 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
102 #undef TYPE_CASE
103 default:
104 LOG(FATAL) << "Unsupported datatype: " << v->dtype();
105 return nullptr;
106 }
107 return nullptr;
108 }
109
110 // A Term represents a grouping of Exprs through multiplication.
111 // E.g. product(scalar, *variables).
112 class Term : public ExprNode<Term> {
113 public:
114 template <class... Args>
Term(HashProvider & hasher,ExprPtr s,Args...ts)115 Term(HashProvider& hasher, ExprPtr s, Args... ts)
116 : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
117 CHECK(s->isConstant());
118 addComponent(ts...);
119 sort();
120 }
121
Term(HashProvider & hasher,ExprPtr s,std::vector<ExprPtr> v)122 Term(HashProvider& hasher, ExprPtr s, std::vector<ExprPtr> v)
123 : ExprNodeBase(promoteTypesVec(s, v)),
124 variables_(std::move(v)),
125 scalar_(std::move(s)),
126 hasher_(hasher) {
127 sort();
128 }
129
130 // Convenience constructor from a map of hash -> var, used when merging Terms.
Term(HashProvider & hasher,const ExprPtr & s,std::unordered_map<SimplifierHashType,ExprPtr> varmap)131 Term(
132 HashProvider& hasher,
133 const ExprPtr& s,
134 std::unordered_map<SimplifierHashType, ExprPtr> varmap)
135 : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
136 for (auto& p : varmap) {
137 addComponent(p.second);
138 }
139 sort();
140 }
141
scalar()142 ExprPtr scalar() const {
143 return scalar_;
144 }
variables()145 const std::vector<ExprPtr>& variables() const {
146 return variables_;
147 }
hasher()148 HashProvider& hasher() const {
149 return hasher_;
150 }
151
152 // Produce a hash of just the variable components of this term, to determine
153 // if it can be combined with another term.
154 SimplifierHashType hashVars() const;
155
156 private:
157 std::vector<ExprPtr> variables_;
158 ExprPtr scalar_;
159 HashProvider& hasher_;
160
addComponent()161 void addComponent() {}
addComponent(ExprPtr e)162 void addComponent(ExprPtr e) {
163 variables_.push_back(std::move(e));
164 }
165 template <class... Es>
addComponent(ExprPtr e,Es &&...es)166 void addComponent(ExprPtr e, Es&&... es) {
167 addComponent(std::move(e));
168 addComponent(std::forward<Es>(es)...);
169 }
170
171 // Sort by hash to normalize order of components.
172 void sort();
173 };
174
175 // Polynomial represents a grouping of Exprs by addition.
176 // E.g. sum(*variables, scalar).
177 // This would better be called Expression, but, naming conflict...
178 class Polynomial : public ExprNode<Polynomial> {
179 public:
180 template <class... Args>
Polynomial(HashProvider & hasher,ExprPtr s,Args...ts)181 Polynomial(HashProvider& hasher, ExprPtr s, Args... ts)
182 : ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
183 CHECK(s->isConstant());
184 addTerm(ts...);
185 sort();
186 }
187
Polynomial(HashProvider & hasher,const ExprPtr & s,std::vector<TermPtr> v)188 Polynomial(HashProvider& hasher, const ExprPtr& s, std::vector<TermPtr> v)
189 : ExprNodeBase(promoteTypesVec(s, v)),
190 variables_(std::move(v)),
191 scalar_(s),
192 hasher_(hasher) {
193 sort();
194 }
195
196 // Helper constructor for list of terms with no scalar component.
Polynomial(HashProvider & hasher,std::vector<TermPtr> terms)197 Polynomial(HashProvider& hasher, std::vector<TermPtr> terms)
198 : ExprNodeBase(promoteTypesVec(terms)),
199 variables_(std::move(terms)),
200 scalar_(getImmediateByType(dtype(), 0)),
201 hasher_(hasher) {
202 sort();
203 }
204
205 // Convenience constructor for map of hash -> var, used when merging
206 // Polynomials.
Polynomial(HashProvider & hasher,const ExprPtr & s,std::unordered_map<SimplifierHashType,TermPtr> varmap)207 Polynomial(
208 HashProvider& hasher,
209 const ExprPtr& s,
210 std::unordered_map<SimplifierHashType, TermPtr> varmap)
211 : ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
212 for (auto& p : varmap) {
213 addTerm(p.second);
214 }
215 sort();
216 }
217
scalar()218 ExprPtr scalar() const {
219 return scalar_;
220 }
variables()221 const std::vector<TermPtr>& variables() const {
222 return variables_;
223 }
hasher()224 HashProvider& hasher() const {
225 return hasher_;
226 }
227
228 SimplifierHashType hashVars() const;
229
230 private:
231 std::vector<TermPtr> variables_;
232 ExprPtr scalar_;
233 HashProvider& hasher_;
234
addTerm(TermPtr t)235 void addTerm(TermPtr t) {
236 variables_.push_back(std::move(t));
237 }
238 template <class... Ts>
addTerm(TermPtr t,Ts &&...ts)239 void addTerm(TermPtr t, Ts&&... ts) {
240 addTerm(std::move(t));
241 addTerm(std::forward<Ts>(ts)...);
242 }
243
244 // Sort by hash to normalize order of terms.
245 void sort();
246 };
247
248 class RoundOff : public BinaryOpNode<RoundOff> {
249 public:
RoundOff(ExprPtr lhs,ExprPtr rhs)250 RoundOff(ExprPtr lhs, ExprPtr rhs)
251 : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kOther) {}
252 };
253
254 class MaxTerm : public ExprNode<MaxTerm> {
255 public:
256 template <class... Args>
MaxTerm(HashProvider & hasher,ExprPtr s,bool p,Args...ts)257 MaxTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts)
258 : ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
259 scalar_(s),
260 hasher_(hasher),
261 propagate_nans_(p) {
262 addComponent(ts...);
263 uniquefy();
264 }
265
MaxTerm(HashProvider & hasher,const ExprPtr & s,bool p,std::vector<ExprPtr> v)266 MaxTerm(
267 HashProvider& hasher,
268 const ExprPtr& s,
269 bool p,
270 std::vector<ExprPtr> v)
271 : ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
272 variables_(std::move(v)),
273 scalar_(s),
274 hasher_(hasher),
275 propagate_nans_(p) {
276 uniquefy();
277 }
278
propagate_nans()279 bool propagate_nans() const {
280 return propagate_nans_;
281 }
282
scalar()283 ExprPtr scalar() const {
284 return scalar_;
285 }
variables()286 const std::vector<ExprPtr>& variables() const {
287 return variables_;
288 }
hasher()289 HashProvider& hasher() const {
290 return hasher_;
291 }
292
293 private:
294 std::vector<ExprPtr> variables_;
295 ExprPtr scalar_;
296 HashProvider& hasher_;
297 bool propagate_nans_;
298
addComponent()299 void addComponent() {}
addComponent(ExprPtr e)300 void addComponent(ExprPtr e) {
301 variables_.push_back(std::move(e));
302 }
303 template <class... Es>
addComponent(ExprPtr e,Es &&...es)304 void addComponent(ExprPtr e, Es&&... es) {
305 addComponent(std::move(e));
306 addComponent(std::forward<Es>(es)...);
307 }
308
309 // Uniquefy the terms using their hash.
310 void uniquefy();
311 };
312
313 class MinTerm : public ExprNode<MinTerm> {
314 public:
315 template <class... Args>
MinTerm(HashProvider & hasher,ExprPtr s,bool p,Args...ts)316 MinTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts)
317 : ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
318 scalar_(s),
319 hasher_(hasher),
320 propagate_nans_(p) {
321 addComponent(ts...);
322 uniquefy();
323 }
324
MinTerm(HashProvider & hasher,const ExprPtr & s,bool p,std::vector<ExprPtr> v)325 MinTerm(
326 HashProvider& hasher,
327 const ExprPtr& s,
328 bool p,
329 std::vector<ExprPtr> v)
330 : ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
331 variables_(std::move(v)),
332 scalar_(s),
333 hasher_(hasher),
334 propagate_nans_(p) {
335 uniquefy();
336 }
337
propagate_nans()338 bool propagate_nans() const {
339 return propagate_nans_;
340 }
341
scalar()342 ExprPtr scalar() const {
343 return scalar_;
344 }
variables()345 const std::vector<ExprPtr>& variables() const {
346 return variables_;
347 }
hasher()348 HashProvider& hasher() const {
349 return hasher_;
350 }
351
352 private:
353 std::vector<ExprPtr> variables_;
354 ExprPtr scalar_;
355 HashProvider& hasher_;
356 bool propagate_nans_;
357
addComponent()358 void addComponent() {}
addComponent(ExprPtr e)359 void addComponent(ExprPtr e) {
360 variables_.push_back(std::move(e));
361 }
362 template <class... Es>
addComponent(ExprPtr e,Es &&...es)363 void addComponent(ExprPtr e, Es&&... es) {
364 addComponent(std::move(e));
365 addComponent(std::forward<Es>(es)...);
366 }
367
368 // Uniquefy the terms using their hash.
369 void uniquefy();
370 };
371
372 // Context-sensitive IR simplification
373 using VarBoundInfo = std::unordered_map<VarPtr, analysis::Bound>;
374
375 class TORCH_API SimplifierUnderContext : public IRMutator {
376 public:
377 ~SimplifierUnderContext() override = default;
378 // Add boundary info for index variables in for-loops
379 StmtPtr mutate(const ForPtr& v) override;
380
381 ExprPtr mutate(const DivPtr& v) override;
382 ExprPtr mutate(const ModPtr& v) override;
383 ExprPtr mutate(const CompareSelectPtr& v) override;
384 ExprPtr mutate(const IfThenElsePtr& v) override;
385
386 protected:
387 bool getLoopBoundInfo(const ExprPtr& expr, analysis::Bound* loop_bound_info);
388
389 protected:
390 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
391 HashProvider hasher_;
392 VarBoundInfo var_bound_info_;
393 };
394
395 // Stmt simplification should occur in both modes.
396 class TORCH_API PolynomialBase : public IRMutator {
397 public:
398 ~PolynomialBase() override = default;
399
400 StmtPtr mutate(const BlockPtr& v) override;
401
402 StmtPtr mutate(const CondPtr& v) override;
403
404 StmtPtr mutate(const ForPtr& v) override;
405
406 // Trivially factorize terms by GCD of scalar components.
407 TermPtr factorizePolynomial(const PolynomialPtr& poly);
408
hasher()409 HashProvider& hasher() {
410 return hasher_;
411 }
412
413 protected:
414 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
415 HashProvider hasher_;
416 };
417
418 // Simplify the IR by combining arithmetic expressions over common terms.
419 class TORCH_API PolynomialTransformer : public PolynomialBase {
420 public:
421 using PolynomialBase::mutate;
422 // Inserts term into the provided map, in the case of a hash collision
423 // combines the term with the existing and updates the map.
424 void addOrUpdateTerm(
425 std::unordered_map<SimplifierHashType, TermPtr>& varmap,
426 const TermPtr& term);
427
428 // Add Polynomial expressions, combining Terms representing the same
429 // variables.
430 ExprPtr addPolynomials(const PolynomialPtr& lhs, const PolynomialPtr& rhs);
431
432 // Insert a new Term into the provided polynomial. If the new term has
433 // common variables to an existing term it is combined.
434 ExprPtr insertTerm(const PolynomialPtr& poly, const TermPtr& term);
435
436 // Merge and simplify addition.
437 ExprPtr mutate(const AddPtr& v) override;
438
439 // Subtract one term from another, cancelling if necessary.
440 ExprPtr subTerms(const TermPtr& lhs, TermPtr rhs, bool negated);
441
442 // Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
443 // possible.
444 ExprPtr subPolynomials(const PolynomialPtr& lhs, const PolynomialPtr& rhs);
445
446 // Merge and simplify subtraction.
447 ExprPtr mutate(const SubPtr& v) override;
448
449 // Multiply two terms together, usually creating a new term with the variable
450 // lists concatenated.
451 TermPtr mulTerms(const TermPtr& lhs, const TermPtr& rhs);
452
453 // Multiply a Polynomial by a Term.
454 ExprPtr polyByTerm(const PolynomialPtr& poly, const TermPtr& term);
455
456 // Match a rounding pattern and create a RoundOff if found.
457 ExprPtr isRoundOff(const ExprPtr& lhs, const ExprPtr& rhs);
458
459 // Inserts a new component into a term, simplifying if possible.
460 ExprPtr insertIntoTerm(const TermPtr& term, const ExprPtr& expr);
461
462 // Merge and simplify multiplication.
463 ExprPtr mutate(const MulPtr& v) override;
464
465 ExprPtr mutate(const DivPtr& v) override;
466
467 ExprPtr mutate(const ModPtr& v) override;
468
469 ExprPtr mutate(const AndPtr& v) override;
470
471 ExprPtr mutate(const XorPtr& v) override;
472
473 ExprPtr mutate(const LshiftPtr& v) override;
474
475 ExprPtr mutate(const RshiftPtr& v) override;
476
477 ExprPtr mutate(const MaxPtr& v) override;
478
479 ExprPtr mutate(const MinPtr& v) override;
480
481 ExprPtr mutate(const CompareSelectPtr& v) override;
482
483 ExprPtr mutate(const IntrinsicsPtr& v) override;
484
485 ExprPtr mutate(const CastPtr& v) override;
486
487 ExprPtr mutate(const IfThenElsePtr& v) override;
488
489 static ExprPtr simplify(ExprPtr e);
490 static ExprHandle simplify(const ExprHandle& e);
491 static StmtPtr simplify(StmtPtr e);
492 };
493
494 // Expands Terms and Polynomial expressions into primitive operations.
495 // Does some simple factorization and reordering.
496 class TORCH_API TermExpander : public PolynomialBase {
497 PolynomialTransformer* simplifier_;
498 std::set<VarPtr> eliminated_allocations_;
499
500 public:
501 using PolynomialBase::mutate;
TermExpander(PolynomialTransformer * simplifier)502 TermExpander(PolynomialTransformer* simplifier) : simplifier_(simplifier) {}
check_safe()503 bool check_safe() {
504 return eliminated_allocations_.empty();
505 }
506
507 // Expand Terms out to a series of Muls.
508 ExprPtr mutate(const TermPtr& v) override;
509
510 // Expand Polynomials out to a series of Adds.
511 ExprPtr mutate(const PolynomialPtr& v) override;
512
513 // Expand MaxTerms to a series of Max ops.
514 ExprPtr mutate(const MaxTermPtr& v) override;
515
516 // Expand MinTerms to a series of Min ops.
517 ExprPtr mutate(const MinTermPtr& v) override;
518
519 // Expand RoundOff to it's component: Mul(Div(lhs, rhs), rhs).
520 ExprPtr mutate(const RoundOffPtr& v) override;
521
522 // Eliminate zero length allocations.
523 StmtPtr mutate(const AllocatePtr& v) override;
524 StmtPtr mutate(const FreePtr& v) override;
525
526 // Override to enable condition fusing.
527 BlockPtr fuseConditions(BlockPtr v);
528 StmtPtr fuseSyncThreads(BlockPtr block);
529 StmtPtr mutate(const BlockPtr& v) override;
530 };
531
532 class TORCH_API IRSimplifier {
533 public:
534 static StmtPtr simplify(StmtPtr s);
535 static ExprPtr simplify(ExprPtr e);
simplify(const ExprHandle & e)536 static ExprHandle simplify(const ExprHandle& e) {
537 return ExprHandle(simplify(e.node()));
538 }
539 };
540
541 // Flattens the buf and performs the simplifier on the flattened dims.
542 ExprPtr buf_flat_size(const BufPtr& v);
543 // Returns true if expressions A and B can be simplified to an equal expression.
544 TORCH_API bool exprEquals(const ExprPtr& A, const ExprPtr& B);
545
546 } // namespace torch::jit::tensorexpr
547