xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/stmt.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <algorithm>
4 #include <list>
5 #include <string>
6 #include <unordered_set>
7 #include <utility>
8 #include <vector>
9 
10 #include <torch/csrc/jit/tensorexpr/expr.h>
11 
12 namespace torch::jit::tensorexpr {
13 
14 // The common base between all statement node.
15 class TORCH_API Stmt : public std::enable_shared_from_this<Stmt> {
16  public:
17   Stmt() = default;
18   virtual ~Stmt() = default;
19   virtual void accept(IRVisitor* visitor) = 0;
20   virtual StmtPtr accept_mutator(IRMutator* mutator) = 0;
21 
get_parent()22   StmtPtr get_parent() const {
23     return parent_ ? parent_->getptr() : nullptr;
24   }
25 
26   /*
27    * Make a deep copy of the given statement.
28    *
29    * All statements and expressions used in children of the statement are
30    * cloned. Note that the variables are not deep-copied since they are
31    * immutable.
32    */
33   static StmtPtr clone(const StmtPtr& s);
34 
35  protected:
set_parent(const StmtPtr & s,Stmt * new_parent)36   static void set_parent(const StmtPtr& s, Stmt* new_parent) {
37     s->parent_ = new_parent;
38   }
getptr()39   std::shared_ptr<Stmt> getptr() {
40     return shared_from_this();
41   }
42 
43  private:
44   Stmt* parent_ = nullptr;
45 };
46 
47 template <class Op>
48 class StmtNode : public Stmt {
49  public:
50   using StmtNodeBase = StmtNode<Op>;
accept(IRVisitor * visitor)51   void accept(IRVisitor* visitor) override {
52     visitor->visit(static_to<Op>(getptr()));
53   }
54   StmtPtr accept_mutator(IRMutator* mutator) override;
55   StmtNode() = default;
56 };
57 
58 template <class Op>
accept_mutator(IRMutator * mutator)59 StmtPtr StmtNode<Op>::accept_mutator(IRMutator* mutator) {
60   return mutator->mutate(static_to<Op>(getptr()));
61 }
62 
63 // Concrete Stmt classes
64 class TORCH_API Block : public StmtNode<Block> {
65  public:
make(const std::vector<StmtPtr> & stmts)66   static BlockPtr make(const std::vector<StmtPtr>& stmts) {
67     std::vector<StmtPtr> valid_stmts;
68     for (auto& stmt : stmts) {
69       if (!stmt) {
70         continue;
71       }
72       valid_stmts.push_back(stmt);
73     }
74     if (valid_stmts.empty()) {
75       return nullptr;
76     }
77     return alloc<Block>(valid_stmts);
78   }
79 
nstmts()80   size_t nstmts() const {
81     return stmts_.size();
82   }
empty()83   bool empty() const {
84     return stmts_.empty();
85   }
86 
prepend_stmt(StmtPtr s)87   void prepend_stmt(StmtPtr s) {
88     if (s->get_parent()) {
89       throw malformed_input(
90           "Block prepend Stmt with existing parent", std::move(s));
91     }
92 
93     stmts_.push_front(s);
94     set_parent(s, this);
95   }
append_stmt(StmtPtr s)96   void append_stmt(StmtPtr s) {
97     if (s->get_parent()) {
98       throw malformed_input(
99           "Block append Stmt with existing parent", std::move(s));
100     }
101 
102     stmts_.push_back(s);
103     set_parent(s, this);
104   }
105 
insert_stmt_before(StmtPtr s,const StmtPtr & before)106   void insert_stmt_before(StmtPtr s, const StmtPtr& before) {
107     if (s->get_parent()) {
108       throw malformed_input(
109           "Block append Stmt with existing parent", std::move(s));
110     }
111 
112     auto pos = std::find(stmts_.begin(), stmts_.end(), before);
113     if (pos == stmts_.end()) {
114       throw malformed_input(
115           "Inserting after statement that is not in block", std::move(s));
116     }
117 
118     stmts_.insert(pos, s);
119     set_parent(s, this);
120   }
121 
insert_stmt_after(StmtPtr s,const StmtPtr & after)122   void insert_stmt_after(StmtPtr s, const StmtPtr& after) {
123     if (s->get_parent()) {
124       throw malformed_input(
125           "Block append Stmt with existing parent", std::move(s));
126     }
127 
128     auto pos = std::find(stmts_.begin(), stmts_.end(), after);
129     if (pos == stmts_.end()) {
130       throw malformed_input(
131           "Inserting after statement that is not in block", std::move(s));
132     }
133 
134     ++pos;
135 
136     stmts_.insert(pos, s);
137     set_parent(s, this);
138   }
139 
replace_stmt(const StmtPtr & old_stmt,StmtPtr new_stmt)140   bool replace_stmt(const StmtPtr& old_stmt, StmtPtr new_stmt) {
141     if (new_stmt->get_parent()) {
142       throw malformed_input(
143           "Block replace Stmt with existing parent", std::move(new_stmt));
144     }
145 
146     auto pos = std::find(stmts_.begin(), stmts_.end(), old_stmt);
147     if (pos == stmts_.end()) {
148       return false;
149     }
150     stmts_.insert(pos, new_stmt);
151     stmts_.erase(pos);
152     set_parent(old_stmt, nullptr);
153     set_parent(new_stmt, this);
154     return true;
155   }
156 
157   // Creates a new block by cloning `this` block and replacing the given
158   // statement with a new statement. Note that `old_stmt` refers to a statement
159   // in `this` block. If the `old_stmt` is not found, it will return `nullptr`.
clone_and_replace(const StmtPtr & old_stmt,StmtPtr new_stmt)160   BlockPtr clone_and_replace(const StmtPtr& old_stmt, StmtPtr new_stmt) {
161     if (new_stmt->get_parent()) {
162       throw malformed_input(
163           "Block replace Stmt with existing parent", std::move(new_stmt));
164     }
165 
166     std::vector<StmtPtr> stmts(stmts_.begin(), stmts_.end());
167     std::vector<StmtPtr> cloned_stmts(stmts.size());
168     bool found = false;
169     for (int i = 0; i < static_cast<int>(stmts.size()); ++i) {
170       if (stmts[i] == old_stmt) {
171         found = true;
172         cloned_stmts[i] = new_stmt;
173       } else {
174         cloned_stmts[i] = Stmt::clone(stmts[i]);
175       }
176     }
177     if (!found) {
178       return nullptr;
179     }
180     return alloc<Block>(cloned_stmts);
181   }
182 
remove_stmt(const StmtPtr & stmt)183   bool remove_stmt(const StmtPtr& stmt) {
184     auto pos = std::find(stmts_.begin(), stmts_.end(), stmt);
185     if (pos == stmts_.end()) {
186       return false;
187     }
188 
189     set_parent(stmt, nullptr);
190     stmts_.erase(pos);
191     return true;
192   }
193 
stmts()194   std::list<StmtPtr> stmts() const {
195     return stmts_;
196   }
197 
clear()198   void clear() {
199     for (const auto& s : stmts_) {
200       set_parent(s, nullptr);
201     }
202     stmts_.clear();
203   }
204 
set_stmts(const std::vector<StmtPtr> & stmts)205   void set_stmts(const std::vector<StmtPtr>& stmts) {
206     clear();
207     init(stmts);
208   }
209 
Block(const std::vector<StmtPtr> & stmts)210   explicit Block(const std::vector<StmtPtr>& stmts) {
211     init(stmts);
212   }
213 
214   typedef std::list<StmtPtr>::iterator iterator;
215   typedef std::list<StmtPtr>::const_iterator const_iterator;
216 
begin()217   iterator begin() {
218     return stmts_.begin();
219   }
220 
begin()221   const_iterator begin() const {
222     return stmts_.begin();
223   }
224 
end()225   iterator end() {
226     return stmts_.end();
227   }
228 
end()229   const_iterator end() const {
230     return stmts_.end();
231   }
232 
front()233   StmtPtr front() {
234     return stmts_.front();
235   }
236 
front()237   StmtPtr front() const {
238     return stmts_.front();
239   }
240 
back()241   StmtPtr back() {
242     return stmts_.back();
243   }
244 
back()245   StmtPtr back() const {
246     return stmts_.back();
247   }
248 
splice(Block::iterator it,const BlockPtr & other)249   void splice(Block::iterator it, const BlockPtr& other) {
250     for (const StmtPtr& s : *other) {
251       set_parent(s, this);
252     }
253 
254     stmts_.splice(it, other->stmts_);
255   }
256 
getSharedParent(StmtPtr p1,StmtPtr p2)257   static BlockPtr getSharedParent(StmtPtr p1, StmtPtr p2) {
258     std::unordered_set<BlockPtr> enclosing;
259 
260     StmtPtr p1_p = std::move(p1);
261     while (p1_p) {
262       if (BlockPtr b = to<Block>(p1_p)) {
263         if (b) {
264           enclosing.insert(b);
265         }
266       }
267       p1_p = p1_p->get_parent();
268     }
269 
270     StmtPtr p2_p = std::move(p2);
271     while (p2_p) {
272       if (BlockPtr b = to<Block>(p2_p)) {
273         if (enclosing.count(b) != 0) {
274           return b;
275         }
276       }
277       p2_p = p2_p->get_parent();
278     }
279 
280     return nullptr;
281   }
282 
283   // returns the immediate child containing statement s.
getEnclosedRoot(StmtPtr s)284   StmtPtr getEnclosedRoot(StmtPtr s) const {
285     while (s && s->get_parent().get() != this) {
286       s = s->get_parent();
287     }
288     return s;
289   }
290 
291  private:
292   std::list<StmtPtr> stmts_;
293 
init(const std::vector<StmtPtr> & stmts)294   void init(const std::vector<StmtPtr>& stmts) {
295     for (const StmtPtr& s : stmts) {
296       if (!s) {
297         continue;
298       }
299       if (!s->get_parent()) {
300         // If we get here, it's a bug, but we cannot throw an error from a
301         // constructor. But IR verifier would catch this.
302         set_parent(s, this);
303       }
304 
305       stmts_.push_back(s);
306     }
307   }
308 };
309 
310 class TORCH_API Store : public StmtNode<Store> {
311  public:
base_handle()312   VarPtr base_handle() const {
313     return buf_->base_handle();
314   }
indices()315   std::vector<ExprPtr> indices() const {
316     return indices_;
317   }
flat_index()318   ExprPtr flat_index() const {
319     TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
320     return indices_[0];
321   }
value()322   ExprPtr value() const {
323     return value_;
324   }
buf()325   BufPtr buf() const {
326     return buf_;
327   }
328 
set_buf(BufPtr buf)329   void set_buf(BufPtr buf) {
330     buf_ = std::move(buf);
331   }
332 
set_indices(std::vector<ExprPtr> indices)333   void set_indices(std::vector<ExprPtr> indices) {
334     indices_ = std::move(indices);
335   }
336 
set_value(ExprPtr value)337   void set_value(ExprPtr value) {
338     value_ = std::move(value);
339   }
340 
341   static StorePtr make(
342       const BufHandle& buf,
343       const std::vector<ExprHandle>& indices,
344       const ExprHandle& value);
345 
346   Store(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value);
347 
348  private:
349   BufPtr buf_;
350   std::vector<ExprPtr> indices_;
351   ExprPtr value_;
352 };
353 
354 // Allocate a buffer of given shapes and dtypes and bind it with the given
355 // buffer var. The life span is at most through the current program, until it is
356 // explicitly freed. An unfreed memory is likely considered an error.
357 class TORCH_API Allocate : public StmtNode<Allocate> {
358  public:
make(const BufHandle & buf_handle)359   static AllocatePtr make(const BufHandle& buf_handle) {
360     return alloc<Allocate>(buf_handle.node());
361   }
362 
buffer_var()363   VarPtr buffer_var() const {
364     return buf_->base_handle();
365   }
366 
dtype()367   Dtype dtype() const {
368     return buf_->dtype();
369   }
370 
dims()371   const std::vector<ExprPtr> dims() const {
372     return buf_->dims();
373   }
374 
buf()375   BufPtr buf() const {
376     return buf_;
377   }
378 
set_buf(BufPtr buf)379   void set_buf(BufPtr buf) {
380     buf_ = std::move(buf);
381   }
382 
Allocate(BufPtr buf)383   explicit Allocate(BufPtr buf) : buf_(std::move(buf)) {}
384 
385  private:
386   BufPtr buf_;
387   // TODO: add memory types.
388 };
389 
390 // PlacementAllocate is a variation of the Allocate operator in NNC IR. It does
391 // not allocate memory but reuse the memory of another buffer for the given
392 // buffer.
393 class TORCH_API PlacementAllocate : public StmtNode<PlacementAllocate> {
394  public:
make(const BufHandle & buf_handle,const BufHandle & buf_handle_to_reuse)395   static PlacementAllocatePtr make(
396       const BufHandle& buf_handle,
397       const BufHandle& buf_handle_to_reuse) {
398     return alloc<PlacementAllocate>(
399         buf_handle.node(), buf_handle_to_reuse.node());
400   }
401 
buf()402   BufPtr buf() const {
403     return buf_;
404   }
405 
buf_to_reuse()406   BufPtr buf_to_reuse() const {
407     return buf_to_reuse_;
408   }
409 
set_buf(BufPtr buf)410   void set_buf(BufPtr buf) {
411     buf_ = std::move(buf);
412   }
413 
set_buf_to_reuse(BufPtr buf)414   void set_buf_to_reuse(BufPtr buf) {
415     buf_to_reuse_ = std::move(buf);
416   }
417 
PlacementAllocate(BufPtr buf,BufPtr buf_to_reuse)418   explicit PlacementAllocate(BufPtr buf, BufPtr buf_to_reuse)
419       : buf_(std::move(buf)), buf_to_reuse_(std::move(buf_to_reuse)) {}
420 
421  private:
422   BufPtr buf_;
423   BufPtr buf_to_reuse_;
424 };
425 
426 // Free the specific buffer. It is an error.
427 class TORCH_API Free : public StmtNode<Free> {
428  public:
make(const BufHandle & buf_handle)429   static FreePtr make(const BufHandle& buf_handle) {
430     return alloc<Free>(buf_handle.node());
431   }
432 
buffer_var()433   VarPtr buffer_var() const {
434     return buf_->base_handle();
435   }
436 
buf()437   BufPtr buf() const {
438     return buf_;
439   }
440 
set_buf(BufPtr buf)441   void set_buf(BufPtr buf) {
442     buf_ = std::move(buf);
443   }
444 
Free(BufPtr buf)445   explicit Free(BufPtr buf) : buf_(std::move(buf)) {}
446 
447  private:
448   BufPtr buf_;
449 };
450 
451 class TORCH_API FreeExt : public StmtNode<FreeExt> {
452  public:
453   static FreeExtPtr make(const std::vector<BufHandle>& bufs);
454 
bufs()455   std::vector<BufPtr> bufs() const {
456     return bufs_;
457   }
458 
set_bufs(std::vector<BufPtr> bufs)459   void set_bufs(std::vector<BufPtr> bufs) {
460     bufs_ = std::move(bufs);
461   }
462 
FreeExt(std::vector<BufPtr> bufs)463   explicit FreeExt(std::vector<BufPtr> bufs) : bufs_(std::move(bufs)) {}
464 
465  private:
466   std::vector<BufPtr> bufs_;
467 };
468 
469 class TORCH_API Let : public StmtNode<Let> {
470  public:
make(const VarHandle & var,const ExprHandle & val)471   static LetPtr make(const VarHandle& var, const ExprHandle& val) {
472     return alloc<Let>(var.node(), val.node());
473   }
474 
Let(VarPtr var,ExprPtr val)475   Let(VarPtr var, ExprPtr val) : var_(std::move(var)), val_(std::move(val)) {}
476 
var()477   VarPtr var() const {
478     return var_;
479   }
480 
value()481   ExprPtr value() const {
482     return val_;
483   }
484 
set_var(VarPtr var)485   void set_var(VarPtr var) {
486     var_ = std::move(var);
487   }
488 
set_val(ExprPtr val)489   void set_val(ExprPtr val) {
490     val_ = std::move(val);
491   }
492 
493  private:
494   VarPtr var_;
495   ExprPtr val_;
496 };
497 
498 class TORCH_API Cond : public StmtNode<Cond> {
499  public:
make(const ExprHandle & condition,const StmtPtr & true_stmt,const StmtPtr & false_stmt)500   static CondPtr make(
501       const ExprHandle& condition,
502       const StmtPtr& true_stmt,
503       const StmtPtr& false_stmt) {
504     return alloc<Cond>(condition.node(), true_stmt, false_stmt);
505   }
506 
condition()507   ExprPtr condition() const {
508     return condition_;
509   }
510 
true_stmt()511   BlockPtr true_stmt() const {
512     return true_stmt_;
513   }
514 
false_stmt()515   BlockPtr false_stmt() const {
516     return false_stmt_;
517   }
518 
set_condition(ExprPtr condition)519   void set_condition(ExprPtr condition) {
520     condition_ = std::move(condition);
521   }
522 
set_true_stmt(StmtPtr true_stmt)523   void set_true_stmt(StmtPtr true_stmt) {
524     if (true_stmt) {
525       BlockPtr b = to<Block>(true_stmt);
526       if (!b) {
527         b = alloc<Block>(std::vector<StmtPtr>({std::move(true_stmt)}));
528       }
529       true_stmt_ = b;
530       set_parent(true_stmt_, this);
531     }
532   }
533 
set_false_stmt(StmtPtr false_stmt)534   void set_false_stmt(StmtPtr false_stmt) {
535     if (false_stmt) {
536       BlockPtr b = to<Block>(false_stmt);
537       if (!b) {
538         b = alloc<Block>(std::vector<StmtPtr>({std::move(false_stmt)}));
539       }
540       false_stmt_ = b;
541       set_parent(false_stmt_, this);
542     }
543   }
544 
Cond(ExprPtr condition,StmtPtr true_stmt,StmtPtr false_stmt)545   Cond(ExprPtr condition, StmtPtr true_stmt, StmtPtr false_stmt)
546       : condition_(std::move(condition)) {
547     set_true_stmt(std::move(true_stmt));
548     set_false_stmt(std::move(false_stmt));
549   }
550 
cloneWithNewBodies(const StmtPtr & true_stmt,const StmtPtr & false_stmt)551   CondPtr cloneWithNewBodies(
552       const StmtPtr& true_stmt,
553       const StmtPtr& false_stmt) {
554     return alloc<Cond>(condition_, true_stmt, false_stmt);
555   }
556 
cloneWithNewBody(const StmtPtr & true_stmt)557   CondPtr cloneWithNewBody(const StmtPtr& true_stmt) {
558     return alloc<Cond>(condition_, true_stmt, nullptr);
559   }
560 
561  private:
562   ExprPtr condition_;
563   BlockPtr true_stmt_ = nullptr;
564   BlockPtr false_stmt_ = nullptr;
565 };
566 
567 class TORCH_API LoopOptions {
568  public:
569   enum {
570     IDX_UNSET = -1,
571     IDX_X = 0,
572     IDX_Y = 1,
573     IDX_Z = 2,
574     IDX_W = 3,
575     IDX_MAX = IDX_W,
576   };
577   // GPU Block Index
is_gpu_block_index()578   bool is_gpu_block_index() const {
579     return gpu_block_index_ != IDX_UNSET;
580   }
581 
gpu_block_index()582   int gpu_block_index() const {
583     return gpu_block_index_;
584   }
585 
gpu_block_index_str()586   std::string gpu_block_index_str() const {
587     if (!is_gpu_block_index()) {
588       throw malformed_input("Has no GPU block index");
589     }
590 
591     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
592     static const char* kBlockIndexNames[] = {
593         "blockIdx.x",
594         "blockIdx.y",
595         "blockIdx.z",
596         "blockIdx.w",
597     };
598 
599     if (gpu_block_index_ < IDX_X || gpu_block_index_ > IDX_MAX) {
600       throw malformed_input("invalid GPU block index");
601     }
602 
603     return kBlockIndexNames[gpu_block_index_];
604   }
605 
set_gpu_block_index(int index)606   void set_gpu_block_index(int index) {
607     if (index == IDX_UNSET) {
608       gpu_block_index_ = IDX_UNSET;
609     }
610 
611     if (is_gpu_thread_index()) {
612       throw std::runtime_error("Cannot set both gpu block and thread index");
613     }
614     if (is_gpu_block_index() && gpu_block_index() != index) {
615       throw std::runtime_error("Cannot set a previously set block index");
616     }
617     gpu_block_index_ = index;
618   }
619 
620   // GPU Thread Index
is_gpu_thread_index()621   bool is_gpu_thread_index() const {
622     return gpu_thread_index() != IDX_UNSET;
623   }
624 
gpu_thread_index()625   int gpu_thread_index() const {
626     return gpu_thread_index_;
627   }
628 
gpu_thread_index_str()629   std::string gpu_thread_index_str() const {
630     if (!is_gpu_thread_index()) {
631       throw malformed_input("has no GPU thread index");
632     }
633 
634     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
635     static const char* kThreadIndexNames[] = {
636         "threadIdx.x", "threadIdx.y", "threadIdx.z", "threadIdx.w"};
637 
638     if (gpu_thread_index_ < IDX_X || gpu_thread_index_ > IDX_MAX) {
639       throw malformed_input("invalid GPU thread index");
640     }
641 
642     return kThreadIndexNames[gpu_thread_index_];
643   }
644 
set_gpu_thread_index(int index)645   void set_gpu_thread_index(int index) {
646     if (index == IDX_UNSET) {
647       gpu_thread_index_ = IDX_UNSET;
648     }
649 
650     if (is_gpu_block_index()) {
651       throw std::runtime_error("Cannot set both gpu thread and block index");
652     }
653     if (is_gpu_thread_index() && gpu_thread_index() != index) {
654       throw std::runtime_error("Cannot set a previously set thread index");
655     }
656     gpu_thread_index_ = index;
657   }
658 
set_parallel()659   void set_parallel() {
660     is_parallel_ = true;
661   }
662 
is_parallel()663   bool is_parallel() const {
664     return is_parallel_;
665   }
666 
ToString()667   std::string ToString() const {
668     if (is_gpu_block_index()) {
669       return gpu_block_index_str();
670     } else if (is_gpu_thread_index()) {
671       return gpu_thread_index_str();
672     } else if (is_parallel()) {
673       return "parallel";
674     }
675     return "";
676   }
677 
isDefault()678   bool isDefault() const {
679     return gpu_block_index_ == IDX_UNSET && gpu_thread_index_ == IDX_UNSET &&
680         !is_parallel_;
681   }
682 
set_buffer_mapping(const std::unordered_map<std::string,BufPtr> & map)683   void set_buffer_mapping(const std::unordered_map<std::string, BufPtr>& map) {
684     map_input_to_tensor_bufs_ = map;
685   }
686 
get_buffer_mapping()687   std::unordered_map<std::string, BufPtr> get_buffer_mapping() const {
688     return map_input_to_tensor_bufs_;
689   }
690 
691  private:
692   int gpu_block_index_{IDX_UNSET};
693   int gpu_thread_index_{IDX_UNSET};
694   bool is_parallel_{false};
695   std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
696 };
697 
698 class TORCH_API For : public StmtNode<For> {
699  public:
var()700   VarPtr var() const {
701     return var_;
702   }
start()703   ExprPtr start() const {
704     return start_;
705   }
stop()706   ExprPtr stop() const {
707     return stop_;
708   }
body()709   BlockPtr body() const {
710     return body_;
711   }
make(const VarHandle & var,const ExprHandle & start,const ExprHandle & stop,const StmtPtr & body)712   static ForPtr make(
713       const VarHandle& var,
714       const ExprHandle& start,
715       const ExprHandle& stop,
716       const StmtPtr& body) {
717     if (!body) {
718       return nullptr;
719     }
720     return alloc<For>(var.node(), start.node(), stop.node(), body);
721   }
make(const VarHandle & var,const ExprHandle & start,const ExprHandle & stop,const StmtPtr & body,const LoopOptions & loop_options)722   static ForPtr make(
723       const VarHandle& var,
724       const ExprHandle& start,
725       const ExprHandle& stop,
726       const StmtPtr& body,
727       const LoopOptions& loop_options) {
728     if (!body) {
729       return nullptr;
730     }
731     return alloc<For>(
732         var.node(), start.node(), stop.node(), body, loop_options);
733   }
loop_options()734   const LoopOptions loop_options() const {
735     return loop_options_;
736   }
737 
For(VarPtr var,ExprPtr start,ExprPtr stop,StmtPtr body)738   For(VarPtr var, ExprPtr start, ExprPtr stop, StmtPtr body)
739       : var_(std::move(var)), start_(std::move(start)), stop_(std::move(stop)) {
740     BlockPtr b = to<Block>(body);
741     if (!b) {
742       b = alloc<Block>(std::vector<StmtPtr>({std::move(body)}));
743     }
744     body_ = b;
745     set_parent(body_, this);
746   }
747 
For(VarPtr var,ExprPtr start,ExprPtr stop,StmtPtr body,LoopOptions loop_options)748   For(VarPtr var,
749       ExprPtr start,
750       ExprPtr stop,
751       StmtPtr body,
752       LoopOptions loop_options)
753       : var_(std::move(var)),
754         start_(std::move(start)),
755         stop_(std::move(stop)),
756         loop_options_(std::move(loop_options)) {
757     if (!var_) {
758       throw malformed_input("invalid Var in For loop");
759     } else if (!start_) {
760       throw malformed_input("invalid Start in For loop");
761     } else if (!stop_) {
762       throw malformed_input("invalid Stop in For loop");
763     } else if (!body || body->get_parent()) {
764       throw malformed_input("invalid Body in For loop");
765     }
766 
767     BlockPtr b = to<Block>(body);
768     if (!b) {
769       b = alloc<Block>(std::vector<StmtPtr>({std::move(body)}));
770     }
771     body_ = b;
772     set_parent(body_, this);
773   }
774 
set_gpu_block_index(int block_index)775   void set_gpu_block_index(int block_index) {
776     loop_options_.set_gpu_block_index(block_index);
777   }
778 
set_gpu_thread_index(int thread_index)779   void set_gpu_thread_index(int thread_index) {
780     loop_options_.set_gpu_thread_index(thread_index);
781   }
782 
set_parallel()783   void set_parallel() {
784     loop_options_.set_parallel();
785   }
786 
is_parallel()787   bool is_parallel() const {
788     return loop_options_.is_parallel();
789   }
790 
set_buffer_map(const std::unordered_map<std::string,BufPtr> & map)791   void set_buffer_map(const std::unordered_map<std::string, BufPtr>& map) {
792     loop_options_.set_buffer_mapping(map);
793   }
794 
cloneWithNewBody(const StmtPtr & body)795   ForPtr cloneWithNewBody(const StmtPtr& body) const {
796     return alloc<For>(var_, start_, stop_, body, loop_options_);
797   }
798 
removeBody()799   BlockPtr removeBody() {
800     auto res = body_;
801     set_parent(res, nullptr);
802     body_ = nullptr;
803     return res;
804   }
805 
set_body(StmtPtr body)806   void set_body(StmtPtr body) {
807     BlockPtr b = to<Block>(body);
808     if (!b) {
809       b = alloc<Block>(std::vector<StmtPtr>({std::move(body)}));
810     }
811     body_ = b;
812     set_parent(body_, this);
813   }
814 
set_start(ExprPtr start)815   void set_start(ExprPtr start) {
816     start_ = std::move(start);
817   }
818 
set_stop(ExprPtr stop)819   void set_stop(ExprPtr stop) {
820     stop_ = std::move(stop);
821   }
822 
set_var(VarPtr var)823   void set_var(VarPtr var) {
824     var_ = std::move(var);
825   }
826 
827  private:
828   VarPtr var_;
829   ExprPtr start_;
830   ExprPtr stop_;
831   BlockPtr body_;
832   LoopOptions loop_options_;
833 };
834 
835 // A backend specific IR Node that implements atomic-add.
836 // This node could only shows up as an internal with GPU backends.
837 // TODO: move to this an internal IR.
838 // TODO: make IR nodes extensible.
839 class TORCH_API AtomicAdd : public StmtNode<AtomicAdd> {
840  public:
AtomicAdd(BufPtr buf,std::vector<ExprPtr> indices,ExprPtr value)841   AtomicAdd(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value)
842       : buf_(std::move(buf)),
843         indices_(std::move(indices)),
844         value_(std::move(value)) {}
845 
base_handle()846   VarPtr base_handle() const {
847     return buf_->base_handle();
848   }
849 
buf()850   BufPtr buf() const {
851     return buf_;
852   }
853 
flat_index()854   ExprPtr flat_index() const {
855     TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
856     return indices_[0];
857   }
858 
value()859   ExprPtr value() const {
860     return value_;
861   }
862 
indices()863   const std::vector<ExprPtr>& indices() const {
864     return indices_;
865   }
866 
set_buf(BufPtr buf)867   void set_buf(BufPtr buf) {
868     buf_ = std::move(buf);
869   }
870 
set_indices(std::vector<ExprPtr> indices)871   void set_indices(std::vector<ExprPtr> indices) {
872     indices_ = std::move(indices);
873   }
874 
set_value(ExprPtr value)875   void set_value(ExprPtr value) {
876     value_ = std::move(value);
877   }
878 
879  private:
880   BufPtr buf_;
881   std::vector<ExprPtr> indices_;
882   ExprPtr value_;
883 };
884 
885 class TORCH_API SyncThreads : public StmtNode<SyncThreads> {
886  public:
887   SyncThreads() = default;
888 };
889 
890 /*
891  * ExternalCall statement represents a call to an external function that would
892  * compute the contents of the output buffer. An ExternalCall statement consists
893  * of:
894  *   1) output buffer - the buffer that'll be initialized by the call
895  *   2) external function name - a key from the NNC function registry to lookup
896  *      the actual function to call
897  *   3) buffer arguments - the input buffers used by the function
898  *   4) non-buffer arguments - scalar arguments to pass to the function
899  *
900  * An example:
901  *   A = nnc_conv2d(buf_args={Input, Weight, Bias}, args={1})
902  * Here 'A' is the output buffer, "nnc_conv2d" is the function name, the buffer
903  * arguments are 'Input', 'Weight', and 'Bias', and there is a single non-buffer
904  * argument - 1.
905  *
906  * The semantics of the scalar arguments is defined solely by the implementation
907  * of the external function.
908  */
909 class TORCH_API ExternalCall : public StmtNode<ExternalCall> {
910  public:
911   static ExternalCallPtr make(
912       BufHandle buf,
913       const std::string& func_name,
914       const std::vector<BufHandle>& buf_args,
915       const std::vector<ExprHandle>& args);
916 
buf()917   BufPtr buf() const {
918     return buf_;
919   }
920 
func_name()921   std::string func_name() const {
922     return func_name_;
923   }
924 
buf_args()925   std::vector<BufPtr> buf_args() const {
926     return buf_args_;
927   }
928 
args()929   std::vector<ExprPtr> args() const {
930     return args_;
931   }
932 
set_buf(BufPtr buf)933   void set_buf(BufPtr buf) {
934     buf_ = std::move(buf);
935   }
936 
set_buf_args(std::vector<BufPtr> buf_args)937   void set_buf_args(std::vector<BufPtr> buf_args) {
938     buf_args_ = std::move(buf_args);
939   }
940 
set_args(std::vector<ExprPtr> args)941   void set_args(std::vector<ExprPtr> args) {
942     args_ = std::move(args);
943   }
944 
ExternalCall(BufPtr buf,std::string func_name,std::vector<BufPtr> buf_args,std::vector<ExprPtr> args)945   ExternalCall(
946       BufPtr buf,
947       std::string func_name,
948       std::vector<BufPtr> buf_args,
949       std::vector<ExprPtr> args)
950       : buf_(std::move(buf)),
951         func_name_(std::move(func_name)),
952         buf_args_(std::move(buf_args)),
953         args_(std::move(args)) {}
954 
955  private:
956   BufPtr buf_;
957   std::string func_name_;
958   std::vector<BufPtr> buf_args_;
959   std::vector<ExprPtr> args_;
960 };
961 
962 class TORCH_API ExternalCallWithAlloc : public StmtNode<ExternalCallWithAlloc> {
963  public:
964   static ExternalCallWithAllocPtr make(
965       const std::string& func_name,
966       const std::vector<BufHandle>& buf_out_args,
967       const std::vector<BufHandle>& buf_args,
968       const std::vector<ExprHandle>& args);
969 
buf_out_args()970   std::vector<BufPtr> buf_out_args() const {
971     return buf_out_args_;
972   }
973 
func_name()974   std::string func_name() const {
975     return func_name_;
976   }
977 
buf_args()978   std::vector<BufPtr> buf_args() const {
979     return buf_args_;
980   }
981 
args()982   std::vector<ExprPtr> args() const {
983     return args_;
984   }
985 
set_buf_out_args(std::vector<BufPtr> buf_out_args)986   void set_buf_out_args(std::vector<BufPtr> buf_out_args) {
987     buf_out_args_ = std::move(buf_out_args);
988   }
989 
set_buf_args(std::vector<BufPtr> buf_args)990   void set_buf_args(std::vector<BufPtr> buf_args) {
991     buf_args_ = std::move(buf_args);
992   }
993 
set_args(std::vector<ExprPtr> args)994   void set_args(std::vector<ExprPtr> args) {
995     args_ = std::move(args);
996   }
997 
ExternalCallWithAlloc(std::string func_name,std::vector<BufPtr> buf_out_args,std::vector<BufPtr> buf_args,std::vector<ExprPtr> args)998   ExternalCallWithAlloc(
999       std::string func_name,
1000       std::vector<BufPtr> buf_out_args,
1001       std::vector<BufPtr> buf_args,
1002       std::vector<ExprPtr> args)
1003       : func_name_(std::move(func_name)),
1004         buf_out_args_(std::move(buf_out_args)),
1005         buf_args_(std::move(buf_args)),
1006         args_(std::move(args)) {}
1007 
1008  private:
1009   std::string func_name_;
1010   std::vector<BufPtr> buf_out_args_;
1011   std::vector<BufPtr> buf_args_;
1012   std::vector<ExprPtr> args_;
1013 };
1014 
1015 } // namespace torch::jit::tensorexpr
1016